From 3ed6259325455f266be3ec6de2a206065553c7b2 Mon Sep 17 00:00:00 2001 From: Ammar Bandukwala Date: Thu, 7 Nov 2024 11:07:29 -0600 Subject: [PATCH] Add support for API key authentication --- .gitignore | 2 + auth.go | 5 ++- client.go | 99 +++++++++++++++++++++++++++++++++++++++++++----- client_test.go | 22 ++++++----- exchange_test.go | 4 +- 5 files changed, 111 insertions(+), 21 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5a23fa1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +secrets/** +secrets.env diff --git a/auth.go b/auth.go index 549ea91..8877700 100644 --- a/auth.go +++ b/auth.go @@ -1,6 +1,8 @@ package kalshi -import "context" +import ( + "context" +) // LoginRequest is described here: // https://trading-api.readme.io/reference/login. @@ -32,6 +34,7 @@ func (c *Client) Login(ctx context.Context, req LoginRequest) (*LoginResponse, e if err != nil { return nil, err } + return &resp, nil } diff --git a/client.go b/client.go index e1681a7..bab7db0 100644 --- a/client.go +++ b/client.go @@ -3,7 +3,13 @@ package kalshi import ( "bytes" "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" "encoding/json" + "encoding/pem" "fmt" "io" "net/http" @@ -30,6 +36,30 @@ func (c Cents) String() string { return fmt.Sprintf("$%.2f", dollars) } +type APIKey struct { + ID string + Key *rsa.PrivateKey +} + +func LoadAPIKey(apiKeyID, path string) (*APIKey, error) { + key, err := os.ReadFile(path) + if err != nil { + return nil, err + } + // Parse PEM encoded RSA private key + block, _ := pem.Decode(key) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return &APIKey{ID: apiKeyID, Key: rsaKey}, nil +} + // Client must be instantiated via New. type Client struct { // BaseURL is one of APIDemoURL or APIProdURL. @@ -39,6 +69,11 @@ type Client struct { WriteRatelimit *rate.Limiter ReadRateLimit *rate.Limiter + // APIKey is optional if you use the Login method. + // As of 2024-11-07, Login-based auth is not working and returning + // 403 Forbidden? + APIKey *APIKey + httpClient *http.Client } @@ -60,27 +95,72 @@ type request struct { JSONResponse any } -func jsonRequestHeaders( +func (c *Client) signRequest(req *http.Request) error { + if c.APIKey == nil { + return nil + } + + timestamp := time.Now().UnixMilli() + payload := fmt.Sprintf("%d%s%s", timestamp, req.Method, req.URL.Path) + + hashed := crypto.SHA256.New() + hashed.Write([]byte(payload)) + signature, err := rsa.SignPSS( + rand.Reader, + c.APIKey.Key, + crypto.SHA256, + hashed.Sum(nil), + &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, + }) + if err != nil { + return fmt.Errorf("failed to sign request: %w", err) + } + + req.Header.Set("KALSHI-ACCESS-KEY", c.APIKey.ID) + req.Header.Set("KALSHI-ACCESS-TIMESTAMP", strconv.FormatInt(timestamp, 10)) + req.Header.Set("KALSHI-ACCESS-SIGNATURE", base64.StdEncoding.EncodeToString(signature)) + + return nil +} + +func (c *Client) jsonRequestHeaders( ctx context.Context, client *http.Client, headers http.Header, method string, reqURL string, jsonReq any, jsonResp any, ) error { - reqBodyByt, err := json.Marshal(jsonReq) - if err != nil { - return err + var ( + reqBodyReader io.Reader + reqBodyBytes []byte + ) + if jsonReq != nil { + var err error + reqBodyBytes, err = json.Marshal(jsonReq) + if err != nil { + return err + } + reqBodyReader = bytes.NewReader(reqBodyBytes) } - req, err := http.NewRequest(method, reqURL, bytes.NewReader(reqBodyByt)) + req, err := http.NewRequestWithContext(ctx, method, reqURL, reqBodyReader) if err != nil { return err } if headers != nil { req.Header = headers } - req.Header.Set("Content-Type", "application/json") + + if err := c.signRequest(req); err != nil { + return fmt.Errorf("sign request: %w", err) + } + + if req.Method == "POST" || req.Method == "PUT" { + req.Header.Set("Content-Type", "application/json") + } req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", "ammario/kalshi-go") resp, err := client.Do(req) if err != nil { @@ -102,9 +182,10 @@ func jsonRequestHeaders( if err != nil { return fmt.Errorf("dump: %w", err) } - dumpErr := fmt.Sprintf("Request\n%s%s\nResponse\n%s%s", + dumpErr := fmt.Sprintf("Request %s\n%s%s\nResponse\n%s%s", + reqURL, reqDump, - reqBodyByt, + reqBodyBytes, respDump, respBodyByt, ) @@ -164,7 +245,7 @@ func (c *Client) request( } } - return jsonRequestHeaders( + return c.jsonRequestHeaders( ctx, c.httpClient, nil, diff --git a/client_test.go b/client_test.go index 16d8381..f312d6b 100644 --- a/client_test.go +++ b/client_test.go @@ -14,27 +14,29 @@ var rateLimit = rate.NewLimiter(rate.Every(time.Second), 10-1) func testClient(t *testing.T) *Client { const ( - emailEnv = "KALSHI_EMAIL" - passEnv = "KALSHI_PASSWORD" + apiKeyIDEnv = "KALSHI_API_KEY_ID" + apiKeyPathEnv = "KALSHI_API_KEY_PATH" ) ctx := context.Background() - email, ok := os.LookupEnv(emailEnv) + apiKeyID, ok := os.LookupEnv(apiKeyIDEnv) if !ok { - t.Fatalf("no $%s provided", emailEnv) + t.Fatalf("no $%s provided", apiKeyIDEnv) } - password, ok := os.LookupEnv(passEnv) + + apiKeyPath, ok := os.LookupEnv(apiKeyPathEnv) if !ok { - t.Fatalf("no $%s provided", passEnv) + t.Fatalf("no $%s provided", apiKeyPathEnv) } + apiKey, err := LoadAPIKey(apiKeyID, apiKeyPath) + require.NoError(t, err) + c := New(APIDemoURL) c.WriteRatelimit = rateLimit - _, err := c.Login(ctx, LoginRequest{ - Email: email, - Password: password, - }) + c.APIKey = apiKey + require.NoError(t, err) t.Cleanup(func() { // Logout will fail during the Logout test. diff --git a/exchange_test.go b/exchange_test.go index 2b24263..ce9eacb 100644 --- a/exchange_test.go +++ b/exchange_test.go @@ -12,6 +12,9 @@ func TestExchangeStatus(t *testing.T) { client := testClient(t) + // ExchangeStatus is not authenticated + client.APIKey = nil + s, err := client.ExchangeStatus(context.Background()) require.NoError(t, err) // The Demo API never sleeps. @@ -27,5 +30,4 @@ func TestExchangeSchedule(t *testing.T) { _, err := client.ExchangeSchedule(context.Background()) require.NoError(t, err) - }