Skip to content

Commit

Permalink
Not to change method signature
Browse files Browse the repository at this point in the history
  • Loading branch information
sai-g committed Jun 6, 2024
1 parent 6c11243 commit 2ef5813
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 31 deletions.
36 changes: 22 additions & 14 deletions alpaca/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package alpaca
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -314,10 +315,12 @@ type CloseAllPositionsRequest struct {
}

// CloseAllPositions liquidates all open positions at market price.
func (c *Client) CloseAllPositions(req CloseAllPositionsRequest) (map[string]*Order, map[string]*APIError, error) {
// It returns the list of orders that were submitted to close the positions.
// If an error occurs while closing a position, the error will be returned
func (c *Client) CloseAllPositions(req CloseAllPositionsRequest) ([]Order, error) {
u, err := url.Parse(fmt.Sprintf("%s/%s/positions", c.opts.BaseURL, apiVersion))
if err != nil {
return nil, nil, err
return nil, err
}

q := u.Query()
Expand All @@ -326,35 +329,40 @@ func (c *Client) CloseAllPositions(req CloseAllPositionsRequest) (map[string]*Or

resp, err := c.delete(u)
if err != nil {
return nil, nil, err
return nil, err
}

var closeAllPositions closeAllPositionsSlice
if err = unmarshal(resp, &closeAllPositions); err != nil {
return nil, nil, err
return nil, err
}

var (
orderEntityMap = make(map[string]*Order, len(closeAllPositions))
apiErrorMap = make(map[string]*APIError, len(closeAllPositions))
orders = make([]Order, 0, len(closeAllPositions))
errs = make([]error, 0, len(closeAllPositions))
)
for _, capr := range closeAllPositions {
if capr.Status == http.StatusOK {
var orderEntity Order
if err := easyjson.Unmarshal(capr.Body, &orderEntity); err != nil {
return nil, nil, err
var order Order
if err := easyjson.Unmarshal(capr.Body, &order); err != nil {
return nil, err
}
orderEntityMap[capr.Symbol] = &orderEntity
orders = append(orders, order)
continue
}
var apiErr APIError
if err := easyjson.Unmarshal(capr.Body, &apiErr); err != nil {
return nil, nil, err
return nil, err
}
apiErrorMap[capr.Symbol] = &apiErr
apiErr.StatusCode = capr.Status
errs = append(errs, &apiErr)
}

return orderEntityMap, apiErrorMap, nil
if len(errs) > 0 {
return orders, errors.Join(errs...)
}

return orders, nil
}

type ClosePositionRequest struct {
Expand Down Expand Up @@ -932,7 +940,7 @@ func GetPosition(symbol string) (*Position, error) {
}

// CloseAllPositions liquidates all open positions at market price.
func CloseAllPositions(req CloseAllPositionsRequest) (map[string]*Order, map[string]*APIError, error) {
func CloseAllPositions(req CloseAllPositionsRequest) ([]Order, error) {
return DefaultClient.CloseAllPositions(req)
}

Expand Down
22 changes: 5 additions & 17 deletions alpaca/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func TestCancelAllPositions(t *testing.T) {
c := DefaultClient

closeAllPositionsResponse := []CloseAllPositionsResponse{
{Symbol: "AAPL", Status: 200, Body: json.RawMessage(`{"id":"0571ce61-bf65-4f0c-b3de-6f42ce628422"}`)},
{Symbol: "AAPL", Status: 200, Body: json.RawMessage(`{"id":"0571ce61-bf65-4f0c-b3de-6f42ce628422", "symbol": "AAPL"}`)},
{Symbol: "TSLA", Status: 422, Body: json.RawMessage(`{"code": 42210000, "message": "error"}`)},
}
c.do = func(c *Client, req *http.Request) (*http.Response, error) {
Expand All @@ -193,24 +193,12 @@ func TestCancelAllPositions(t *testing.T) {
Body: genBody(closeAllPositionsResponse),
}, nil
}
gotOrders, gotApiErrors, err := c.CloseAllPositions(CloseAllPositionsRequest{
gotOrders, err := c.CloseAllPositions(CloseAllPositionsRequest{
CancelOrders: true,
})
require.NoError(t, err)
require.Len(t, gotOrders, 1)
require.Len(t, gotApiErrors, 1)

aaplOrder, ok := gotOrders["AAPL"]
require.True(t, ok)
require.NotNil(t, aaplOrder)

tslaOrder, ok := gotOrders["TSLA"]
require.False(t, ok)
require.Nil(t, tslaOrder)

apiErr, ok := gotApiErrors["TSLA"]
require.True(t, ok)
require.NotNil(t, apiErr)
require.Error(t, err)
assert.Len(t, gotOrders, 1)
assert.Equal(t, "AAPL", gotOrders[0].Symbol)
}

func TestGetClock(t *testing.T) {
Expand Down

0 comments on commit 2ef5813

Please sign in to comment.