Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Threshold issue #47

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 38 additions & 34 deletions bar_chart_race/_bar_chart_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
from ._common_chart import CommonChart
from ._utils import prepare_wide_data


class _BarChartRace(CommonChart):
def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_max,
steps_per_period, period_length, end_period_pause, interpolate_period,
period_label, period_template, period_summary_func, perpendicular_bar_func,
colors, title, bar_size, bar_textposition, bar_texttemplate, bar_label_font,
tick_label_font, tick_template, shared_fontdict, scale, fig, writer,

def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_max, fixed_min,
steps_per_period, period_length, end_period_pause, interpolate_period,
period_label, period_template, period_summary_func, perpendicular_bar_func,
colors, title, bar_size, bar_textposition, bar_texttemplate, bar_label_font,
tick_label_font, tick_template, shared_fontdict, scale, fig, writer,
bar_kwargs, fig_kwargs, filter_column_colors):
self.filename = filename
self.extension = self.get_extension()
Expand All @@ -25,6 +26,7 @@ def __init__(self, df, filename, orientation, sort, n_bars, fixed_order, fixed_m
self.n_bars = n_bars or df.shape[1]
self.fixed_order = fixed_order
self.fixed_max = fixed_max
self.fixed_min = fixed_min
self.steps_per_period = steps_per_period
self.period_length = period_length
self.end_period_pause = end_period_pause
Expand Down Expand Up @@ -64,7 +66,7 @@ def validate_params(self):
raise ValueError('`filename` must have an extension')
elif self.filename is not None:
raise TypeError('`filename` must be None or a string')

if self.sort not in ('asc', 'desc'):
raise ValueError('`sort` must be "asc" or "desc"')

Expand Down Expand Up @@ -143,15 +145,17 @@ def get_font(self, font, ticks=False):
return font

def prepare_data(self, df):

if self.fixed_order is True:
last_values = df.iloc[-1].sort_values(ascending=False)
cols = last_values.iloc[:self.n_bars].index
df = df[cols]
elif isinstance(self.fixed_order, list):
cols = self.fixed_order
print(cols)
df = df[cols]
self.n_bars = min(len(cols), self.n_bars)

compute_ranks = self.fixed_order is False
dfs = prepare_wide_data(df, self.orientation, self.sort, self.n_bars,
self.interpolate_period, self.steps_per_period, compute_ranks)
Expand All @@ -165,14 +169,14 @@ def prepare_data(self, df):
m = df_values.shape[0]
rank_row = np.arange(1, n)
if (self.sort == 'desc' and self.orientation == 'h') or \
(self.sort == 'asc' and self.orientation == 'v'):
(self.sort == 'asc' and self.orientation == 'v'):
rank_row = rank_row[::-1]

ranks_arr = np.repeat(rank_row.reshape(1, -1), m, axis=0)
df_ranks = pd.DataFrame(data=ranks_arr, columns=cols)

return df_values, df_ranks

def get_col_filt(self):
col_filt = pd.Series([True] * self.df_values.shape[1])
if self.n_bars < self.df_ranks.shape[1]:
Expand All @@ -188,7 +192,7 @@ def get_col_filt(self):
self.df_values = self.df_values.loc[:, col_filt]
self.df_ranks = self.df_ranks.loc[:, col_filt]
return col_filt

def get_bar_colors(self, colors):
if colors is None:
colors = 'dark12'
Expand All @@ -197,7 +201,7 @@ def get_bar_colors(self, colors):

if isinstance(colors, str):
from ._colormaps import colormaps

try:
bar_colors = colormaps[colors.lower()]
except KeyError:
Expand Down Expand Up @@ -230,7 +234,7 @@ def get_bar_colors(self, colors):
exp_ct = np.bincount(np.arange(num_cols) % n, minlength=n)
if (col_idx_ct > exp_ct).any():
warnings.warn("Some of your columns never make an appearance in the animation. "
"To reduce color repetition, set `filter_column_colors` to `True`")
"To reduce color repetition, set `filter_column_colors` to `True`")
return bar_colors

def get_max_plotted_value(self):
Expand Down Expand Up @@ -271,16 +275,16 @@ def get_subplots_adjust(self):
plot_func = ax.barh if self.orientation == 'h' else ax.bar
bar_location, bar_length, cols, _ = self.get_bar_info(-1)
plot_func(bar_location, bar_length, tick_label=cols)

self.prepare_axes(ax)
texts = self.add_bar_labels(ax, bar_location, bar_length)

