From bb0f08b86b0e85ebaf7fc170cda39cbdd2f69727 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com>
Date: Wed, 13 Dec 2023 12:02:30 +0100
Subject: [PATCH 1/5] fix: properly stop client on Stop

---
 client.go      |  4 +++-
 client_test.go | 26 ++++++++++++++++++++++++++
 2 files changed, 29 insertions(+), 1 deletion(-)

diff --git a/client.go b/client.go
index 4309993c..dd46cb5d 100644
--- a/client.go
+++ b/client.go
@@ -221,8 +221,10 @@ func (c *client) Start() {
 func (c *client) Stop() {
 	if c.cancelFunc != nil {
 		c.cancelFunc()
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // in practice, it is faster than 5 seconds so this is just to avoid infinite block
+		defer cancel()
+		c.WaitForState(ctx, ClientClosed)
 	}
-	c.setState(ClientClosed)
 }
 
 func (c *client) run() error {
diff --git a/client_test.go b/client_test.go
index e709b991..ab92613b 100644
--- a/client_test.go
+++ b/client_test.go
@@ -149,6 +149,32 @@ var _ = Describe("Client", func() {
 			close(done)
 		}, 1.0)
 	})
+	Context("Stop", func() {
+		It("should stop the client properly", func(done Done) {
+			// Create a simple server
+			server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}),
+				testLoggerOption(),
+				ChanReceiveTimeout(200*time.Millisecond),
+				StreamBufferCapacity(5))
+			Expect(err).NotTo(HaveOccurred())
+			Expect(server).NotTo(BeNil())
+			// Create both ends of the connection
+			cliConn, srvConn := newClientServerConnections()
+			// Start the server
+			go func() { _ = server.Serve(srvConn) }()
+			// Create the Client
+			clientConn, err := NewClient(context.Background(), WithConnection(cliConn), testLoggerOption(), formatOption)
+			Expect(err).NotTo(HaveOccurred())
+			Expect(clientConn).NotTo(BeNil())
+			// Start it
+			clientConn.Start()
+			Expect(<-clientConn.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred())
+			clientConn.Stop()
+			Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed))
+			server.cancel()
+			close(done)
+		})
+	})
 	Context("Invoke", func() {
 		It("should invoke a server method and return the result", func(done Done) {
 			_, client, _, cancelClient := getTestBed(&simpleReceiver{}, formatOption)

From d02f2e6b259072386a967e02492e2587a0cd82eb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com>
Date: Fri, 5 Jan 2024 11:23:01 +0100
Subject: [PATCH 2/5] fix: add waitGroup to start go func

---
 client.go | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/client.go b/client.go
index dd46cb5d..f7b49d1e 100644
--- a/client.go
+++ b/client.go
@@ -163,12 +163,15 @@ type client struct {
 	lastID            int64
 	backoffFactory    func() backoff.BackOff
 	cancelFunc        context.CancelFunc
+	wg                sync.WaitGroup
 }
 
 func (c *client) Start() {
 	c.setState(ClientConnecting)
 	boff := c.backoffFactory()
+	c.wg.Add(1)
 	go func() {
+		defer c.wg.Done()
 		for {
 			c.setErr(nil)
 			// Listen for state change to ClientConnected and signal backoff Reset then.
@@ -221,9 +224,8 @@ func (c *client) Start() {
 func (c *client) Stop() {
 	if c.cancelFunc != nil {
 		c.cancelFunc()
-		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) // in practice, it is faster than 5 seconds so this is just to avoid infinite block
-		defer cancel()
-		c.WaitForState(ctx, ClientClosed)
+		c.wg.Wait()
+		c.setState(ClientClosed)
 	}
 }
 

From f4db639ceb16366e1d34627f9be7920eb8a22b69 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com>
Date: Fri, 5 Jan 2024 11:39:15 +0100
Subject: [PATCH 3/5] test: add a test to ensure no logs after stop

---
 client_test.go | 44 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 44 insertions(+)

diff --git a/client_test.go b/client_test.go
index ab92613b..8a506578 100644
--- a/client_test.go
+++ b/client_test.go
@@ -120,6 +120,18 @@ func (s *simpleReceiver) OnCallback(result string) {
 	s.ch <- result
 }
 
+type noLogAfterStopLogger struct {
+	StructuredLogger
+	shouldPanic atomic.Bool
+}
+
+func (n *noLogAfterStopLogger) Log(keyVals ...interface{}) error {
+	if n.shouldPanic.Load() {
+		panic("oh no")
+	}
+	return n.StructuredLogger.Log(keyVals)
+}
+
 var _ = Describe("Client", func() {
 	formatOption := TransferFormat("Text")
 	j := 1
@@ -174,6 +186,38 @@ var _ = Describe("Client", func() {
 			server.cancel()
 			close(done)
 		})
+		It("should not log after stop", func(done Done) {
+			// Create a simple server
+			server, err := NewServer(context.TODO(), SimpleHubFactory(&simpleHub{}),
+				testLoggerOption(),
+				ChanReceiveTimeout(200*time.Millisecond),
+				StreamBufferCapacity(5))
+			Expect(err).NotTo(HaveOccurred())
+			Expect(server).NotTo(BeNil())
+			// Create both ends of the connection
+			cliConn, srvConn := newClientServerConnections()
+			// Start the server
+			go func() { _ = server.Serve(srvConn) }()
+			// Create the Client
+			clientConn, err := NewClient(context.Background(), WithConnection(cliConn), testLoggerOption(), formatOption)
+			Expect(err).NotTo(HaveOccurred())
+			Expect(clientConn).NotTo(BeNil())
+			// Replace loggers with loggers that panic after stop
+			info, debug := clientConn.loggers()
+			panicableInfo, panicableDebug := &noLogAfterStopLogger{StructuredLogger: info}, &noLogAfterStopLogger{StructuredLogger: debug}
+			clientConn.setLoggers(panicableInfo, panicableDebug)
+			// Start it
+			clientConn.Start()
+			Expect(<-clientConn.WaitForState(context.Background(), ClientConnected)).NotTo(HaveOccurred())
+			clientConn.Stop()
+			panicableInfo.shouldPanic.Store(true)
+			panicableDebug.shouldPanic.Store(true)
+			// Ensure that we really don't get any logs anymore
+			time.Sleep(1 * time.Second)
+			Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed))
+			server.cancel()
+			close(done)
+		})
 	})
 	Context("Invoke", func() {
 		It("should invoke a server method and return the result", func(done Done) {

From 9f6cd07c9b1c12de9d28d5bd4da160137a76192c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com>
Date: Fri, 5 Jan 2024 11:53:35 +0100
Subject: [PATCH 4/5] refactor: move wg to partybase

---
 client.go      |  1 -
 client_test.go |  2 +-
 loop.go        |  3 +++
 party.go       | 20 ++++++++++++++------
 4 files changed, 18 insertions(+), 8 deletions(-)

diff --git a/client.go b/client.go
index f7b49d1e..ac55e348 100644
--- a/client.go
+++ b/client.go
@@ -163,7 +163,6 @@ type client struct {
 	lastID            int64
 	backoffFactory    func() backoff.BackOff
 	cancelFunc        context.CancelFunc
-	wg                sync.WaitGroup
 }
 
 func (c *client) Start() {
diff --git a/client_test.go b/client_test.go
index 8a506578..45891132 100644
--- a/client_test.go
+++ b/client_test.go
@@ -213,7 +213,7 @@ var _ = Describe("Client", func() {
 			panicableInfo.shouldPanic.Store(true)
 			panicableDebug.shouldPanic.Store(true)
 			// Ensure that we really don't get any logs anymore
-			time.Sleep(1 * time.Second)
+			time.Sleep(500 * time.Millisecond)
 			Expect(clientConn.State()).To(BeEquivalentTo(ClientClosed))
 			server.cancel()
 			close(done)
diff --git a/loop.go b/loop.go
index bfb128b8..02b16b5b 100644
--- a/loop.go
+++ b/loop.go
@@ -50,7 +50,10 @@ func (l *loop) Run(connected chan struct{}) (err error) {
 	close(connected)
 	// Process messages
 	ch := make(chan receiveResult, 1)
+	wg := l.party.waitGroup()
+	wg.Add(1)
 	go func() {
+		defer wg.Done()
 		recv := l.hubConn.Receive()
 	loop:
 		for {
diff --git a/party.go b/party.go
index b2f43233..174b2c7b 100644
--- a/party.go
+++ b/party.go
@@ -2,6 +2,7 @@ package signalr
 
 import (
 	"context"
+	"sync"
 	"time"
 
 	"github.com/go-kit/log"
@@ -29,7 +30,7 @@ type Party interface {
 	insecureSkipVerify() bool
 	setInsecureSkipVerify(skip bool)
 
-	originPatterns()   [] string
+	originPatterns() []string
 	setOriginPatterns(orgs []string)
 
 	chanReceiveTimeout() time.Duration
@@ -50,6 +51,8 @@ type Party interface {
 
 	maximumReceiveMessageSize() uint
 	setMaximumReceiveMessageSize(size uint)
+
+	waitGroup() *sync.WaitGroup
 }
 
 func newPartyBase(parentContext context.Context, info log.Logger, dbg log.Logger) partyBase {
@@ -81,10 +84,11 @@ type partyBase struct {
 	_streamBufferCapacity      uint
 	_maximumReceiveMessageSize uint
 	_enableDetailedErrors      bool
-	_insecureSkipVerify		   bool
-	_originPatterns             []string
+	_insecureSkipVerify        bool
+	_originPatterns            []string
 	info                       StructuredLogger
 	dbg                        StructuredLogger
+	wg                         sync.WaitGroup
 }
 
 func (p *partyBase) context() context.Context {
@@ -120,16 +124,16 @@ func (p *partyBase) setKeepAliveInterval(interval time.Duration) {
 }
 
 func (p *partyBase) insecureSkipVerify() bool {
-	return  p._insecureSkipVerify
+	return p._insecureSkipVerify
 }
 func (p *partyBase) setInsecureSkipVerify(skip bool) {
 	p._insecureSkipVerify = skip
 }
 
 func (p *partyBase) originPatterns() []string {
-	return  p._originPatterns
+	return p._originPatterns
 }
-func (p *partyBase) setOriginPatterns(origins  []string) {
+func (p *partyBase) setOriginPatterns(origins []string) {
 	p._originPatterns = origins
 }
 
@@ -173,3 +177,7 @@ func (p *partyBase) setLoggers(info StructuredLogger, dbg StructuredLogger) {
 func (p *partyBase) loggers() (info StructuredLogger, debug StructuredLogger) {
 	return p.info, p.dbg
 }
+
+func (p *partyBase) waitGroup() *sync.WaitGroup {
+	return &p.wg
+}

From 800a6c092a0ca93be8f65db6e1274ab605baeb9b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Sebastian=20R=C3=BChl?= <sebastian@mapped.com>
Date: Thu, 18 Jan 2024 23:00:29 +0100
Subject: [PATCH 5/5] refactor: use waitGroup() instead of wg in client

---
 client.go | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/client.go b/client.go
index ac55e348..edcfa00c 100644
--- a/client.go
+++ b/client.go
@@ -168,9 +168,9 @@ type client struct {
 func (c *client) Start() {
 	c.setState(ClientConnecting)
 	boff := c.backoffFactory()
-	c.wg.Add(1)
+	c.partyBase.waitGroup().Add(1)
 	go func() {
-		defer c.wg.Done()
+		defer c.partyBase.waitGroup().Done()
 		for {
 			c.setErr(nil)
 			// Listen for state change to ClientConnected and signal backoff Reset then.
@@ -223,7 +223,7 @@ func (c *client) Start() {
 func (c *client) Stop() {
 	if c.cancelFunc != nil {
 		c.cancelFunc()
-		c.wg.Wait()
+		c.partyBase.waitGroup().Wait()
 		c.setState(ClientClosed)
 	}
 }