diff --git a/bar_chart_race/_bar_chart_race.py b/bar_chart_race/_bar_chart_race.py
index 8b31d34..91453a1 100644
--- a/bar_chart_race/_bar_chart_race.py
+++ b/bar_chart_race/_bar_chart_race.py
@@ -10,13 +10,14 @@
from ._common_chart import CommonChart
from ._utils import prepare_wide_data
+
class _BarChartRace(CommonChart):
-
- def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_max,
- steps_per_period, period_length, end_period_pause, interpolate_period,
- period_label, period_template, period_summary_func, perpendicular_bar_func,
- colors, title, bar_size, bar_textposition, bar_texttemplate, bar_label_font,
- tick_label_font, tick_template, shared_fontdict, scale, fig, writer,
+
+ def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_max, fixed_min,
+ steps_per_period, period_length, end_period_pause, interpolate_period,
+ period_label, period_template, period_summary_func, perpendicular_bar_func,
+ colors, title, bar_size, bar_textposition, bar_texttemplate, bar_label_font,
+ tick_label_font, tick_template, shared_fontdict, scale, fig, writer,
bar_kwargs, fig_kwargs, filter_column_colors):
self.filename = filename
self.extension = self.get_extension()
@@ -25,6 +26,7 @@ def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_m
self.n_bars = n_bars or df.shape[1]
self.fixed_order = fixed_order
self.fixed_max = fixed_max
+ self.fixed_min = fixed_min
self.steps_per_period = steps_per_period
self.period_length = period_length
self.end_period_pause = end_period_pause
@@ -64,7 +66,7 @@ def validate_params(self):
raise ValueError('`filename` must have an extension')
elif self.filename is not None:
raise TypeError('`filename` must be None or a string')
-
+
if self.sort not in ('asc', 'desc'):
raise ValueError('`sort` must be "asc" or "desc"')
@@ -143,15 +145,17 @@ def get_font(self, font, ticks=False):
return font
def prepare_data(self, df):
+
if self.fixed_order is True:
last_values = df.iloc[-1].sort_values(ascending=False)
cols = last_values.iloc[:self.n_bars].index
df = df[cols]
elif isinstance(self.fixed_order, list):
cols = self.fixed_order
+ print(cols)
df = df[cols]
self.n_bars = min(len(cols), self.n_bars)
-
+
compute_ranks = self.fixed_order is False
dfs = prepare_wide_data(df, self.orientation, self.sort, self.n_bars,
self.interpolate_period, self.steps_per_period, compute_ranks)
@@ -165,14 +169,14 @@ def prepare_data(self, df):
m = df_values.shape[0]
rank_row = np.arange(1, n)
if (self.sort == 'desc' and self.orientation == 'h') or \
- (self.sort == 'asc' and self.orientation == 'v'):
+ (self.sort == 'asc' and self.orientation == 'v'):
rank_row = rank_row[::-1]
-
+
ranks_arr = np.repeat(rank_row.reshape(1, -1), m, axis=0)
df_ranks = pd.DataFrame(data=ranks_arr, columns=cols)
return df_values, df_ranks
-
+
def get_col_filt(self):
col_filt = pd.Series([True] * self.df_values.shape[1])
if self.n_bars < self.df_ranks.shape[1]:
@@ -188,7 +192,7 @@ def get_col_filt(self):
self.df_values = self.df_values.loc[:, col_filt]
self.df_ranks = self.df_ranks.loc[:, col_filt]
return col_filt
-
+
def get_bar_colors(self, colors):
if colors is None:
colors = 'dark12'
@@ -197,7 +201,7 @@ def get_bar_colors(self, colors):
if isinstance(colors, str):
from ._colormaps import colormaps
-
+
try:
bar_colors = colormaps[colors.lower()]
except KeyError:
@@ -230,7 +234,7 @@ def get_bar_colors(self, colors):
exp_ct = np.bincount(np.arange(num_cols) % n, minlength=n)
if (col_idx_ct > exp_ct).any():
warnings.warn("Some of your columns never make an appearance in the animation. "
- "To reduce color repetition, set `filter_column_colors` to `True`")
+ "To reduce color repetition, set `filter_column_colors` to `True`")
return bar_colors
def get_max_plotted_value(self):
@@ -271,16 +275,16 @@ def get_subplots_adjust(self):
plot_func = ax.barh if self.orientation == 'h' else ax.bar
bar_location, bar_length, cols, _ = self.get_bar_info(-1)
plot_func(bar_location, bar_length, tick_label=cols)
-
+
self.prepare_axes(ax)
texts = self.add_bar_labels(ax, bar_location, bar_length)
fig.canvas.print_figure(io.BytesIO(), format='png')
- xmin = min(label.get_window_extent().x0 for label in ax.get_yticklabels())
+ xmin = min(label.get_window_extent().x0 for label in ax.get_yticklabels())
xmin /= (fig.dpi * fig.get_figwidth())
left = ax.get_position().x0 - xmin + .01
- ymin = min(label.get_window_extent().y0 for label in ax.get_xticklabels())
+ ymin = min(label.get_window_extent().y0 for label in ax.get_xticklabels())
ymin /= (fig.dpi * fig.get_figheight())
bottom = ax.get_position().y0 - ymin + .01
@@ -298,7 +302,7 @@ def get_subplots_adjust(self):
else:
max_bar_pixels = ax.transData.transform((0, max_bar))[1]
max_text = max(text.get_window_extent().y1 for text in texts)
-
+
self.extra_pixels = max_text - max_bar_pixels + 10
if self.fixed_max:
@@ -348,7 +352,7 @@ def set_major_formatter(self, ax):
def plot_bars(self, ax, i):
bar_location, bar_length, cols, colors = self.get_bar_info(i)
if self.orientation == 'h':
- ax.barh(bar_location, bar_length, tick_label=cols,
+ ax.barh(bar_location, bar_length, tick_label=cols,
color=colors, **self.bar_kwargs)
ax.set_yticklabels(ax.get_yticklabels(), **self.tick_label_font)
if not self.fixed_max and self.bar_textposition == 'outside':
@@ -357,7 +361,7 @@ def plot_bars(self, ax, i):
new_xmax = ax.transData.inverted().transform((new_max_pixels, 0))[0]
ax.set_xlim(ax.get_xlim()[0], new_xmax)
else:
- ax.bar(bar_location, bar_length, tick_label=cols,
+ ax.bar(bar_location, bar_length, tick_label=cols,
color=colors, **self.bar_kwargs)
ax.set_xticklabels(ax.get_xticklabels(), **self.tick_label_font)
if not self.fixed_max and self.bar_textposition == 'outside':
@@ -397,7 +401,7 @@ def add_period_summary(self, ax, i):
if 'x' not in text_dict or 'y' not in text_dict or 's' not in text_dict:
name = self.period_summary_func.__name__
raise ValueError(f'The dictionary returned from `{name}` must contain '
- '"x", "y", and "s"')
+ '"x", "y", and "s"')
ax.text(transform=ax.transAxes, **text_dict)
def add_bar_labels(self, ax, bar_location, bar_length):
@@ -450,7 +454,7 @@ def add_perpendicular_bar(self, ax, bar_length, i):
line.set_xdata([val] * 2)
else:
line.set_ydata([val] * 2)
-
+
def anim_func(self, i):
if i is None:
return
@@ -461,7 +465,7 @@ def anim_func(self, i):
for text in ax.texts[start:]:
text.remove()
self.plot_bars(ax, i)
-
+
def make_animation(self):
def init_func():
ax = self.fig.axes[0]
@@ -478,7 +482,7 @@ def frame_generator(n):
for _ in range(pause):
frames.append(None)
return frames
-
+
frames = frame_generator(len(self.df_values))
anim = FuncAnimation(self.fig, self.anim_func, frames, init_func, interval=interval)
@@ -498,8 +502,8 @@ def frame_generator(n):
fc = self.fig.get_facecolor()
if fc == (1, 1, 1, 0):
fc = 'white'
- ret_val = anim.save(self.filename, fps=self.fps, writer=self.writer,
- savefig_kwargs=savefig_kwargs)
+ ret_val = anim.save(self.filename, fps=self.fps, writer=self.writer,
+ savefig_kwargs=savefig_kwargs)
except Exception as e:
message = str(e)
raise Exception(message)
@@ -509,15 +513,15 @@ def frame_generator(n):
return ret_val
-def bar_chart_race(df, filename=None, orientation='h', sort='desc', n_bars=None,
- fixed_order=False, fixed_max=False, steps_per_period=10,
- period_length=500, end_period_pause=0, interpolate_period=False,
+def bar_chart_race(df, filename=None, orientation='h', sort='desc', n_bars=None,
+ fixed_order=False, fixed_max=False, steps_per_period=10,
+ period_length=500, end_period_pause=0, interpolate_period=False,
period_label=True, period_template=None, period_summary_func=None,
perpendicular_bar_func=None, colors=None, title=None, bar_size=.95,
bar_textposition='outside', bar_texttemplate='{x:,.0f}',
bar_label_font=None, tick_label_font=None, tick_template='{x:,.0f}',
- shared_fontdict=None, scale='linear', fig=None, writer=None,
- bar_kwargs=None, fig_kwargs=None, filter_column_colors=False):
+ shared_fontdict=None, scale='linear', fig=None, writer=None,
+ bar_kwargs=None, fig_kwargs=None, filter_column_colors=False):
'''
Create an animated bar chart race using matplotlib. Data must be in
'wide' format where each row represents a single time period and each
@@ -868,9 +872,9 @@ def func(val):
These sizes are relative to plt.rcParams['font.size'].
'''
bcr = _BarChartRace(df, filename, orientation, sort, n_bars, fixed_order, fixed_max,
- steps_per_period, period_length, end_period_pause, interpolate_period,
+ steps_per_period, period_length, end_period_pause, interpolate_period,
period_label, period_template, period_summary_func, perpendicular_bar_func,
- colors, title, bar_size, bar_textposition, bar_texttemplate,
- bar_label_font, tick_label_font, tick_template, shared_fontdict, scale,
+ colors, title, bar_size, bar_textposition, bar_texttemplate,
+ bar_label_font, tick_label_font, tick_template, shared_fontdict, scale,
fig, writer, bar_kwargs, fig_kwargs, filter_column_colors)
return bcr.make_animation()
diff --git a/bar_chart_race/_colormaps.py b/bar_chart_race/_colormaps.py
index 59dc289..2d11bdd 100644
--- a/bar_chart_race/_colormaps.py
+++ b/bar_chart_race/_colormaps.py
@@ -39110,5 +39110,10 @@
"#311339",
"#301338",
"#301437"
- ]
+ ],
+ "rgby": ['red',
+ 'green',
+ 'blue',
+ 'yellow'
+ ]
}
\ No newline at end of file
diff --git a/bar_chart_race/_utils.py b/bar_chart_race/_utils.py
index 89de10d..368b621 100644
--- a/bar_chart_race/_utils.py
+++ b/bar_chart_race/_utils.py
@@ -4,7 +4,7 @@
from matplotlib import image as mimage
-def load_dataset(name='covid19'):
+def load_dataset(name='covid19', threshold=0):
'''
Return a pandas DataFrame suitable for immediate use in `bar_chart_race`.
Must be connected to the internet
@@ -18,6 +18,8 @@ def load_dataset(name='covid19'):
* 'covid19_tutorial'
* 'urban_pop'
* 'baseball'
+ threshold : int, default 0
+ Lowest value that will be shown on the bar_chart_race
Returns
-------
@@ -31,9 +33,24 @@ def load_dataset(name='covid19'):
'baseball': None}
index_col = index_dict[name]
parse_dates = [index_col] if index_col else None
- return pd.read_csv(url, index_col=index_col, parse_dates=parse_dates)
+ df = pd.read_csv(url, index_col=index_col, parse_dates=parse_dates)
+ new_df = filter_threshold(df, threshold)
-def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate_period=False,
+ return new_df
+
+
+def load_custom_data(filename, index_col, parse_dates):
+ df = pd.read_csv(filename, index_col=index_col, parse_dates=parse_dates)
+ df.fillna(0, inplace=True)
+ return df
+
+
+
+def filter_threshold(df, thresh):
+ return df.loc[(df.hr > thresh)]
+
+
+def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate_period=False,
steps_per_period=10, compute_ranks=True):
'''
Prepares 'wide' data for bar chart animation.
@@ -109,21 +126,22 @@ def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate
df_values.iloc[:, 0] = df_values.iloc[:, 0].interpolate()
else:
df_values.iloc[:, 0] = df_values.iloc[:, 0].fillna(method='ffill')
-
+
df_values = df_values.set_index(df_values.columns[0])
if compute_ranks:
df_ranks = df_values.rank(axis=1, method='first', ascending=False).clip(upper=n_bars + 1)
if (sort == 'desc' and orientation == 'h') or (sort == 'asc' and orientation == 'v'):
df_ranks = n_bars + 1 - df_ranks
df_ranks = df_ranks.interpolate()
-
+
df_values = df_values.interpolate()
if compute_ranks:
return df_values, df_ranks
return df_values
-def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h',
- sort='desc', n_bars=None, interpolate_period=False,
+
+def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h',
+ sort='desc', n_bars=None, interpolate_period=False,
steps_per_period=10, compute_ranks=True):
'''
Prepares 'long' data for bar chart animation.
@@ -201,7 +219,7 @@ def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h'
df_values, df_ranks = bcr.prepare_long_data(df)
bcr.bar_chart_race(df_values, steps_per_period=1, period_length=50)
'''
- df_wide = df.pivot_table(index=index, columns=columns, values=values,
+ df_wide = df.pivot_table(index=index, columns=columns, values=values,
aggfunc=aggfunc).fillna(method='ffill')
return prepare_wide_data(df_wide, orientation, sort, n_bars, interpolate_period,
steps_per_period, compute_ranks)
@@ -222,4 +240,4 @@ def read_images(filename, columns):
else:
final_url = url_path.format(code=code)
image_dict[col] = mimage.imread(final_url)
- return image_dict
\ No newline at end of file
+ return image_dict
diff --git a/bar_chart_race/demo.ipynb b/bar_chart_race/demo.ipynb
new file mode 100644
index 0000000..f717bce
--- /dev/null
+++ b/bar_chart_race/demo.ipynb
@@ -0,0 +1,6391 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ ""
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 1,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import bar_chart_race as bcr\n",
+ "df_baseball = bcr.load_dataset('baseball',50).pivot(index='year',\n",
+ " columns='name',\n",
+ " values='hr')\n",
+ "#bcr.bar_chart_race(df=df_baseball)\n",
+ "path = r\"C:\\Users\\rcame\\PycharmProjects\\bar_chart_race_\\data\\weather_data.csv\"\n",
+ "df_weather = bcr._utils.load_custom_data(path,'timestamp', None)\n",
+ "bcr.bar_chart_race(df_weather, colors='rgby')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
\ No newline at end of file
diff --git a/data/weather_data.csv b/data/weather_data.csv
new file mode 100644
index 0000000..e8f2a83
--- /dev/null
+++ b/data/weather_data.csv
@@ -0,0 +1,101 @@
+timestamp,Boston,Fairbanks,Yakutsk
+20210318T0000,40.345,11.120939,11.300156
+20210318T0100,39.751,11.246939,10.814154
+20210318T0200,39.463,11.4989395,9.338156
+20210318T0300,39.391003,12.05694,8.744156
+20210318T0400,39.535,2.8769379,9.428156
+20210318T0500,39.697002,2.7329369,10.868155
+20210318T0600,39.769,3.2009392,11.570154
+20210318T0700,39.715,4.33494,9.122154
+20210318T0800,37.141003,5.12694,2.0301552
+20210318T0900,40.183002,6.386938,-2.6678429
+20210318T1000,42.487,7.808939,-2.8838425
+20210318T1100,44.467003,9.266939,-1.1378441
+20210318T1200,45.907,10.70694,0.4281578
+20210318T1300,47.401,11.912939,0.73415756
+20210318T1400,48.355003,12.86694,0.4281578
+20210318T1500,48.049004,13.83894,-0.54384613
+20210318T1600,47.347,14.666939,-1.7498436
+20210318T1700,46.645,15.11694,-3.3338432
+20210318T1800,45.853,15.638939,-7.0958443
+20210318T1900,44.863,15.692938,-13.737843
+20210318T2000,44.287003,15.296938,-17.607841
+20210318T2100,43.495003,12.632938,-13.3958435
+20210318T2200,42.631,12.200939,-16.077843
+20210318T2300,40.921,12.05694,-17.949844
+20210319T0000,34.495003,11.786938,-18.813843
+20210319T0100,31.399002,11.174938,-19.245842
+20210319T0200,27.763,10.202938,-19.31784
+20210319T0300,25.333002,9.284941,-18.813843
+20210319T0400,22.813002,7.3049393,-17.98584
+20210319T0500,22.237001,5.558939,-17.031845
+20210319T0600,23.029001,4.33494,-16.79784
+20210319T0700,22.723001,3.452938,-14.295841
+20210319T0800,25.711002,2.2649364,-7.113842
+20210319T0900,26.755001,3.794939,4.2261543
+20210319T1000,28.015001,7.75494,12.740156
+20210319T1100,29.815,8.94294,18.644154
+20210319T1200,31.723001,10.202938,22.460155
+20210319T1300,33.901,12.146938,24.314156
+20210319T1400,35.593002,13.946938,25.412155
+20210319T1500,36.709,14.792938,26.042154
+20210319T1600,37.087,14.252939,26.114155
+20210319T1700,36.781002,13.100939,25.862156
+20210319T1800,35.827,12.488939,25.268154
+20210319T1900,31.921001,12.05694,23.864155
+20210319T2000,30.769001,11.102938,22.874153
+20210319T2100,29.977001,10.292938,14.774155
+20210319T2200,29.527,4.3709393,17.096155
+20210319T2300,29.329002,0.8609352,19.382154
+20210320T0000,29.311,-2.2530632,21.344154
+20210320T0100,29.329002,-3.8910637,22.100155
+20210320T0200,29.221,-6.267063,22.388153
+20210320T0300,28.879002,-9.021065,22.280155
+20210320T0400,28.483002,-18.993061,22.226154
+20210320T0500,28.069002,-21.49506,22.064156
+20210320T0600,27.673,-23.331062,21.992155
+20210320T0700,27.529001,-24.789062,22.568153
+20210320T0800,31.417002,-25.941063,24.962154
+20210320T0900,36.637,-21.657059,26.294155
+20210320T1000,41.335,-11.883064,27.410154
+20210320T1100,44.989002,-1.8390656,27.752155
+20210320T1200,47.833,4.676939,27.842155
+20210320T1300,50.191,10.058939,27.626156
+20210320T1400,51.847,13.6049385,27.788155
+20210320T1500,52.621002,16.10694,26.870155
+20210320T1600,52.819,17.438938,25.952154
+20210320T1700,52.117,17.888939,24.998154
+20210320T1800,50.227,18.26694,23.810154
+20210320T1900,46.771,17.618938,21.776154
+20210320T2000,45.763,9.62694,17.852154
+20210320T2100,44.809002,0.1409359,15.008154
+20210320T2200,43.135002,-6.7350616,11.084156
+20210320T2300,41.443,-11.559063,7.466154
+20210321T0000,39.787003,-14.439064,4.532154
+20210321T0100,38.419003,-15.825062,3.2721558
+20210321T0200,37.591,-16.887062,3.4341545
+20210321T0300,36.979,-17.805061,4.082155
+20210321T0400,36.457,-20.847061,9.050156
+20210321T0500,35.989002,-21.315063,7.6461563
+20210321T0600,35.701,-21.819061,3.9741554
+20210321T0700,35.557003,-22.323063,8.438154
+20210321T0800,37.357002,-22.575062,12.866156
+20210321T0900,42.595,-19.371063,17.906155
+20210321T1000,47.023,-10.065063,21.416155
+20210321T1100,50.947,-3.09906,23.918156
+20210321T1200,53.611,1.940937,25.268154
+20210321T1300,55.465,6.008938,23.792156
+20210321T1400,56.977,9.014938,17.042154
+20210321T1500,57.967003,11.408939,8.042154
+20210321T1600,58.273003,13.316938,4.244156
+20210321T1700,57.805,14.324938,3.3801537
+20210321T1800,56.347,14.882938,-0.11184311
+20210321T1900,53.143,14.81094,-7.31184
+20210321T2000,50.839,13.946938,-11.793842
+20210321T2100,48.427002,12.974937,-13.773842
+20210321T2200,46.177002,7.232939,-5.439842
+20210321T2300,44.467003,5.252939,-8.247841
+20210322T0000,42.811,2.678936,-17.571842
+20210322T0100,41.263,-0.2910614,-20.937843
+20210322T0200,39.805,-3.2610626,-21.783844
+20210322T0300,38.725002,-6.087063,-28.71384
diff --git a/tests/test_bar_charts.py b/tests/test_bar_charts.py
index 574fae3..3067a5e 100644
--- a/tests/test_bar_charts.py
+++ b/tests/test_bar_charts.py
@@ -106,4 +106,6 @@ def test_fig(self):
def test_bar_kwargs(self):
bar_chart_race(df, n_bars=6, bar_kwargs={'alpha': .2, 'ec': 'black', 'lw': 3})
-
\ No newline at end of file
+
+ def test_threshold(self):
+ bar_chart_race(df)
\ No newline at end of file
diff --git a/tests/test_util.py b/tests/test_util.py
new file mode 100644
index 0000000..c95e778
--- /dev/null
+++ b/tests/test_util.py
@@ -0,0 +1,22 @@
+import pytest
+import bar_chart_race._utils as utils
+from bar_chart_race import load_dataset, bar_chart_race
+
+df = load_dataset('baseball')
+df = df.iloc[-20:-16]
+
+
+def test_threshold():
+ filtered_df = utils.filter_threshold(df, 60)
+ assert len(filtered_df) == 1
+
+ filtered_df = utils.filter_threshold(df, 0)
+ assert len(filtered_df) == 4
+
+ filtered_df = utils.filter_threshold(df, 50)
+ assert filtered_df.iloc[0]['hr'] > 50
+
+
+def test_custom_data():
+ path = r"C:\Users\rcame\PycharmProjects\bar_chart_race_\data\weather_data.csv"
+ utils.load_custom_data(path, 'timestamp', None)
\ No newline at end of file