fig.canvas.print_figure(io.BytesIO(), format='png')
xmin = min(label.get_window_extent().x0 for label in ax.get_yticklabels())
xmin = min(label.get_window_extent().x0 for label in ax.get_yticklabels())
xmin /= (fig.dpi * fig.get_figwidth())
left = ax.get_position().x0 - xmin + .01

ymin = min(label.get_window_extent().y0 for label in ax.get_xticklabels())
ymin = min(label.get_window_extent().y0 for label in ax.get_xticklabels())
ymin /= (fig.dpi * fig.get_figheight())
bottom = ax.get_position().y0 - ymin + .01

Expand All @@ -298,7 +302,7 @@ def get_subplots_adjust(self):
else:
max_bar_pixels = ax.transData.transform((0, max_bar))[1]
max_text = max(text.get_window_extent().y1 for text in texts)

self.extra_pixels = max_text - max_bar_pixels + 10

if self.fixed_max:
Expand Down Expand Up @@ -348,7 +352,7 @@ def set_major_formatter(self, ax):
def plot_bars(self, ax, i):
bar_location, bar_length, cols, colors = self.get_bar_info(i)
if self.orientation == 'h':
ax.barh(bar_location, bar_length, tick_label=cols,
ax.barh(bar_location, bar_length, tick_label=cols,
color=colors, **self.bar_kwargs)
ax.set_yticklabels(ax.get_yticklabels(), **self.tick_label_font)
if not self.fixed_max and self.bar_textposition == 'outside':
Expand All @@ -357,7 +361,7 @@ def plot_bars(self, ax, i):
new_xmax = ax.transData.inverted().transform((new_max_pixels, 0))[0]
ax.set_xlim(ax.get_xlim()[0], new_xmax)
else:
ax.bar(bar_location, bar_length, tick_label=cols,
ax.bar(bar_location, bar_length, tick_label=cols,
color=colors, **self.bar_kwargs)
ax.set_xticklabels(ax.get_xticklabels(), **self.tick_label_font)
if not self.fixed_max and self.bar_textposition == 'outside':
Expand Down Expand Up @@ -397,7 +401,7 @@ def add_period_summary(self, ax, i):
if 'x' not in text_dict or 'y' not in text_dict or 's' not in text_dict:
name = self.period_summary_func.__name__
raise ValueError(f'The dictionary returned from `{name}` must contain '
'"x", "y", and "s"')
'"x", "y", and "s"')
ax.text(transform=ax.transAxes, **text_dict)

def add_bar_labels(self, ax, bar_location, bar_length):
Expand Down Expand Up @@ -450,7 +454,7 @@ def add_perpendicular_bar(self, ax, bar_length, i):
line.set_xdata([val] * 2)
else:
line.set_ydata([val] * 2)

def anim_func(self, i):
if i is None:
return
Expand All @@ -461,7 +465,7 @@ def anim_func(self, i):
for text in ax.texts[start:]:
text.remove()
self.plot_bars(ax, i)

def make_animation(self):
def init_func():
ax = self.fig.axes[0]
Expand All @@ -478,7 +482,7 @@ def frame_generator(n):
for _ in range(pause):
frames.append(None)
return frames

frames = frame_generator(len(self.df_values))
anim = FuncAnimation(self.fig, self.anim_func, frames, init_func, interval=interval)

Expand All @@ -498,8 +502,8 @@ def frame_generator(n):
fc = self.fig.get_facecolor()
if fc == (1, 1, 1, 0):
fc = 'white'
ret_val = anim.save(self.filename, fps=self.fps, writer=self.writer,
savefig_kwargs=savefig_kwargs)
ret_val = anim.save(self.filename, fps=self.fps, writer=self.writer,
savefig_kwargs=savefig_kwargs)
except Exception as e:
message = str(e)
raise Exception(message)
Expand All @@ -509,15 +513,15 @@ def frame_generator(n):
return ret_val


def bar_chart_race(df, filename=None, orientation='h', sort='desc', n_bars=None,
fixed_order=False, fixed_max=False, steps_per_period=10,
period_length=500, end_period_pause=0, interpolate_period=False,
def bar_chart_race(df, filename=None, orientation='h', sort='desc', n_bars=None,
fixed_order=False, fixed_max=False, steps_per_period=10,
period_length=500, end_period_pause=0, interpolate_period=False,
period_label=True, period_template=None, period_summary_func=None,
perpendicular_bar_func=None, colors=None, title=None, bar_size=.95,
bar_textposition='outside', bar_texttemplate='{x:,.0f}',
bar_label_font=None, tick_label_font=None, tick_template='{x:,.0f}',
shared_fontdict=None, scale='linear', fig=None, writer=None,
bar_kwargs=None, fig_kwargs=None, filter_column_colors=False):
shared_fontdict=None, scale='linear', fig=None, writer=None,
bar_kwargs=None, fig_kwargs=None, filter_column_colors=False):
'''
Create an animated bar chart race using matplotlib. Data must be in
'wide' format where each row represents a single time period and each
Expand Down Expand Up @@ -868,9 +872,9 @@ def func(val):
These sizes are relative to plt.rcParams['font.size'].
'''
bcr = _BarChartRace(df, filename, orientation, sort, n_bars, fixed_order, fixed_max,
steps_per_period, period_length, end_period_pause, interpolate_period,
steps_per_period, period_length, end_period_pause, interpolate_period,
period_label, period_template, period_summary_func, perpendicular_bar_func,
colors, title, bar_size, bar_textposition, bar_texttemplate,
bar_label_font, tick_label_font, tick_template, shared_fontdict, scale,
colors, title, bar_size, bar_textposition, bar_texttemplate,
bar_label_font, tick_label_font, tick_template, shared_fontdict, scale,
fig, writer, bar_kwargs, fig_kwargs, filter_column_colors)
return bcr.make_animation()
7 changes: 6 additions & 1 deletion bar_chart_race/_colormaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -39110,5 +39110,10 @@
"#311339",
"#301338",
"#301437"
]
],
"rgby": ['red',
'green',
'blue',
'yellow'
]
}
36 changes: 27 additions & 9 deletions bar_chart_race/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from matplotlib import image as mimage


