diff --git a/cmd/win-sshproxy/main.go b/cmd/win-sshproxy/main.go index 73d1fdd75..afdae6211 100644 --- a/cmd/win-sshproxy/main.go +++ b/cmd/win-sshproxy/main.go @@ -11,6 +11,7 @@ import ( "path/filepath" "strings" "syscall" + "time" "unsafe" "github.com/containers/gvisor-tap-vsock/pkg/sshclient" @@ -173,11 +174,35 @@ func saveThreadId() (uint32, error) { return 0, err } defer file.Close() - tid := winquit.GetCurrentMessageLoopThreadId() + + tid, err := getThreadId() + if err != nil { + return 0, err + } + fmt.Fprintf(file, "%d:%d\n", os.Getpid(), tid) return tid, nil } +func getThreadId() (uint32, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var tid uint32 + for { + select { + case <-ctx.Done(): + return 0, fmt.Errorf("failed to get thread ID") + default: + tid = winquit.GetCurrentMessageLoopThreadId() + if tid != 0 { + return tid, nil + } + time.Sleep(100 * time.Millisecond) + } + } +} + // Creates an "error" style pop-up window func alert(caption string) int { // Error box style diff --git a/test-win-sshproxy/basic_test.go b/test-win-sshproxy/basic_test.go index e9541c58d..894bd3753 100644 --- a/test-win-sshproxy/basic_test.go +++ b/test-win-sshproxy/basic_test.go @@ -1,3 +1,4 @@ +//go:build windows // +build windows package e2e @@ -25,15 +26,16 @@ var _ = Describe("connectivity", func() { err := startProxy() Expect(err).ShouldNot(HaveOccurred()) - var pid uint32 + var pid, tid uint32 for i := 0; i < 20; i++ { - pid, _, err = readTid() - if err == nil { + pid, tid, err = readTid() + if err == nil && tid != 0 { break } time.Sleep(100 * time.Millisecond) } + Expect(tid).ShouldNot(Equal(0)) Expect(err).ShouldNot(HaveOccurred()) proc, err := os.FindProcess(int(pid)) Expect(err).ShouldNot(HaveOccurred())