diff --git a/bar_chart_race/_bar_chart_race.py b/bar_chart_race/_bar_chart_race.py index 8b31d34..91453a1 100644 --- a/bar_chart_race/_bar_chart_race.py +++ b/bar_chart_race/_bar_chart_race.py @@ -10,13 +10,14 @@ from ._common_chart import CommonChart from ._utils import prepare_wide_data + class _BarChartRace(CommonChart): - - def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_max, - steps_per_period, period_length, end_period_pause, interpolate_period, - period_label, period_template, period_summary_func, perpendicular_bar_func, - colors, title, bar_size, bar_textposition, bar_texttemplate, bar_label_font, - tick_label_font, tick_template, shared_fontdict, scale, fig, writer, + + def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_max, fixed_min, + steps_per_period, period_length, end_period_pause, interpolate_period, + period_label, period_template, period_summary_func, perpendicular_bar_func, + colors, title, bar_size, bar_textposition, bar_texttemplate, bar_label_font, + tick_label_font, tick_template, shared_fontdict, scale, fig, writer, bar_kwargs, fig_kwargs, filter_column_colors): self.filename = filename self.extension = self.get_extension() @@ -25,6 +26,7 @@ def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_m self.n_bars = n_bars or df.shape[1] self.fixed_order = fixed_order self.fixed_max = fixed_max + self.fixed_min = fixed_min self.steps_per_period = steps_per_period self.period_length = period_length self.end_period_pause = end_period_pause @@ -64,7 +66,7 @@ def validate_params(self): raise ValueError('`filename` must have an extension') elif self.filename is not None: raise TypeError('`filename` must be None or a string') - + if self.sort not in ('asc', 'desc'): raise ValueError('`sort` must be "asc" or "desc"') @@ -143,15 +145,17 @@ def get_font(self, font, ticks=False): return font def prepare_data(self, df): + if self.fixed_order is True: last_values = df.iloc[-1].sort_values(ascending=False) cols = last_values.iloc[:self.n_bars].index df = df[cols] elif isinstance(self.fixed_order, list): cols = self.fixed_order + print(cols) df = df[cols] self.n_bars = min(len(cols), self.n_bars) - + compute_ranks = self.fixed_order is False dfs = prepare_wide_data(df, self.orientation, self.sort, self.n_bars, self.interpolate_period, self.steps_per_period, compute_ranks) @@ -165,14 +169,14 @@ def prepare_data(self, df): m = df_values.shape[0] rank_row = np.arange(1, n) if (self.sort == 'desc' and self.orientation == 'h') or \ - (self.sort == 'asc' and self.orientation == 'v'): + (self.sort == 'asc' and self.orientation == 'v'): rank_row = rank_row[::-1] - + ranks_arr = np.repeat(rank_row.reshape(1, -1), m, axis=0) df_ranks = pd.DataFrame(data=ranks_arr, columns=cols) return df_values, df_ranks - + def get_col_filt(self): col_filt = pd.Series([True] * self.df_values.shape[1]) if self.n_bars < self.df_ranks.shape[1]: @@ -188,7 +192,7 @@ def get_col_filt(self): self.df_values = self.df_values.loc[:, col_filt] self.df_ranks = self.df_ranks.loc[:, col_filt] return col_filt - + def get_bar_colors(self, colors): if colors is None: colors = 'dark12' @@ -197,7 +201,7 @@ def get_bar_colors(self, colors): if isinstance(colors, str): from ._colormaps import colormaps - + try: bar_colors = colormaps[colors.lower()] except KeyError: @@ -230,7 +234,7 @@ def get_bar_colors(self, colors): exp_ct = np.bincount(np.arange(num_cols) % n, minlength=n) if (col_idx_ct > exp_ct).any(): warnings.warn("Some of your columns never make an appearance in the animation. " - "To reduce color repetition, set `filter_column_colors` to `True`") + "To reduce color repetition, set `filter_column_colors` to `True`") return bar_colors def get_max_plotted_value(self): @@ -271,16 +275,16 @@ def get_subplots_adjust(self): plot_func = ax.barh if self.orientation == 'h' else ax.bar bar_location, bar_length, cols, _ = self.get_bar_info(-1) plot_func(bar_location, bar_length, tick_label=cols) - + self.prepare_axes(ax) texts = self.add_bar_labels(ax, bar_location, bar_length) fig.canvas.print_figure(io.BytesIO(), format='png') - xmin = min(label.get_window_extent().x0 for label in ax.get_yticklabels()) + xmin = min(label.get_window_extent().x0 for label in ax.get_yticklabels()) xmin /= (fig.dpi * fig.get_figwidth()) left = ax.get_position().x0 - xmin + .01 - ymin = min(label.get_window_extent().y0 for label in ax.get_xticklabels()) + ymin = min(label.get_window_extent().y0 for label in ax.get_xticklabels()) ymin /= (fig.dpi * fig.get_figheight()) bottom = ax.get_position().y0 - ymin + .01 @@ -298,7 +302,7 @@ def get_subplots_adjust(self): else: max_bar_pixels = ax.transData.transform((0, max_bar))[1] max_text = max(text.get_window_extent().y1 for text in texts) - + self.extra_pixels = max_text - max_bar_pixels + 10 if self.fixed_max: @@ -348,7 +352,7 @@ def set_major_formatter(self, ax): def plot_bars(self, ax, i): bar_location, bar_length, cols, colors = self.get_bar_info(i) if self.orientation == 'h': - ax.barh(bar_location, bar_length, tick_label=cols, + ax.barh(bar_location, bar_length, tick_label=cols, color=colors, **self.bar_kwargs) ax.set_yticklabels(ax.get_yticklabels(), **self.tick_label_font) if not self.fixed_max and self.bar_textposition == 'outside': @@ -357,7 +361,7 @@ def plot_bars(self, ax, i): new_xmax = ax.transData.inverted().transform((new_max_pixels, 0))[0] ax.set_xlim(ax.get_xlim()[0], new_xmax) else: - ax.bar(bar_location, bar_length, tick_label=cols, + ax.bar(bar_location, bar_length, tick_label=cols, color=colors, **self.bar_kwargs) ax.set_xticklabels(ax.get_xticklabels(), **self.tick_label_font) if not self.fixed_max and self.bar_textposition == 'outside': @@ -397,7 +401,7 @@ def add_period_summary(self, ax, i): if 'x' not in text_dict or 'y' not in text_dict or 's' not in text_dict: name = self.period_summary_func.__name__ raise ValueError(f'The dictionary returned from `{name}` must contain ' - '"x", "y", and "s"') + '"x", "y", and "s"') ax.text(transform=ax.transAxes, **text_dict) def add_bar_labels(self, ax, bar_location, bar_length): @@ -450,7 +454,7 @@ def add_perpendicular_bar(self, ax, bar_length, i): line.set_xdata([val] * 2) else: line.set_ydata([val] * 2) - + def anim_func(self, i): if i is None: return @@ -461,7 +465,7 @@ def anim_func(self, i): for text in ax.texts[start:]: text.remove() self.plot_bars(ax, i) - + def make_animation(self): def init_func(): ax = self.fig.axes[0] @@ -478,7 +482,7 @@ def frame_generator(n): for _ in range(pause): frames.append(None) return frames - + frames = frame_generator(len(self.df_values)) anim = FuncAnimation(self.fig, self.anim_func, frames, init_func, interval=interval) @@ -498,8 +502,8 @@ def frame_generator(n): fc = self.fig.get_facecolor() if fc == (1, 1, 1, 0): fc = 'white' - ret_val = anim.save(self.filename, fps=self.fps, writer=self.writer, - savefig_kwargs=savefig_kwargs) + ret_val = anim.save(self.filename, fps=self.fps, writer=self.writer, + savefig_kwargs=savefig_kwargs) except Exception as e: message = str(e) raise Exception(message) @@ -509,15 +513,15 @@ def frame_generator(n): return ret_val -def bar_chart_race(df, filename=None, orientation='h', sort='desc', n_bars=None, - fixed_order=False, fixed_max=False, steps_per_period=10, - period_length=500, end_period_pause=0, interpolate_period=False, +def bar_chart_race(df, filename=None, orientation='h', sort='desc', n_bars=None, + fixed_order=False, fixed_max=False, steps_per_period=10, + period_length=500, end_period_pause=0, interpolate_period=False, period_label=True, period_template=None, period_summary_func=None, perpendicular_bar_func=None, colors=None, title=None, bar_size=.95, bar_textposition='outside', bar_texttemplate='{x:,.0f}', bar_label_font=None, tick_label_font=None, tick_template='{x:,.0f}', - shared_fontdict=None, scale='linear', fig=None, writer=None, - bar_kwargs=None, fig_kwargs=None, filter_column_colors=False): + shared_fontdict=None, scale='linear', fig=None, writer=None, + bar_kwargs=None, fig_kwargs=None, filter_column_colors=False): ''' Create an animated bar chart race using matplotlib. Data must be in 'wide' format where each row represents a single time period and each @@ -868,9 +872,9 @@ def func(val): These sizes are relative to plt.rcParams['font.size']. ''' bcr = _BarChartRace(df, filename, orientation, sort, n_bars, fixed_order, fixed_max, - steps_per_period, period_length, end_period_pause, interpolate_period, + steps_per_period, period_length, end_period_pause, interpolate_period, period_label, period_template, period_summary_func, perpendicular_bar_func, - colors, title, bar_size, bar_textposition, bar_texttemplate, - bar_label_font, tick_label_font, tick_template, shared_fontdict, scale, + colors, title, bar_size, bar_textposition, bar_texttemplate, + bar_label_font, tick_label_font, tick_template, shared_fontdict, scale, fig, writer, bar_kwargs, fig_kwargs, filter_column_colors) return bcr.make_animation() diff --git a/bar_chart_race/_colormaps.py b/bar_chart_race/_colormaps.py index 59dc289..2d11bdd 100644 --- a/bar_chart_race/_colormaps.py +++ b/bar_chart_race/_colormaps.py @@ -39110,5 +39110,10 @@ "#311339", "#301338", "#301437" - ] + ], + "rgby": ['red', + 'green', + 'blue', + 'yellow' + ] } \ No newline at end of file diff --git a/bar_chart_race/_utils.py b/bar_chart_race/_utils.py index 89de10d..368b621 100644 --- a/bar_chart_race/_utils.py +++ b/bar_chart_race/_utils.py @@ -4,7 +4,7 @@ from matplotlib import image as mimage -def load_dataset(name='covid19'): +def load_dataset(name='covid19', threshold=0): ''' Return a pandas DataFrame suitable for immediate use in `bar_chart_race`. Must be connected to the internet @@ -18,6 +18,8 @@ def load_dataset(name='covid19'): * 'covid19_tutorial' * 'urban_pop' * 'baseball' + threshold : int, default 0 + Lowest value that will be shown on the bar_chart_race Returns ------- @@ -31,9 +33,24 @@ def load_dataset(name='covid19'): 'baseball': None} index_col = index_dict[name] parse_dates = [index_col] if index_col else None - return pd.read_csv(url, index_col=index_col, parse_dates=parse_dates) + df = pd.read_csv(url, index_col=index_col, parse_dates=parse_dates) + new_df = filter_threshold(df, threshold) -def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate_period=False, + return new_df + + +def load_custom_data(filename, index_col, parse_dates): + df = pd.read_csv(filename, index_col=index_col, parse_dates=parse_dates) + df.fillna(0, inplace=True) + return df + + + +def filter_threshold(df, thresh): + return df.loc[(df.hr > thresh)] + + +def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate_period=False, steps_per_period=10, compute_ranks=True): ''' Prepares 'wide' data for bar chart animation. @@ -109,21 +126,22 @@ def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate df_values.iloc[:, 0] = df_values.iloc[:, 0].interpolate() else: df_values.iloc[:, 0] = df_values.iloc[:, 0].fillna(method='ffill') - + df_values = df_values.set_index(df_values.columns[0]) if compute_ranks: df_ranks = df_values.rank(axis=1, method='first', ascending=False).clip(upper=n_bars + 1) if (sort == 'desc' and orientation == 'h') or (sort == 'asc' and orientation == 'v'): df_ranks = n_bars + 1 - df_ranks df_ranks = df_ranks.interpolate() - + df_values = df_values.interpolate() if compute_ranks: return df_values, df_ranks return df_values -def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h', - sort='desc', n_bars=None, interpolate_period=False, + +def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h', + sort='desc', n_bars=None, interpolate_period=False, steps_per_period=10, compute_ranks=True): ''' Prepares 'long' data for bar chart animation. @@ -201,7 +219,7 @@ def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h' df_values, df_ranks = bcr.prepare_long_data(df) bcr.bar_chart_race(df_values, steps_per_period=1, period_length=50) ''' - df_wide = df.pivot_table(index=index, columns=columns, values=values, + df_wide = df.pivot_table(index=index, columns=columns, values=values, aggfunc=aggfunc).fillna(method='ffill') return prepare_wide_data(df_wide, orientation, sort, n_bars, interpolate_period, steps_per_period, compute_ranks) @@ -222,4 +240,4 @@ def read_images(filename, columns): else: final_url = url_path.format(code=code) image_dict[col] = mimage.imread(final_url) - return image_dict \ No newline at end of file + return image_dict diff --git a/bar_chart_race/demo.ipynb b/bar_chart_race/demo.ipynb new file mode 100644 index 0000000..f717bce --- /dev/null +++ b/bar_chart_race/demo.ipynb @@ -0,0 +1,6391 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import bar_chart_race as bcr\n", + "df_baseball = bcr.load_dataset('baseball',50).pivot(index='year',\n", + " columns='name',\n", + " values='hr')\n", + "#bcr.bar_chart_race(df=df_baseball)\n", + "path = r\"C:\\Users\\rcame\\PycharmProjects\\bar_chart_race_\\data\\weather_data.csv\"\n", + "df_weather = bcr._utils.load_custom_data(path,'timestamp', None)\n", + "bcr.bar_chart_race(df_weather, colors='rgby')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.8" + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} \ No newline at end of file diff --git a/data/weather_data.csv b/data/weather_data.csv new file mode 100644 index 0000000..e8f2a83 --- /dev/null +++ b/data/weather_data.csv @@ -0,0 +1,101 @@ +timestamp,Boston,Fairbanks,Yakutsk +20210318T0000,40.345,11.120939,11.300156 +20210318T0100,39.751,11.246939,10.814154 +20210318T0200,39.463,11.4989395,9.338156 +20210318T0300,39.391003,12.05694,8.744156 +20210318T0400,39.535,2.8769379,9.428156 +20210318T0500,39.697002,2.7329369,10.868155 +20210318T0600,39.769,3.2009392,11.570154 +20210318T0700,39.715,4.33494,9.122154 +20210318T0800,37.141003,5.12694,2.0301552 +20210318T0900,40.183002,6.386938,-2.6678429 +20210318T1000,42.487,7.808939,-2.8838425 +20210318T1100,44.467003,9.266939,-1.1378441 +20210318T1200,45.907,10.70694,0.4281578 +20210318T1300,47.401,11.912939,0.73415756 +20210318T1400,48.355003,12.86694,0.4281578 +20210318T1500,48.049004,13.83894,-0.54384613 +20210318T1600,47.347,14.666939,-1.7498436 +20210318T1700,46.645,15.11694,-3.3338432 +20210318T1800,45.853,15.638939,-7.0958443 +20210318T1900,44.863,15.692938,-13.737843 +20210318T2000,44.287003,15.296938,-17.607841 +20210318T2100,43.495003,12.632938,-13.3958435 +20210318T2200,42.631,12.200939,-16.077843 +20210318T2300,40.921,12.05694,-17.949844 +20210319T0000,34.495003,11.786938,-18.813843 +20210319T0100,31.399002,11.174938,-19.245842 +20210319T0200,27.763,10.202938,-19.31784 +20210319T0300,25.333002,9.284941,-18.813843 +20210319T0400,22.813002,7.3049393,-17.98584 +20210319T0500,22.237001,5.558939,-17.031845 +20210319T0600,23.029001,4.33494,-16.79784 +20210319T0700,22.723001,3.452938,-14.295841 +20210319T0800,25.711002,2.2649364,-7.113842 +20210319T0900,26.755001,3.794939,4.2261543 +20210319T1000,28.015001,7.75494,12.740156 +20210319T1100,29.815,8.94294,18.644154 +20210319T1200,31.723001,10.202938,22.460155 +20210319T1300,33.901,12.146938,24.314156 +20210319T1400,35.593002,13.946938,25.412155 +20210319T1500,36.709,14.792938,26.042154 +20210319T1600,37.087,14.252939,26.114155 +20210319T1700,36.781002,13.100939,25.862156 +20210319T1800,35.827,12.488939,25.268154 +20210319T1900,31.921001,12.05694,23.864155 +20210319T2000,30.769001,11.102938,22.874153 +20210319T2100,29.977001,10.292938,14.774155 +20210319T2200,29.527,4.3709393,17.096155 +20210319T2300,29.329002,0.8609352,19.382154 +20210320T0000,29.311,-2.2530632,21.344154 +20210320T0100,29.329002,-3.8910637,22.100155 +20210320T0200,29.221,-6.267063,22.388153 +20210320T0300,28.879002,-9.021065,22.280155 +20210320T0400,28.483002,-18.993061,22.226154 +20210320T0500,28.069002,-21.49506,22.064156 +20210320T0600,27.673,-23.331062,21.992155 +20210320T0700,27.529001,-24.789062,22.568153 +20210320T0800,31.417002,-25.941063,24.962154 +20210320T0900,36.637,-21.657059,26.294155 +20210320T1000,41.335,-11.883064,27.410154 +20210320T1100,44.989002,-1.8390656,27.752155 +20210320T1200,47.833,4.676939,27.842155 +20210320T1300,50.191,10.058939,27.626156 +20210320T1400,51.847,13.6049385,27.788155 +20210320T1500,52.621002,16.10694,26.870155 +20210320T1600,52.819,17.438938,25.952154 +20210320T1700,52.117,17.888939,24.998154 +20210320T1800,50.227,18.26694,23.810154 +20210320T1900,46.771,17.618938,21.776154 +20210320T2000,45.763,9.62694,17.852154 +20210320T2100,44.809002,0.1409359,15.008154 +20210320T2200,43.135002,-6.7350616,11.084156 +20210320T2300,41.443,-11.559063,7.466154 +20210321T0000,39.787003,-14.439064,4.532154 +20210321T0100,38.419003,-15.825062,3.2721558 +20210321T0200,37.591,-16.887062,3.4341545 +20210321T0300,36.979,-17.805061,4.082155 +20210321T0400,36.457,-20.847061,9.050156 +20210321T0500,35.989002,-21.315063,7.6461563 +20210321T0600,35.701,-21.819061,3.9741554 +20210321T0700,35.557003,-22.323063,8.438154 +20210321T0800,37.357002,-22.575062,12.866156 +20210321T0900,42.595,-19.371063,17.906155 +20210321T1000,47.023,-10.065063,21.416155 +20210321T1100,50.947,-3.09906,23.918156 +20210321T1200,53.611,1.940937,25.268154 +20210321T1300,55.465,6.008938,23.792156 +20210321T1400,56.977,9.014938,17.042154 +20210321T1500,57.967003,11.408939,8.042154 +20210321T1600,58.273003,13.316938,4.244156 +20210321T1700,57.805,14.324938,3.3801537 +20210321T1800,56.347,14.882938,-0.11184311 +20210321T1900,53.143,14.81094,-7.31184 +20210321T2000,50.839,13.946938,-11.793842 +20210321T2100,48.427002,12.974937,-13.773842 +20210321T2200,46.177002,7.232939,-5.439842 +20210321T2300,44.467003,5.252939,-8.247841 +20210322T0000,42.811,2.678936,-17.571842 +20210322T0100,41.263,-0.2910614,-20.937843 +20210322T0200,39.805,-3.2610626,-21.783844 +20210322T0300,38.725002,-6.087063,-28.71384 diff --git a/tests/test_bar_charts.py b/tests/test_bar_charts.py index 574fae3..3067a5e 100644 --- a/tests/test_bar_charts.py +++ b/tests/test_bar_charts.py @@ -106,4 +106,6 @@ def test_fig(self): def test_bar_kwargs(self): bar_chart_race(df, n_bars=6, bar_kwargs={'alpha': .2, 'ec': 'black', 'lw': 3}) - \ No newline at end of file + + def test_threshold(self): + bar_chart_race(df) \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..c95e778 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,22 @@ +import pytest +import bar_chart_race._utils as utils +from bar_chart_race import load_dataset, bar_chart_race + +df = load_dataset('baseball') +df = df.iloc[-20:-16] + + +def test_threshold(): + filtered_df = utils.filter_threshold(df, 60) + assert len(filtered_df) == 1 + + filtered_df = utils.filter_threshold(df, 0) + assert len(filtered_df) == 4 + + filtered_df = utils.filter_threshold(df, 50) + assert filtered_df.iloc[0]['hr'] > 50 + + +def test_custom_data(): + path = r"C:\Users\rcame\PycharmProjects\bar_chart_race_\data\weather_data.csv" + utils.load_custom_data(path, 'timestamp', None) \ No newline at end of file