Skip to content
This repository has been archived by the owner on Aug 27, 2024. It is now read-only.

Commit

Permalink
functions clarification
Browse files Browse the repository at this point in the history
  • Loading branch information
fou3fou3 committed Jan 27, 2024
1 parent 60e1777 commit 8c4caf5
Show file tree
Hide file tree
Showing 12 changed files with 29 additions and 34 deletions.
6 changes: 3 additions & 3 deletions src/algorithms/bollinger_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ class Algorithm:
def __init__(self, window_size=20, standard_deviations=2):
self.window_size, self.standard_deviations = window_size, standard_deviations

def algorithm(self, prices):
def algorithm(self, prices: list[float]) -> tuple[float]:
return BBANDS(prices,
timeperiod=self.window_size,
nbdevup=self.standard_deviations,
nbdevdn=self.standard_deviations)

def signal(self, prices, data):
def signal(self, prices: list[float], data: tuple[float]):
upper_bands, _, lower_bands = data
if prices[-1] > upper_bands[-1]:
return 'sell', 1
Expand All @@ -22,7 +22,7 @@ def signal(self, prices, data):

return 'no_action', 0

def plot(self, prices, timestamps, **kwargs):
def plot(self, prices: list[float], timestamps: list[float], **kwargs):
upper_bands, middle_bands, lower_bands = self.algorithm(prices, **kwargs)

