diff --git a/changelog.md b/changelog.md index e53025b..8dc214a 100644 --- a/changelog.md +++ b/changelog.md @@ -17,6 +17,7 @@ - Added get_info method - Added interval (1d, 1w, 1m) when getting historical prices +- Added WealthSimple refresh_tokens method #### - :nail_care: [Polish] diff --git a/examples/wealthsimple.py b/examples/wealthsimple.py index a62bf68..c4fa7a8 100644 --- a/examples/wealthsimple.py +++ b/examples/wealthsimple.py @@ -1,36 +1,45 @@ +from loguru import logger +from stockai import WealthSimple import sys sys.path.append('../stockai') -from stockai import WealthSimple - email: str = '' password: str = '' + def prepare_credentails(): - global email, password - print('Prepare credentials') - while not email: - email = str(input("Enter email: \n>>> ")) - while not password: - password = str(input("Enter password: \n>>> ")) + global email, password + + logger.debug('Prepare credentials') + + while not email: + email = str(input("Enter email: \n>>> ")) + while not password: + password = str(input("Enter password: \n>>> ")) + def init(): - global username, password + global username, password + + ws = WealthSimple(email, password) + + logger.debug('Get me...') + me = ws.get_me() + print(me) - ws = WealthSimple(email, password) + logger.debug('Get accounts') + accounts = ws.get_accounts() + print(accounts) - print('Get me...\n') - me = ws.get_me() - print(me) + logger.debug('Refresh tokens\n) + ws.refresh_tokens() - print('Get accounts\n') - accounts = ws.get_accounts() - print(accounts) + print(ws.session.headers['Authorization']) + + logger.debug('Get security') + TSLA = ws.get_security('TSLA') + print(TSLA) - print('Get security \n') - TSLA = ws.get_security('TSLA') - print(TSLA) prepare_credentails() init() - diff --git a/stockai/wealthsimple/requests.py b/stockai/wealthsimple/requests.py index 64e5c33..132e5ee 100644 --- a/stockai/wealthsimple/requests.py +++ b/stockai/wealthsimple/requests.py @@ -1,30 +1,31 @@ import requests + class WSAPIRequest: """ Handle API requests for WealthSimple """ - def __init__(self, session, WS_URL): + def __init__(self, session, ws_url): self.session = session - self.WS_URL = WS_URL + self.WS_URL = ws_url def request(self, method, endpoint, params=None): url = self.WS_URL + endpoint if method == 'POST': - return self.post(url, params) + return self.__post(url, params) elif method == 'GET': - return self.get(url, params) + return self.__get(url, params) else: raise Exception('Invalid request method: {method}') - def post(self, url, params=None): + def __post(self, url, params=None): try: return self.session.post(url, params) except Exception as error: print(error) - def get(self, url, payload=None): - auth = self.session.headers['Authorization'] - return requests.get(url, headers = { 'Authorization': auth }) + def __get(self, url, payload=None): + auth = self.session.headers['Authorization'] + return requests.get(url, headers={'Authorization': auth}) diff --git a/stockai/wealthsimple/stock.py b/stockai/wealthsimple/stock.py index 2d2654c..01f15e3 100644 --- a/stockai/wealthsimple/stock.py +++ b/stockai/wealthsimple/stock.py @@ -2,67 +2,93 @@ from .requests import WSAPIRequest from loguru import logger + class WealthSimple(): - BASE_DOMAIN = 'https://trade-service.wealthsimple.com/' - - def __init__(self, email: str, password: str): - self.session = Session() - self.WSAPI = WSAPIRequest(self.session, self.BASE_DOMAIN) - self.login(email, password) - - def login (self, email: str = None, password: str = None) -> None: - if email and password: - payload = { "email": email, "password": password, "timeoutMs": 2e4 } - response = self.WSAPI.request('POST', 'auth/login', payload) - - # Check of OTP - if "x-wealthsimple-otp" in response.headers: - TFACode = '' - while not TFACode: - # Obtain user input and ensure it is not empty - TFACode = input('Enter 2FA code: ') - payload['otp'] = TFACode - response = self.WSAPI.request('POST', 'auth/login', payload) - - if response.status_code == 401: - raise Exception('Invalid Login') - - self.session.headers.update( - {"Authorization": response.headers["X-Access-Token"]} - ) - self.session.headers.update( - {"refresh_token": response.headers["X-Refresh-Token"]} - ) - else: - raise Exception('Missing login credentials') - - def get_me(self): - logger.debug('get_me') - response = self.WSAPI.request('GET', 'me') - logger.debug(f'get_me {response.status_code}') - - if response.status_code == 401: - raise Exception('Invalid Access Token') - else: - return response.json() - - def get_accounts(self) -> list: - """ - Get Wealthsimple Trade Accounts - """ - response = self.WSAPI.request('GET', 'account/list') - response = response.json() - return response['results'] - - def get_security(self, id: str) -> dict: - """ - Get Security Info - """ - logger.debug('get_security') - response = self.WSAPI.request('GET', f'securities/{id}') - logger.debug(f"get_security {response.status_code}") - - if response.status_code == 401: + """WealthSimple class for API interaction""" + + BASE_DOMAIN = 'https://trade-service.wealthsimple.com/' + + def __init__(self, email: str, password: str): + self.session = Session() + self.WSAPI = WSAPIRequest(self.session, self.BASE_DOMAIN) + self.login(email, password) + + def login(self, email: str = None, password: str = None) -> None: + if email and password: + payload = {"email": email, "password": password, "timeoutMs": 2e4} + response = self.WSAPI.request('POST', 'auth/login', payload) + + # Check of OTP + if "x-wealthsimple-otp" in response.headers: + TFACode = '' + while not TFACode: + # Obtain user input and ensure it is not empty + TFACode = input('Enter 2FA code: ') + payload['otp'] = TFACode + response = self.WSAPI.request('POST', 'auth/login', payload) + + if response.status_code == 401: + raise Exception('Invalid Login') + + self.__update_tokens(response) + else: + raise Exception('Missing login credentials') + + def get_me(self): + """Return owner information""" + + logger.debug('get_me') + response = self.WSAPI.request('GET', 'me') + logger.debug(f'get_me {response.status_code}') + + if response.status_code == 401: + raise Exception('Invalid Access Token') + else: + return response.json() + + def get_accounts(self) -> list: + """ + Get Wealthsimple Trade Accounts + """ + + response = self.WSAPI.request('GET', 'account/list') + response = response.json() + return response['results'] + + def get_security(self, id: str) -> dict: + """ + Get Security Info + """ + + logger.debug('get_security') + response = self.WSAPI.request('GET', f'securities/{id}') + logger.debug(f"get_security {response.status_code}") + logger.debug(f"get_security {response}") + + if response.status_code == 401: raise Exception(f'Cannot get security {id}') - else: - return response.json() + else: + return response.json() + + def refresh_tokens(self): + """ + Genereates new tokens + """ + + logger.debug('Refresh tokens') + response = self.WSAPI.request('POST', 'auth/refresh') + + if response.status_code == 401: + logger.error('Current refresh token is expired') + raise Exception('Refresh token is expired') + else: + self.__update_tokens(response) + + def __update_tokens(self, response): + """Update session tokens""" + self.session.headers.update( + {"Authorization": response.headers["X-Access-Token"]} + ) + self.session.headers.update( + {"refresh_token": response.headers["X-Refresh-Token"]} + ) diff --git a/stockai/yahoo/base.py b/stockai/yahoo/base.py index b22205b..a5b5588 100644 --- a/stockai/yahoo/base.py +++ b/stockai/yahoo/base.py @@ -5,7 +5,7 @@ class Base(object): def __init__(self, symbol): self.symbol = symbol - def _prepare_request(self, region='US', lang='en-US', includePrePost='false', interval='2m', range='1d'): + def __prepare_request(self, region='US', lang='en-US', includePrePost='false', interval='2m', range='1d'): """ Basic Yahoo Rquest URL """ @@ -19,8 +19,8 @@ def _prepare_request(self, region='US', lang='en-US', includePrePost='false', in ) return url - def _request(self): - url = self._prepare_request() + def __request(self): + url = self.__prepare_request() data = get(url) if data.json()['quoteSummary']['error'] is not None: @@ -50,7 +50,7 @@ def refresh(self): """ Refresh stock data """ - self.data_set = self._request() + self.data_set = self.__request() def __process_historical_result(self, data): """