diff --git a/pipe.go b/pipe.go index a2da6639..116509ae 100644 --- a/pipe.go +++ b/pipe.go @@ -574,9 +574,11 @@ func (l *win32PipeListener) Accept() (net.Conn, error) { func (l *win32PipeListener) Close() error { select { - case l.closeCh <- 1: - <-l.doneCh case <-l.doneCh: + case <-l.closeCh: + default: + close(l.closeCh) + <-l.doneCh } return nil } diff --git a/pipe_test.go b/pipe_test.go index 8b1d5b94..5a9dcd3e 100644 --- a/pipe_test.go +++ b/pipe_test.go @@ -646,3 +646,43 @@ func TestListenConnectRace(t *testing.T) { wg.Wait() } } + +func TestCloseRace(t *testing.T) { + for i := 0; i < 200 && !t.Failed(); i++ { + l, err := ListenPipe(testPipeName, &PipeConfig{MessageMode: true}) + if err != nil { + t.Fatal(err) + } + go func() { + for { + c, err := l.Accept() + if err != nil { + return + } + b, err := io.ReadAll(c) + if err != nil { + t.Error(err) + return + } + _, _ = c.Write(b) + _ = c.Close() + } + }() + + c, err := DialPipe(testPipeName, nil) + if err != nil { + t.Fatal(err) + } + if _, err = c.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + if err := c.(CloseWriter).CloseWrite(); err != nil { + t.Fatal(err) + } + if _, err := io.ReadAll(c); err != nil { + t.Fatal(err) + } + _ = c.Close() + _ = l.Close() + } +}