diff --git a/cmd/win-sshproxy/main.go b/cmd/win-sshproxy/main.go index 73d1fdd75..d4af1e3ef 100644 --- a/cmd/win-sshproxy/main.go +++ b/cmd/win-sshproxy/main.go @@ -11,10 +11,12 @@ import ( "path/filepath" "strings" "syscall" + "time" "unsafe" "github.com/containers/gvisor-tap-vsock/pkg/sshclient" "github.com/containers/gvisor-tap-vsock/pkg/types" + "github.com/containers/gvisor-tap-vsock/pkg/utils" "github.com/containers/winquit/pkg/winquit" "github.com/sirupsen/logrus" "golang.org/x/sync/errgroup" @@ -173,11 +175,31 @@ func saveThreadId() (uint32, error) { return 0, err } defer file.Close() - tid := winquit.GetCurrentMessageLoopThreadId() + + tid, err := getThreadId() + if err != nil { + return 0, err + } + fmt.Fprintf(file, "%d:%d\n", os.Getpid(), tid) return tid, nil } +func getThreadId() (uint32, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + getTid := func() (uint32, error) { + tid := winquit.GetCurrentMessageLoopThreadId() + if tid != 0 { + return tid, nil + } + return 0, fmt.Errorf("failed to get thread ID") + } + + return utils.Retry(ctx, getTid, "Waiting for message loop thread id") +} + // Creates an "error" style pop-up window func alert(caption string) int { // Error box style diff --git a/pkg/sshclient/ssh_forwarder.go b/pkg/sshclient/ssh_forwarder.go index b994f24d9..9133a2e70 100644 --- a/pkg/sshclient/ssh_forwarder.go +++ b/pkg/sshclient/ssh_forwarder.go @@ -13,6 +13,7 @@ import ( "time" "github.com/containers/gvisor-tap-vsock/pkg/fs" + "github.com/containers/gvisor-tap-vsock/pkg/utils" "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -98,13 +99,13 @@ func connectForward(ctx context.Context, bastion *Bastion) (CloseWriteConn, erro if err == nil { break } - if bastionRetries > 2 || !sleep(ctx, 200*time.Millisecond) { + if bastionRetries > 2 || !utils.Sleep(ctx, 200*time.Millisecond) { return nil, errors.Wrapf(err, "Couldn't reestablish ssh connection: %s", bastion.Host) } } } - if !sleep(ctx, 200*time.Millisecond) { + if !utils.Sleep(ctx, 200*time.Millisecond) { retries = 3 } } @@ -173,7 +174,7 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity } return CreateBastion(dest, passphrase, identity, conn, connectFunc) } - bastion, err := retry(ctx, createBastion, "Waiting for sshd") + bastion, err := utils.Retry(ctx, createBastion, "Waiting for sshd") if err != nil { return &SSHForward{}, fmt.Errorf("setupProxy failed: %w", err) } @@ -183,37 +184,6 @@ func setupProxy(ctx context.Context, socketURI *url.URL, dest *url.URL, identity return &SSHForward{listener, bastion, socketURI}, nil } -const maxRetries = 60 -const initialBackoff = 100 * time.Millisecond - -func retry[T comparable](ctx context.Context, retryFunc func() (T, error), retryMsg string) (T, error) { - var ( - returnVal T - err error - ) - - backoff := initialBackoff - -loop: - for i := 0; i < maxRetries; i++ { - select { - case <-ctx.Done(): - break loop - default: - // proceed - } - - returnVal, err = retryFunc() - if err == nil { - return returnVal, nil - } - logrus.Debugf("%s (%s)", retryMsg, backoff) - sleep(ctx, backoff) - backoff = backOff(backoff) - } - return returnVal, fmt.Errorf("timeout: %w", err) -} - func acceptConnection(ctx context.Context, listener net.Listener, bastion *Bastion, socketURI *url.URL) error { con, err := listener.Accept() if err != nil { @@ -256,24 +226,3 @@ func forward(src io.ReadCloser, dest CloseWriteStream, complete *sync.WaitGroup) // Trigger an EOF on the other end _ = dest.CloseWrite() } - -func backOff(delay time.Duration) time.Duration { - if delay == 0 { - delay = 5 * time.Millisecond - } else { - delay *= 2 - } - if delay > time.Second { - delay = time.Second - } - return delay -} - -func sleep(ctx context.Context, wait time.Duration) bool { - select { - case <-ctx.Done(): - return false - case <-time.After(wait): - return true - } -} diff --git a/pkg/utils/retry.go b/pkg/utils/retry.go new file mode 100644 index 000000000..6422a1d6f --- /dev/null +++ b/pkg/utils/retry.go @@ -0,0 +1,61 @@ +package utils + +import ( + "context" + "fmt" + "time" + + "github.com/sirupsen/logrus" +) + +const maxRetries = 60 +const initialBackoff = 100 * time.Millisecond + +func Retry[T comparable](ctx context.Context, retryFunc func() (T, error), retryMsg string) (T, error) { + var ( + returnVal T + err error + ) + + backoff := initialBackoff + +loop: + for i := 0; i < maxRetries; i++ { + select { + case <-ctx.Done(): + break loop + default: + // proceed + } + + returnVal, err = retryFunc() + if err == nil { + return returnVal, nil + } + logrus.Debugf("%s (%s)", retryMsg, backoff) + Sleep(ctx, backoff) + backoff = backOff(backoff) + } + return returnVal, fmt.Errorf("timeout: %w", err) +} + +func backOff(delay time.Duration) time.Duration { + if delay == 0 { + delay = 5 * time.Millisecond + } else { + delay *= 2 + } + if delay > time.Second { + delay = time.Second + } + return delay +} + +func Sleep(ctx context.Context, wait time.Duration) bool { + select { + case <-ctx.Done(): + return false + case <-time.After(wait): + return true + } +} diff --git a/test-win-sshproxy/basic_test.go b/test-win-sshproxy/basic_test.go index e9541c58d..894bd3753 100644 --- a/test-win-sshproxy/basic_test.go +++ b/test-win-sshproxy/basic_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package e2e @@ -25,15 +26,16 @@ var _ = Describe("connectivity", func() { err := startProxy() Expect(err).ShouldNot(HaveOccurred()) - var pid uint32 + var pid, tid uint32 for i := 0; i < 20; i++ { - pid, _, err = readTid() - if err == nil { + pid, tid, err = readTid() + if err == nil && tid != 0 { break } time.Sleep(100 * time.Millisecond) } + Expect(tid).ShouldNot(Equal(0)) Expect(err).ShouldNot(HaveOccurred()) proc, err := os.FindProcess(int(pid)) Expect(err).ShouldNot(HaveOccurred())