Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use backoff to refresh tokens #31

Merged
merged 6 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ cython_debug/
notebooks
*.log
data-logs/
*tokens.json
*tokens*.json
9 changes: 5 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ pip install CSchwabPy
```python

# save these lines in a file named like cschwab.py
from cschwabpy.SchwabAsyncClient import SchwabAsyncClient
# NOTE: should use SchwabClient to get tokens manually after version 0.1.3
from cschwabpy.SchwabClient import SchwabClient

app_client_key = "---your-app-client-key-here-"
app_secret = "app-secret"

schwab_client = SchwabAsyncClient(app_client_id=app_client_key, app_secret=app_secret)
schwab_client = SchwabClient(app_client_id=app_client_key, app_secret=app_secret)
schwab_client.get_tokens_manually()

# run in your Terminal, follow the prompt to complete authentication:
Expand All @@ -47,13 +48,13 @@ schwab_client.get_tokens_manually()
#----------------
ticker = '$SPX'
# get option expirations:
expiration_list = await schwab_client.get_option_expirations_async(underlying_symbol = ticker)
expiration_list = schwab_client.get_option_expirations(underlying_symbol = ticker)

# download SPX option chains
from_date = 2024-07-01
to_date = 2024-07-01

opt_chain_result = await schwab_client.download_option_chain_async(ticker, from_date, to_date)
opt_chain_result = schwab_client.download_option_chain(ticker, from_date, to_date)

# get call-put dataframe pairs by expiration
opt_df_pairs = opt_chain_result.to_dataframe_pairs_by_expiration()
Expand Down
84 changes: 10 additions & 74 deletions cschwabpy/SchwabAsyncClient.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from cschwabpy.models.token import Tokens, ITokenStore, LocalTokenStore
from cschwabpy.models.token import Tokens, IAsyncTokenStore, AsyncLocalTokenStore
from cschwabpy.models import (
OptionChainQueryFilter,
OptionContractType,
Expand Down Expand Up @@ -28,7 +28,7 @@
SCHWAB_AUTH_PATH,
SCHWAB_TOKEN_PATH,
)

import backoff
import httpx
import re
import base64
Expand All @@ -42,7 +42,7 @@ def __init__(
self,
app_client_id: str,
app_secret: str,
token_store: ITokenStore = LocalTokenStore(),
token_store: IAsyncTokenStore = AsyncLocalTokenStore(),
tokens: Optional[Tokens] = None,
http_client: Optional[httpx.AsyncClient] = None,
) -> None:
Expand All @@ -51,28 +51,23 @@ def __init__(
self.__token_store = token_store
self.__client = http_client
self.__keep_client_alive = http_client is not None
if (
tokens is not None
and tokens.is_access_token_valid
and tokens.is_refresh_token_valid
):
token_store.save_tokens(tokens)

self.__tokens = token_store.get_tokens()
self.__tokens = tokens

@property
def token_url(self) -> str:
return f"{SCHWAB_API_BASE_URL}/{SCHWAB_TOKEN_PATH}"

@backoff.on_exception(backoff.expo, Exception, max_tries=3, max_time=10)
async def _ensure_valid_access_token(self, force_refresh: bool = False) -> bool:
if self.__tokens is None:
self.__tokens = await self.__token_store.get_tokens()

if self.__tokens is None:
raise Exception(
"Tokens are not available. Please use get_tokens_manually() to get tokens first."
)

if self.__tokens.is_access_token_valid and not force_refresh:
return True

client = httpx.AsyncClient() if self.__client is None else self.__client
try:
key_sec_encoded = self.__encode_app_key_secret()
Expand All @@ -91,7 +86,7 @@ async def _ensure_valid_access_token(self, force_refresh: bool = False) -> bool:
if response.status_code == 200:
json_res = response.json()
tokens = Tokens(**json_res)
self.__token_store.save_tokens(tokens)
await self.__token_store.save_tokens(tokens)
return True
else:
raise Exception(
Expand Down Expand Up @@ -385,67 +380,8 @@ async def download_option_chain_async(
url=target_url, params={}, headers=self.__auth_header()
)
json_res = response.json()
print("json_res: ", json_res)
return OptionChain(**json_res)
finally:
if not self.__keep_client_alive:
await client.aclose()

def get_tokens_manually(
self,
) -> None:
"""Manual steps to get tokens from Charles Schwab API."""
from prompt_toolkit import prompt
import urllib.parse as url_parser

redirect_uri = prompt("Enter your redirect uri> ").strip()
complete_auth_url = f"{SCHWAB_API_BASE_URL}/{SCHWAB_AUTH_PATH}?response_type=code&client_id={self.__client_id}&redirect_uri={redirect_uri}"
print(
f"Copy and open the following URL in browser. Complete Login & Authorization:\n {complete_auth_url}"
)
auth_code_response_url = prompt(
"Paste the entire authorization response URL here> "
).strip()

auth_code = ""
try:
auth_code_pattern = re.compile(r"code=(.+)&?")
d = re.search(auth_code_pattern, auth_code_response_url)
if d:
auth_code = d.group(1)
auth_code = url_parser.unquote(auth_code.split("&")[0])
else:
raise Exception(
"authorization response url does not contain authorization code"
)

if len(auth_code) == 0:
raise Exception("authorization code is empty")
except Exception as ex:
raise Exception(
"Failed to get authorization code. Please try again. Exception: ", ex
)

key_sec_encoded = self.__encode_app_key_secret()
with httpx.Client() as client:
response = client.post(
url=self.token_url,
headers={
"Authorization": f"Basic {key_sec_encoded}",
"Content-Type": "application/x-www-form-urlencoded",
},
data={
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": redirect_uri,
},
)

if response.status_code == 200:
json_res = response.json()
tokens = Tokens(**json_res)
self.__token_store.save_tokens(tokens)
print(
f"Tokens saved successfully at path: {self.__token_store.token_file_path}"
)
else:
print("Failed to get tokens. Please try again.")
14 changes: 5 additions & 9 deletions cschwabpy/SchwabClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
InstrumentProjection,
)
import cschwabpy.util as util

import backoff
from datetime import datetime, timedelta
from typing import Optional, List, Mapping
from cschwabpy.costants import (
Expand Down Expand Up @@ -53,20 +53,16 @@ def __init__(
self.__token_store = token_store
self.__client = http_client
self.__keep_client_alive = http_client is not None
if (
tokens is not None
and tokens.is_access_token_valid
and tokens.is_refresh_token_valid
):
token_store.save_tokens(tokens)

self.__tokens = token_store.get_tokens()
self.__tokens = tokens

@property
def token_url(self) -> str:
return f"{SCHWAB_API_BASE_URL}/{SCHWAB_TOKEN_PATH}"

@backoff.on_exception(backoff.expo, Exception, max_tries=3, max_time=10)
def _ensure_valid_access_token(self, force_refresh: bool = False) -> bool:
if self.__tokens is None:
self.__tokens = self.__token_store.get_tokens()
if self.__tokens is None:
raise Exception(
"Tokens are not available. Please use get_tokens_manually() to get tokens first."
Expand Down
46 changes: 46 additions & 0 deletions cschwabpy/models/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import json
import time
import aiofiles as af
from pathlib import Path

REFRESH_TOKEN_VALIDITY_SECONDS = 7 * 24 * 60 * 60 # 7 days
Expand Down Expand Up @@ -77,3 +78,48 @@ def get_tokens(self) -> Optional[Tokens]:
def save_tokens(self, tokens: Tokens) -> None:
with open(self.token_file_path, "w") as token_file:
token_file.write(json.dumps(tokens.to_json(), indent=4))


class IAsyncTokenStore(Protocol):
@property
def token_output_path(self) -> str:
"""Path for outputting tokens."""
return ""

async def get_tokens(self) -> Optional[Tokens]:
pass

async def save_tokens(self, tokens: Tokens) -> None:
pass


class AsyncLocalTokenStore(IAsyncTokenStore):
def __init__(
self, json_file_name: str = "tokens.json", file_path: Optional[str] = None
):
self.file_name = json_file_name
self.token_file_path = file_path
if file_path is None:
self.token_file_path = Path(Path(__file__).parent, json_file_name)
else:
self.token_file_path = Path(file_path)

if not os.path.exists(self.token_file_path.parent):
os.makedirs(self.token_file_path.parent)

@property
def token_output_path(self) -> str:
return str(self.token_file_path)

async def get_tokens(self) -> Optional[Tokens]:
try:
async with af.open(self.token_file_path, mode="r") as token_file:
token_json_str = await token_file.read()
tokens_json = json.loads(token_json_str)
return Tokens(**tokens_json)
except:
return None

async def save_tokens(self, tokens: Tokens) -> None:
async with af.open(self.token_file_path, mode="w") as token_file:
await token_file.write(json.dumps(tokens.to_json(), indent=4))
Loading