From 2ef5813e6ae9dcb8a0700c6384cd5f6200806815 Mon Sep 17 00:00:00 2001 From: sai-g Date: Wed, 5 Jun 2024 20:27:46 -0400 Subject: [PATCH] Not to change method signature --- alpaca/rest.go | 36 ++++++++++++++++++++++-------------- alpaca/rest_test.go | 22 +++++----------------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/alpaca/rest.go b/alpaca/rest.go index 7c086f8..0e57d49 100644 --- a/alpaca/rest.go +++ b/alpaca/rest.go @@ -3,6 +3,7 @@ package alpaca import ( "bytes" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -314,10 +315,12 @@ type CloseAllPositionsRequest struct { } // CloseAllPositions liquidates all open positions at market price. -func (c *Client) CloseAllPositions(req CloseAllPositionsRequest) (map[string]*Order, map[string]*APIError, error) { +// It returns the list of orders that were submitted to close the positions. +// If an error occurs while closing a position, the error will be returned +func (c *Client) CloseAllPositions(req CloseAllPositionsRequest) ([]Order, error) { u, err := url.Parse(fmt.Sprintf("%s/%s/positions", c.opts.BaseURL, apiVersion)) if err != nil { - return nil, nil, err + return nil, err } q := u.Query() @@ -326,35 +329,40 @@ func (c *Client) CloseAllPositions(req CloseAllPositionsRequest) (map[string]*Or resp, err := c.delete(u) if err != nil { - return nil, nil, err + return nil, err } var closeAllPositions closeAllPositionsSlice if err = unmarshal(resp, &closeAllPositions); err != nil { - return nil, nil, err + return nil, err } var ( - orderEntityMap = make(map[string]*Order, len(closeAllPositions)) - apiErrorMap = make(map[string]*APIError, len(closeAllPositions)) + orders = make([]Order, 0, len(closeAllPositions)) + errs = make([]error, 0, len(closeAllPositions)) ) for _, capr := range closeAllPositions { if capr.Status == http.StatusOK { - var orderEntity Order - if err := easyjson.Unmarshal(capr.Body, &orderEntity); err != nil { - return nil, nil, err + var order Order + if err := easyjson.Unmarshal(capr.Body, &order); err != nil { + return nil, err } - orderEntityMap[capr.Symbol] = &orderEntity + orders = append(orders, order) continue } var apiErr APIError if err := easyjson.Unmarshal(capr.Body, &apiErr); err != nil { - return nil, nil, err + return nil, err } - apiErrorMap[capr.Symbol] = &apiErr + apiErr.StatusCode = capr.Status + errs = append(errs, &apiErr) } - return orderEntityMap, apiErrorMap, nil + if len(errs) > 0 { + return orders, errors.Join(errs...) + } + + return orders, nil } type ClosePositionRequest struct { @@ -932,7 +940,7 @@ func GetPosition(symbol string) (*Position, error) { } // CloseAllPositions liquidates all open positions at market price. -func CloseAllPositions(req CloseAllPositionsRequest) (map[string]*Order, map[string]*APIError, error) { +func CloseAllPositions(req CloseAllPositionsRequest) ([]Order, error) { return DefaultClient.CloseAllPositions(req) } diff --git a/alpaca/rest_test.go b/alpaca/rest_test.go index 55501fc..f9b7f60 100644 --- a/alpaca/rest_test.go +++ b/alpaca/rest_test.go @@ -182,7 +182,7 @@ func TestCancelAllPositions(t *testing.T) { c := DefaultClient closeAllPositionsResponse := []CloseAllPositionsResponse{ - {Symbol: "AAPL", Status: 200, Body: json.RawMessage(`{"id":"0571ce61-bf65-4f0c-b3de-6f42ce628422"}`)}, + {Symbol: "AAPL", Status: 200, Body: json.RawMessage(`{"id":"0571ce61-bf65-4f0c-b3de-6f42ce628422", "symbol": "AAPL"}`)}, {Symbol: "TSLA", Status: 422, Body: json.RawMessage(`{"code": 42210000, "message": "error"}`)}, } c.do = func(c *Client, req *http.Request) (*http.Response, error) { @@ -193,24 +193,12 @@ func TestCancelAllPositions(t *testing.T) { Body: genBody(closeAllPositionsResponse), }, nil } - gotOrders, gotApiErrors, err := c.CloseAllPositions(CloseAllPositionsRequest{ + gotOrders, err := c.CloseAllPositions(CloseAllPositionsRequest{ CancelOrders: true, }) - require.NoError(t, err) - require.Len(t, gotOrders, 1) - require.Len(t, gotApiErrors, 1) - - aaplOrder, ok := gotOrders["AAPL"] - require.True(t, ok) - require.NotNil(t, aaplOrder) - - tslaOrder, ok := gotOrders["TSLA"] - require.False(t, ok) - require.Nil(t, tslaOrder) - - apiErr, ok := gotApiErrors["TSLA"] - require.True(t, ok) - require.NotNil(t, apiErr) + require.Error(t, err) + assert.Len(t, gotOrders, 1) + assert.Equal(t, "AAPL", gotOrders[0].Symbol) } func TestGetClock(t *testing.T) {