From f17ebb0b61b41bd563abf3bc88f5bb92d89fcef6 Mon Sep 17 00:00:00 2001 From: shranet Date: Sat, 15 Aug 2020 00:59:37 +0500 Subject: [PATCH] Fix #22 #23 #24 #25 --- .gitignore | 3 ++- tcplistener.go | 36 +++++++++++++++++++++--------------- test/client/test_client.go | 8 ++++++++ test/server/test_server.go | 6 +++++- 4 files changed, 36 insertions(+), 17 deletions(-) diff --git a/.gitignore b/.gitignore index 7a9917a..9f8c8bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ log* test/test test/server/server -test/client/client \ No newline at end of file +test/client/client +/.idea \ No newline at end of file diff --git a/tcplistener.go b/tcplistener.go index 9aed9e0..be4b634 100644 --- a/tcplistener.go +++ b/tcplistener.go @@ -11,7 +11,7 @@ import ( // size header, meaning you can directly serialize the raw slice. You would then perform your // custom logic for interpretting the message, before returning. You can optionally // return an error, which in turn will be logged if EnableLogging is set to true. -type ListenCallback func([]byte) error +type ListenCallback func(*TCPConn, []byte) error // TCPListener represents the abstraction over a raw TCP socket for reading streaming // protocolbuffer data without having to write a ton of boilerplate @@ -76,12 +76,6 @@ func (t *TCPListener) blockListen() error { for { // Wait for someone to connect c, err := t.socket.AcceptTCP() - conn, err := newTCPConn(t.connConfig) - if err != nil { - return err - } - // Don't dial out, wrap the underlying conn in one of ours - conn.socket = c if err != nil { if t.enableLogging { log.Printf("Error attempting to accept connection: %s", err) @@ -95,10 +89,14 @@ func (t *TCPListener) blockListen() error { default: // Nothing, continue to the top of the loop } - } else { - // Hand this off and immediately listen for more - go t.readLoop(conn) + + continue } + + conn, _ := newTCPConn(t.connConfig) + // Don't dial out, wrap the underlying conn in one of ours + conn.socket = c + go t.readLoop(conn) } } @@ -152,13 +150,21 @@ func (t *TCPListener) readLoop(conn *TCPConn) { // dataBuffer will hold the message from each read dataBuffer := make([]byte, conn.maxMessageSize) + closed := make(chan struct{}) + + defer close(closed) + // Start an asyncrhonous call that will wait on the shutdown channel, and then close // the connection. This will let us respond to the shutdown but also not incur // a cost for checking the channel on each run of the loop - go func(c *TCPConn, s <-chan struct{}) { - <-s - c.Close() - }(conn, t.shutdownChannel) + go func(c *TCPConn, s <-chan struct{}, closed <-chan struct{}) { + select { + case <-closed: + return + case <-s: + c.Close() + } + }(conn, t.shutdownChannel, closed) // Begin the read loop // If there is any error, close the connection officially and break out of the listen-loop. @@ -176,7 +182,7 @@ func (t *TCPListener) readLoop(conn *TCPConn) { } // We take action on the actual message data - but only up to the amount of bytes read, // since we re-use the cache - if err = t.callback(dataBuffer[:msgLen]); err != nil && t.enableLogging { + if err = t.callback(conn, dataBuffer[:msgLen]); err != nil && t.enableLogging { log.Printf("Error in Callback: %s", err.Error()) // TODO if it's a protobuffs error, it means we likely had an issue and can't // deserialize data? Should we kill the connection and have the client start over? diff --git a/test/client/test_client.go b/test/client/test_client.go index 1d2d8fd..d80e1b7 100644 --- a/test/client/test_client.go +++ b/test/client/test_client.go @@ -18,6 +18,7 @@ func main() { MaxMessageSize: 2048, Address: buffstreams.FormatAddress("127.0.0.1", strconv.Itoa(5031)), } + ansBytes := make([]byte, 1024) name := "Stabby" date := time.Now().UnixNano() data := "This is an intenntionally long and rambling sentence to pad out the size of the message." @@ -38,6 +39,7 @@ func main() { if err != nil { log.Print("There was an error") log.Print(err) + break } count = count + 1 if lastTime.Second() != currentTime.Second() { @@ -46,5 +48,11 @@ func main() { count = 0 } currentTime = time.Now() + if n, err := btw.Read(ansBytes); err != nil { + log.Println("Read error:", err) + break + } else { + log.Println("From server:", string(ansBytes[0:n])) + } } } diff --git a/test/server/test_server.go b/test/server/test_server.go index bb19091..2c689fd 100644 --- a/test/server/test_server.go +++ b/test/server/test_server.go @@ -12,11 +12,15 @@ import ( // TestCallback is a simple server for test purposes. It has a single callback, // which is to unmarshall some data and log it. -func (t *testController) TestCallback(bts []byte) error { +func (t *testController) TestCallback(conn *buffstreams.TCPConn, bts []byte) error { msg := &message.Note{} err := proto.Unmarshal(bts, msg) if t.enableLogging { log.Print(msg.GetComment()) + _, err := conn.Write([]byte(time.Now().String())) + if err != nil { + log.Println("Write error:", err) + } } return err }