def load_dataset(name='covid19'):
def load_dataset(name='covid19', threshold=0):
'''
Return a pandas DataFrame suitable for immediate use in `bar_chart_race`.
Must be connected to the internet
Expand All @@ -18,6 +18,8 @@ def load_dataset(name='covid19'):
* 'covid19_tutorial'
* 'urban_pop'
* 'baseball'
threshold : int, default 0
Lowest value that will be shown on the bar_chart_race

Returns
-------
Expand All @@ -31,9 +33,24 @@ def load_dataset(name='covid19'):
'baseball': None}
index_col = index_dict[name]
parse_dates = [index_col] if index_col else None
return pd.read_csv(url, index_col=index_col, parse_dates=parse_dates)
df = pd.read_csv(url, index_col=index_col, parse_dates=parse_dates)
new_df = filter_threshold(df, threshold)

def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate_period=False,
return new_df


def load_custom_data(filename, index_col, parse_dates):
df = pd.read_csv(filename, index_col=index_col, parse_dates=parse_dates)
df.fillna(0, inplace=True)
return df



def filter_threshold(df, thresh):
return df.loc[(df.hr > thresh)]


def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate_period=False,
steps_per_period=10, compute_ranks=True):
'''
Prepares 'wide' data for bar chart animation.
Expand Down Expand Up @@ -109,21 +126,22 @@ def prepare_wide_data(df, orientation='h', sort='desc', n_bars=None, interpolate
df_values.iloc[:, 0] = df_values.iloc[:, 0].interpolate()
else:
df_values.iloc[:, 0] = df_values.iloc[:, 0].fillna(method='ffill')

df_values = df_values.set_index(df_values.columns[0])
if compute_ranks:
df_ranks = df_values.rank(axis=1, method='first', ascending=False).clip(upper=n_bars + 1)
if (sort == 'desc' and orientation == 'h') or (sort == 'asc' and orientation == 'v'):
df_ranks = n_bars + 1 - df_ranks
df_ranks = df_ranks.interpolate()

df_values = df_values.interpolate()
if compute_ranks:
return df_values, df_ranks
return df_values

def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h',
sort='desc', n_bars=None, interpolate_period=False,

def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h',
sort='desc', n_bars=None, interpolate_period=False,
steps_per_period=10, compute_ranks=True):
'''
Prepares 'long' data for bar chart animation.
Expand Down Expand Up @@ -201,7 +219,7 @@ def prepare_long_data(df, index, columns, values, aggfunc='sum', orientation='h'
df_values, df_ranks = bcr.prepare_long_data(df)
bcr.bar_chart_race(df_values, steps_per_period=1, period_length=50)
'''
df_wide = df.pivot_table(index=index, columns=columns, values=values,
df_wide = df.pivot_table(index=index, columns=columns, values=values,
aggfunc=aggfunc).fillna(method='ffill')
return prepare_wide_data(df_wide, orientation, sort, n_bars, interpolate_period,
steps_per_period, compute_ranks)
Expand All @@ -222,4 +240,4 @@ def read_images(filename, columns):
else:
final_url = url_path.format(code=code)
image_dict[col] = mimage.imread(final_url)
return image_dict
return image_dict
Loading