From c680bac85066e347c48c430c11cdd1c42ee13da5 Mon Sep 17 00:00:00 2001 From: hiohiohio Date: Tue, 26 Mar 2024 16:51:53 +0900 Subject: [PATCH] fix: Remove sub changes request when timeout happens (#274) * fix: remove sub changes request when timeout happens * fix: add unit test case for subscribe twice during disconnected * fix: remove sub changes request when timeout happens * fix: remove unnecessary init value * fix: remove commented code --- marketdata/stream/client_test.go | 154 ++++++++++++++++++++++++++++++ marketdata/stream/subscription.go | 7 ++ 2 files changed, 161 insertions(+) diff --git a/marketdata/stream/client_test.go b/marketdata/stream/client_test.go index 3741bcc..d76d1ab 100644 --- a/marketdata/stream/client_test.go +++ b/marketdata/stream/client_test.go @@ -3,6 +3,7 @@ package stream import ( "context" "errors" + "fmt" "net/url" "testing" "time" @@ -600,6 +601,159 @@ func TestSubscriptionAcrossConnectionIssues(t *testing.T) { require.ElementsMatch(t, []string{"PACA"}, c.sub.trades) } +func TestSubscriptionTwiceAcrossConnectionIssues(t *testing.T) { + mockTimeAfterCh := make(chan time.Time) + timeAfter = func(d time.Duration) <-chan time.Time { + return mockTimeAfterCh + } + defer func() { + timeAfter = time.After + }() + + conn1 := newMockConn() + writeInitialFlowMessagesToConn(t, conn1, subscriptions{}) + + connected := make(chan struct{}) + connectCallback := func() { + t.Log("connected") + connected <- struct{}{} + } + + disconnected := make(chan struct{}) + disconnectCallback := func() { + t.Log("disconnected") + disconnected <- struct{}{} + } + + key := "testkey" + secret := "testsecret" + c := NewStocksClient(marketdata.IEX, + WithCredentials(key, secret), + withConnCreator(func(ctx context.Context, u url.URL) (conn, error) { + return conn1, nil + }), + WithReconnectSettings(0, 150*time.Millisecond), + WithConnectCallback(connectCallback), + WithDisconnectCallback(disconnectCallback), + ) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // connect + err := c.Connect(ctx) + require.NoError(t, err) + // wait connect callback + <-connected + checkInitialMessagesSentByClient(t, conn1, key, secret, subscriptions{}) + + // subscribing to something + trades1 := []string{"AL", "PACA"} + subRes := make(chan error) + subFunc := func() { + subRes <- c.SubscribeToTrades(func(trade Trade) {}, "AL", "PACA") + } + go subFunc() + sub := expectWrite(t, conn1) + require.Equal(t, "subscribe", sub["action"]) + require.ElementsMatch(t, trades1, sub["trades"]) + // server accepts subscription + conn1.readCh <- serializeToMsgpack(t, []subWithT{ + { + Type: "subscription", + Trades: trades1, + }, + }) + err = <-subRes + require.NoError(t, err) + + // shutting down the first connection + c.connCreator = func(ctx context.Context, u url.URL) (conn, error) { + return nil, fmt.Errorf("connection failed") + } + conn1.close() + // wait disconnect callback + <-disconnected + + // request subscribe will be timed out during disconnection + go subFunc() + + mockTimeAfterCh <- time.Now() + err = <-subRes + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrSubscriptionChangeTimeout), "actual: %s", err) + + // after a timeout we should be able to get timed out again + go subFunc() + + mockTimeAfterCh <- time.Now() + err = <-subRes + assert.Error(t, err) + assert.True(t, errors.Is(err, ErrSubscriptionChangeTimeout), "actual: %s", err) + + // establish 2nd connection + conn2 := newMockConn() + writeInitialFlowMessagesToConn(t, conn2, subscriptions{trades: trades1}) + c.connCreator = func(ctx context.Context, u url.URL) (conn, error) { + return conn2, nil + } + // wait connect callback + <-connected + + // checking whether the client sent what we wanted it to (auth,sub1,sub2) + checkInitialMessagesSentByClient(t, conn2, key, secret, subscriptions{trades: trades1}) + + go subFunc() + sub = expectWrite(t, conn2) + require.Equal(t, "subscribe", sub["action"]) + require.ElementsMatch(t, trades1, sub["trades"]) + + // responding to the subscription request + conn2.readCh <- serializeToMsgpack(t, []subWithT{ + { + Type: "subscription", + Trades: trades1, + Quotes: []string{}, + Bars: []string{}, + }, + }) + require.NoError(t, <-subRes) + require.ElementsMatch(t, trades1, c.sub.trades) + + // the connection is shut down and the new one isn't established for a while + conn3 := newMockConn() + defer conn3.close() + c.connCreator = func(ctx context.Context, u url.URL) (conn, error) { + time.Sleep(100 * time.Millisecond) + writeInitialFlowMessagesToConn(t, conn3, subscriptions{trades: trades1}) + return conn3, nil + } + conn2.close() + + // call an unsubscribe with the connection being down + unsubRes := make(chan error) + go func() { unsubRes <- c.UnsubscribeFromTrades("AL") }() + + // connection starts up, proper messages (auth,sub,unsub) + checkInitialMessagesSentByClient(t, conn3, key, secret, subscriptions{trades: trades1}) + unsub := expectWrite(t, conn3) + require.Equal(t, "unsubscribe", unsub["action"]) + require.ElementsMatch(t, []string{"AL"}, unsub["trades"]) + + // responding to the unsub request + conn3.readCh <- serializeToMsgpack(t, []subWithT{ + { + Type: "subscription", + Trades: []string{"PACA"}, + Quotes: []string{}, + Bars: []string{}, + }, + }) + + // make sure the sub has returned by now (client changed) + require.NoError(t, <-unsubRes) + require.ElementsMatch(t, []string{"PACA"}, c.sub.trades) +} + func TestSubscribeFailsDueToError(t *testing.T) { connection := newMockConn() defer connection.close() diff --git a/marketdata/stream/subscription.go b/marketdata/stream/subscription.go index 5b42634..57b019a 100644 --- a/marketdata/stream/subscription.go +++ b/marketdata/stream/subscription.go @@ -262,6 +262,13 @@ func (c *client) handleSubChange(subscribe bool, changes subscriptions) error { c.pendingSubChangeMutex.Lock() defer c.pendingSubChangeMutex.Unlock() c.pendingSubChange = nil + // Drain the c.subChanges channel to avoid waiting size 1 channel when connection is lost. + // Please consider using connect/disconnect callbacks to avoid requesting sub change during disconnection. + select { + case <-c.subChanges: + c.logger.Warnf("datav2stream: removed sub changes request due to timeout") + default: + } } return ErrSubscriptionChangeTimeout