diff --git a/docs/ADD_NEW_EXCHANGE.md b/docs/ADD_NEW_EXCHANGE.md index 41e5712d192..166f30dd897 100644 --- a/docs/ADD_NEW_EXCHANGE.md +++ b/docs/ADD_NEW_EXCHANGE.md @@ -1176,7 +1176,7 @@ Please test all `pair` commands to disable and enable different assets types to - `get` to ensure correct enabled and disabled pairs for a supported asset type. - `disableasset` to ensure disabling of entire asset class and associated unsubscriptions. - `enableasset` to ensure correct enabling of entire asset class and associated subscriptions. -- `disable` to ensure correct disabling of pair(s) and and associated unsubscriptions. +- `disable` to ensure correct disabling of pair(s) and associated unsubscriptions. - `enable` to ensure correct enabling of pair(s) and associated subscriptions. - `enableall` to ensure correct enabling of all pairs for an asset type and associated subscriptions. - `disableall` to ensure correct disabling of all pairs for an asset type and associated unsubscriptions. diff --git a/engine/engine.go b/engine/engine.go index 8dda10706ba..27e11219a25 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -791,17 +791,10 @@ func (bot *Engine) LoadExchange(name string) error { localWG.Wait() if !bot.Settings.EnableExchangeHTTPRateLimiter { - gctlog.Warnf(gctlog.ExchangeSys, - "Loaded exchange %s rate limiting has been turned off.\n", - exch.GetName(), - ) + gctlog.Warnf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting has been turned off.\n", exch.GetName()) err = exch.DisableRateLimiter() if err != nil { - gctlog.Errorf(gctlog.ExchangeSys, - "Loaded exchange %s rate limiting cannot be turned off: %s.\n", - exch.GetName(), - err, - ) + gctlog.Errorf(gctlog.ExchangeSys, "Loaded exchange %s rate limiting cannot be turned off: %s.\n", exch.GetName(), err) } } @@ -820,29 +813,18 @@ func (bot *Engine) LoadExchange(name string) error { return err } - base := exch.GetBase() - if base.API.AuthenticatedSupport || - base.API.AuthenticatedWebsocketSupport { - assetTypes := base.GetAssetTypes(false) - var useAsset asset.Item - for a := range assetTypes { - err = base.CurrencyPairs.IsAssetEnabled(assetTypes[a]) - if err != nil { - continue - } - useAsset = assetTypes[a] - break - } - err = exch.ValidateAPICredentials(context.TODO(), useAsset) + b := exch.GetBase() + if b.API.AuthenticatedSupport || b.API.AuthenticatedWebsocketSupport { + err = exch.ValidateAPICredentials(context.TODO(), asset.Spot) if err != nil { - gctlog.Warnf(gctlog.ExchangeSys, - "%s: Cannot validate credentials, authenticated support has been disabled, Error: %s\n", - base.Name, - err) - base.API.AuthenticatedSupport = false - base.API.AuthenticatedWebsocketSupport = false + gctlog.Warnf(gctlog.ExchangeSys, "%s: Cannot validate credentials, authenticated support has been disabled, Error: %s", b.Name, err) + b.API.AuthenticatedSupport = false + b.API.AuthenticatedWebsocketSupport = false exchCfg.API.AuthenticatedSupport = false exchCfg.API.AuthenticatedWebsocketSupport = false + if b.Websocket != nil { + b.Websocket.SetCanUseAuthenticatedEndpoints(false) + } } } @@ -854,10 +836,7 @@ func (bot *Engine) dryRunParamInteraction(param string) { return } - gctlog.Warnf(gctlog.Global, - "Command line argument '-%s' induces dry run mode."+ - " Set -dryrun=false if you wish to override this.", - param) + gctlog.Warnf(gctlog.Global, "Command line argument '-%s' induces dry run mode. Set -dryrun=false if you wish to override this.", param) if !bot.Settings.EnableDryRun { bot.Settings.EnableDryRun = true diff --git a/exchanges/bybit/bybit.go b/exchanges/bybit/bybit.go index 832af4ddba0..88515427b2b 100644 --- a/exchanges/bybit/bybit.go +++ b/exchanges/bybit/bybit.go @@ -27,10 +27,10 @@ import ( // Bybit is the overarching type across this package type Bybit struct { exchange.Base - // AccountType holds information about whether the account to which the api key belongs is a unified margin account or not. // 0: unified, and 1: for normal account AccountType uint8 + Counter common.Counter } const ( diff --git a/exchanges/bybit/bybit_inverse_websocket.go b/exchanges/bybit/bybit_inverse_websocket.go index 0a0346bde58..79d47b20c96 100644 --- a/exchanges/bybit/bybit_inverse_websocket.go +++ b/exchanges/bybit/bybit_inverse_websocket.go @@ -2,81 +2,43 @@ package bybit import ( "context" - "net/http" + "errors" - "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) -// WsInverseConnect connects to inverse websocket feed -func (by *Bybit) WsInverseConnect() error { - if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.CoinMarginedFutures) { - return stream.ErrWebsocketNotEnabled - } - by.Websocket.Conn.SetURL(inversePublic) - var dialer websocket.Dialer - err := by.Websocket.Conn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, - Message: []byte(`{"op": "ping"}`), - Delay: bybitWebsocketTimer, - }) - - by.Websocket.Wg.Add(1) - go by.wsReadData(asset.CoinMarginedFutures, by.Websocket.Conn) - return nil -} - // GenerateInverseDefaultSubscriptions generates default subscription func (by *Bybit) GenerateInverseDefaultSubscriptions() (subscription.List, error) { - var subscriptions subscription.List - var channels = []string{chanOrderbook, chanPublicTrade, chanPublicTicker} pairs, err := by.GetEnabledPairs(asset.CoinMarginedFutures) if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil + } return nil, err } + + var subscriptions subscription.List for z := range pairs { - for x := range channels { - subscriptions = append(subscriptions, - &subscription.Subscription{ - Channel: channels[x], - Pairs: currency.Pairs{pairs[z]}, - Asset: asset.CoinMarginedFutures, - }) + for _, channel := range []string{chanOrderbook, chanPublicTrade, chanPublicTicker} { + subscriptions = append(subscriptions, &subscription.Subscription{ + Channel: channel, + Pairs: currency.Pairs{pairs[z]}, + Asset: asset.CoinMarginedFutures, + }) } } return subscriptions, nil } // InverseSubscribe sends a subscription message to linear public channels. -func (by *Bybit) InverseSubscribe(channelSubscriptions subscription.List) error { - return by.handleInversePayloadSubscription("subscribe", channelSubscriptions) +func (by *Bybit) InverseSubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.CoinMarginedFutures, "subscribe", channelSubscriptions) } // InverseUnsubscribe sends an unsubscription messages through linear public channels. -func (by *Bybit) InverseUnsubscribe(channelSubscriptions subscription.List) error { - return by.handleInversePayloadSubscription("unsubscribe", channelSubscriptions) -} - -func (by *Bybit) handleInversePayloadSubscription(operation string, channelSubscriptions subscription.List) error { - payloads, err := by.handleSubscriptions(operation, channelSubscriptions) - if err != nil { - return err - } - for a := range payloads { - // The options connection does not send the subscription request id back with the subscription notification payload - // therefore the code doesn't wait for the response to check whether the subscription is successful or not. - err = by.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, payloads[a]) - if err != nil { - return err - } - } - return nil +func (by *Bybit) InverseUnsubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.CoinMarginedFutures, "unsubscribe", channelSubscriptions) } diff --git a/exchanges/bybit/bybit_linear_websocket.go b/exchanges/bybit/bybit_linear_websocket.go index 0edb1a322b3..91e83127812 100644 --- a/exchanges/bybit/bybit_linear_websocket.go +++ b/exchanges/bybit/bybit_linear_websocket.go @@ -2,99 +2,43 @@ package bybit import ( "context" - "net/http" + "errors" - "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) -// WsLinearConnect connects to linear a websocket feed -func (by *Bybit) WsLinearConnect() error { - if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.LinearContract) { - return stream.ErrWebsocketNotEnabled - } - by.Websocket.Conn.SetURL(linearPublic) - var dialer websocket.Dialer - err := by.Websocket.Conn.Dial(&dialer, http.Header{}) +// GenerateLinearDefaultSubscriptions generates default subscription +func (by *Bybit) GenerateLinearDefaultSubscriptions(a asset.Item) (subscription.List, error) { + pairs, err := by.GetEnabledPairs(a) if err != nil { - return err - } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, - Message: []byte(`{"op": "ping"}`), - Delay: bybitWebsocketTimer, - }) - - by.Websocket.Wg.Add(1) - go by.wsReadData(asset.LinearContract, by.Websocket.Conn) - if by.IsWebsocketAuthenticationSupported() { - err = by.WsAuth(context.TODO()) - if err != nil { - by.Websocket.DataHandler <- err - by.Websocket.SetCanUseAuthenticatedEndpoints(false) + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil } + return nil, err } - return nil -} -// GenerateLinearDefaultSubscriptions generates default subscription -func (by *Bybit) GenerateLinearDefaultSubscriptions() (subscription.List, error) { var subscriptions subscription.List - var channels = []string{chanOrderbook, chanPublicTrade, chanPublicTicker} - pairs, err := by.GetEnabledPairs(asset.USDTMarginedFutures) - if err != nil { - return nil, err - } - linearPairMap := map[asset.Item]currency.Pairs{ - asset.USDTMarginedFutures: pairs, - } - usdcPairs, err := by.GetEnabledPairs(asset.USDCMarginedFutures) - if err != nil { - return nil, err - } - linearPairMap[asset.USDCMarginedFutures] = usdcPairs - pairs = append(pairs, usdcPairs...) - for a := range linearPairMap { - for p := range linearPairMap[a] { - for x := range channels { - subscriptions = append(subscriptions, - &subscription.Subscription{ - Channel: channels[x], - Pairs: currency.Pairs{pairs[p]}, - Asset: a, - }) - } + for _, pair := range pairs { + for _, channel := range []string{chanOrderbook, chanPublicTrade, chanPublicTicker} { + subscriptions = append(subscriptions, &subscription.Subscription{ + Channel: channel, + Pairs: currency.Pairs{pair}, + Asset: a, + }) } } return subscriptions, nil } // LinearSubscribe sends a subscription message to linear public channels. -func (by *Bybit) LinearSubscribe(channelSubscriptions subscription.List) error { - return by.handleLinearPayloadSubscription("subscribe", channelSubscriptions) +func (by *Bybit) LinearSubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.USDTMarginedFutures, "subscribe", channelSubscriptions) } // LinearUnsubscribe sends an unsubscription messages through linear public channels. -func (by *Bybit) LinearUnsubscribe(channelSubscriptions subscription.List) error { - return by.handleLinearPayloadSubscription("unsubscribe", channelSubscriptions) -} - -func (by *Bybit) handleLinearPayloadSubscription(operation string, channelSubscriptions subscription.List) error { - payloads, err := by.handleSubscriptions(operation, channelSubscriptions) - if err != nil { - return err - } - for a := range payloads { - // The options connection does not send the subscription request id back with the subscription notification payload - // therefore the code doesn't wait for the response to check whether the subscription is successful or not. - err = by.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, payloads[a]) - if err != nil { - return err - } - } - return nil +func (by *Bybit) LinearUnsubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.USDTMarginedFutures, "unsubscribe", channelSubscriptions) } diff --git a/exchanges/bybit/bybit_options_websocket.go b/exchanges/bybit/bybit_options_websocket.go index e2e5836346b..06895875061 100644 --- a/exchanges/bybit/bybit_options_websocket.go +++ b/exchanges/bybit/bybit_options_websocket.go @@ -2,88 +2,43 @@ package bybit import ( "context" - "encoding/json" - "net/http" - "strconv" + "errors" - "github.com/gorilla/websocket" "github.com/thrasher-corp/gocryptotrader/currency" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" - "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" ) -// WsOptionsConnect connects to options a websocket feed -func (by *Bybit) WsOptionsConnect() error { - if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Options) { - return stream.ErrWebsocketNotEnabled - } - by.Websocket.Conn.SetURL(optionPublic) - var dialer websocket.Dialer - err := by.Websocket.Conn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - pingMessage := PingMessage{Operation: "ping", RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10)} - pingData, err := json.Marshal(pingMessage) - if err != nil { - return err - } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, - Message: pingData, - Delay: bybitWebsocketTimer, - }) - - by.Websocket.Wg.Add(1) - go by.wsReadData(asset.Options, by.Websocket.Conn) - return nil -} - // GenerateOptionsDefaultSubscriptions generates default subscription func (by *Bybit) GenerateOptionsDefaultSubscriptions() (subscription.List, error) { - var subscriptions subscription.List - var channels = []string{chanOrderbook, chanPublicTrade, chanPublicTicker} pairs, err := by.GetEnabledPairs(asset.Options) if err != nil { + if errors.Is(err, asset.ErrNotEnabled) { + return nil, nil + } return nil, err } + + var subscriptions subscription.List for z := range pairs { - for x := range channels { - subscriptions = append(subscriptions, - &subscription.Subscription{ - Channel: channels[x], - Pairs: currency.Pairs{pairs[z]}, - Asset: asset.Options, - }) + for _, channel := range []string{chanOrderbook, chanPublicTrade, chanPublicTicker} { + subscriptions = append(subscriptions, &subscription.Subscription{ + Channel: channel, + Pairs: currency.Pairs{pairs[z]}, + Asset: asset.Options, + }) } } return subscriptions, nil } // OptionSubscribe sends a subscription message to options public channels. -func (by *Bybit) OptionSubscribe(channelSubscriptions subscription.List) error { - return by.handleOptionsPayloadSubscription("subscribe", channelSubscriptions) +func (by *Bybit) OptionSubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.Options, "subscribe", channelSubscriptions) } // OptionUnsubscribe sends an unsubscription messages through options public channels. -func (by *Bybit) OptionUnsubscribe(channelSubscriptions subscription.List) error { - return by.handleOptionsPayloadSubscription("unsubscribe", channelSubscriptions) -} - -func (by *Bybit) handleOptionsPayloadSubscription(operation string, channelSubscriptions subscription.List) error { - payloads, err := by.handleSubscriptions(operation, channelSubscriptions) - if err != nil { - return err - } - for a := range payloads { - // The options connection does not send the subscription request id back with the subscription notification payload - // therefore the code doesn't wait for the response to check whether the subscription is successful or not. - err = by.Websocket.Conn.SendJSONMessage(context.TODO(), request.Unset, payloads[a]) - if err != nil { - return err - } - } - return nil +func (by *Bybit) OptionUnsubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.Options, "unsubscribe", channelSubscriptions) } diff --git a/exchanges/bybit/bybit_test.go b/exchanges/bybit/bybit_test.go index 60d5d4ecef2..7ee6dc727b8 100644 --- a/exchanges/bybit/bybit_test.go +++ b/exchanges/bybit/bybit_test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "maps" + "net/http" "slices" "testing" "time" @@ -18,19 +19,20 @@ import ( "github.com/thrasher-corp/gocryptotrader/common/key" "github.com/thrasher-corp/gocryptotrader/currency" exchange "github.com/thrasher-corp/gocryptotrader/exchanges" + "github.com/thrasher-corp/gocryptotrader/exchanges/account" "github.com/thrasher-corp/gocryptotrader/exchanges/asset" "github.com/thrasher-corp/gocryptotrader/exchanges/fundingrate" "github.com/thrasher-corp/gocryptotrader/exchanges/futures" "github.com/thrasher-corp/gocryptotrader/exchanges/kline" "github.com/thrasher-corp/gocryptotrader/exchanges/margin" "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" testsubs "github.com/thrasher-corp/gocryptotrader/internal/testing/subscriptions" - testws "github.com/thrasher-corp/gocryptotrader/internal/testing/websocket" "github.com/thrasher-corp/gocryptotrader/portfolio/withdraw" "github.com/thrasher-corp/gocryptotrader/types" ) @@ -3137,45 +3139,21 @@ func TestCancelBatchOrders(t *testing.T) { } } -func TestWsConnect(t *testing.T) { - t.Parallel() - if mockTests { - t.Skip(skippingWebsocketFunctionsForMockTesting) - } - err := b.WsConnect() - if err != nil { - t.Error(err) - } -} -func TestWsLinearConnect(t *testing.T) { - t.Parallel() - if mockTests { - t.Skip(skippingWebsocketFunctionsForMockTesting) - } - err := b.WsLinearConnect() - if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { - t.Error(err) - } +type DummyConnection struct{ stream.Connection } + +func (d *DummyConnection) GenerateMessageID(bool) int64 { return 1337 } +func (d *DummyConnection) SetupPingHandler(request.EndpointLimit, stream.PingHandler) {} +func (d *DummyConnection) DialContext(context.Context, *websocket.Dialer, http.Header) error { + return nil } -func TestWsInverseConnect(t *testing.T) { - t.Parallel() - if mockTests { - t.Skip(skippingWebsocketFunctionsForMockTesting) - } - err := b.WsInverseConnect() - if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { - t.Error(err) - } +func (d *DummyConnection) SendMessageReturnResponse(context.Context, request.EndpointLimit, any, any) ([]byte, error) { + return []byte(`{"success":true,"ret_msg":"subscribe","conn_id":"5758770c-8152-4545-a84f-dae089e56499","req_id":"1","op":"subscribe"}`), nil } -func TestWsOptionsConnect(t *testing.T) { + +func TestWsConnect(t *testing.T) { t.Parallel() - if mockTests { - t.Skip(skippingWebsocketFunctionsForMockTesting) - } - err := b.WsOptionsConnect() - if err != nil && !errors.Is(err, stream.ErrWebsocketNotEnabled) { - t.Error(err) - } + err := b.WsConnect(context.Background(), &DummyConnection{}) + require.NoError(t, err) } var pushDataMap = map[string]string{ @@ -3201,7 +3179,7 @@ func TestPushData(t *testing.T) { slices.Sort(keys) for x := range keys { - err := b.wsHandleData(asset.Spot, []byte(pushDataMap[keys[x]])) + err := b.wsHandleData(context.Background(), []byte(pushDataMap[keys[x]]), asset.Spot) assert.NoError(t, err, "wsHandleData should not error") } } @@ -3216,7 +3194,7 @@ func TestWsTicker(t *testing.T) { require.NoError(t, testexch.Setup(b), "Test instance Setup must not error") testexch.FixtureToDataHandler(t, "testdata/wsTicker.json", func(r []byte) error { defer slices.Delete(assetRouting, 0, 1) - return b.wsHandleData(assetRouting[0], r) + return b.wsHandleData(context.Background(), r, assetRouting[0]) }) close(b.Websocket.DataHandler) expected := 8 @@ -3478,7 +3456,7 @@ func TestFetchTradablePairs(t *testing.T) { func TestDeltaUpdateOrderbook(t *testing.T) { t.Parallel() data := `{"topic":"orderbook.50.WEMIXUSDT","ts":1697573183768,"type":"snapshot","data":{"s":"WEMIXUSDT","b":[["0.9511","260.703"],["0.9677","0"]],"a":[],"u":3119516,"seq":14126848493},"cts":1728966699481}` - err := b.wsHandleData(asset.Spot, []byte(data)) + err := b.wsHandleData(context.Background(), []byte(data), asset.Spot) if err != nil { t.Fatal(err) } @@ -3746,38 +3724,39 @@ func TestSubscribe(t *testing.T) { require.NoError(t, err, "ExpandTemplates must not error") b.Features.Subscriptions = subscription.List{} testexch.SetupWs(t, b) - err = b.Subscribe(subs) + err = b.Subscribe(context.Background(), &DummyConnection{}, subs) require.NoError(t, err, "Subscribe must not error") } func TestAuthSubscribe(t *testing.T) { t.Parallel() + b := new(Bybit) - require.NoError(t, testexch.Setup(b), "Test instance Setup must not error") + require.NoError(t, testexch.Setup(b)) + require.NoError(t, b.authSubscribe(context.Background(), &DummyConnection{}, subscription.List{})) + + authsubs, err := b.generateAuthSubscriptions() + require.NoError(t, err) + require.Empty(t, authsubs) + b.Websocket.SetCanUseAuthenticatedEndpoints(true) - subs, err := b.Features.Subscriptions.ExpandTemplates(b) - require.NoError(t, err, "ExpandTemplates must not error") - b.Features.Subscriptions = subscription.List{} - success := true - mock := func(tb testing.TB, msg []byte, w *websocket.Conn) error { - tb.Helper() - var req SubscriptionArgument - require.NoError(tb, json.Unmarshal(msg, &req), "Unmarshal must not error") - require.Equal(tb, "subscribe", req.Operation) - msg, err = json.Marshal(SubscriptionResponse{ - Success: success, - RetMsg: "Mock Resp Error", - RequestID: req.RequestID, - Operation: req.Operation, - }) - require.NoError(tb, err, "Marshal must not error") - return w.WriteMessage(websocket.TextMessage, msg) - } - b = testexch.MockWsInstance[Bybit](t, testws.CurryWsMockUpgrader(t, mock)) - b.Websocket.AuthConn = b.Websocket.Conn - err = b.Subscribe(subs) - require.NoError(t, err, "Subscribe must not error") - success = false - err = b.Subscribe(subs) - assert.ErrorContains(t, err, "Mock Resp Error", "Subscribe should error containing the returned RetMsg") + authsubs, err = b.generateAuthSubscriptions() + require.NoError(t, err) + require.NotEmpty(t, authsubs) + + require.NoError(t, b.authSubscribe(context.Background(), &DummyConnection{}, authsubs)) + require.NoError(t, b.authUnsubscribe(context.Background(), &DummyConnection{}, authsubs)) +} + +func TestWebsocketAuthenticateConnection(t *testing.T) { + t.Parallel() + + b := new(Bybit) + require.NoError(t, testexch.Setup(b)) + b.API.AuthenticatedSupport = true + b.API.AuthenticatedWebsocketSupport = true + b.Websocket.SetCanUseAuthenticatedEndpoints(true) + ctx := account.DeployCredentialsToContext(context.Background(), &account.Credentials{Key: "dummy", Secret: "dummy"}) + err := b.WebsocketAuthenticateConnection(ctx, &DummyConnection{}) + require.NoError(t, err) } diff --git a/exchanges/bybit/bybit_types.go b/exchanges/bybit/bybit_types.go index 656a4d2ffa7..9cff81fe438 100644 --- a/exchanges/bybit/bybit_types.go +++ b/exchanges/bybit/bybit_types.go @@ -161,11 +161,23 @@ func constructOrderbook(o *orderbookResponse) (*Orderbook, error) { // TickerData represents a list of ticker detailed information. type TickerData struct { Category string `json:"category"` - List []TickerItem `json:"list"` + List []TickerREST `json:"list"` } -// TickerItem represents a ticker item detail -type TickerItem struct { +// TickerREST for REST API +type TickerREST struct { + TickerCommon + DeliveryTime types.Time `json:"deliveryTime"` +} + +// TickerWebsocket for websocket API +type TickerWebsocket struct { + TickerCommon + DeliveryTime time.Time `json:"deliveryTime"` // "2025-03-28T08:00:00Z" +} + +// TickerCommon common ticker fields +type TickerCommon struct { Symbol string `json:"symbol"` TickDirection string `json:"tickDirection"` LastPrice types.Number `json:"lastPrice"` diff --git a/exchanges/bybit/bybit_websocket.go b/exchanges/bybit/bybit_websocket.go index 9d1b84268c1..9cdc3909245 100644 --- a/exchanges/bybit/bybit_websocket.go +++ b/exchanges/bybit/bybit_websocket.go @@ -55,6 +55,8 @@ const ( // Main-net private websocketPrivate = "wss://stream.bybit.com/v5/private" + + privateConnection = "private" ) var defaultSubscriptions = subscription.List{ @@ -62,93 +64,58 @@ var defaultSubscriptions = subscription.List{ {Enabled: true, Asset: asset.Spot, Channel: subscription.OrderbookChannel, Levels: 50}, {Enabled: true, Asset: asset.Spot, Channel: subscription.AllTradesChannel}, {Enabled: true, Asset: asset.Spot, Channel: subscription.CandlesChannel, Interval: kline.OneHour}, - {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyOrdersChannel}, - {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyWalletChannel}, - {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyTradesChannel}, - {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: chanPositions}, + // {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyOrdersChannel}, + // {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyWalletChannel}, + // {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: subscription.MyTradesChannel}, + // {Enabled: true, Asset: asset.Spot, Authenticated: true, Channel: chanPositions}, } var subscriptionNames = map[string]string{ subscription.TickerChannel: chanPublicTicker, subscription.OrderbookChannel: chanOrderbook, subscription.AllTradesChannel: chanPublicTrade, - subscription.MyOrdersChannel: chanOrder, - subscription.MyTradesChannel: chanExecution, - subscription.MyWalletChannel: chanWallet, - subscription.CandlesChannel: chanKline, + // subscription.MyOrdersChannel: chanOrder, + // subscription.MyTradesChannel: chanExecution, + // subscription.MyWalletChannel: chanWallet, + subscription.CandlesChannel: chanKline, } // WsConnect connects to a websocket feed -func (by *Bybit) WsConnect() error { - if !by.Websocket.IsEnabled() || !by.IsEnabled() || !by.IsAssetWebsocketSupported(asset.Spot) { - return stream.ErrWebsocketNotEnabled - } - var dialer websocket.Dialer - err := by.Websocket.Conn.Dial(&dialer, http.Header{}) - if err != nil { +func (by *Bybit) WsConnect(ctx context.Context, conn stream.Connection) error { + if err := conn.DialContext(ctx, &websocket.Dialer{}, http.Header{}); err != nil { return err } - by.Websocket.Conn.SetupPingHandler(request.Unset, stream.PingHandler{ + conn.SetupPingHandler(request.Unset, stream.PingHandler{ MessageType: websocket.TextMessage, Message: []byte(`{"op": "ping"}`), Delay: bybitWebsocketTimer, }) - - by.Websocket.Wg.Add(1) - go by.wsReadData(asset.Spot, by.Websocket.Conn) - if by.Websocket.CanUseAuthenticatedEndpoints() { - err = by.WsAuth(context.TODO()) - if err != nil { - by.Websocket.DataHandler <- err - by.Websocket.SetCanUseAuthenticatedEndpoints(false) - } - } return nil } -// WsAuth sends an authentication message to receive auth data -func (by *Bybit) WsAuth(ctx context.Context) error { - var dialer websocket.Dialer - err := by.Websocket.AuthConn.Dial(&dialer, http.Header{}) - if err != nil { - return err - } - - by.Websocket.AuthConn.SetupPingHandler(request.Unset, stream.PingHandler{ - MessageType: websocket.TextMessage, - Message: []byte(`{"op":"ping"}`), - Delay: bybitWebsocketTimer, - }) - - by.Websocket.Wg.Add(1) - go by.wsReadData(asset.Spot, by.Websocket.AuthConn) +// WebsocketAuthenticateConnection sends an authentication message to the websocket +func (by *Bybit) WebsocketAuthenticateConnection(ctx context.Context, conn stream.Connection) error { creds, err := by.GetCredentials(ctx) if err != nil { return err } intNonce := time.Now().Add(time.Hour * 6).UnixMilli() strNonce := strconv.FormatInt(intNonce, 10) - hmac, err := crypto.GetHMAC( - crypto.HashSHA256, - []byte("GET/realtime"+strNonce), - []byte(creds.Secret), - ) + hmac, err := crypto.GetHMAC(crypto.HashSHA256, []byte("GET/realtime"+strNonce), []byte(creds.Secret)) if err != nil { return err } - sign := crypto.HexEncodeToString(hmac) req := Authenticate{ - RequestID: strconv.FormatInt(by.Websocket.AuthConn.GenerateMessageID(false), 10), + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), Operation: "auth", - Args: []interface{}{creds.Key, intNonce, sign}, + Args: []interface{}{creds.Key, intNonce, crypto.HexEncodeToString(hmac)}, } - resp, err := by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.Unset, req.RequestID, req) + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.RequestID, req) if err != nil { return err } var response SubscriptionResponse - err = json.Unmarshal(resp, &response) - if err != nil { + if err := json.Unmarshal(resp, &response); err != nil { return err } if !response.Success { @@ -158,13 +125,14 @@ func (by *Bybit) WsAuth(ctx context.Context) error { } // Subscribe sends a websocket message to receive data from the channel -func (by *Bybit) Subscribe(channelsToSubscribe subscription.List) error { - return by.handleSpotSubscription("subscribe", channelsToSubscribe) +func (by *Bybit) Subscribe(ctx context.Context, conn stream.Connection, channelsToSubscribe subscription.List) error { + return by.handleSpotSubscription(ctx, conn, "subscribe", channelsToSubscribe) } -func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) (args []SubscriptionArgument, err error) { +func (by *Bybit) handleSubscriptions(conn stream.Connection, operation string, subs subscription.List) (args []SubscriptionArgument, err error) { subs, err = subs.ExpandTemplates(by) if err != nil { + fmt.Println("expandy silly", conn.GetURL()) return } chans := []string{} @@ -179,7 +147,7 @@ func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) ( for _, b := range common.Batch(chans, 10) { args = append(args, SubscriptionArgument{ Operation: operation, - RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), Arguments: b, }) } @@ -187,7 +155,7 @@ func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) ( args = append(args, SubscriptionArgument{ auth: true, Operation: operation, - RequestID: strconv.FormatInt(by.Websocket.Conn.GenerateMessageID(false), 10), + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), Arguments: authChans, }) } @@ -195,27 +163,20 @@ func (by *Bybit) handleSubscriptions(operation string, subs subscription.List) ( } // Unsubscribe sends a websocket message to stop receiving data from the channel -func (by *Bybit) Unsubscribe(channelsToUnsubscribe subscription.List) error { - return by.handleSpotSubscription("unsubscribe", channelsToUnsubscribe) +func (by *Bybit) Unsubscribe(ctx context.Context, conn stream.Connection, channelsToUnsubscribe subscription.List) error { + return by.handleSpotSubscription(ctx, conn, "unsubscribe", channelsToUnsubscribe) } -func (by *Bybit) handleSpotSubscription(operation string, channelsToSubscribe subscription.List) error { - payloads, err := by.handleSubscriptions(operation, channelsToSubscribe) +func (by *Bybit) handleSpotSubscription(ctx context.Context, conn stream.Connection, operation string, channelsToSubscribe subscription.List) error { + payloads, err := by.handleSubscriptions(conn, operation, channelsToSubscribe) if err != nil { return err } - for a := range payloads { + for _, payload := range payloads { var response []byte - if payloads[a].auth { - response, err = by.Websocket.AuthConn.SendMessageReturnResponse(context.TODO(), request.Unset, payloads[a].RequestID, payloads[a]) - if err != nil { - return err - } - } else { - response, err = by.Websocket.Conn.SendMessageReturnResponse(context.TODO(), request.Unset, payloads[a].RequestID, payloads[a]) - if err != nil { - return err - } + response, err = conn.SendMessageReturnResponse(ctx, request.Unset, payload.RequestID, payload) + if err != nil { + return err } var resp SubscriptionResponse err = json.Unmarshal(response, &resp) @@ -243,30 +204,9 @@ func (by *Bybit) GetSubscriptionTemplate(_ *subscription.Subscription) (*templat }).Parse(subTplText) } -// wsReadData receives and passes on websocket messages for processing -func (by *Bybit) wsReadData(assetType asset.Item, ws stream.Connection) { - defer by.Websocket.Wg.Done() - for { - select { - case <-by.Websocket.ShutdownC: - return - default: - resp := ws.ReadMessage() - if resp.Raw == nil { - return - } - err := by.wsHandleData(assetType, resp.Raw) - if err != nil { - by.Websocket.DataHandler <- err - } - } - } -} - -func (by *Bybit) wsHandleData(assetType asset.Item, respRaw []byte) error { +func (by *Bybit) wsHandleData(_ context.Context, respRaw []byte, assetType asset.Item) error { var result WebsocketResponse - err := json.Unmarshal(respRaw, &result) - if err != nil { + if err := json.Unmarshal(respRaw, &result); err != nil { return err } if result.Topic == "" { @@ -307,14 +247,16 @@ func (by *Bybit) wsHandleData(assetType asset.Item, respRaw []byte) error { return by.wsProcessLeverageTokenTicker(assetType, &result) case chanLeverageTokenNav: return by.wsLeverageTokenNav(&result) + // TODO: The following cases are coming from the dedicated authenticated websocket connection, this is asset + // agnostic and will need an update in a future PR to handle asset specific data. case chanPositions: return by.wsProcessPosition(&result) case chanExecution: - return by.wsProcessExecution(asset.Spot, &result) + return by.wsProcessExecution(assetType, &result) case chanOrder: - return by.wsProcessOrder(asset.Spot, &result) + return by.wsProcessOrder(assetType, &result) case chanWallet: - return by.wsProcessWalletPushData(asset.Spot, respRaw) + return by.wsProcessWalletPushData(assetType, respRaw) case chanGreeks: return by.wsProcessGreeks(respRaw) case chanDCP: @@ -450,7 +392,7 @@ func (by *Bybit) wsLeverageTokenNav(resp *WebsocketResponse) error { } func (by *Bybit) wsProcessLeverageTokenTicker(assetType asset.Item, resp *WebsocketResponse) error { - var result TickerItem + var result TickerWebsocket err := json.Unmarshal(resp.Data, &result) if err != nil { return err @@ -551,8 +493,8 @@ func (by *Bybit) wsProcessKline(assetType asset.Item, resp *WebsocketResponse, t } func (by *Bybit) wsProcessPublicTicker(assetType asset.Item, resp *WebsocketResponse) error { - tickResp := new(TickerItem) - if err := json.Unmarshal(resp.Data, tickResp); err != nil { + var tickResp TickerWebsocket + if err := json.Unmarshal(resp.Data, &tickResp); err != nil { return err } @@ -560,38 +502,24 @@ func (by *Bybit) wsProcessPublicTicker(assetType asset.Item, resp *WebsocketResp if err != nil { return err } - pFmt, err := by.GetPairFormat(assetType, false) - if err != nil { - return err - } - p = p.Format(pFmt) - var tick *ticker.Price - if resp.Type == "snapshot" { - tick = &ticker.Price{ - Pair: p, - ExchangeName: by.Name, - AssetType: assetType, - } - } else { + tick := &ticker.Price{Pair: p, ExchangeName: by.Name, AssetType: assetType} + snapshot, err := ticker.GetTicker(by.Name, p, assetType) + if err == nil && resp.Type != "snapshot" { // ticker updates may be partial, so we need to update the current ticker - tick, err = ticker.GetTicker(by.Name, p, assetType) - if err != nil { - return err - } + tick = snapshot } - updateTicker(tick, tickResp) + updateTicker(tick, &tickResp) tick.LastUpdated = resp.PushTimestamp.Time() if err = ticker.ProcessTicker(tick); err == nil { by.Websocket.DataHandler <- tick } - return err } -func updateTicker(tick *ticker.Price, resp *TickerItem) { +func updateTicker(tick *ticker.Price, resp *TickerWebsocket) { if resp.LastPrice.Float64() != 0 { tick.Last = resp.LastPrice.Float64() } @@ -777,3 +705,126 @@ const subTplText = ` func hasPotentialDelimiter(a asset.Item) bool { return a == asset.Options || a == asset.USDCMarginedFutures } + +// TODO: Remove this function when template expansion is across all assets +func (by *Bybit) handleSubscriptionNonTemplate(ctx context.Context, conn stream.Connection, a asset.Item, operation string, channelsToSubscribe subscription.List) error { + payloads, err := by.handleSubscriptionsNonTemplate(conn, a, operation, channelsToSubscribe) + if err != nil { + return err + } + for _, payload := range payloads { + if a == asset.Options { + // The options connection does not send the subscription request id back with the subscription notification payload + // therefore the code doesn't wait for the response to check whether the subscription is successful or not. + err = conn.SendJSONMessage(ctx, request.Unset, payload) + if err != nil { + return err + } + continue + } + var response []byte + response, err = conn.SendMessageReturnResponse(ctx, request.Unset, payload.RequestID, payload) + if err != nil { + return err + } + var resp SubscriptionResponse + err = json.Unmarshal(response, &resp) + if err != nil { + return err + } + if !resp.Success { + return fmt.Errorf("%s with request ID %s msg: %s", resp.Operation, resp.RequestID, resp.RetMsg) + } + } + return nil +} + +// TODO: Remove this function when template expansion is across all assets +func (by *Bybit) handleSubscriptionsNonTemplate(conn stream.Connection, assetType asset.Item, operation string, channelsToSubscribe subscription.List) ([]SubscriptionArgument, error) { + var args []SubscriptionArgument + arg := SubscriptionArgument{ + Operation: operation, + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + Arguments: []string{}, + } + authArg := SubscriptionArgument{ + auth: true, + Operation: operation, + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + Arguments: []string{}, + } + + chanMap := map[string]bool{} + pairFormat, err := by.GetPairFormat(assetType, true) + if err != nil { + return nil, err + } + for i := range channelsToSubscribe { + if len(channelsToSubscribe[i].Pairs) != 1 { + return nil, subscription.ErrNotSinglePair + } + pair := channelsToSubscribe[i].Pairs[0] + switch channelsToSubscribe[i].Channel { + case chanOrderbook: + arg.Arguments = append(arg.Arguments, fmt.Sprintf("%s.%d.%s", channelsToSubscribe[i].Channel, 50, pair.Format(pairFormat).String())) + case chanPublicTrade, chanPublicTicker, chanLiquidation, chanLeverageTokenTicker, chanLeverageTokenNav: + arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+pair.Format(pairFormat).String()) + case chanKline, chanLeverageTokenKline: + interval, err := intervalToString(kline.FiveMin) + if err != nil { + return nil, err + } + arg.Arguments = append(arg.Arguments, channelsToSubscribe[i].Channel+"."+interval+"."+pair.Format(pairFormat).String()) + case chanPositions, chanExecution, chanOrder, chanWallet, chanGreeks, chanDCP: + if chanMap[channelsToSubscribe[i].Channel] { + continue + } + authArg.Arguments = append(authArg.Arguments, channelsToSubscribe[i].Channel) + // adding the channel to selected channels so that we will not visit it again. + chanMap[channelsToSubscribe[i].Channel] = true + } + if len(arg.Arguments) >= 10 { + args = append(args, arg) + arg = SubscriptionArgument{ + Operation: operation, + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + Arguments: []string{}, + } + } + } + if len(arg.Arguments) != 0 { + args = append(args, arg) + } + if len(authArg.Arguments) != 0 { + args = append(args, authArg) + } + return args, nil +} + +// generateAuthSubscriptions generates default subscription for the dedicated auth websocket connection. These are +// agnostic to the asset type and pair as all account level data will be routed through this connection. +// TODO: Remove this function when template expansion is across all assets +func (by *Bybit) generateAuthSubscriptions() (subscription.List, error) { + if !by.Websocket.CanUseAuthenticatedEndpoints() { + return nil, nil + } + var subscriptions subscription.List + for _, channel := range []string{chanPositions, chanExecution, chanOrder, chanWallet} { + subscriptions = append(subscriptions, &subscription.Subscription{ + Channel: channel, + Pairs: currency.Pairs{currency.EMPTYPAIR}, // This is a placeholder, the actual pair is not required for these channels + Asset: asset.All, + }) + } + return subscriptions, nil +} + +// LinearSubscribe sends a subscription message to linear public channels. +func (by *Bybit) authSubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.Spot, "subscribe", channelSubscriptions) +} + +// LinearUnsubscribe sends an unsubscription messages through linear public channels. +func (by *Bybit) authUnsubscribe(ctx context.Context, conn stream.Connection, channelSubscriptions subscription.List) error { + return by.handleSubscriptionNonTemplate(ctx, conn, asset.Spot, "unsubscribe", channelSubscriptions) +} diff --git a/exchanges/bybit/bybit_wrapper.go b/exchanges/bybit/bybit_wrapper.go index 662f1e900b1..8442195f881 100644 --- a/exchanges/bybit/bybit_wrapper.go +++ b/exchanges/bybit/bybit_wrapper.go @@ -27,6 +27,7 @@ import ( "github.com/thrasher-corp/gocryptotrader/exchanges/request" "github.com/thrasher-corp/gocryptotrader/exchanges/stream" "github.com/thrasher-corp/gocryptotrader/exchanges/stream/buffer" + "github.com/thrasher-corp/gocryptotrader/exchanges/subscription" "github.com/thrasher-corp/gocryptotrader/exchanges/ticker" "github.com/thrasher-corp/gocryptotrader/exchanges/trade" "github.com/thrasher-corp/gocryptotrader/log" @@ -67,12 +68,6 @@ func (by *Bybit) SetDefaults() { } } - for _, a := range []asset.Item{asset.CoinMarginedFutures, asset.USDTMarginedFutures, asset.USDCMarginedFutures, asset.Options} { - if err := by.DisableAssetWebsocketSupport(a); err != nil { - log.Errorln(log.ExchangeSys, err) - } - } - by.Features = exchange.Features{ CurrencyTranslations: currency.NewTranslations( map[currency.Code]currency.Code{ @@ -214,65 +209,145 @@ func (by *Bybit) SetDefaults() { // Setup takes in the supplied exchange configuration details and sets params func (by *Bybit) Setup(exch *config.Exchange) error { - err := exch.Validate() - if err != nil { + if err := exch.Validate(); err != nil { return err } + if !exch.Enabled { by.SetEnabled(false) return nil } - err = by.SetupDefaults(exch) - if err != nil { + if err := by.SetupDefaults(exch); err != nil { return err } - wsRunningEndpoint, err := by.API.Endpoints.GetURL(exchange.WebsocketSpot) - if err != nil { + if err := by.Websocket.Setup(&stream.WebsocketSetup{ + ExchangeConfig: exch, + RunningURLAuth: websocketPrivate, + Features: &by.Features.Supports.WebsocketCapabilities, + OrderbookBufferConfig: buffer.Config{SortBuffer: true, SortBufferByUpdateIDs: true}, + TradeFeed: by.Features.Enabled.TradeFeed, + UseMultiConnectionManagement: true, + }); err != nil { return err } - err = by.Websocket.Setup( - &stream.WebsocketSetup{ - ExchangeConfig: exch, - DefaultURL: spotPublic, - RunningURL: wsRunningEndpoint, - RunningURLAuth: websocketPrivate, - Connector: by.WsConnect, - Subscriber: by.Subscribe, - Unsubscriber: by.Unsubscribe, - GenerateSubscriptions: by.generateSubscriptions, - Features: &by.Features.Supports.WebsocketCapabilities, - OrderbookBufferConfig: buffer.Config{ - SortBuffer: true, - SortBufferByUpdateIDs: true, - }, - TradeFeed: by.Features.Enabled.TradeFeed, - }) - if err != nil { + // Spot + if err := by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: spotPublic, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + RateLimit: request.NewWeightedRateLimitByDuration(time.Microsecond), + Connector: by.WsConnect, + GenerateSubscriptions: func() (subscription.List, error) { return by.generateSubscriptions() }, + Subscriber: by.Subscribe, + Unsubscriber: by.Unsubscribe, + Handler: func(ctx context.Context, resp []byte) error { return by.wsHandleData(ctx, resp, asset.Spot) }, + BespokeGenerateMessageID: by.bespokeWebsocketRequestID, + }); err != nil { return err } - err = by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: by.Websocket.GetWebsocketURL(), + + // Options + if err := by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: optionPublic, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + RateLimit: request.NewWeightedRateLimitByDuration(time.Microsecond), + Connector: by.WsConnect, + GenerateSubscriptions: by.GenerateOptionsDefaultSubscriptions, + Subscriber: by.OptionSubscribe, + Unsubscriber: by.OptionUnsubscribe, + Handler: func(ctx context.Context, resp []byte) error { return by.wsHandleData(ctx, resp, asset.Options) }, + BespokeGenerateMessageID: by.bespokeWebsocketRequestID, + }); err != nil { + return err + } + + // Linear - USDT margined futures. + if err := by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: linearPublic, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: bybitWebsocketTimer, - }) - if err != nil { + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + RateLimit: request.NewWeightedRateLimitByDuration(time.Microsecond), + Connector: by.WsConnect, + GenerateSubscriptions: func() (subscription.List, error) { + return by.GenerateLinearDefaultSubscriptions(asset.USDTMarginedFutures) + }, + Subscriber: by.LinearSubscribe, + Unsubscriber: by.LinearUnsubscribe, + Handler: func(ctx context.Context, resp []byte) error { + return by.wsHandleData(ctx, resp, asset.USDTMarginedFutures) + }, + BespokeGenerateMessageID: by.bespokeWebsocketRequestID, + WrapperDefinedConnectionSignature: asset.USDTMarginedFutures, // Unused but it allows us to differentiate between the two linear futures types. + }); err != nil { return err } - return by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: websocketPrivate, + // Linear - USDC margined futures. + if err := by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: linearPublic, ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Authenticated: true, + RateLimit: request.NewWeightedRateLimitByDuration(time.Microsecond), + Connector: by.WsConnect, + GenerateSubscriptions: func() (subscription.List, error) { + return by.GenerateLinearDefaultSubscriptions(asset.USDCMarginedFutures) + }, + Subscriber: by.LinearSubscribe, + Unsubscriber: by.LinearUnsubscribe, + Handler: func(ctx context.Context, resp []byte) error { + return by.wsHandleData(ctx, resp, asset.USDCMarginedFutures) + }, + BespokeGenerateMessageID: by.bespokeWebsocketRequestID, + WrapperDefinedConnectionSignature: asset.USDCMarginedFutures, // Unused but it allows us to differentiate between the two linear futures types. + }); err != nil { + return err + } + + // Inverse - Coin margined futures. + if err := by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: inversePublic, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + RateLimit: request.NewWeightedRateLimitByDuration(time.Microsecond), + Connector: by.WsConnect, + GenerateSubscriptions: by.GenerateInverseDefaultSubscriptions, + Subscriber: by.InverseSubscribe, + Unsubscriber: by.InverseUnsubscribe, + Handler: func(ctx context.Context, resp []byte) error { + return by.wsHandleData(ctx, resp, asset.CoinMarginedFutures) + }, + BespokeGenerateMessageID: by.bespokeWebsocketRequestID, + }); err != nil { + return err + } + + // Private + return by.Websocket.SetupNewConnection(&stream.ConnectionSetup{ + URL: websocketPrivate, + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + RateLimit: request.NewWeightedRateLimitByDuration(time.Microsecond), + Authenticated: true, + Connector: by.WsConnect, + GenerateSubscriptions: by.generateAuthSubscriptions, + Subscriber: by.authSubscribe, + Unsubscriber: by.authUnsubscribe, + // Private websocket data is handled by the same function as the public data. Intentionally set asset as asset.All. + // As all asset type order execution, wallet and other data is centralised through the private websocket connection. + // TODO: Handle private websocket data to be asset specific. + Handler: func(ctx context.Context, resp []byte) error { return by.wsHandleData(ctx, resp, asset.All) }, + BespokeGenerateMessageID: by.bespokeWebsocketRequestID, + Authenticate: by.WebsocketAuthenticateConnection, }) } -// AuthenticateWebsocket sends an authentication message to the websocket -func (by *Bybit) AuthenticateWebsocket(ctx context.Context) error { - return by.WsAuth(ctx) +// bespokeWebsocketRequestID generates a unique ID for websocket requests, this is just a simple counter. +func (by *Bybit) bespokeWebsocketRequestID(bool) int64 { + return by.Counter.IncrementAndGet() } // FetchTradablePairs returns a list of the exchanges tradable pairs diff --git a/exchanges/exchange.go b/exchanges/exchange.go index 938923e2332..ec7bd55daf8 100644 --- a/exchanges/exchange.go +++ b/exchanges/exchange.go @@ -975,8 +975,7 @@ func (b *Base) SupportsAsset(a asset.Item) bool { // PrintEnabledPairs prints the exchanges enabled asset pairs func (b *Base) PrintEnabledPairs() { for k, v := range b.CurrencyPairs.Pairs { - log.Infof(log.ExchangeSys, "%s Asset type %v:\n\t Enabled pairs: %v", - b.Name, strings.ToUpper(k.String()), v.Enabled) + log.Infof(log.ExchangeSys, "%s Asset type %v:\n\t Enabled pairs: %v", b.Name, strings.ToUpper(k.String()), v.Enabled) } } @@ -987,10 +986,7 @@ func (b *Base) GetBase() *Base { return b } // for validation of API credentials func (b *Base) CheckTransientError(err error) error { if _, ok := err.(net.Error); ok { - log.Warnf(log.ExchangeSys, - "%s net error captured, will not disable authentication %s", - b.Name, - err) + log.Warnf(log.ExchangeSys, "%s net error captured, will not disable authentication %s", b.Name, err) return nil } return err @@ -1947,3 +1943,8 @@ func (b *Base) GetTradingRequirements() protocol.TradingRequirements { } return b.Features.TradingRequirements } + +// WebsocketSubmitOrder submits an order to the exchange via a websocket connection +func (*Base) WebsocketSubmitOrder(context.Context, *order.Submit) (*order.SubmitResponse, error) { + return nil, common.ErrFunctionNotSupported +} diff --git a/exchanges/exchange_test.go b/exchanges/exchange_test.go index a5db589d047..339aafbcbbe 100644 --- a/exchanges/exchange_test.go +++ b/exchanges/exchange_test.go @@ -3076,3 +3076,8 @@ func TestGetTradingRequirements(t *testing.T) { requirements = (&Base{Features: Features{TradingRequirements: protocol.TradingRequirements{ClientOrderID: true}}}).GetTradingRequirements() require.NotEmpty(t, requirements) } + +func TestWebsocketSubmitOrder(t *testing.T) { + _, err := (&Base{}).WebsocketSubmitOrder(context.Background(), nil) + require.ErrorIs(t, err, common.ErrFunctionNotSupported) +} diff --git a/exchanges/gateio/gateio_test.go b/exchanges/gateio/gateio_test.go index 477e0d1616f..d15a6f87ab4 100644 --- a/exchanges/gateio/gateio_test.go +++ b/exchanges/gateio/gateio_test.go @@ -367,7 +367,7 @@ func TestCreateSpotOrder(t *testing.T) { func TestGetSpotOrders(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - if _, err := g.GetSpotOrders(context.Background(), currency.Pair{Base: currency.BTC, Quote: currency.USDT, Delimiter: currency.UnderscoreDelimiter}, "open", 0, 0); err != nil { + if _, err := g.GetSpotOrders(context.Background(), currency.Pair{Base: currency.BTC, Quote: currency.USDT, Delimiter: currency.UnderscoreDelimiter}, statusOpen, 0, 0); err != nil { t.Errorf("%s GetSpotOrders() error %v", g.Name, err) } } @@ -489,7 +489,7 @@ func TestCreatePriceTriggeredOrder(t *testing.T) { func TestGetPriceTriggeredOrderList(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - if _, err := g.GetPriceTriggeredOrderList(context.Background(), "open", currency.EMPTYPAIR, asset.Empty, 0, 0); err != nil { + if _, err := g.GetPriceTriggeredOrderList(context.Background(), statusOpen, currency.EMPTYPAIR, asset.Empty, 0, 0); err != nil { t.Errorf("%s GetPriceTriggeredOrderList() error %v", g.Name, err) } } @@ -563,7 +563,7 @@ func TestMarginLoan(t *testing.T) { func TestGetMarginAllLoans(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - if _, err := g.GetMarginAllLoans(context.Background(), "open", "lend", "", currency.BTC, currency.Pair{Base: currency.BTC, Delimiter: currency.UnderscoreDelimiter, Quote: currency.USDT}, false, 0, 0); err != nil { + if _, err := g.GetMarginAllLoans(context.Background(), statusOpen, "lend", "", currency.BTC, currency.Pair{Base: currency.BTC, Delimiter: currency.UnderscoreDelimiter, Quote: currency.USDT}, false, 0, 0); err != nil { t.Errorf("%s GetMarginAllLoans() error %v", g.Name, err) } } @@ -1126,7 +1126,7 @@ func TestGetDeliveryOrders(t *testing.T) { sharedtestvalues.SkipTestIfCredentialsUnset(t, g) settle, err := getSettlementFromCurrency(getPair(t, asset.DeliveryFutures)) require.NoError(t, err, "getSettlementFromCurrency must not error") - _, err = g.GetDeliveryOrders(context.Background(), getPair(t, asset.DeliveryFutures), "open", settle, "", 0, 0, 1) + _, err = g.GetDeliveryOrders(context.Background(), getPair(t, asset.DeliveryFutures), statusOpen, settle, "", 0, 0, 1) assert.NoError(t, err, "GetDeliveryOrders should not error") } @@ -1208,7 +1208,7 @@ func TestGetDeliveryPriceTriggeredOrder(t *testing.T) { func TestGetDeliveryAllAutoOrder(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - _, err := g.GetDeliveryAllAutoOrder(context.Background(), "open", currency.USDT, getPair(t, asset.DeliveryFutures), 0, 1) + _, err := g.GetDeliveryAllAutoOrder(context.Background(), statusOpen, currency.USDT, getPair(t, asset.DeliveryFutures), 0, 1) assert.NoError(t, err, "GetDeliveryAllAutoOrder should not error") } @@ -1297,7 +1297,7 @@ func TestPlaceFuturesOrder(t *testing.T) { func TestGetFuturesOrders(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - _, err := g.GetFuturesOrders(context.Background(), currency.NewPair(currency.BTC, currency.USD), "open", "", currency.BTC, 0, 0, 1) + _, err := g.GetFuturesOrders(context.Background(), currency.NewPair(currency.BTC, currency.USD), statusOpen, "", currency.BTC, 0, 0, 1) assert.NoError(t, err, "GetFuturesOrders should not error") } @@ -1434,7 +1434,7 @@ func TestCreatePriceTriggeredFuturesOrder(t *testing.T) { func TestListAllFuturesAutoOrders(t *testing.T) { t.Parallel() sharedtestvalues.SkipTestIfCredentialsUnset(t, g) - _, err := g.ListAllFuturesAutoOrders(context.Background(), "open", currency.BTC, currency.EMPTYPAIR, 0, 0) + _, err := g.ListAllFuturesAutoOrders(context.Background(), statusOpen, currency.BTC, currency.EMPTYPAIR, 0, 0) assert.NoError(t, err, "ListAllFuturesAutoOrders should not error") } @@ -2601,7 +2601,7 @@ func TestFuturesPositionsNotification(t *testing.T) { } } -const wsFuturesAutoOrdersPushDataJSON = `{"time": 1596798126,"channel": "futures.autoorders", "event": "update", "error": null, "result": [ { "user": 123456, "trigger": { "strategy_type": 0, "price_type": 0, "price": "10000", "rule": 2, "expiration": 86400 }, "initial": { "contract": "BTC_USDT", "size": 10, "price": "10000", "tif": "gtc", "text": "web", "iceberg": 0, "is_close": false, "is_reduce_only": false }, "id": 9256, "trade_id": 0, "status": "open", "reason": "", "create_time": 1596798126, "name": "price_autoorders", "is_stop_order": false, "stop_trigger": { "rule": 0, "trigger_price": "", "order_price": "" } } ]}` +const wsFuturesAutoOrdersPushDataJSON = `{"time": 1596798126,"channel": "futures.autoorders", "event": "update", "error": null, "result": [ { "user": 123456, "trigger": { "strategy_type": 0, "price_type": 0, "price": "10000", "rule": 2, "expiration": 86400 }, "initial": { "contract": "BTC_USDT", "size": 10, "price": "10000", "tif": "gtc", "text": "web", "iceberg": 0, "is_close": false, "is_reduce_only": false }, "id": 9256, "trade_id": 0, "status": "OPEN", "reason": "", "create_time": 1596798126, "name": "price_autoorders", "is_stop_order": false, "stop_trigger": { "rule": 0, "trigger_price": "", "order_price": "" } } ]}` func TestFuturesAutoOrderPushData(t *testing.T) { t.Parallel() @@ -3011,11 +3011,11 @@ func TestGetSettlementFromCurrency(t *testing.T) { for _, assetType := range []asset.Item{asset.Futures, asset.DeliveryFutures, asset.Options} { availPairs, err := g.GetAvailablePairs(assetType) require.NoErrorf(t, err, "GetAvailablePairs for asset %s must not error", assetType) - for x := range availPairs { - t.Run(strconv.Itoa(x), func(t *testing.T) { + for x, pair := range availPairs { + t.Run(strconv.Itoa(x)+":"+assetType.String(), func(t *testing.T) { t.Parallel() - _, err = getSettlementFromCurrency(availPairs[x]) - assert.NoErrorf(t, err, "getSettlementFromCurrency should not error for pair %s and asset %s", availPairs[x], assetType) + _, err := getSettlementFromCurrency(pair) + assert.NoErrorf(t, err, "getSettlementFromCurrency should not error for pair %s and asset %s", pair, assetType) }) } } @@ -3439,7 +3439,7 @@ func TestProcessFuturesOrdersPushData(t *testing.T) { incoming string status order.Status }{ - {`{"channel":"futures.orders","event":"update","time":1541505434,"time_ms":1541505434123,"result":[{"contract":"BTC_USD","create_time":1628736847,"create_time_ms":1628736847325,"fill_price":40000.4,"finish_as":"","finish_time":1628736848,"finish_time_ms":1628736848321,"iceberg":0,"id":4872460,"is_close":false,"is_liq":false,"is_reduce_only":false,"left":0,"mkfr":-0.00025,"price":40000.4,"refr":0,"refu":0,"size":1,"status":"open","text":"-","tif":"gtc","tkfr":0.0005,"user":"110xxxxx"}]}`, order.Open}, + {`{"channel":"futures.orders","event":"update","time":1541505434,"time_ms":1541505434123,"result":[{"contract":"BTC_USD","create_time":1628736847,"create_time_ms":1628736847325,"fill_price":40000.4,"finish_as":"","finish_time":1628736848,"finish_time_ms":1628736848321,"iceberg":0,"id":4872460,"is_close":false,"is_liq":false,"is_reduce_only":false,"left":0,"mkfr":-0.00025,"price":40000.4,"refr":0,"refu":0,"size":1,"status":"OPEN","text":"-","tif":"gtc","tkfr":0.0005,"user":"110xxxxx"}]}`, order.Open}, {`{"channel":"futures.orders","event":"update","time":1541505434,"time_ms":1541505434123,"result":[{"contract":"BTC_USD","create_time":1628736847,"create_time_ms":1628736847325,"fill_price":40000.4,"finish_as":"filled","finish_time":1628736848,"finish_time_ms":1628736848321,"iceberg":0,"id":4872460,"is_close":false,"is_liq":false,"is_reduce_only":false,"left":0,"mkfr":-0.00025,"price":40000.4,"refr":0,"refu":0,"size":1,"status":"finished","text":"-","tif":"gtc","tkfr":0.0005,"user":"110xxxxx"}]}`, order.Filled}, {`{"channel":"futures.orders","event":"update","time":1541505434,"time_ms":1541505434123,"result":[{"contract":"BTC_USD","create_time":1628736847,"create_time_ms":1628736847325,"fill_price":40000.4,"finish_as":"cancelled","finish_time":1628736848,"finish_time_ms":1628736848321,"iceberg":0,"id":4872460,"is_close":false,"is_liq":false,"is_reduce_only":false,"left":0,"mkfr":-0.00025,"price":40000.4,"refr":0,"refu":0,"size":1,"status":"finished","text":"-","tif":"gtc","tkfr":0.0005,"user":"110xxxxx"}]}`, order.Cancelled}, {`{"channel":"futures.orders","event":"update","time":1541505434,"time_ms":1541505434123,"result":[{"contract":"BTC_USD","create_time":1628736847,"create_time_ms":1628736847325,"fill_price":40000.4,"finish_as":"liquidated","finish_time":1628736848,"finish_time_ms":1628736848321,"iceberg":0,"id":4872460,"is_close":false,"is_liq":false,"is_reduce_only":false,"left":0,"mkfr":-0.00025,"price":40000.4,"refr":0,"refu":0,"size":1,"status":"finished","text":"-","tif":"gtc","tkfr":0.0005,"user":"110xxxxx"}]}`, order.Liquidated}, diff --git a/exchanges/gateio/gateio_types.go b/exchanges/gateio/gateio_types.go index c821677829e..dedc8bed42a 100644 --- a/exchanges/gateio/gateio_types.go +++ b/exchanges/gateio/gateio_types.go @@ -1384,6 +1384,8 @@ type CreateOrderRequestData struct { Price types.Number `json:"price,omitempty"` TimeInForce string `json:"time_in_force,omitempty"` AutoBorrow bool `json:"auto_borrow,omitempty"` + AutoRepay bool `json:"auto_repay,omitempty"` + StpAct string `json:"stp_act,omitempty"` } // SpotOrder represents create order response. @@ -1814,15 +1816,16 @@ type DualModeResponse struct { // OrderCreateParams represents future order creation parameters type OrderCreateParams struct { Contract currency.Pair `json:"contract"` - Size float64 `json:"size"` - Iceberg int64 `json:"iceberg"` - Price string `json:"price"` // NOTE: Market orders require string "0" + Size float64 `json:"size"` // positive long, negative short + Iceberg int64 `json:"iceberg"` // required; can be zero + Price string `json:"price"` // NOTE: Market orders require string "0" TimeInForce string `json:"tif"` Text string `json:"text,omitempty"` // Omitempty required as payload sent as `text:""` will return error message: Text content not starting with `t-`" ClosePosition bool `json:"close,omitempty"` // Size needs to be zero if true ReduceOnly bool `json:"reduce_only,omitempty"` - AutoSize string `json:"auto_size,omitempty"` - Settle currency.Code `json:"-"` // Used in URL. + AutoSize string `json:"auto_size,omitempty"` // either close_long or close_short, size needs to be zero. + Settle currency.Code `json:"-"` // Used in URL. REST Calls only. + StpAct string `json:"stp_act,omitempty"` } // Order represents future order response @@ -2008,12 +2011,13 @@ type WsEventResponse struct { // WsResponse represents generalized websocket push data from the server. type WsResponse struct { - ID int64 `json:"id"` - Time types.Time `json:"time"` - TimeMs types.Time `json:"time_ms"` - Channel string `json:"channel"` - Event string `json:"event"` - Result json.RawMessage `json:"result"` + ID int64 `json:"id"` + Time types.Time `json:"time"` + TimeMs types.Time `json:"time_ms"` + Channel string `json:"channel"` + Event string `json:"event"` + Result json.RawMessage `json:"result"` + RequestID string `json:"request_id"` } // WsTicker websocket ticker information. diff --git a/exchanges/gateio/gateio_websocket.go b/exchanges/gateio/gateio_websocket.go index d00f03b451f..69c5760fc45 100644 --- a/exchanges/gateio/gateio_websocket.go +++ b/exchanges/gateio/gateio_websocket.go @@ -109,11 +109,17 @@ func (g *Gateio) generateWsSignature(secret, event, channel string, t int64) (st // WsHandleSpotData handles spot data func (g *Gateio) WsHandleSpotData(_ context.Context, respRaw []byte) error { var push WsResponse - err := json.Unmarshal(respRaw, &push) - if err != nil { + if err := json.Unmarshal(respRaw, &push); err != nil { return err } + if push.RequestID != "" { + if !g.Websocket.Match.IncomingWithData(push.RequestID, respRaw) { + return fmt.Errorf("gateio_websocket.go error - unable to match requestID %v", push.RequestID) + } + return nil + } + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) diff --git a/exchanges/gateio/gateio_websocket_request_futures.go b/exchanges/gateio/gateio_websocket_request_futures.go new file mode 100644 index 00000000000..26fbc68ba25 --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_futures.go @@ -0,0 +1,205 @@ +package gateio + +import ( + "context" + "errors" + "fmt" + + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" +) + +var ( + errInvalidAutoSize = errors.New("invalid auto size") + errSettlementCurrencyConflict = errors.New("settlement currency conflict") + errInvalidSide = errors.New("invalid side") + errStatusNotSet = errors.New("status not set") +) + +// AuthenticateFutures sends an authentication message to the websocket connection +func (g *Gateio) authenticateFutures(ctx context.Context, conn stream.Connection) error { + _, err := g.websocketLogin(ctx, conn, "futures.login") + return err +} + +// WebsocketOrderPlaceFutures places an order via the websocket connection. You can +// send multiple orders in a single request. NOTE: When sending multiple orders +// the response will be an array of responses and a succeeded bool will be +// returned in the response. +func (g *Gateio) WebsocketOrderPlaceFutures(ctx context.Context, batch []OrderCreateParams) ([]WebsocketFuturesOrderResponse, error) { + if len(batch) == 0 { + return nil, errBatchSliceEmpty + } + + var a asset.Item + for i := range batch { + if batch[i].Contract.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if batch[i].Price == "" && batch[i].TimeInForce != "ioc" { + return nil, fmt.Errorf("%w: cannot be zero when time in force is not IOC", errInvalidPrice) + } + + if batch[i].Size == 0 && batch[i].AutoSize == "" { + return nil, fmt.Errorf("%w: size cannot be zero", errInvalidAmount) + } + + if batch[i].AutoSize != "" { + if batch[i].AutoSize != "close_long" && batch[i].AutoSize != "close_short" { + return nil, fmt.Errorf("%w: %s", errInvalidAutoSize, batch[i].AutoSize) + } + if batch[i].Size != 0 { + return nil, fmt.Errorf("%w: size needs to be zero when auto size is set", errInvalidAmount) + } + } + + switch { + case batch[i].Contract.Quote.Equal(currency.USDT): + if a != asset.Empty && a != asset.USDTMarginedFutures { + return nil, fmt.Errorf("%w: either btc or usdt margined can only be batched as they are using different connections", errSettlementCurrencyConflict) + } + a = asset.USDTMarginedFutures + case batch[i].Contract.Quote.Equal(currency.USD): + if a != asset.Empty && a != asset.CoinMarginedFutures { + return nil, fmt.Errorf("%w: either btc or usdt margined can only be batched as they are using different connections", errSettlementCurrencyConflict) + } + a = asset.CoinMarginedFutures + } + } + + if len(batch) == 1 { + var singleResponse WebsocketFuturesOrderResponse + err := g.SendWebsocketRequest(ctx, "futures.order_place", a, batch[0], &singleResponse, 2) + return []WebsocketFuturesOrderResponse{singleResponse}, err + } + + var resp []WebsocketFuturesOrderResponse + return resp, g.SendWebsocketRequest(ctx, "futures.order_batch_place", a, batch, &resp, 2) +} + +// WebsocketOrderCancelFutures cancels an order via the websocket connection. +// Contract is used for routing the request internally to the correct connection. +func (g *Gateio) WebsocketOrderCancelFutures(ctx context.Context, orderID string, contract currency.Pair) (*WebsocketFuturesOrderResponse, error) { + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + + if contract.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + a := asset.USDTMarginedFutures + if contract.Quote.Equal(currency.USD) { + a = asset.CoinMarginedFutures + } + + params := &struct { + OrderID string `json:"order_id"` + }{OrderID: orderID} + + var resp WebsocketFuturesOrderResponse + return &resp, g.SendWebsocketRequest(ctx, "futures.order_cancel", a, params, &resp, 1) +} + +// WebsocketOrderCancelAllOpenFuturesOrdersMatched cancels multiple orders via +// the websocket. +func (g *Gateio) WebsocketOrderCancelAllOpenFuturesOrdersMatched(ctx context.Context, contract currency.Pair, side string) ([]WebsocketFuturesOrderResponse, error) { + if contract.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if side != "" && side != "ask" && side != "bid" { + return nil, fmt.Errorf("%w: %s", errInvalidSide, side) + } + + params := struct { + Contract currency.Pair `json:"contract"` + Side string `json:"side,omitempty"` + }{Contract: contract, Side: side} + + a := asset.USDTMarginedFutures + if contract.Quote.Equal(currency.USD) { + a = asset.CoinMarginedFutures + } + + var resp []WebsocketFuturesOrderResponse + return resp, g.SendWebsocketRequest(ctx, "futures.order_cancel_cp", a, params, &resp, 2) +} + +// WebsocketOrderAmendFutures amends an order via the websocket connection +func (g *Gateio) WebsocketOrderAmendFutures(ctx context.Context, amend *WebsocketFuturesAmendOrder) (*WebsocketFuturesOrderResponse, error) { + if amend == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, amend) + } + + if amend.OrderID == "" { + return nil, order.ErrOrderIDNotSet + } + + if amend.Contract.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if amend.Size == 0 && amend.Price == "" { + return nil, fmt.Errorf("%w: size or price must be set", errInvalidAmount) + } + + a := asset.USDTMarginedFutures + if amend.Contract.Quote.Equal(currency.USD) { + a = asset.CoinMarginedFutures + } + + var resp WebsocketFuturesOrderResponse + return &resp, g.SendWebsocketRequest(ctx, "futures.order_amend", a, amend, &resp, 1) +} + +// WebsocketOrderListFutures fetches a list of orders via the websocket connection +func (g *Gateio) WebsocketOrderListFutures(ctx context.Context, list *WebsocketFutureOrdersList) ([]WebsocketFuturesOrderResponse, error) { + if list == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, list) + } + + if list.Contract.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if list.Status == "" { + return nil, errStatusNotSet + } + + a := asset.USDTMarginedFutures + if list.Contract.Quote.Equal(currency.USD) { + a = asset.CoinMarginedFutures + } + + var resp []WebsocketFuturesOrderResponse + return resp, g.SendWebsocketRequest(ctx, "futures.order_list", a, list, &resp, 1) +} + +// WebsocketGetOrderStatusFutures gets the status of an order via the websocket +// connection. +func (g *Gateio) WebsocketGetOrderStatusFutures(ctx context.Context, contract currency.Pair, orderID string) (*WebsocketFuturesOrderResponse, error) { + if contract.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + + params := &struct { + OrderID string `json:"order_id"` + }{OrderID: orderID} + + a := asset.USDTMarginedFutures + if contract.Quote.Equal(currency.USD) { + a = asset.CoinMarginedFutures + } + + var resp WebsocketFuturesOrderResponse + return &resp, g.SendWebsocketRequest(ctx, "futures.order_status", a, params, &resp, 1) +} diff --git a/exchanges/gateio/gateio_websocket_request_futures_test.go b/exchanges/gateio/gateio_websocket_request_futures_test.go new file mode 100644 index 00000000000..9cfba680fa2 --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_futures_test.go @@ -0,0 +1,197 @@ +package gateio + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" +) + +func TestWebsocketOrderPlaceFutures(t *testing.T) { + t.Parallel() + _, err := g.WebsocketOrderPlaceFutures(context.Background(), nil) + require.ErrorIs(t, err, errBatchSliceEmpty) + _, err = g.WebsocketOrderPlaceFutures(context.Background(), make([]OrderCreateParams, 1)) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + out := OrderCreateParams{} + _, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out}) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + out.Contract, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out}) + require.ErrorIs(t, err, errInvalidPrice) + + out.Price = "40000" + _, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out}) + require.ErrorIs(t, err, errInvalidAmount) + + out.Size = 1 // 1 lovely long contract + out.AutoSize = "silly_billies" + _, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out}) + require.ErrorIs(t, err, errInvalidAutoSize) + + out.AutoSize = "close_long" + _, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out}) + require.ErrorIs(t, err, errInvalidAmount) + + out.AutoSize = "" + outBad := out + outBad.Contract, err = currency.NewPairFromString("BTC_USD") + require.NoError(t, err) + + _, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out, outBad}) + require.ErrorIs(t, err, errSettlementCurrencyConflict) + + outBad.Contract, out.Contract = out.Contract, outBad.Contract // swapsies + _, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out, outBad}) + require.ErrorIs(t, err, errSettlementCurrencyConflict) + + outBad.Contract, out.Contract = out.Contract, outBad.Contract // swapsies back + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + // test single order + got, err := g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out}) + require.NoError(t, err) + require.NotEmpty(t, got) + + // test batch orders + got, err = g.WebsocketOrderPlaceFutures(context.Background(), []OrderCreateParams{out, out}) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderCancelFutures(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketOrderCancelFutures(context.Background(), "", currency.EMPTYPAIR) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + _, err = g.WebsocketOrderCancelFutures(context.Background(), "42069", currency.EMPTYPAIR) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + pair, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + got, err := g.WebsocketOrderCancelFutures(context.Background(), "513160761072", pair) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderCancelAllOpenFuturesOrdersMatched(t *testing.T) { + t.Parallel() + _, err := g.WebsocketOrderCancelAllOpenFuturesOrdersMatched(context.Background(), currency.EMPTYPAIR, "") + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + pair, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + _, err = g.WebsocketOrderCancelAllOpenFuturesOrdersMatched(context.Background(), pair, "bruh") + require.ErrorIs(t, err, errInvalidSide) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketOrderCancelAllOpenFuturesOrdersMatched(context.Background(), pair, "bid") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderAmendFutures(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketOrderAmendFutures(context.Background(), nil) + require.ErrorIs(t, err, common.ErrNilPointer) + + amend := &WebsocketFuturesAmendOrder{} + _, err = g.WebsocketOrderAmendFutures(context.Background(), amend) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + amend.OrderID = "1337" + _, err = g.WebsocketOrderAmendFutures(context.Background(), amend) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + amend.Contract, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderAmendFutures(context.Background(), amend) + require.ErrorIs(t, err, errInvalidAmount) + + amend.Size = 2 + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + amend.OrderID = "513170215869" + got, err := g.WebsocketOrderAmendFutures(context.Background(), amend) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderListFutures(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketOrderListFutures(context.Background(), nil) + require.ErrorIs(t, err, common.ErrNilPointer) + + list := &WebsocketFutureOrdersList{} + _, err = g.WebsocketOrderListFutures(context.Background(), list) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + list.Contract, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderListFutures(context.Background(), list) + require.ErrorIs(t, err, errStatusNotSet) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + list.Status = statusOpen + got, err := g.WebsocketOrderListFutures(context.Background(), list) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketGetOrderStatusFutures(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketGetOrderStatusFutures(context.Background(), currency.EMPTYPAIR, "") + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + pair, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketGetOrderStatusFutures(context.Background(), pair, "") + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketGetOrderStatusFutures(context.Background(), pair, "513170215869") + require.NoError(t, err) + require.NotEmpty(t, got) +} diff --git a/exchanges/gateio/gateio_websocket_request_spot.go b/exchanges/gateio/gateio_websocket_request_spot.go new file mode 100644 index 00000000000..11d3fe802ba --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_spot.go @@ -0,0 +1,290 @@ +package gateio + +import ( + "context" + "crypto/hmac" + "crypto/sha512" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "time" + + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/request" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" +) + +var ( + errBatchSliceEmpty = errors.New("batch cannot be empty") + errNoOrdersToCancel = errors.New("no orders to cancel") + errEdgeCaseIssue = errors.New("edge case issue") + errChannelEmpty = errors.New("channel cannot be empty") +) + +// authenticateSpot sends an authentication message to the websocket connection +func (g *Gateio) authenticateSpot(ctx context.Context, conn stream.Connection) error { + _, err := g.websocketLogin(ctx, conn, "spot.login") + return err +} + +// websocketLogin authenticates the websocket connection +func (g *Gateio) websocketLogin(ctx context.Context, conn stream.Connection, channel string) (*WebsocketLoginResponse, error) { + if conn == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, conn) + } + + if channel == "" { + return nil, errChannelEmpty + } + + creds, err := g.GetCredentials(ctx) + if err != nil { + return nil, err + } + + tn := time.Now().Unix() + msg := "api\n" + channel + "\n" + "\n" + strconv.FormatInt(tn, 10) + mac := hmac.New(sha512.New, []byte(creds.Secret)) + if _, err = mac.Write([]byte(msg)); err != nil { + return nil, err + } + signature := hex.EncodeToString(mac.Sum(nil)) + + payload := WebsocketPayload{ + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + APIKey: creds.Key, + Signature: signature, + Timestamp: strconv.FormatInt(tn, 10), + } + + req := WebsocketRequest{Time: tn, Channel: channel, Event: "api", Payload: payload} + + resp, err := conn.SendMessageReturnResponse(ctx, request.Unset, req.Payload.RequestID, req) + if err != nil { + return nil, err + } + + var inbound WebsocketAPIResponse + if err := json.Unmarshal(resp, &inbound); err != nil { + return nil, err + } + + if inbound.Header.Status != "200" { + var wsErr WebsocketErrors + if err := json.Unmarshal(inbound.Data, &wsErr.Errors); err != nil { + return nil, err + } + return nil, fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) + } + + var result WebsocketLoginResponse + return &result, json.Unmarshal(inbound.Data, &result) +} + +// WebsocketOrderPlaceSpot places an order via the websocket connection. You can +// send multiple orders in a single request. But only for one asset route. +// So this can only batch spot orders or futures orders, not both. +func (g *Gateio) WebsocketOrderPlaceSpot(ctx context.Context, batch []CreateOrderRequestData) ([]WebsocketOrderResponse, error) { + if len(batch) == 0 { + return nil, errBatchSliceEmpty + } + + for i := range batch { + if batch[i].Text == "" { + // For some reason the API requires a text field, or it will be + // rejected in the second response. This is a workaround. + batch[i].Text = "t-" + strconv.FormatInt(g.Counter.IncrementAndGet(), 10) + } + if batch[i].CurrencyPair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + if batch[i].Side == "" { + return nil, order.ErrSideIsInvalid + } + if batch[i].Amount == 0 { + return nil, errInvalidAmount + } + if batch[i].Type == "limit" && batch[i].Price == 0 { + return nil, errInvalidPrice + } + } + + if len(batch) == 1 { + var singleResponse WebsocketOrderResponse + err := g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch[0], &singleResponse, 2) + return []WebsocketOrderResponse{singleResponse}, err + } + + var resp []WebsocketOrderResponse + err := g.SendWebsocketRequest(ctx, "spot.order_place", asset.Spot, batch, &resp, 2) + return resp, err +} + +// WebsocketOrderCancelSpot cancels an order via the websocket connection +func (g *Gateio) WebsocketOrderCancelSpot(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + if pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} + + var resp WebsocketOrderResponse + err := g.SendWebsocketRequest(ctx, "spot.order_cancel", asset.Spot, params, &resp, 1) + return &resp, err +} + +// WebsocketOrderCancelAllByIDsSpot cancels multiple orders via the websocket +func (g *Gateio) WebsocketOrderCancelAllByIDsSpot(ctx context.Context, o []WebsocketOrderBatchRequest) ([]WebsocketCancellAllResponse, error) { + if len(o) == 0 { + return nil, errNoOrdersToCancel + } + + for i := range o { + if o[i].OrderID == "" { + return nil, order.ErrOrderIDNotSet + } + if o[i].Pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + } + + var resp []WebsocketCancellAllResponse + err := g.SendWebsocketRequest(ctx, "spot.order_cancel_ids", asset.Spot, o, &resp, 2) + return resp, err +} + +// WebsocketOrderCancelAllByPairSpot cancels all orders for a specific pair +func (g *Gateio) WebsocketOrderCancelAllByPairSpot(ctx context.Context, pair currency.Pair, side order.Side, account string) ([]WebsocketOrderResponse, error) { + if !pair.IsEmpty() && side == order.UnknownSide { + return nil, fmt.Errorf("%w: side cannot be unknown when pair is set as this will purge *ALL* open orders", errEdgeCaseIssue) + } + + sideStr := "" + if side != order.UnknownSide { + sideStr = side.Lower() + } + + params := &WebsocketCancelParam{ + Pair: pair, + Side: sideStr, + Account: account, + } + + var resp []WebsocketOrderResponse + return resp, g.SendWebsocketRequest(ctx, "spot.order_cancel_cp", asset.Spot, params, &resp, 1) +} + +// WebsocketOrderAmendSpot amends an order via the websocket connection +func (g *Gateio) WebsocketOrderAmendSpot(ctx context.Context, amend *WebsocketAmendOrder) (*WebsocketOrderResponse, error) { + if amend == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, amend) + } + + if amend.OrderID == "" { + return nil, order.ErrOrderIDNotSet + } + + if amend.Pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + if amend.Amount == "" && amend.Price == "" { + return nil, fmt.Errorf("%w: amount or price must be set", errInvalidAmount) + } + + var resp WebsocketOrderResponse + return &resp, g.SendWebsocketRequest(ctx, "spot.order_amend", asset.Spot, amend, &resp, 1) +} + +// WebsocketGetOrderStatusSpot gets the status of an order via the websocket connection +func (g *Gateio) WebsocketGetOrderStatusSpot(ctx context.Context, orderID string, pair currency.Pair, account string) (*WebsocketOrderResponse, error) { + if orderID == "" { + return nil, order.ErrOrderIDNotSet + } + if pair.IsEmpty() { + return nil, currency.ErrCurrencyPairEmpty + } + + params := &WebsocketOrderRequest{OrderID: orderID, Pair: pair.String(), Account: account} + + var resp WebsocketOrderResponse + return &resp, g.SendWebsocketRequest(ctx, "spot.order_status", asset.Spot, params, &resp, 1) +} + +// SendWebsocketRequest sends a websocket request to the exchange +func (g *Gateio) SendWebsocketRequest(ctx context.Context, channel string, connSignature, params, result any, expectedResponses int) error { + paramPayload, err := json.Marshal(params) + if err != nil { + return err + } + + conn, err := g.Websocket.GetConnection(connSignature) + if err != nil { + return err + } + + tn := time.Now().Unix() + req := &WebsocketRequest{ + Time: tn, + Channel: channel, + Event: "api", + Payload: WebsocketPayload{ + // This request ID associated with the payload is the match to the + // response. + RequestID: strconv.FormatInt(conn.GenerateMessageID(false), 10), + RequestParam: paramPayload, + Timestamp: strconv.FormatInt(tn, 10), + }, + } + + responses, err := conn.SendMessageReturnResponses(ctx, request.Unset, req.Payload.RequestID, req, expectedResponses, InspectPayloadForAck) + if err != nil { + return err + } + + if len(responses) == 0 { + return errors.New("no responses received") + } + + var inbound WebsocketAPIResponse + // The last response is the one we want to unmarshal, the other is just + // an ack. If the request fails on the ACK then we can unmarshal the error + // from that as the next response won't come anyway. + endResponse := responses[len(responses)-1] + + if err := json.Unmarshal(endResponse, &inbound); err != nil { + return err + } + + if inbound.Header.Status != "200" { + var wsErr WebsocketErrors + if err := json.Unmarshal(inbound.Data, &wsErr); err != nil { + return err + } + return fmt.Errorf("%s: %s", wsErr.Errors.Label, wsErr.Errors.Message) + } + + to := struct { + Result any `json:"result"` + }{ + Result: result, + } + + return json.Unmarshal(inbound.Data, &to) +} + +// InspectPayloadForAck checks the payload for an ack, it returns true if the +// payload does not contain an ack. This will force the cancellation of further +// waiting for responses. +func InspectPayloadForAck(data []byte) bool { + return !strings.Contains(string(data), "ack") +} diff --git a/exchanges/gateio/gateio_websocket_request_spot_test.go b/exchanges/gateio/gateio_websocket_request_spot_test.go new file mode 100644 index 00000000000..0021baa3a15 --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_spot_test.go @@ -0,0 +1,239 @@ +package gateio + +import ( + "context" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/thrasher-corp/gocryptotrader/common" + "github.com/thrasher-corp/gocryptotrader/config" + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/exchanges/asset" + "github.com/thrasher-corp/gocryptotrader/exchanges/order" + "github.com/thrasher-corp/gocryptotrader/exchanges/sharedtestvalues" + "github.com/thrasher-corp/gocryptotrader/exchanges/stream" + testexch "github.com/thrasher-corp/gocryptotrader/internal/testing/exchange" +) + +func TestWebsocketLogin(t *testing.T) { + t.Parallel() + _, err := g.websocketLogin(context.Background(), nil, "") + require.ErrorIs(t, err, common.ErrNilPointer) + + _, err = g.websocketLogin(context.Background(), &stream.WebsocketConnection{}, "") + require.ErrorIs(t, err, errChannelEmpty) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + demonstrationConn, err := g.Websocket.GetConnection(asset.Spot) + require.NoError(t, err) + + got, err := g.websocketLogin(context.Background(), demonstrationConn, "spot.login") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderPlaceSpot(t *testing.T) { + t.Parallel() + _, err := g.WebsocketOrderPlaceSpot(context.Background(), nil) + require.ErrorIs(t, err, errBatchSliceEmpty) + _, err = g.WebsocketOrderPlaceSpot(context.Background(), make([]CreateOrderRequestData, 1)) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + pair, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + out := CreateOrderRequestData{CurrencyPair: pair} + _, err = g.WebsocketOrderPlaceSpot(context.Background(), []CreateOrderRequestData{out}) + require.ErrorIs(t, err, order.ErrSideIsInvalid) + out.Side = strings.ToLower(order.Buy.String()) + _, err = g.WebsocketOrderPlaceSpot(context.Background(), []CreateOrderRequestData{out}) + require.ErrorIs(t, err, errInvalidAmount) + out.Amount = 0.0003 + out.Type = "limit" + _, err = g.WebsocketOrderPlaceSpot(context.Background(), []CreateOrderRequestData{out}) + require.ErrorIs(t, err, errInvalidPrice) + out.Price = 20000 + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + // test single order + got, err := g.WebsocketOrderPlaceSpot(context.Background(), []CreateOrderRequestData{out}) + require.NoError(t, err) + require.NotEmpty(t, got) + + // test batch orders + got, err = g.WebsocketOrderPlaceSpot(context.Background(), []CreateOrderRequestData{out, out}) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderCancelSpot(t *testing.T) { + t.Parallel() + _, err := g.WebsocketOrderCancelSpot(context.Background(), "", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + _, err = g.WebsocketOrderCancelSpot(context.Background(), "1337", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + btcusdt, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketOrderCancelSpot(context.Background(), "644913098758", btcusdt, "") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderCancelAllByIDsSpot(t *testing.T) { + t.Parallel() + out := WebsocketOrderBatchRequest{} + _, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + out.OrderID = "1337" + _, err = g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + out.Pair, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + out.OrderID = "644913101755" + got, err := g.WebsocketOrderCancelAllByIDsSpot(context.Background(), []WebsocketOrderBatchRequest{out}) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderCancelAllByPairSpot(t *testing.T) { + t.Parallel() + pair, err := currency.NewPairFromString("LTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderCancelAllByPairSpot(context.Background(), pair, 0, "") + require.ErrorIs(t, err, errEdgeCaseIssue) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + got, err := g.WebsocketOrderCancelAllByPairSpot(context.Background(), currency.EMPTYPAIR, order.Buy, "") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketOrderAmendSpot(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketOrderAmendSpot(context.Background(), nil) + require.ErrorIs(t, err, common.ErrNilPointer) + + amend := &WebsocketAmendOrder{} + _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + amend.OrderID = "1337" + _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + amend.Pair, err = currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + _, err = g.WebsocketOrderAmendSpot(context.Background(), amend) + require.ErrorIs(t, err, errInvalidAmount) + + amend.Amount = "0.0004" + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + amend.OrderID = "645029162673" + got, err := g.WebsocketOrderAmendSpot(context.Background(), amend) + require.NoError(t, err) + require.NotEmpty(t, got) +} + +func TestWebsocketGetOrderStatusSpot(t *testing.T) { + t.Parallel() + + _, err := g.WebsocketGetOrderStatusSpot(context.Background(), "", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, order.ErrOrderIDNotSet) + + _, err = g.WebsocketGetOrderStatusSpot(context.Background(), "1337", currency.EMPTYPAIR, "") + require.ErrorIs(t, err, currency.ErrCurrencyPairEmpty) + + sharedtestvalues.SkipTestIfCredentialsUnset(t, g, canManipulateRealOrders) + + testexch.UpdatePairsOnce(t, g) + g := getWebsocketInstance(t, g) //nolint:govet // Intentional shadow to avoid future copy/paste mistakes + + pair, err := currency.NewPairFromString("BTC_USDT") + require.NoError(t, err) + + got, err := g.WebsocketGetOrderStatusSpot(context.Background(), "644999650452", pair, "") + require.NoError(t, err) + require.NotEmpty(t, got) +} + +// getWebsocketInstance returns a websocket instance copy for testing. +// This restricts the pairs to a single pair per asset type to reduce test time. +func getWebsocketInstance(t *testing.T, g *Gateio) *Gateio { + t.Helper() + + cpy := new(Gateio) + cpy.SetDefaults() + gConf, err := config.GetConfig().GetExchangeConfig("GateIO") + require.NoError(t, err) + gConf.API.AuthenticatedSupport = true + gConf.API.AuthenticatedWebsocketSupport = true + gConf.API.Credentials.Key = apiKey + gConf.API.Credentials.Secret = apiSecret + + require.NoError(t, cpy.Setup(gConf), "Test instance Setup must not error") + cpy.CurrencyPairs.Load(&g.CurrencyPairs) + +assetLoader: + for _, a := range cpy.GetAssetTypes(true) { + var avail currency.Pairs + switch a { + case asset.Spot: + avail, err = cpy.GetAvailablePairs(a) + require.NoError(t, err) + if len(avail) > 1 { // reduce pairs to 1 to speed up tests + avail = avail[:1] + } + case asset.Futures: + avail, err = cpy.GetAvailablePairs(a) + require.NoError(t, err) + usdtPairs, err := avail.GetPairsByQuote(currency.USDT) // Get USDT margin pairs + require.NoError(t, err) + btcPairs, err := avail.GetPairsByQuote(currency.USD) // Get BTC margin pairs + require.NoError(t, err) + // below makes sure there is both a USDT and BTC pair available + // so that allows two connections to be made. + avail[0] = usdtPairs[0] + avail[1] = btcPairs[0] + avail = avail[:2] + default: + require.NoError(t, cpy.CurrencyPairs.SetAssetEnabled(a, false)) + continue assetLoader + } + require.NoError(t, cpy.SetPairs(avail, a, true)) + } + require.NoError(t, cpy.Websocket.Connect()) + return cpy +} diff --git a/exchanges/gateio/gateio_websocket_request_types.go b/exchanges/gateio/gateio_websocket_request_types.go new file mode 100644 index 00000000000..da71950fc83 --- /dev/null +++ b/exchanges/gateio/gateio_websocket_request_types.go @@ -0,0 +1,176 @@ +package gateio + +import ( + "encoding/json" + + "github.com/thrasher-corp/gocryptotrader/currency" + "github.com/thrasher-corp/gocryptotrader/types" +) + +// WebsocketAPIResponse defines a general websocket response for api calls +type WebsocketAPIResponse struct { + Header Header `json:"header"` + Data json.RawMessage `json:"data"` +} + +// Header defines a websocket header +type Header struct { + ResponseTime types.Time `json:"response_time"` + Status string `json:"status"` + Channel string `json:"channel"` + Event string `json:"event"` + ClientID string `json:"client_id"` + ConnectionID string `json:"conn_id"` + TraceID string `json:"trace_id"` +} + +// WebsocketRequest defines a websocket request +type WebsocketRequest struct { + Time int64 `json:"time,omitempty"` + ID int64 `json:"id,omitempty"` + Channel string `json:"channel"` + Event string `json:"event"` + Payload WebsocketPayload `json:"payload"` +} + +// WebsocketPayload defines an individualised websocket payload +type WebsocketPayload struct { + RequestID string `json:"req_id,omitempty"` + // APIKey and signature are only required in the initial login request + // which is done when the connection is established. + APIKey string `json:"api_key,omitempty"` + Timestamp string `json:"timestamp,omitempty"` + Signature string `json:"signature,omitempty"` + RequestParam json.RawMessage `json:"req_param,omitempty"` +} + +// WebsocketErrors defines a websocket error +type WebsocketErrors struct { + Errors struct { + Label string `json:"label"` + Message string `json:"message"` + } `json:"errs"` +} + +// WebsocketLoginResponse defines a websocket login response when authenticating +// the connection. +type WebsocketLoginResponse struct { + Result struct { + APIKey string `json:"api_key"` + UID string `json:"uid"` + } `json:"result"` +} + +// WebsocketOrderResponse defines a websocket order response +type WebsocketOrderResponse struct { + Left types.Number `json:"left"` + UpdateTime types.Time `json:"update_time"` + Amount types.Number `json:"amount"` + CreateTime types.Time `json:"create_time"` + Price types.Number `json:"price"` + FinishAs string `json:"finish_as"` + TimeInForce string `json:"time_in_force"` + CurrencyPair currency.Pair `json:"currency_pair"` + Type string `json:"type"` + Account string `json:"account"` + Side string `json:"side"` + AmendText string `json:"amend_text"` + Text string `json:"text"` + Status string `json:"status"` + Iceberg types.Number `json:"iceberg"` + FilledTotal types.Number `json:"filled_total"` + ID string `json:"id"` + FillPrice types.Number `json:"fill_price"` + UpdateTimeMs types.Time `json:"update_time_ms"` + CreateTimeMs types.Time `json:"create_time_ms"` + Fee types.Number `json:"fee"` + FeeCurrency currency.Code `json:"fee_currency"` + PointFee types.Number `json:"point_fee"` + GTFee types.Number `json:"gt_fee"` + GTMakerFee types.Number `json:"gt_maker_fee"` + GTTakerFee types.Number `json:"gt_taker_fee"` + GTDiscount bool `json:"gt_discount"` + RebatedFee types.Number `json:"rebated_fee"` + RebatedFeeCurrency currency.Code `json:"rebated_fee_currency"` + STPID int `json:"stp_id"` + STPAct string `json:"stp_act"` +} + +// WebsocketFuturesOrderResponse defines a websocket futures order response +type WebsocketFuturesOrderResponse struct { + Text string `json:"text"` + Price types.Number `json:"price"` + BizInfo string `json:"biz_info"` + TimeInForce string `json:"tif"` + AmendText string `json:"amend_text"` + Status string `json:"status"` + Contract currency.Pair `json:"contract"` + STPAct string `json:"stp_act"` + FinishAs string `json:"finish_as"` + FillPrice types.Number `json:"fill_price"` + ID int64 `json:"id"` + CreateTime types.Time `json:"create_time"` + UpdateTime types.Time `json:"update_time"` + FinishTime types.Time `json:"finish_time"` + Size int64 `json:"size"` + Left int64 `json:"left"` + User int64 `json:"user"` + Succeeded *bool `json:"succeeded"` // Nil if not present in returned response. +} + +// WebsocketOrderBatchRequest defines a websocket order batch request +type WebsocketOrderBatchRequest struct { + OrderID string `json:"id"` // This requires id tag not order_id + Pair currency.Pair `json:"currency_pair"` + Account string `json:"account,omitempty"` +} + +// WebsocketOrderRequest defines a websocket order request +type WebsocketOrderRequest struct { + OrderID string `json:"order_id"` // This requires order_id tag + Pair string `json:"pair"` + Account string `json:"account,omitempty"` +} + +// WebsocketCancellAllResponse defines a websocket order cancel response +type WebsocketCancellAllResponse struct { + Pair currency.Pair `json:"currency_pair"` + Label string `json:"label"` + Message string `json:"message"` + Succeeded bool `json:"succeeded"` +} + +// WebsocketCancelParam is a struct to hold the parameters for cancelling orders +type WebsocketCancelParam struct { + Pair currency.Pair `json:"pair"` + Side string `json:"side"` + Account string `json:"account,omitempty"` +} + +// WebsocketAmendOrder defines a websocket amend order +type WebsocketAmendOrder struct { + OrderID string `json:"order_id"` + Pair currency.Pair `json:"currency_pair"` + Account string `json:"account,omitempty"` + AmendText string `json:"amend_text,omitempty"` + Price string `json:"price,omitempty"` + Amount string `json:"amount,omitempty"` +} + +// WebsocketFuturesAmendOrder defines a websocket amend order +type WebsocketFuturesAmendOrder struct { + OrderID string `json:"order_id"` + Contract currency.Pair `json:"-"` // This is not required in the payload, it is used to determine the asset type. + AmendText string `json:"amend_text,omitempty"` + Price string `json:"price,omitempty"` + Size int64 `json:"size,omitempty"` +} + +// WebsocketFutureOrdersList defines a websocket future orders list +type WebsocketFutureOrdersList struct { + Contract currency.Pair `json:"contract,omitempty"` + Status string `json:"status"` + Limit int64 `json:"limit,omitempty"` + Offset int64 `json:"offset,omitempty"` + LastID string `json:"last_id,omitempty"` +} diff --git a/exchanges/gateio/gateio_wrapper.go b/exchanges/gateio/gateio_wrapper.go index 02d2b760fa6..6d7f0c1b385 100644 --- a/exchanges/gateio/gateio_wrapper.go +++ b/exchanges/gateio/gateio_wrapper.go @@ -40,6 +40,8 @@ import ( // this error. const unfundedFuturesAccount = `please transfer funds first to create futures account` +var errNoResponseReceived = errors.New("no response received") + // SetDefaults sets default values for the exchange func (g *Gateio) SetDefaults() { g.Name = "GateIO" @@ -209,16 +211,18 @@ func (g *Gateio) Setup(exch *config.Exchange) error { } // Spot connection err = g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: gateioWebsocketEndpoint, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleSpotData, - Subscriber: g.Subscribe, - Unsubscriber: g.Unsubscribe, - GenerateSubscriptions: g.generateSubscriptionsSpot, - Connector: g.WsConnectSpot, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: gateioWebsocketEndpoint, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleSpotData, + Subscriber: g.Subscribe, + Unsubscriber: g.Unsubscribe, + GenerateSubscriptions: g.generateSubscriptionsSpot, + Connector: g.WsConnectSpot, + Authenticate: g.authenticateSpot, + WrapperDefinedConnectionSignature: asset.Spot, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -232,11 +236,13 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, - Connector: g.WsFuturesConnect, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.USDT) }, + Connector: g.WsFuturesConnect, + Authenticate: g.authenticateFutures, + WrapperDefinedConnectionSignature: asset.USDTMarginedFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -251,11 +257,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.Futures) }, - Subscriber: g.FuturesSubscribe, - Unsubscriber: g.FuturesUnsubscribe, - GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, - Connector: g.WsFuturesConnect, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.FuturesSubscribe, + Unsubscriber: g.FuturesUnsubscribe, + GenerateSubscriptions: func() (subscription.List, error) { return g.GenerateFuturesDefaultSubscriptions(currency.BTC) }, + Connector: g.WsFuturesConnect, + WrapperDefinedConnectionSignature: asset.CoinMarginedFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -271,11 +278,12 @@ func (g *Gateio) Setup(exch *config.Exchange) error { Handler: func(ctx context.Context, incoming []byte) error { return g.WsHandleFuturesData(ctx, incoming, asset.DeliveryFutures) }, - Subscriber: g.DeliveryFuturesSubscribe, - Unsubscriber: g.DeliveryFuturesUnsubscribe, - GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, - Connector: g.WsDeliveryFuturesConnect, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + Subscriber: g.DeliveryFuturesSubscribe, + Unsubscriber: g.DeliveryFuturesUnsubscribe, + GenerateSubscriptions: g.GenerateDeliveryFuturesDefaultSubscriptions, + Connector: g.WsDeliveryFuturesConnect, + WrapperDefinedConnectionSignature: asset.DeliveryFutures, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) if err != nil { return err @@ -283,16 +291,17 @@ func (g *Gateio) Setup(exch *config.Exchange) error { // Futures connection - Options return g.Websocket.SetupNewConnection(&stream.ConnectionSetup{ - URL: optionsWebsocketURL, - RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), - ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, - ResponseMaxLimit: exch.WebsocketResponseMaxLimit, - Handler: g.WsHandleOptionsData, - Subscriber: g.OptionsSubscribe, - Unsubscriber: g.OptionsUnsubscribe, - GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, - Connector: g.WsOptionsConnect, - BespokeGenerateMessageID: g.GenerateWebsocketMessageID, + URL: optionsWebsocketURL, + RateLimit: request.NewWeightedRateLimitByDuration(gateioWebsocketRateLimit), + ResponseCheckTimeout: exch.WebsocketResponseCheckTimeout, + ResponseMaxLimit: exch.WebsocketResponseMaxLimit, + Handler: g.WsHandleOptionsData, + Subscriber: g.OptionsSubscribe, + Unsubscriber: g.OptionsUnsubscribe, + GenerateSubscriptions: g.GenerateOptionsDefaultSubscriptions, + Connector: g.WsOptionsConnect, + WrapperDefinedConnectionSignature: asset.Options, + BespokeGenerateMessageID: g.GenerateWebsocketMessageID, }) } @@ -1057,34 +1066,14 @@ func (g *Gateio) SubmitOrder(ctx context.Context, s *order.Submit) (*order.Submi return nil, err } s.Pair = s.Pair.Upper() + switch s.AssetType { case asset.Spot, asset.Margin, asset.CrossMargin: - switch { - case s.Side.IsLong(): - s.Side = order.Buy - case s.Side.IsShort(): - s.Side = order.Sell - default: - return nil, errInvalidOrderSide - } - timeInForce, err := getTimeInForce(s) + req, err := g.getSpotOrderRequest(s) if err != nil { return nil, err } - - sOrder, err := g.PlaceSpotOrder(ctx, &CreateOrderRequestData{ - Side: s.Side.Lower(), - Type: s.Type.Lower(), - Account: g.assetTypeToString(s.AssetType), - // When doing spot market orders when purchasing base currency, the - // quote currency amount is used. When selling the base currency the - // base currency amount is used. - Amount: types.Number(s.GetTradeAmount(g.GetTradingRequirements())), - Price: types.Number(s.Price), - CurrencyPair: s.Pair, - Text: s.ClientOrderID, - TimeInForce: timeInForce, - }) + sOrder, err := g.PlaceSpotOrder(ctx, req) if err != nil { return nil, err } @@ -1145,7 +1134,7 @@ func (g *Gateio) SubmitOrder(ctx context.Context, s *order.Submit) (*order.Submi return nil, err } var status = order.Open - if fOrder.Status != "open" { + if fOrder.Status != statusOpen { status, err = order.StringToOrderStatus(fOrder.FinishAs) if err != nil { return nil, err @@ -1192,7 +1181,7 @@ func (g *Gateio) SubmitOrder(ctx context.Context, s *order.Submit) (*order.Submi return nil, err } var status = order.Open - if newOrder.Status != "open" { + if newOrder.Status != statusOpen { status, err = order.StringToOrderStatus(newOrder.FinishAs) if err != nil { return nil, err @@ -1510,7 +1499,7 @@ func (g *Gateio) GetOrderInfo(ctx context.Context, orderID string, pair currency return nil, err } orderStatus := order.Open - if fOrder.Status != "open" { + if fOrder.Status != statusOpen { orderStatus, err = order.StringToOrderStatus(fOrder.FinishAs) if err != nil { return nil, err @@ -1669,7 +1658,7 @@ func (g *Gateio) GetActiveOrders(ctx context.Context, req *order.MultiOrderReque return nil, err } for y := range spotOrders[x].Orders { - if spotOrders[x].Orders[y].Status != "open" { + if spotOrders[x].Orders[y].Status != statusOpen { continue } var side order.Side @@ -1727,9 +1716,9 @@ func (g *Gateio) GetActiveOrders(ctx context.Context, req *order.MultiOrderReque for settlement := range settlements { var futuresOrders []Order if req.AssetType == asset.Futures { - futuresOrders, err = g.GetFuturesOrders(ctx, currency.EMPTYPAIR, "open", "", settlement, 0, 0, 0) + futuresOrders, err = g.GetFuturesOrders(ctx, currency.EMPTYPAIR, statusOpen, "", settlement, 0, 0, 0) } else { - futuresOrders, err = g.GetDeliveryOrders(ctx, currency.EMPTYPAIR, "open", settlement, "", 0, 0, 0) + futuresOrders, err = g.GetDeliveryOrders(ctx, currency.EMPTYPAIR, statusOpen, settlement, "", 0, 0, 0) } if err != nil { if strings.Contains(err.Error(), unfundedFuturesAccount) { @@ -1745,7 +1734,7 @@ func (g *Gateio) GetActiveOrders(ctx context.Context, req *order.MultiOrderReque return nil, err } - if futuresOrders[x].Status != "open" || (len(req.Pairs) > 0 && !req.Pairs.Contains(pair, true)) { + if futuresOrders[x].Status != statusOpen || (len(req.Pairs) > 0 && !req.Pairs.Contains(pair, true)) { continue } @@ -1775,7 +1764,7 @@ func (g *Gateio) GetActiveOrders(ctx context.Context, req *order.MultiOrderReque } case asset.Options: var optionsOrders []OptionOrderResponse - optionsOrders, err = g.GetOptionFuturesOrders(ctx, currency.EMPTYPAIR, "", "open", 0, 0, req.StartTime, req.EndTime) + optionsOrders, err = g.GetOptionFuturesOrders(ctx, currency.EMPTYPAIR, "", statusOpen, 0, 0, req.StartTime, req.EndTime) if err != nil { return nil, err } @@ -2616,3 +2605,139 @@ func (g *Gateio) GetCurrencyTradeURL(_ context.Context, a asset.Item, cp currenc return "", fmt.Errorf("%w %v", asset.ErrNotSupported, a) } } + +// WebsocketSubmitOrder submits an order to the exchange through the websocket +// connection. +func (g *Gateio) WebsocketSubmitOrder(ctx context.Context, s *order.Submit) (*order.SubmitResponse, error) { + err := s.Validate(g.GetTradingRequirements()) + if err != nil { + return nil, err + } + + s.Pair, err = g.FormatExchangeCurrency(s.Pair, s.AssetType) + if err != nil { + return nil, err + } + s.Pair = s.Pair.Upper() + + switch s.AssetType { + case asset.Spot: + var req *CreateOrderRequestData + req, err = g.getSpotOrderRequest(s) + if err != nil { + return nil, err + } + + var got []WebsocketOrderResponse + got, err = g.WebsocketOrderPlaceSpot(ctx, []CreateOrderRequestData{*req}) + if err != nil { + return nil, err + } + + if len(got) == 0 { + return nil, errNoResponseReceived + } + + var resp *order.SubmitResponse + resp, err = s.DeriveSubmitResponse(got[0].ID) + if err != nil { + return nil, err + } + resp.Side, err = order.StringToOrderSide(got[0].Side) + if err != nil { + return nil, err + } + resp.Status, err = order.StringToOrderStatus(got[0].Status) + if err != nil { + return nil, err + } + resp.Pair = s.Pair + resp.Date = got[0].CreateTime.Time() + resp.ClientOrderID = got[0].Text + resp.Date = got[0].CreateTimeMs.Time() + resp.LastUpdated = got[0].UpdateTimeMs.Time() + return resp, nil + case asset.Futures: + var amountWithDirection float64 + amountWithDirection, err = getFutureOrderSize(s) + if err != nil { + return nil, err + } + + var timeInForce string + timeInForce, err = getTimeInForce(s) + if err != nil { + return nil, err + } + + out := OrderCreateParams{ + Contract: s.Pair, + Size: amountWithDirection, + Price: strconv.FormatFloat(s.Price, 'f', -1, 64), + ReduceOnly: s.ReduceOnly, + TimeInForce: timeInForce, + Text: s.ClientOrderID, + } + + var got []WebsocketFuturesOrderResponse + got, err = g.WebsocketOrderPlaceFutures(ctx, []OrderCreateParams{out}) + if err != nil { + return nil, err + } + + if len(got) == 0 { + return nil, errNoResponseReceived + } + + resp, err := s.DeriveSubmitResponse(strconv.FormatInt(got[0].ID, 10)) + if err != nil { + return nil, err + } + resp.Status = order.Open + if got[0].Status != statusOpen { + resp.Status, err = order.StringToOrderStatus(got[0].FinishAs) + if err != nil { + return nil, err + } + } + resp.Pair = s.Pair + resp.Date = got[0].CreateTime.Time() + resp.ClientOrderID = getClientOrderIDFromText(got[0].Text) + resp.ReduceOnly = s.ReduceOnly + resp.Amount = math.Abs(float64(got[0].Size)) + resp.Price = got[0].FillPrice.Float64() + resp.AverageExecutedPrice = got[0].FillPrice.Float64() + return resp, nil + default: + return nil, common.ErrNotYetImplemented + } +} + +func (g *Gateio) getSpotOrderRequest(s *order.Submit) (*CreateOrderRequestData, error) { + switch { + case s.Side.IsLong(): + s.Side = order.Buy + case s.Side.IsShort(): + s.Side = order.Sell + default: + return nil, errInvalidOrderSide + } + + timeInForce, err := getTimeInForce(s) + if err != nil { + return nil, err + } + + return &CreateOrderRequestData{ + Side: s.Side.Lower(), + Type: s.Type.Lower(), + Account: g.assetTypeToString(s.AssetType), + // When doing spot market orders when purchasing base currency, the quote currency amount is used. When selling + // the base currency the base currency amount is used. + Amount: types.Number(s.GetTradeAmount(g.GetTradingRequirements())), + Price: types.Number(s.Price), + CurrencyPair: s.Pair, + Text: s.ClientOrderID, + TimeInForce: timeInForce, + }, nil +} diff --git a/exchanges/gateio/gateio_ws_futures.go b/exchanges/gateio/gateio_ws_futures.go index 6e54cbf3c01..e20e468a9ec 100644 --- a/exchanges/gateio/gateio_ws_futures.go +++ b/exchanges/gateio/gateio_ws_futures.go @@ -165,11 +165,17 @@ func (g *Gateio) FuturesUnsubscribe(ctx context.Context, conn stream.Connection, // WsHandleFuturesData handles futures websocket data func (g *Gateio) WsHandleFuturesData(_ context.Context, respRaw []byte, a asset.Item) error { var push WsResponse - err := json.Unmarshal(respRaw, &push) - if err != nil { + if err := json.Unmarshal(respRaw, &push); err != nil { return err } + if push.RequestID != "" { + if !g.Websocket.Match.IncomingWithData(push.RequestID, respRaw) { + return fmt.Errorf("gateio_websocket.go error - unable to match requestID %v", push.RequestID) + } + return nil + } + if push.Event == subscribeEvent || push.Event == unsubscribeEvent { if !g.Websocket.Match.IncomingWithData(push.ID, respRaw) { return fmt.Errorf("couldn't match subscription message with ID: %d", push.ID) @@ -191,8 +197,7 @@ func (g *Gateio) WsHandleFuturesData(_ context.Context, respRaw []byte, a asset. case futuresCandlesticksChannel: return g.processFuturesCandlesticks(respRaw, a) case futuresOrdersChannel: - var processed []order.Detail - processed, err = g.processFuturesOrdersPushData(respRaw, a) + processed, err := g.processFuturesOrdersPushData(respRaw, a) if err != nil { return err } diff --git a/exchanges/interfaces.go b/exchanges/interfaces.go index 0f6071a864d..8f30c563f70 100644 --- a/exchanges/interfaces.go +++ b/exchanges/interfaces.go @@ -133,6 +133,9 @@ type OrderManagement interface { GetOrderInfo(ctx context.Context, orderID string, pair currency.Pair, assetType asset.Item) (*order.Detail, error) GetActiveOrders(ctx context.Context, getOrdersRequest *order.MultiOrderRequest) (order.FilteredOrders, error) GetOrderHistory(ctx context.Context, getOrdersRequest *order.MultiOrderRequest) (order.FilteredOrders, error) + + // WebsocketSubmitOrder submits an order via the websocket connection + WebsocketSubmitOrder(ctx context.Context, s *order.Submit) (*order.SubmitResponse, error) } // CurrencyStateManagement defines functionality for currency state management diff --git a/exchanges/request/request.go b/exchanges/request/request.go index 8bc1690ce83..25a6cbba0a5 100644 --- a/exchanges/request/request.go +++ b/exchanges/request/request.go @@ -383,10 +383,8 @@ func WithVerbose(ctx context.Context) context.Context { // IsVerbose checks main verbosity first then checks context verbose values // for specific request verbosity. func IsVerbose(ctx context.Context, verbose bool) bool { - if verbose { - return true + if !verbose { + verbose, _ = ctx.Value(contextVerboseFlag).(bool) } - - isCtxVerbose, _ := ctx.Value(contextVerboseFlag).(bool) - return isCtxVerbose + return verbose } diff --git a/exchanges/request/request_test.go b/exchanges/request/request_test.go index 5793ef66324..ef2bbc583df 100644 --- a/exchanges/request/request_test.go +++ b/exchanges/request/request_test.go @@ -698,7 +698,7 @@ func TestGetHTTPClientUserAgent(t *testing.T) { } } -func TestContextVerbosity(t *testing.T) { +func TestIsVerbose(t *testing.T) { t.Parallel() require.False(t, IsVerbose(context.Background(), false)) require.True(t, IsVerbose(context.Background(), true)) diff --git a/exchanges/stream/stream_types.go b/exchanges/stream/stream_types.go index 2cbf0a2fe04..630982f195c 100644 --- a/exchanges/stream/stream_types.go +++ b/exchanges/stream/stream_types.go @@ -26,7 +26,7 @@ type Connection interface { // SendMessageReturnResponse will send a WS message to the connection and wait for response SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature any, request any) ([]byte, error) // SendMessageReturnResponses will send a WS message to the connection and wait for N responses - SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int) ([][]byte, error) + SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature any, request any, expected int, messageInspector ...Inspector) ([][]byte, error) // SendRawMessage sends a message over the connection without JSON encoding it SendRawMessage(ctx context.Context, epl request.EndpointLimit, messageType int, message []byte) error // SendJSONMessage sends a JSON encoded message over the connection @@ -37,6 +37,9 @@ type Connection interface { Shutdown() error } +// Inspector is a hook that allows for custom message inspection +type Inspector func([]byte) bool + // Response defines generalised data from the stream connection type Response struct { Type int @@ -76,6 +79,15 @@ type ConnectionSetup struct { // This is useful for when an exchange connection requires a unique or // structured message ID for each message sent. BespokeGenerateMessageID func(highPrecision bool) int64 + // Authenticate is a function that will be called to authenticate the + // connection to the exchange's websocket server. This function should + // handle the authentication process and return an error if the + // authentication fails. + Authenticate func(ctx context.Context, conn Connection) error + // WrapperDefinedConnectionSignature is any type that will match to a specific connection. This could be an asset + // type `asset.Spot`, a string type denoting the individual URL, an authenticated or unauthenticated string or a + // mixture of these. + WrapperDefinedConnectionSignature any } // ConnectionWrapper contains the connection setup details to be used when diff --git a/exchanges/stream/websocket.go b/exchanges/stream/websocket.go index 309db9a79d4..f90e4b56521 100644 --- a/exchanges/stream/websocket.go +++ b/exchanges/stream/websocket.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/url" + "reflect" "slices" "sync" "time" @@ -29,41 +30,47 @@ var ( ErrNotConnected = errors.New("websocket is not connected") ErrNoMessageListener = errors.New("websocket listener not found for message") ErrSignatureTimeout = errors.New("websocket timeout waiting for response with signature") + ErrRequestRouteNotFound = errors.New("request route not found") + ErrSignatureNotSet = errors.New("signature not set") + ErrRequestPayloadNotSet = errors.New("request payload not set") ) // Private websocket errors var ( - errExchangeConfigIsNil = errors.New("exchange config is nil") - errWebsocketIsNil = errors.New("websocket is nil") - errWebsocketSetupIsNil = errors.New("websocket setup is nil") - errWebsocketAlreadyInitialised = errors.New("websocket already initialised") - errWebsocketAlreadyEnabled = errors.New("websocket already enabled") - errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") - errConfigFeaturesIsNil = errors.New("exchange config features is nil") - errDefaultURLIsEmpty = errors.New("default url is empty") - errRunningURLIsEmpty = errors.New("running url cannot be empty") - errInvalidWebsocketURL = errors.New("invalid websocket url") - errExchangeConfigNameEmpty = errors.New("exchange config name empty") - errInvalidTrafficTimeout = errors.New("invalid traffic timeout") - errTrafficAlertNil = errors.New("traffic alert is nil") - errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") - errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") - errWebsocketConnectorUnset = errors.New("websocket connector function not set") - errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") - errReadMessageErrorsNil = errors.New("read message errors is nil") - errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") - errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") - errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") - errSameProxyAddress = errors.New("cannot set proxy address to the same address") - errNoConnectFunc = errors.New("websocket connect func not set") - errAlreadyConnected = errors.New("websocket already connected") - errCannotShutdown = errors.New("websocket cannot shutdown") - errAlreadyReconnecting = errors.New("websocket in the process of reconnection") - errConnSetup = errors.New("error in connection setup") - errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") - errConnectionWrapperDuplication = errors.New("connection wrapper duplication") - errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") - errExchangeConfigEmpty = errors.New("exchange config is empty") + errExchangeConfigIsNil = errors.New("exchange config is nil") + errWebsocketIsNil = errors.New("websocket is nil") + errWebsocketSetupIsNil = errors.New("websocket setup is nil") + errWebsocketAlreadyInitialised = errors.New("websocket already initialised") + errWebsocketAlreadyEnabled = errors.New("websocket already enabled") + errWebsocketFeaturesIsUnset = errors.New("websocket features is unset") + errConfigFeaturesIsNil = errors.New("exchange config features is nil") + errDefaultURLIsEmpty = errors.New("default url is empty") + errRunningURLIsEmpty = errors.New("running url cannot be empty") + errInvalidWebsocketURL = errors.New("invalid websocket url") + errExchangeConfigNameEmpty = errors.New("exchange config name empty") + errInvalidTrafficTimeout = errors.New("invalid traffic timeout") + errTrafficAlertNil = errors.New("traffic alert is nil") + errWebsocketSubscriberUnset = errors.New("websocket subscriber function needs to be set") + errWebsocketUnsubscriberUnset = errors.New("websocket unsubscriber functionality allowed but unsubscriber function not set") + errWebsocketConnectorUnset = errors.New("websocket connector function not set") + errWebsocketDataHandlerUnset = errors.New("websocket data handler not set") + errReadMessageErrorsNil = errors.New("read message errors is nil") + errWebsocketSubscriptionsGeneratorUnset = errors.New("websocket subscriptions generator function needs to be set") + errSubscriptionsExceedsLimit = errors.New("subscriptions exceeds limit") + errInvalidMaxSubscriptions = errors.New("max subscriptions cannot be less than 0") + errSameProxyAddress = errors.New("cannot set proxy address to the same address") + errNoConnectFunc = errors.New("websocket connect func not set") + errAlreadyConnected = errors.New("websocket already connected") + errCannotShutdown = errors.New("websocket cannot shutdown") + errAlreadyReconnecting = errors.New("websocket in the process of reconnection") + errConnSetup = errors.New("error in connection setup") + errNoPendingConnections = errors.New("no pending connections, call SetupNewConnection first") + errConnectionWrapperDuplication = errors.New("connection wrapper duplication") + errCannotChangeConnectionURL = errors.New("cannot change connection URL when using multi connection management") + errExchangeConfigEmpty = errors.New("exchange config is empty") + errCannotObtainOutboundConnection = errors.New("cannot obtain outbound connection") + errConnectionSignatureNotSet = errors.New("connection signature not set") + errWrapperDefinedConnectionSignatureNotComparable = errors.New("wrapper defined connection signature is not comparable") ) var globalReporter Reporter @@ -85,12 +92,12 @@ func NewWebsocket() *Websocket { // after subscriptions are made but before the connectionMonitor has // started. This allows the error to be read and handled in the // connectionMonitor and start a connection cycle again. - ReadMessageErrors: make(chan error, 1), - Match: NewMatch(), - subscriptions: subscription.NewStore(), - features: &protocol.Features{}, - Orderbook: buffer.Orderbook{}, - connections: make(map[Connection]*ConnectionWrapper), + ReadMessageErrors: make(chan error, 1), + Match: NewMatch(), + subscriptions: subscription.NewStore(), + features: &protocol.Features{}, + Orderbook: buffer.Orderbook{}, + connectionToWrapper: make(map[Connection]*ConnectionWrapper), } } @@ -259,13 +266,19 @@ func (w *Websocket) SetupNewConnection(c *ConnectionSetup) error { return fmt.Errorf("%w: %w", errConnSetup, errWebsocketDataHandlerUnset) } + if c.WrapperDefinedConnectionSignature != nil && !reflect.TypeOf(c.WrapperDefinedConnectionSignature).Comparable() { + return errWrapperDefinedConnectionSignatureNotComparable + } + for x := range w.connectionManager { - if w.connectionManager[x].Setup.URL == c.URL { + // Below allows for multiple connections to the same URL with different outbound request signatures. This + // allows for easier determination of inbound and outbound messages. e.g. Gateio cross_margin, margin on + // a spot connection. + if w.connectionManager[x].Setup.URL == c.URL && c.WrapperDefinedConnectionSignature == w.connectionManager[x].Setup.WrapperDefinedConnectionSignature { return fmt.Errorf("%w: %w", errConnSetup, errConnectionWrapperDuplication) } } - - w.connectionManager = append(w.connectionManager, ConnectionWrapper{ + w.connectionManager = append(w.connectionManager, &ConnectionWrapper{ Setup: c, Subscriptions: subscription.NewStore(), }) @@ -422,12 +435,21 @@ func (w *Websocket) connect() error { break } - w.connections[conn] = &w.connectionManager[i] + w.connectionToWrapper[conn] = w.connectionManager[i] w.connectionManager[i].Connection = conn w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[i].Setup.Handler) + if w.connectionManager[i].Setup.Authenticate != nil && w.CanUseAuthenticatedEndpoints() { + err = w.connectionManager[i].Setup.Authenticate(context.TODO(), conn) + if err != nil { + // Opted to not fail entirely here for POC. This should be + // revisited and handled more gracefully. + log.Errorf(log.WebsocketMgr, "%s websocket: [conn:%d] [URL:%s] failed to authenticate %v", w.exchangeName, i+1, conn.URL, err) + } + } + err = w.connectionManager[i].Setup.Subscriber(context.TODO(), conn, subs) if err != nil { multiConnectFatalError = fmt.Errorf("%v Error subscribing %w", w.exchangeName, err) @@ -454,7 +476,7 @@ func (w *Websocket) connect() error { } w.connectionManager[x].Subscriptions.Clear() } - clear(w.connections) + clear(w.connectionToWrapper) w.setState(disconnectedState) // Flip from connecting to disconnected. // Drain residual error in the single buffered channel, this mitigates @@ -542,7 +564,7 @@ func (w *Websocket) shutdown() error { } } // Clean map of old connections - clear(w.connections) + clear(w.connectionToWrapper) if w.Conn != nil { if err := w.Conn.Shutdown(); err != nil { @@ -633,7 +655,7 @@ func (w *Websocket) FlushChannels() error { } w.Wg.Add(1) go w.Reader(context.TODO(), conn, w.connectionManager[x].Setup.Handler) - w.connections[conn] = &w.connectionManager[x] + w.connectionToWrapper[conn] = w.connectionManager[x] w.connectionManager[x].Connection = conn } @@ -652,7 +674,7 @@ func (w *Websocket) FlushChannels() error { // If there are no subscriptions to subscribe to, close the connection as it is no longer needed. if w.connectionManager[x].Subscriptions.Len() == 0 { - delete(w.connections, w.connectionManager[x].Connection) // Remove from lookup map + delete(w.connectionToWrapper, w.connectionManager[x].Connection) // Remove from lookup map if err := w.connectionManager[x].Connection.Shutdown(); err != nil { log.Warnf(log.WebsocketMgr, "%v websocket: failed to shutdown connection: %v", w.exchangeName, err) } @@ -813,7 +835,7 @@ func (w *Websocket) GetName() string { // and the new subscription list when pairs are disabled or enabled. func (w *Websocket) GetChannelDifference(conn Connection, newSubs subscription.List) (sub, unsub subscription.List) { var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -829,7 +851,7 @@ func (w *Websocket) UnsubscribeChannels(conn Connection, channels subscription.L if len(channels) == 0 { return nil // No channels to unsubscribe from is not an error } - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { return w.unsubscribe(wrapper.Subscriptions, channels, func(channels subscription.List) error { return wrapper.Setup.Unsubscriber(context.TODO(), conn, channels) }) @@ -875,7 +897,7 @@ func (w *Websocket) SubscribeToChannels(conn Connection, subs subscription.List) return err } - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { return wrapper.Setup.Subscriber(context.TODO(), conn, subs) } @@ -896,7 +918,7 @@ func (w *Websocket) AddSubscriptions(conn Connection, subs ...*subscription.Subs return fmt.Errorf("%w: AddSubscriptions called on nil Websocket", common.ErrNilPointer) } var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -926,7 +948,7 @@ func (w *Websocket) AddSuccessfulSubscriptions(conn Connection, subs ...*subscri } var subscriptionStore **subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = &wrapper.Subscriptions } else { subscriptionStore = &w.subscriptions @@ -955,7 +977,7 @@ func (w *Websocket) RemoveSubscriptions(conn Connection, subs ...*subscription.S } var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions @@ -1042,7 +1064,7 @@ func checkWebsocketURL(s string) error { // The subscription state is not considered when counting existing subscriptions func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) error { var subscriptionStore *subscription.Store - if wrapper, ok := w.connections[conn]; ok && conn != nil { + if wrapper, ok := w.connectionToWrapper[conn]; ok && conn != nil { subscriptionStore = wrapper.Subscriptions } else { subscriptionStore = w.subscriptions @@ -1064,7 +1086,7 @@ func (w *Websocket) checkSubscriptions(conn Connection, subs subscription.List) if s.State() == subscription.ResubscribingState { continue } - if found := w.subscriptions.Get(s); found != nil { + if found := subscriptionStore.Get(s); found != nil { return fmt.Errorf("%w: %s", subscription.ErrDuplicate, s) } } @@ -1081,7 +1103,7 @@ func (w *Websocket) Reader(ctx context.Context, conn Connection, handler func(ct return // Connection has been closed } if err := handler(ctx, resp.Raw); err != nil { - w.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w", conn.GetURL(), err) + w.DataHandler <- fmt.Errorf("connection URL:[%v] error: %w for %s", conn.GetURL(), err, resp.Raw) } } } @@ -1241,3 +1263,38 @@ func signalReceived(ch chan struct{}) bool { return false } } + +// GetConnection returns a connection by connection signature (defined in wrapper setup) for request and response +// handling in a multi connection context. +func (w *Websocket) GetConnection(connSignature any) (Connection, error) { + if w == nil { + return nil, fmt.Errorf("%w: %T", common.ErrNilPointer, w) + } + + if connSignature == nil { + return nil, errConnectionSignatureNotSet + } + + w.m.Lock() + defer w.m.Unlock() + + if !w.IsConnected() { + return nil, ErrNotConnected + } + + if !w.useMultiConnectionManagement { + return nil, fmt.Errorf("%s: multi connection management not enabled %w please use exported Conn and AuthConn fields", w.exchangeName, errCannotObtainOutboundConnection) + } + + // Opted to range and not have a map, as connection level wrappers will be limited. + for _, wrapper := range w.connectionManager { + if wrapper.Setup.WrapperDefinedConnectionSignature == connSignature { + if wrapper.Connection == nil { + return nil, fmt.Errorf("%s: %s %w: %v", w.exchangeName, wrapper.Setup.URL, ErrNotConnected, connSignature) + } + return wrapper.Connection, nil + } + } + + return nil, fmt.Errorf("%s: %w: %v", w.exchangeName, ErrRequestRouteNotFound, connSignature) +} diff --git a/exchanges/stream/websocket_connection.go b/exchanges/stream/websocket_connection.go index 1f6f5e6019a..f052ee74f41 100644 --- a/exchanges/stream/websocket_connection.go +++ b/exchanges/stream/websocket_connection.go @@ -293,8 +293,8 @@ func (w *WebsocketConnection) GetURL() string { } // SendMessageReturnResponse will send a WS message to the connection and wait for response -func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, request any) ([]byte, error) { - resps, err := w.SendMessageReturnResponses(ctx, epl, signature, request, 1) +func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl request.EndpointLimit, signature, payload any) ([]byte, error) { + resps, err := w.SendMessageReturnResponses(ctx, epl, signature, payload, 1) if err != nil { return nil, err } @@ -303,7 +303,7 @@ func (w *WebsocketConnection) SendMessageReturnResponse(ctx context.Context, epl // SendMessageReturnResponses will send a WS message to the connection and wait for N responses // An error of ErrSignatureTimeout can be ignored if individual responses are being otherwise tracked -func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int) ([][]byte, error) { +func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, epl request.EndpointLimit, signature, payload any, expected int, messageInspector ...Inspector) ([][]byte, error) { outbound, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("error marshaling json for %s: %w", signature, err) @@ -320,28 +320,42 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep return nil, err } + resps, err := w.waitForResponses(ctx, signature, ch, expected, messageInspector...) + if err != nil { + return nil, err + } + + if w.Reporter != nil { + w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) + } + + return resps, err +} + +// waitForResponses waits for N responses from a channel +func (w *WebsocketConnection) waitForResponses(ctx context.Context, signature any, ch <-chan []byte, expected int, messageInspector ...Inspector) ([][]byte, error) { timeout := time.NewTimer(w.ResponseMaxLimit * time.Duration(expected)) + defer timeout.Stop() resps := make([][]byte, 0, expected) - for err == nil && len(resps) < expected { + for range expected { select { case resp := <-ch: resps = append(resps, resp) + // Checks recently received message to determine if this is in fact the final message in a sequence of messages. + if len(messageInspector) == 1 && messageInspector[0](resp) { + w.Match.RemoveSignature(signature) + return resps, nil + } case <-timeout.C: w.Match.RemoveSignature(signature) - err = fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) + return nil, fmt.Errorf("%s %w %v", w.ExchangeName, ErrSignatureTimeout, signature) case <-ctx.Done(): w.Match.RemoveSignature(signature) - err = ctx.Err() + return nil, ctx.Err() } } - timeout.Stop() - - if err == nil && w.Reporter != nil { - w.Reporter.Latency(w.ExchangeName, outbound, time.Since(start)) - } - // Only check context verbosity. If the exchange is verbose, it will log the responses in the ReadMessage() call. if request.IsVerbose(ctx, false) { for i := range resps { @@ -349,7 +363,7 @@ func (w *WebsocketConnection) SendMessageReturnResponses(ctx context.Context, ep } } - return resps, err + return resps, nil } func removeURLQueryString(url string) string { diff --git a/exchanges/stream/websocket_test.go b/exchanges/stream/websocket_test.go index 2904bccadee..290899b8d9d 100644 --- a/exchanges/stream/websocket_test.go +++ b/exchanges/stream/websocket_test.go @@ -223,13 +223,17 @@ func TestConnectionMessageErrors(t *testing.T) { assert.ErrorIs(t, err, errNoPendingConnections, "Connect should error correctly") ws.useMultiConnectionManagement = true + ws.SetCanUseAuthenticatedEndpoints(true) + ws.verbose = true // NOTE: Intentional mock := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockws.WsMockUpgrader(t, w, r, mockws.EchoHandler) })) defer mock.Close() - ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} + ws.connectionManager = []*ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws" + mock.URL[len("http"):] + "/ws"}}} err = ws.Connect() require.ErrorIs(t, err, errWebsocketSubscriptionsGeneratorUnset) + ws.connectionManager[0].Setup.Authenticate = func(context.Context, Connection) error { return errDastardlyReason } + ws.connectionManager[0].Setup.GenerateSubscriptions = func() (subscription.List, error) { return nil, errDastardlyReason } @@ -371,7 +375,7 @@ func TestWebsocket(t *testing.T) { ws.useMultiConnectionManagement = true - ws.connectionManager = []ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, Connection: &WebsocketConnection{}}} + ws.connectionManager = []*ConnectionWrapper{{Setup: &ConnectionSetup{URL: "ws://demos.kaazing.com/echo"}, Connection: &WebsocketConnection{}}} err = ws.SetProxyAddress("https://192.168.0.1:1337") require.NoError(t, err) } @@ -463,8 +467,8 @@ func TestSubscribeUnsubscribe(t *testing.T) { require.NoError(t, multi.SetupNewConnection(amazingCandidate)) amazingConn := multi.getConnectionFromSetup(amazingCandidate) - multi.connections = map[Connection]*ConnectionWrapper{ - amazingConn: &multi.connectionManager[0], + multi.connectionToWrapper = map[Connection]*ConnectionWrapper{ + amazingConn: multi.connectionManager[0], } subs, err = amazingCandidate.GenerateSubscriptions() @@ -975,7 +979,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - w.connections = map[Connection]*ConnectionWrapper{ + w.connectionToWrapper = map[Connection]*ConnectionWrapper{ sweetConn: {Setup: &ConnectionSetup{URL: "ws://localhost:8080/ws"}}, } @@ -988,7 +992,7 @@ func TestGetChannelDifference(t *testing.T) { require.Equal(t, 1, len(subs)) require.Empty(t, unsubs, "Should get no unsubs") - err := w.connections[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) + err := w.connectionToWrapper[sweetConn].Subscriptions.Add(&subscription.Subscription{Channel: subscription.CandlesChannel}) require.NoError(t, err) subs, unsubs = w.GetChannelDifference(sweetConn, subscription.List{{Channel: subscription.CandlesChannel}}) @@ -1229,6 +1233,11 @@ func TestSetupNewConnection(t *testing.T) { require.ErrorIs(t, err, errWebsocketDataHandlerUnset) connSetup.Handler = func(context.Context, []byte) error { return nil } + connSetup.WrapperDefinedConnectionSignature = []string{"slices are super naughty and not comparable"} + err = multi.SetupNewConnection(connSetup) + require.ErrorIs(t, err, errWrapperDefinedConnectionSignatureNotComparable) + + connSetup.WrapperDefinedConnectionSignature = "comparable string signature" err = multi.SetupNewConnection(connSetup) require.NoError(t, err) @@ -1484,3 +1493,40 @@ func TestMonitorTraffic(t *testing.T) { ws.TrafficAlert <- struct{}{} require.False(t, innerShell()) } + +func TestGetConnection(t *testing.T) { + t.Parallel() + var ws *Websocket + _, err := ws.GetConnection(nil) + require.ErrorIs(t, err, common.ErrNilPointer) + + ws = &Websocket{} + + _, err = ws.GetConnection(nil) + require.ErrorIs(t, err, errConnectionSignatureNotSet) + + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, ErrNotConnected) + + ws.setState(connectedState) + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, errCannotObtainOutboundConnection) + + ws.useMultiConnectionManagement = true + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, ErrRequestRouteNotFound) + + ws.connectionManager = []*ConnectionWrapper{{ + Setup: &ConnectionSetup{WrapperDefinedConnectionSignature: "testURL", URL: "testURL"}, + }} + + _, err = ws.GetConnection("testURL") + require.ErrorIs(t, err, ErrNotConnected) + + expected := &WebsocketConnection{} + ws.connectionManager[0].Connection = expected + + conn, err := ws.GetConnection("testURL") + require.NoError(t, err) + assert.Same(t, expected, conn) +} diff --git a/exchanges/stream/websocket_types.go b/exchanges/stream/websocket_types.go index 27a5c81963f..5152ff2c342 100644 --- a/exchanges/stream/websocket_types.go +++ b/exchanges/stream/websocket_types.go @@ -54,9 +54,9 @@ type Websocket struct { // For example, separate connections can be used for Spot, Margin, and Futures trading. This structure is especially useful // for exchanges that differentiate between trading pairs by using different connection endpoints or protocols for various asset classes. // If an exchange does not require such differentiation, all connections may be managed under a single ConnectionWrapper. - connectionManager []ConnectionWrapper - // connections holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder - connections map[Connection]*ConnectionWrapper + connectionManager []*ConnectionWrapper + // connectionToWrapper holds a look up table for all connections to their corresponding ConnectionWrapper and subscription holder + connectionToWrapper map[Connection]*ConnectionWrapper subscriptions *subscription.Store diff --git a/types/time.go b/types/time.go index f42aeca95e5..f10668b1073 100644 --- a/types/time.go +++ b/types/time.go @@ -28,7 +28,20 @@ func (t *Time) UnmarshalJSON(data []byte) error { s = s[1 : len(s)-1] } - if target := strings.Index(s, "."); target != -1 { + badSyntax := false + target := strings.IndexFunc(s, func(r rune) bool { + if r == '.' { + return true + } + // As a mistake this type can be used instead of time.Time and parse int below will not fail. + badSyntax = r < '0' || r > '9' + return badSyntax // exit early + }) + + if target != -1 { + if badSyntax { + return strconv.ErrSyntax + } s = s[:target] + s[target+1:] } diff --git a/types/time_test.go b/types/time_test.go index 97e529294d4..a7c95459f12 100644 --- a/types/time_test.go +++ b/types/time_test.go @@ -56,17 +56,19 @@ func TestUnmarshalJSON(t *testing.T) { require.ErrorIs(t, json.Unmarshal([]byte(`"blurp"`), &testTime), strconv.ErrSyntax) require.Error(t, json.Unmarshal([]byte(`"123456"`), &testTime)) + + // Captures bad syntax when type should be time.Time (RFC3339) + require.ErrorIs(t, json.Unmarshal([]byte(`"2025-03-28T08:00:00Z"`), &testTime), strconv.ErrSyntax) + require.Error(t, json.Unmarshal([]byte(`"123456"`), &testTime)) } -// 5046307 216.0 ns/op 168 B/op 2 allocs/op (current) +// 5030734 240.1 ns/op 168 B/op 2 allocs/op (current) // 2716176 441.9 ns/op 352 B/op 6 allocs/op (previous) func BenchmarkUnmarshalJSON(b *testing.B) { var testTime Time for i := 0; i < b.N; i++ { err := json.Unmarshal([]byte(`"1691122380942.173000"`), &testTime) - if err != nil { - b.Fatal(err) - } + require.NoError(b, err) } }