From bb0f08b86b0e85ebaf7fc170cda39cbdd2f69727 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com> Date: Wed, 13 Dec 2023 12:02:30 +0100 Subject: [PATCH 1/5] fix: properly stop client on Stop --- client.go | 4 +++- client_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/client.go b/client.go index 4309993c..dd46cb5d 100644 --- a/client.go +++ b/client.go @@ -221,8 +221,10 @@ func (c *client) Start() { func (c *client) Stop() { if c.cancelFunc != nil { c.cancelFunc() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // in practice, it is faster than 5 seconds so this is just to avoid infinite block + defer cancel() + c.WaitForState(ctx, ClientClosed) } - c.setState(ClientClosed) } func (c *client) run() error { diff --git a/client_test.go b/client_test.go index e709b991..ab92613b 100644 --- a/client_test.go +++ b/client_test.go @@ -149,6 +149,32 @@ var _ = Describe("Client", func() { close(done) }, 1.0) }) + Context("Stop", func() { + It("should stop the client properly", func(done Done) { + // Create a simple server + server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}), + testLoggerOption(), + ChanReceiveTimeout(200*time.Millisecond), + StreamBufferCapacity(5)) + Expect(err).NotTo(HaveOccurred()) + Expect(server).NotTo(BeNil()) + // Create both ends of the connection + cliConn, srvConn := newClientServerConnections() + // Start the server + go func() { _ = server.Serve(srvConn) }() + // Create the Client + clientConn, err := NewClient(context.Background(), WithConnection(cliConn), testLoggerOption(), formatOption) + Expect(err).NotTo(HaveOccurred()) + Expect(clientConn).NotTo(BeNil()) + // Start it + clientConn.Start() + Expect(<-clientConn.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred()) + clientConn.Stop() + Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed)) + server.cancel() + close(done) + }) + }) Context("Invoke", func() { It("should invoke a server method and return the result", func(done Done) { _, client, _, cancelClient := getTestBed(&simpleReceiver{}, formatOption) From d02f2e6b259072386a967e02492e2587a0cd82eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com> Date: Fri, 5 Jan 2024 11:23:01 +0100 Subject: [PATCH 2/5] fix: add waitGroup to start go func --- client.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index dd46cb5d..f7b49d1e 100644 --- a/client.go +++ b/client.go @@ -163,12 +163,15 @@ type client struct { lastID int64 backoffFactory func() backoff.BackOff cancelFunc context.CancelFunc + wg sync.WaitGroup } func (c *client) Start() { c.setState(ClientConnecting) boff := c.backoffFactory() + c.wg.Add(1) go func() { + defer c.wg.Done() for { c.setErr(nil) // Listen for state change to ClientConnected and signal backoff Reset then. @@ -221,9 +224,8 @@ func (c *client) Start() { func (c *client) Stop() { if c.cancelFunc != nil { c.cancelFunc() - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // in practice, it is faster than 5 seconds so this is just to avoid infinite block - defer cancel() - c.WaitForState(ctx, ClientClosed) + c.wg.Wait() + c.setState(ClientClosed) } } From f4db639ceb16366e1d34627f9be7920eb8a22b69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com> Date: Fri, 5 Jan 2024 11:39:15 +0100 Subject: [PATCH 3/5] test: add a test to ensure no logs after stop --- client_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/client_test.go b/client_test.go index ab92613b..8a506578 100644 --- a/client_test.go +++ b/client_test.go @@ -120,6 +120,18 @@ func (s *simpleReceiver) OnCallback(result string) { s.ch <- result } +type noLogAfterStopLogger struct { + StructuredLogger + shouldPanic atomic.Bool +} + +func (n *noLogAfterStopLogger) Log(keyVals ...interface{}) error { + if n.shouldPanic.Load() { + panic("oh no") + } + return n.StructuredLogger.Log(keyVals) +} + var _ = Describe("Client", func() { formatOption := TransferFormat("Text") j := 1 @@ -174,6 +186,38 @@ var _ = Describe("Client", func() { server.cancel() close(done) }) + It("should not log after stop", func(done Done) { + // Create a simple server + server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}), + testLoggerOption(), + ChanReceiveTimeout(200*time.Millisecond), + StreamBufferCapacity(5)) + Expect(err).NotTo(HaveOccurred()) + Expect(server).NotTo(BeNil()) + // Create both ends of the connection + cliConn, srvConn := newClientServerConnections() + // Start the server + go func() { _ = server.Serve(srvConn) }() + // Create the Client + clientConn, err := NewClient(context.Background(), WithConnection(cliConn), testLoggerOption(), formatOption) + Expect(err).NotTo(HaveOccurred()) + Expect(clientConn).NotTo(BeNil()) + // Replace loggers with loggers that panic after stop + info, debug := clientConn.loggers() + panicableInfo, panicableDebug := &noLogAfterStopLogger{StructuredLogger: info}, &noLogAfterStopLogger{StructuredLogger: debug} + clientConn.setLoggers(panicableInfo, panicableDebug) + // Start it + clientConn.Start() + Expect(<-clientConn.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred()) + clientConn.Stop() + panicableInfo.shouldPanic.Store(true) + panicableDebug.shouldPanic.Store(true) + // Ensure that we really don't get any logs anymore + time.Sleep(1 * time.Second) + Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed)) + server.cancel() + close(done) + }) }) Context("Invoke", func() { It("should invoke a server method and return the result", func(done Done) { From 9f6cd07c9b1c12de9d28d5bd4da160137a76192c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com> Date: Fri, 5 Jan 2024 11:53:35 +0100 Subject: [PATCH 4/5] refactor: move wg to partybase --- client.go | 1 - client_test.go | 2 +- loop.go | 3 +++ party.go | 20 ++++++++++++++------ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/client.go b/client.go index f7b49d1e..ac55e348 100644 --- a/client.go +++ b/client.go @@ -163,7 +163,6 @@ type client struct { lastID int64 backoffFactory func() backoff.BackOff cancelFunc context.CancelFunc - wg sync.WaitGroup } func (c *client) Start() { diff --git a/client_test.go b/client_test.go index 8a506578..45891132 100644 --- a/client_test.go +++ b/client_test.go @@ -213,7 +213,7 @@ var _ = Describe("Client", func() { panicableInfo.shouldPanic.Store(true) panicableDebug.shouldPanic.Store(true) // Ensure that we really don't get any logs anymore - time.Sleep(1 * time.Second) + time.Sleep(500 * time.Millisecond) Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed)) server.cancel() close(done) diff --git a/loop.go b/loop.go index bfb128b8..02b16b5b 100644 --- a/loop.go +++ b/loop.go @@ -50,7 +50,10 @@ func (l *loop) Run(connected chan struct{}) (err error) { close(connected) // Process messages ch := make(chan receiveResult, 1) + wg := l.party.waitGroup() + wg.Add(1) go func() { + defer wg.Done() recv := l.hubConn.Receive() loop: for { diff --git a/party.go b/party.go index b2f43233..174b2c7b 100644 --- a/party.go +++ b/party.go @@ -2,6 +2,7 @@ package signalr import ( "context" + "sync" "time" "github.com/go-kit/log" @@ -29,7 +30,7 @@ type Party interface { insecureSkipVerify() bool setInsecureSkipVerify(skip bool) - originPatterns() [] string + originPatterns() []string setOriginPatterns(orgs []string) chanReceiveTimeout() time.Duration @@ -50,6 +51,8 @@ type Party interface { maximumReceiveMessageSize() uint setMaximumReceiveMessageSize(size uint) + + waitGroup() *sync.WaitGroup } func newPartyBase(parentContext context.Context, info log.Logger, dbg log.Logger) partyBase { @@ -81,10 +84,11 @@ type partyBase struct { _streamBufferCapacity uint _maximumReceiveMessageSize uint _enableDetailedErrors bool - _insecureSkipVerify bool - _originPatterns []string + _insecureSkipVerify bool + _originPatterns []string info StructuredLogger dbg StructuredLogger + wg sync.WaitGroup } func (p *partyBase) context() context.Context { @@ -120,16 +124,16 @@ func (p *partyBase) setKeepAliveInterval(interval time.Duration) { } func (p *partyBase) insecureSkipVerify() bool { - return p._insecureSkipVerify + return p._insecureSkipVerify } func (p *partyBase) setInsecureSkipVerify(skip bool) { p._insecureSkipVerify = skip } func (p *partyBase) originPatterns() []string { - return p._originPatterns + return p._originPatterns } -func (p *partyBase) setOriginPatterns(origins []string) { +func (p *partyBase) setOriginPatterns(origins []string) { p._originPatterns = origins } @@ -173,3 +177,7 @@ func (p *partyBase) setLoggers(info StructuredLogger, dbg StructuredLogger) { func (p *partyBase) loggers() (info StructuredLogger, debug StructuredLogger) { return p.info, p.dbg } + +func (p *partyBase) waitGroup() *sync.WaitGroup { + return &p.wg +} From 800a6c092a0ca93be8f65db6e1274ab605baeb9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com> Date: Thu, 18 Jan 2024 23:00:29 +0100 Subject: [PATCH 5/5] refactor: use waitGroup() instead of wg in client --- client.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/client.go b/client.go index ac55e348..edcfa00c 100644 --- a/client.go +++ b/client.go @@ -168,9 +168,9 @@ type client struct { func (c *client) Start() { c.setState(ClientConnecting) boff := c.backoffFactory() - c.wg.Add(1) + c.partyBase.waitGroup().Add(1) go func() { - defer c.wg.Done() + defer c.partyBase.waitGroup().Done() for { c.setErr(nil) // Listen for state change to ClientConnected and signal backoff Reset then. @@ -223,7 +223,7 @@ func (c *client) Start() { func (c *client) Stop() { if c.cancelFunc != nil { c.cancelFunc() - c.wg.Wait() + c.partyBase.waitGroup().Wait() c.setState(ClientClosed) } }