Skip to content
This repository has been archived by the owner on Aug 27, 2024. It is now read-only.

Commit

Permalink
add interactive plots
Browse files Browse the repository at this point in the history
  • Loading branch information
CelestialCrafter committed Dec 11, 2023
1 parent 33e1790 commit 93da258
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions views/plot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import io
import io, price
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.dates as md
import numpy as np
import plots.colors as colors
from mpld3 import fig_to_html
from flask import request
from importlib import import_module
from price import get_cached_prices, get_prices, is_cached_interval, is_supported_interval, get_default_interval
from utils import get_algorithms

mpl.use('Agg')
Expand All @@ -16,13 +16,14 @@

def plot(algorithm):
interval = int(request.args.get('interval'))
interactive = bool(request.args.get('interactive'))

if interval and is_cached_interval(interval):
prices, timestamps, _ = get_cached_prices(interval=interval)
elif interval and is_supported_interval(interval):
prices, timestamps, _ = get_prices(interval=interval)
if interval and price.is_cached_interval(interval):
prices, timestamps, _ = price.get_cached_prices(interval=interval)
elif interval and price.is_supported_interval(interval):
prices, timestamps, _ = price.get_prices(interval=interval)
elif not interval:
prices, timestamps, _ = get_cached_prices()
prices, timestamps, _ = price.get_cached_prices()
else:
return 'Unsupported Interval', 400

Expand All @@ -31,15 +32,18 @@ def plot(algorithm):

# Even out timestamps so plotting algos works
timestamps = timestamps.astype('datetime64[s]')
interval_timedelta = np.timedelta64(get_default_interval(), 'm')
interval_timedelta = np.timedelta64(price.get_default_interval(), 'm')
timestamps = np.arange(timestamps[-1] - interval_timedelta * timestamps.shape[0], timestamps[-1], interval_timedelta)

figure = plt.figure()

try:
import_module(f'algorithms.{algorithm}').plot(prices, timestamps)
except Exception as error:
return str(error), 400

axes = plt.gcf().get_axes()
axes = figure.get_axes()

for axis in axes:
axis.tick_params(color=colors.outline(), labelcolor=colors.outline())
for spine in axis.spines.values():
Expand All @@ -59,11 +63,15 @@ def plot(algorithm):

plt.tight_layout()

# Save plot into buffer instead of the FS
svg_buffer = io.StringIO()
plt.savefig(svg_buffer, format='svg', transparent=True)
svg_plot = svg_buffer.getvalue()
plt.close() # Solved plots overwriting each other
svg_buffer.close()
if interactive:
# @TODO change d3 and mpld3 urls to local ones
plot_data = fig_to_html(figure)
else:
# Save plot into buffer instead of the FS
svg_buffer = io.StringIO()
plt.savefig(svg_buffer, format='svg', transparent=True)
plot_data = svg_buffer.getvalue()
plt.close() # Solved plots overwriting each other
svg_buffer.close()

return svg_plot
return plot_data

0 comments on commit 93da258

Please sign in to comment.