From e6c3b83466f3f39a0c67fef3c25fa83e3b683632 Mon Sep 17 00:00:00 2001 From: Dale Nguyen Date: Sun, 28 Feb 2021 11:31:16 -0500 Subject: [PATCH 1/3] Added get infor method & documentation --- changelog.md | 12 ++++++++++++ setup.py | 2 +- stockai/__init__.py | 2 +- stockai/yahoo/base.py | 6 +++--- stockai/yahoo/stock.py | 34 +++++++++++++++++++++++----------- tests/yahoo.py | 3 +++ 6 files changed, 43 insertions(+), 16 deletions(-) diff --git a/changelog.md b/changelog.md index bbd6fcb..03af699 100644 --- a/changelog.md +++ b/changelog.md @@ -11,6 +11,18 @@ --- +## [1.5.0] - 2021-02-28 + +#### - :rocket: [New Feature] + +- Added get_info method + +#### - :nail_care: [Polish] + +- Push date to the first when getting historical prices +- Removed meta from getting historical prices +- Added functions documentation + ## [1.4.0] - 2021-02-27 #### - :bug: [Bug Fix] diff --git a/setup.py b/setup.py index 47375ee..157189b 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name="stockai", - version="1.4.0", + version="1.5.0", author="Dale Nguyen", author_email="dale@dalenguyen.me", description="Get stock info from Yahoo! Finance", diff --git a/stockai/__init__.py b/stockai/__init__.py index 3242ce0..ce7f35c 100644 --- a/stockai/__init__.py +++ b/stockai/__init__.py @@ -2,7 +2,7 @@ Stockai - get stock information from Yahoo! Finance """ -__version__ = '1.4.0' +__version__ = '1.5.0' __author__ = 'Dale Nguyen' __name__ = 'stockai' diff --git a/stockai/yahoo/base.py b/stockai/yahoo/base.py index 7cca46b..dc5f225 100644 --- a/stockai/yahoo/base.py +++ b/stockai/yahoo/base.py @@ -57,11 +57,11 @@ def __process_historical_result(self, data): raise NameError(data.json()['chart']['error']['description']) result = data.json()['chart']['result'][0]['indicators']['quote'][0] - result['date'] = data.json()['chart']['result'][0]['timestamp'] + date_data = { 'date' : data.json()['chart']['result'][0]['timestamp']} result['adjclose'] = data.json()['chart']['result'][0]['indicators']['adjclose'][0]['adjclose'] - result['meta'] = data.json()['chart']['result'][0]['meta'] + # result['meta'] = data.json()['chart']['result'][0]['meta'] # for index, date in enumerate(result['date']): # result['date'][index] = timestamp_to_date(result['date'][index]) - return result + return {**date_data, **result} diff --git a/stockai/yahoo/stock.py b/stockai/yahoo/stock.py index cd13e97..26ddf42 100644 --- a/stockai/yahoo/stock.py +++ b/stockai/yahoo/stock.py @@ -6,25 +6,37 @@ def __init__(self, symbol): super(Stock, self).__init__(symbol) self.refresh() - # Summary Profile def get_summary_profile(self): - return self.data_set['summaryProfile'] + """Return summary profile of a security""" + return self.data_set['summaryProfile'] + + def get_info(self): + """Return detail information of a security""" + return self.data_set['price'] - # Financial Data def get_price(self): - return self.data_set['financialData']['currentPrice']['fmt'] + """Return price of a security""" + return self.data_set['financialData']['currentPrice']['fmt'] def get_currency(self): - return self.data_set['financialData']['financialCurrency'] + """Return currency of a security""" + return self.data_set['financialData']['financialCurrency'] - # Historical Prices def get_historical_prices(self, start_date, end_date): - if (date_to_timestamp(start_date) > date_to_timestamp(end_date)): - raise ValueError('Please check the order of start date and end date') - - return self.get_historical(date_to_timestamp(start_date), date_to_timestamp(end_date)) + """Return historical prices of a security + Parameters + ---------- + start_date: str + Start date + end_date: str + End date + """ + if (date_to_timestamp(start_date) > date_to_timestamp(end_date)): + raise ValueError('Please check the order of start date and end date') + + return self.get_historical(date_to_timestamp(start_date), date_to_timestamp(end_date)) - # Get all prices def get_all_prices(self): + """Return all historical prices starting from 01/01/2000""" return self.get_all_historical() diff --git a/tests/yahoo.py b/tests/yahoo.py index 84d187b..9a52d5d 100644 --- a/tests/yahoo.py +++ b/tests/yahoo.py @@ -17,6 +17,9 @@ def test_td_summary_profile(self): # print(self.td.get_summary_profile()) self.assertEqual(self.td.get_summary_profile()['city'], 'Toronto') + def test_td_get_info(self): + self.assertEqual(self.td.get_info()['symbol'], 'TD.TO') + def test_td_financial_data(self): float(self.td.get_price()) self.assertEqual(self.td.get_currency(), 'CAD') From 47aecb916687349417b55542acf2e48391dc7a47 Mon Sep 17 00:00:00 2001 From: Dale Nguyen Date: Sun, 28 Feb 2021 11:57:47 -0500 Subject: [PATCH 2/3] Added interval (1d, 1w, 1m) when getting historical prices --- changelog.md | 1 + stockai/yahoo/base.py | 30 +++++++++++++++++++++++++----- stockai/yahoo/stock.py | 17 ++++++++++++----- tests/yahoo.py | 4 ++-- 4 files changed, 40 insertions(+), 12 deletions(-) diff --git a/changelog.md b/changelog.md index 03af699..e53025b 100644 --- a/changelog.md +++ b/changelog.md @@ -16,6 +16,7 @@ #### - :rocket: [New Feature] - Added get_info method +- Added interval (1d, 1w, 1m) when getting historical prices #### - :nail_care: [Polish] diff --git a/stockai/yahoo/base.py b/stockai/yahoo/base.py index dc5f225..b22205b 100644 --- a/stockai/yahoo/base.py +++ b/stockai/yahoo/base.py @@ -28,20 +28,23 @@ def _request(self): return data.json()['quoteSummary']['result'][0] - def get_historical(self, start_date, end_date): - url = 'https://query1.finance.yahoo.com/v8/finance/chart/{symbol}?formatted=true&period1={start_date}&period2={end_date}&interval=1d&events=div%7Csplit&corsDomain=finance.yahoo.com'.format( + def get_historical(self, start_date, end_date, interval): + interval = self.__get_interval(interval) + + url = 'https://query1.finance.yahoo.com/v8/finance/chart/{symbol}?formatted=true&period1={start_date}&period2={end_date}&interval={interval}&events=div%7Csplit&corsDomain=finance.yahoo.com'.format( symbol = self.symbol, start_date = start_date, - end_date = end_date + end_date = end_date, + interval = interval ) data = get(url) return self.__process_historical_result(data) - def get_all_historical(self): + def get_all_historical(self, interval): # Start from 01/01/2000 start_date = 946702800 end_date = get_current_timestamp() - return self.get_historical(start_date, end_date) + return self.get_historical(start_date, end_date, interval) def refresh(self): """ @@ -65,3 +68,20 @@ def __process_historical_result(self, data): # result['date'][index] = timestamp_to_date(result['date'][index]) return {**date_data, **result} + + def __get_interval(self, interval): + """Return interval value before sending requests + + Parameters + ---------- + interval: 'daily' | 'weekly' | 'monthly' + + Returns + ---------- + '1d' | '1wk' | '1mo' + """ + return { + 'daily': '1d', + 'weekly': '1wk', + 'monthly': '1mo' + }[interval] diff --git a/stockai/yahoo/stock.py b/stockai/yahoo/stock.py index 26ddf42..fe39271 100644 --- a/stockai/yahoo/stock.py +++ b/stockai/yahoo/stock.py @@ -22,21 +22,28 @@ def get_currency(self): """Return currency of a security""" return self.data_set['financialData']['financialCurrency'] - def get_historical_prices(self, start_date, end_date): + def get_historical_prices(self, start_date, end_date, interval = 'daily'): """Return historical prices of a security + Parameters ---------- start_date: str Start date end_date: str End date + interval: 'daily' | 'weekly' | 'monthly' - Default is 'daily' """ if (date_to_timestamp(start_date) > date_to_timestamp(end_date)): raise ValueError('Please check the order of start date and end date') - return self.get_historical(date_to_timestamp(start_date), date_to_timestamp(end_date)) + return self.get_historical(date_to_timestamp(start_date), date_to_timestamp(end_date), interval) + + def get_all_prices(self, interval = 'monthly'): + """Return all historical prices starting from 01/01/2000 - def get_all_prices(self): - """Return all historical prices starting from 01/01/2000""" - return self.get_all_historical() + Parameters + ---------- + interval: str 'daily' | 'weekly' | 'monthly' - Default is 'monthly' + """ + return self.get_all_historical(interval) diff --git a/tests/yahoo.py b/tests/yahoo.py index 9a52d5d..74d5b08 100644 --- a/tests/yahoo.py +++ b/tests/yahoo.py @@ -25,11 +25,11 @@ def test_td_financial_data(self): self.assertEqual(self.td.get_currency(), 'CAD') def test_td_historical_prices(self): - # print(self.td.get_historical_prices('2019-01-01', '2019-01-05')) + # print(self.td.get_historical_prices('2019-01-01', '2019-01-30', interval='weekly')) dict(self.td.get_historical_prices('2019-01-01', '2019-01-05')) def test_td_all_historical_prices(self): - all = self.td.get_all_prices() + all = self.td.get_all_prices(interval='monthly') # print(all) dict(all) From c84fdb39dca72e78f912d0c6168fc149a7c71d65 Mon Sep 17 00:00:00 2001 From: Dale Nguyen Date: Sun, 28 Feb 2021 13:48:07 -0500 Subject: [PATCH 3/3] Added WealthSimple refresh_tokens method --- changelog.md | 1 + examples/wealthsimple.py | 49 +++++----- stockai/wealthsimple/requests.py | 17 ++-- stockai/wealthsimple/stock.py | 150 ++++++++++++++++++------------- stockai/yahoo/base.py | 8 +- 5 files changed, 131 insertions(+), 94 deletions(-) diff --git a/changelog.md b/changelog.md index e53025b..8dc214a 100644 --- a/changelog.md +++ b/changelog.md @@ -17,6 +17,7 @@ - Added get_info method - Added interval (1d, 1w, 1m) when getting historical prices +- Added WealthSimple refresh_tokens method #### - :nail_care: [Polish] diff --git a/examples/wealthsimple.py b/examples/wealthsimple.py index a62bf68..c4fa7a8 100644 --- a/examples/wealthsimple.py +++ b/examples/wealthsimple.py @@ -1,36 +1,45 @@ +from loguru import logger +from stockai import WealthSimple import sys sys.path.append('../stockai') -from stockai import WealthSimple - email: str = '' password: str = '' + def prepare_credentails(): - global email, password - print('Prepare credentials') - while not email: - email = str(input("Enter email: \n>>> ")) - while not password: - password = str(input("Enter password: \n>>> ")) + global email, password + + logger.debug('Prepare credentials') + + while not email: + email = str(input("Enter email: \n>>> ")) + while not password: + password = str(input("Enter password: \n>>> ")) + def init(): - global username, password + global username, password + + ws = WealthSimple(email, password) + + logger.debug('Get me...') + me = ws.get_me() + print(me) - ws = WealthSimple(email, password) + logger.debug('Get accounts') + accounts = ws.get_accounts() + print(accounts) - print('Get me...\n') - me = ws.get_me() - print(me) + logger.debug('Refresh tokens\n) + ws.refresh_tokens() - print('Get accounts\n') - accounts = ws.get_accounts() - print(accounts) + print(ws.session.headers['Authorization']) + + logger.debug('Get security') + TSLA = ws.get_security('TSLA') + print(TSLA) - print('Get security \n') - TSLA = ws.get_security('TSLA') - print(TSLA) prepare_credentails() init() - diff --git a/stockai/wealthsimple/requests.py b/stockai/wealthsimple/requests.py index 64e5c33..132e5ee 100644 --- a/stockai/wealthsimple/requests.py +++ b/stockai/wealthsimple/requests.py @@ -1,30 +1,31 @@ import requests + class WSAPIRequest: """ Handle API requests for WealthSimple """ - def __init__(self, session, WS_URL): + def __init__(self, session, ws_url): self.session = session - self.WS_URL = WS_URL + self.WS_URL = ws_url def request(self, method, endpoint, params=None): url = self.WS_URL + endpoint if method == 'POST': - return self.post(url, params) + return self.__post(url, params) elif method == 'GET': - return self.get(url, params) + return self.__get(url, params) else: raise Exception('Invalid request method: {method}') - def post(self, url, params=None): + def __post(self, url, params=None): try: return self.session.post(url, params) except Exception as error: print(error) - def get(self, url, payload=None): - auth = self.session.headers['Authorization'] - return requests.get(url, headers = { 'Authorization': auth }) + def __get(self, url, payload=None): + auth = self.session.headers['Authorization'] + return requests.get(url, headers={'Authorization': auth}) diff --git a/stockai/wealthsimple/stock.py b/stockai/wealthsimple/stock.py index 2d2654c..01f15e3 100644 --- a/stockai/wealthsimple/stock.py +++ b/stockai/wealthsimple/stock.py @@ -2,67 +2,93 @@ from .requests import WSAPIRequest from loguru import logger + class WealthSimple(): - BASE_DOMAIN = 'https://trade-service.wealthsimple.com/' - - def __init__(self, email: str, password: str): - self.session = Session() - self.WSAPI = WSAPIRequest(self.session, self.BASE_DOMAIN) - self.login(email, password) - - def login (self, email: str = None, password: str = None) -> None: - if email and password: - payload = { "email": email, "password": password, "timeoutMs": 2e4 } - response = self.WSAPI.request('POST', 'auth/login', payload) - - # Check of OTP - if "x-wealthsimple-otp" in response.headers: - TFACode = '' - while not TFACode: - # Obtain user input and ensure it is not empty - TFACode = input('Enter 2FA code: ') - payload['otp'] = TFACode - response = self.WSAPI.request('POST', 'auth/login', payload) - - if response.status_code == 401: - raise Exception('Invalid Login') - - self.session.headers.update( - {"Authorization": response.headers["X-Access-Token"]} - ) - self.session.headers.update( - {"refresh_token": response.headers["X-Refresh-Token"]} - ) - else: - raise Exception('Missing login credentials') - - def get_me(self): - logger.debug('get_me') - response = self.WSAPI.request('GET', 'me') - logger.debug(f'get_me {response.status_code}') - - if response.status_code == 401: - raise Exception('Invalid Access Token') - else: - return response.json() - - def get_accounts(self) -> list: - """ - Get Wealthsimple Trade Accounts - """ - response = self.WSAPI.request('GET', 'account/list') - response = response.json() - return response['results'] - - def get_security(self, id: str) -> dict: - """ - Get Security Info - """ - logger.debug('get_security') - response = self.WSAPI.request('GET', f'securities/{id}') - logger.debug(f"get_security {response.status_code}") - - if response.status_code == 401: + """WealthSimple class for API interaction""" + + BASE_DOMAIN = 'https://trade-service.wealthsimple.com/' + + def __init__(self, email: str, password: str): + self.session = Session() + self.WSAPI = WSAPIRequest(self.session, self.BASE_DOMAIN) + self.login(email, password) + + def login(self, email: str = None, password: str = None) -> None: + if email and password: + payload = {"email": email, "password": password, "timeoutMs": 2e4} + response = self.WSAPI.request('POST', 'auth/login', payload) + + # Check of OTP + if "x-wealthsimple-otp" in response.headers: + TFACode = '' + while not TFACode: + # Obtain user input and ensure it is not empty + TFACode = input('Enter 2FA code: ') + payload['otp'] = TFACode + response = self.WSAPI.request('POST', 'auth/login', payload) + + if response.status_code == 401: + raise Exception('Invalid Login') + + self.__update_tokens(response) + else: + raise Exception('Missing login credentials') + + def get_me(self): + """Return owner information""" + + logger.debug('get_me') + response = self.WSAPI.request('GET', 'me') + logger.debug(f'get_me {response.status_code}') + + if response.status_code == 401: + raise Exception('Invalid Access Token') + else: + return response.json() + + def get_accounts(self) -> list: + """ + Get Wealthsimple Trade Accounts + """ + + response = self.WSAPI.request('GET', 'account/list') + response = response.json() + return response['results'] + + def get_security(self, id: str) -> dict: + """ + Get Security Info + """ + + logger.debug('get_security') + response = self.WSAPI.request('GET', f'securities/{id}') + logger.debug(f"get_security {response.status_code}") + logger.debug(f"get_security {response}") + + if response.status_code == 401: raise Exception(f'Cannot get security {id}') - else: - return response.json() + else: + return response.json() + + def refresh_tokens(self): + """ + Genereates new tokens + """ + + logger.debug('Refresh tokens') + response = self.WSAPI.request('POST', 'auth/refresh') + + if response.status_code == 401: + logger.error('Current refresh token is expired') + raise Exception('Refresh token is expired') + else: + self.__update_tokens(response) + + def __update_tokens(self, response): + """Update session tokens""" + self.session.headers.update( + {"Authorization": response.headers["X-Access-Token"]} + ) + self.session.headers.update( + {"refresh_token": response.headers["X-Refresh-Token"]} + ) diff --git a/stockai/yahoo/base.py b/stockai/yahoo/base.py index b22205b..a5b5588 100644 --- a/stockai/yahoo/base.py +++ b/stockai/yahoo/base.py @@ -5,7 +5,7 @@ class Base(object): def __init__(self, symbol): self.symbol = symbol - def _prepare_request(self, region='US', lang='en-US', includePrePost='false', interval='2m', range='1d'): + def __prepare_request(self, region='US', lang='en-US', includePrePost='false', interval='2m', range='1d'): """ Basic Yahoo Rquest URL """ @@ -19,8 +19,8 @@ def _prepare_request(self, region='US', lang='en-US', includePrePost='false', in ) return url - def _request(self): - url = self._prepare_request() + def __request(self): + url = self.__prepare_request() data = get(url) if data.json()['quoteSummary']['error'] is not None: @@ -50,7 +50,7 @@ def refresh(self): """ Refresh stock data """ - self.data_set = self._request() + self.data_set = self.__request() def __process_historical_result(self, data): """