plt.fill_between(timestamps, upper_bands, lower_bands, color='grey', alpha=0.3)
Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/custom_bollinger_rsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ def __init__(self, rsi_window_size=13, bollinger_bands_window_size=20, rsi_high=
self.bollinger_bands_window_size = bollinger_bands_window_size
self.window_size = max(rsi_window_size, bollinger_bands_window_size)

def algorithm(self, prices):
def algorithm(self, prices: list[float]):
Bollinger_Bands = BollingerBands(window_size=self.bollinger_bands_window_size)
bb_data = Bollinger_Bands.algorithm(prices)
rsi = RSI(window_size=self.rsi_window_size, high=self.rsi_high, low=self.rsi_low)
rsi_line = rsi.algorithm(prices)

return [*bb_data, rsi_line]

def signal(self, prices, data):
def signal(self, prices: list[float], data: list[list[float]]):
price = prices[-1]

upper_bands, _, lower_bands, rsi_line = data #middle bands not needed && corrected bollinger bands from upper, lowe, middle to current
Expand All @@ -33,7 +33,7 @@ def signal(self, prices, data):

return 'no_action', 0

def plot(self, prices, timestamps, **kwargs):
def plot(self, prices: list[float], timestamps: list[float], **kwargs):
gs = GridSpec(3, 1, figure=plt.gcf())

plt.subplot(gs[0, :])
Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/macd.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ class Algorithm:
def __init__(self, fastperiod=12, slowperiod=26, signalperiod=9):
self.fastperiod, self.slowperiod, self.signalperiod = fastperiod, slowperiod, signalperiod

def algorithm(self, prices):
def algorithm(self, prices: list[float]) -> tuple[float]:
return MACD(prices, fastperiod=self.fastperiod, slowperiod=self.slowperiod, signalperiod=self.signalperiod)

def signal(self, _, data):
def signal(self, _, data: tuple[float]):
macds, signals, histogram = data
positive_histogram = np.abs(histogram)
histogram_max = np.max(np.nan_to_num(positive_histogram))
Expand All @@ -22,7 +22,7 @@ def signal(self, _, data):
return 'sell', 1
return 'no_action', 0

def plot(self, prices, timestamps, **kwargs):
def plot(self, prices: list[float], timestamps: list[float], **kwargs):
macd, signal, histogram = self.algorithm(prices, **kwargs)

buy_condition = np.insert((macd[1:] > signal[1:]) & (macd[:-1] < signal[:-1]), 0, False)
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/price.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ class Algorithm:
def __init__(self) -> None:
pass

def plot(self, prices, timestamps):
def plot(self, prices: list[float], timestamps: list[float]):
plt.plot(timestamps, prices, color=colors.primary())
6 changes: 3 additions & 3 deletions src/algorithms/rsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def __init__(self, window_size=14, high=70, low=30):
self.high = high
self.low = low

def algorithm(self, prices):
def algorithm(self, prices: list[float]) -> list[float]:
return RSI(prices, timeperiod=self.window_size)

def signal(self, _, data):
def signal(self, _, data: list[float]):
rsi = data

if rsi[-1] > self.high:
Expand All @@ -26,7 +26,7 @@ def signal(self, _, data):

return 'no_action', 0

def plot(self, prices, timestamps, custom_algorithm_plot=False, **kwargs):
def plot(self, prices: list[float], timestamps: list[float], custom_algorithm_plot=False, **kwargs):
rsi_line = self.algorithm(prices, **kwargs)
# Thresholds
upper = np.full(rsi_line.shape, self.high)
Expand Down
4 changes: 2 additions & 2 deletions src/backtest_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from plots import colors
import io

def backtest(algorithm, prices, timestamps, balance=200, strength_to_usd=200, plot=False):
def backtest(algorithm:str, prices:list[float], timestamps:list[float], balance=200, strength_to_usd=200, plot=False) -> dict:
transactions = []
start_balance = balance
shares = 0
Expand Down Expand Up @@ -52,7 +52,7 @@ def backtest(algorithm, prices, timestamps, balance=200, strength_to_usd=200, pl
'profit_percentage %': ((balance + shares * price) - start_balance) / start_balance
}

def plot(back_test_data):
def plot(back_test_data:dict):
gs = GridSpec(3, 1, figure=plt.gcf())

timestamps = [transaction['timestamp'] for transaction in back_test_data['transactions']]
Expand Down
2 changes: 1 addition & 1 deletion src/plots/worth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from datetime import datetime as dt

def plot(values, timestamps):
def plot(values: list[float], timestamps: list[float]):
values = np.array(values)
timestamps = [dt.fromtimestamp(timestamp) for timestamp in timestamps]
plt.plot(timestamps, values, color=colors.sell())
Expand Down
10 changes: 3 additions & 7 deletions src/price.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
import os
import numpy as np
from requests import get
from redis import from_url
from dotenv import load_dotenv
from flask import has_request_context, request

load_dotenv()
redis = from_url(os.environ['REDIS_URI'])

# 1 5 15 30 60 240 1440 10080 21600
# 12h 2d12h 1w12h 2w1d 1mo 4mo 2y 14y 30y

point_count = 720
default_interval = 240
price_api_interval = 5
supported_intervals = [5, 15, 30, 60, 240, 1440, 10080]

def is_supported_interval(interval):
def is_supported_interval(interval: int) -> bool:
global supported_intervals
return interval in supported_intervals

def set_default_interval(interval):
def set_default_interval(interval: int) -> int:
global default_interval
if not is_supported_interval(interval):
raise Exception('Unsupported Interval')
Expand All @@ -29,7 +25,7 @@ def set_default_interval(interval):
return default_interval

# Get the default interval
def get_default_interval():
def get_default_interval() -> int:
if has_request_context():
interval = request.args.get('interval')
if interval and is_supported_interval(int(interval)):
Expand Down
15 changes: 7 additions & 8 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jwt, os, io, numpy as np
import jwt, os, io
import matplotlib.pyplot as plt
from pymongo.server_api import ServerApi
from datetime import datetime
from pymongo import MongoClient
from dotenv import load_dotenv
from importlib import import_module
Expand All @@ -11,36 +10,36 @@
client = MongoClient(os.environ['DB_URI'], server_api=ServerApi('1'))
algorithms = client['database']['algorithms']

def get_algorithms():
def get_algorithms() -> list[str]:
return [algorithm['name'] for algorithm in algorithms.find({'owner': {'$not': {'$type': 'object'}}})]

def authorize(encoded):
def authorize(encoded: str) -> str:
if encoded.startswith('Bearer'):
encoded = encoded[7:]

return jwt.decode(encoded, os.environ['JWT_SECRET'], algorithms=['HS256'])

def authorize_server(encoded):
def authorize_server(encoded: str) -> str:
decoded = authorize(encoded)

if not decoded['server']:
raise Exception('Client Token')

return decoded

def algorithm_output(algorithm_name, prices, backtest=False):
def algorithm_output(algorithm_name: str, prices: list[float], backtest=False) -> tuple[str, tuple[str, float]]:
module = import_module(f'algorithms.{algorithm_name}').Algorithm()
signal, strength = module.signal(prices, module.algorithm(prices))
if backtest:
return signal, strength

return algorithm_name, (signal, strength)

def svg_plot():
def svg_plot() -> str:
svg_buffer = io.StringIO()
plt.savefig(svg_buffer, format='svg', transparent=True)
plot_data = svg_buffer.getvalue()
plt.close()
svg_buffer.close()

return plot_data
return plot_data
2 changes: 1 addition & 1 deletion src/views/backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

mpl.use('Agg')

def backtest_view(algorithm_name):
def backtest_view(algorithm_name: str):
interval = int(request.args.get('interval') or get_default_interval())
plot_bool = bool(request.args.get('plot') or False)

Expand Down
2 changes: 1 addition & 1 deletion src/views/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
figure_size = mpl.rcParams['figure.figsize']
figure_size[0] = figure_size[0] * 1.5

def plot(algorithm_name):
def plot(algorithm_name: str):
interval = int(request.args.get('interval') or get_default_interval())
interactive = bool(request.args.get('interactive') or False)

Expand Down
2 changes: 1 addition & 1 deletion src/views/worth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
mpl.use('Agg')
bots = utils.client['database']['bots']

def worth(bot_id):
def worth(bot_id:str):
try:
bot = bots.find_one({'_id': ObjectId(bot_id)})
decoded = utils.authorize(request.headers.get('Authorization'))
Expand Down

0 comments on commit 8c4caf5

Please sign in to comment.