Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New: stacked barplot #152

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions chaco/base_stacked_bar_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@

from __future__ import with_statement
import warnings

# Major library imports
from numpy import array, column_stack, zeros_like

# Enthought library imports
from enable.api import ColorTrait
from traits.api import Any, Bool, Float, Int, List, Property, Trait

# Chaco imports
from base_xy_plot import BaseXYPlot

def Alias(name):
return Property(lambda obj: getattr(obj, name),
lambda obj, val: setattr(obj, name, val))

class BaseStackedBarPlot(BaseXYPlot):
""" Represents the base class for candle- and bar-type plots that are
multi-valued at each index point, and optionally have an extent in the
index dimension.

Implements the rendering logic and centralizes a lot of the visual
attributes for these sorts of plots. The gather and culling and
clipping of data is up to individual subclasses.
"""

#------------------------------------------------------------------------
# Appearance traits
#------------------------------------------------------------------------

# The fill color of the marker.
# TODO: this is a hack...need to see how to handle auto colors correctly
color = List(Any)

# The fill color of the bar
fill_color = Alias("color")

# The color of the rectangular box forming the bar.
outline_color = ColorTrait("black")

# The thickness, in pixels, of the outline to draw around the bar. If
# this is 0, no outline is drawn.
line_width = Float(1.0)


# List of colors to cycle through when auto-coloring is requested. Picked
# and ordered to be red-green color-blind friendly, though should not
# be an issue for blue-yellow.
auto_colors = List(["green", "lightgreen", "blue", "lightblue", "red",
"pink", "darkgray", "silver"])


#------------------------------------------------------------------------
# Private traits
#------------------------------------------------------------------------

# currently used color -- basically using ColorTrait magic for color
# designations.
_current_color = ColorTrait("lightgray")

# Override the base class definition of this because we store a list of
# arrays and not a single array.
_cached_data_pts = List()

#------------------------------------------------------------------------
# BaseXYPlot interface
#------------------------------------------------------------------------

def get_screen_points(self):
# Override the BaseXYPlot implementation so that this is just
# a pass-through, in case anyone calls it.
pass

#------------------------------------------------------------------------
# Protected methods (subclasses should be able to use these directly
# or wrap them)
#------------------------------------------------------------------------

def _render(self, gc, left, right, bar_maxes):
stack = column_stack

bottom = zeros_like(left)
bottom = self.value_mapper.map_screen(bottom)

with gc:
widths = right - left

if len(self.color)<len(bar_maxes):

warnings.warn("Color count does not match data series count.")
for i in range(len(bar_maxes)):
self.color.append(self.color[-1])

idx = 0

for bar_max in bar_maxes:

top = bar_max

if self.color[0] == "auto":
self.color = self.auto_colors[:len(left)]

self._current_color = self.color[idx]

# Draw the bars
bars = stack((left, bottom, widths, top - bottom))

gc.set_antialias(False)
gc.set_stroke_color(self.outline_color_)
gc.set_line_width(self.line_width)
gc.rects(bars)
if self.color in ("none", "transparent", "clear"):
gc.stroke_path()
else:
gc.set_fill_color(self._current_color_)
gc.draw_path()
bottom = top

idx += 1


def _render_icon(self, gc, x, y, width, height):
min = array([y + 1])
max = array([y + height - 1])
bar_min = array([y + height / 3])
bar_max = array([y + height - (height / 3)])
center = array([y + (height / 2)])
self._render(gc, array([x+width/4]), array([x+3*width/4]), bar_maxes)




124 changes: 105 additions & 19 deletions chaco/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Major library imports
import itertools
import warnings
from numpy import arange, array, ndarray, linspace
from numpy import arange, array, ndarray, linspace, sum, zeros_like
from types import FunctionType

# Enthought library imports
Expand Down Expand Up @@ -36,6 +36,7 @@
from plot_label import PlotLabel
from polygon_plot import PolygonPlot
from scatterplot import ScatterPlot
from stacked_bar_plot import StackedBarPlot
from filled_line_plot import FilledLinePlot
from quiverplot import QuiverPlot

Expand Down Expand Up @@ -117,6 +118,7 @@ class Plot(DataView):
contour_line_plot = ContourLinePlot,
contour_poly_plot = ContourPolyPlot,
candle = CandlePlot,
stacked_bar = StackedBarPlot,
quiver = QuiverPlot,))

#------------------------------------------------------------------------
Expand Down Expand Up @@ -305,7 +307,7 @@ def plot(self, data, type="line", name=None, index_scale="linear",
name = self._make_new_plot_name()
if origin is None:
origin = self.default_origin

if plot_type in ("line", "scatter", "polygon", "bar", "filled_line"):
# Tie data to the index range
if len(data) == 1:
Expand Down Expand Up @@ -378,44 +380,44 @@ def plot(self, data, type="line", name=None, index_scale="linear",
orientation=self.orientation,
origin = origin,
**styles)

self.add(plot)
new_plots.append(plot)

if plot_type == 'bar':
# For bar plots, compute the ranges from the data to make the
# plot look clean.
# For bar plots, compute the ranges from the data to make the
# plot look clean.

def custom_index_func(data_low, data_high, margin, tight_bounds):
""" Compute custom bounds of the plot along index (in
""" Compute custom bounds of the plot along index (in
data space).
"""
bar_width = styles.get('bar_width', cls().bar_width)
plot_low = data_low - bar_width
plot_high = data_high + bar_width
return plot_low, plot_high

if self.index_range.bounds_func is None:
self.index_range.bounds_func = custom_index_func

def custom_value_func(data_low, data_high, margin, tight_bounds):
""" Compute custom bounds of the plot along value (in
""" Compute custom bounds of the plot along value (in
data space).
"""
plot_low = data_low - (data_high-data_low)*0.1
plot_high = data_high + (data_high-data_low)*0.1
return plot_low, plot_high
if self.value_range.bounds_func is None:

if self.value_range.bounds_func is None:
self.value_range.bounds_func = custom_value_func

self.index_range.tight_bounds = False
self.value_range.tight_bounds = False
self.index_range.refresh()
self.value_range.refresh()

self.plots[name] = new_plots

elif plot_type == "cmap_scatter":
if len(data) != 3:
raise ValueError("Colormapped scatter plots require (index, value, color) data")
Expand Down Expand Up @@ -717,6 +719,92 @@ def _create_2d_plot(self, cls, name, origin, xbounds, ybounds, value_ds,
self.plots[name] = [plot]
return self.plots[name]

def stacked_bar_plot(self, data, name=None, value_scale="linear", origin=None,
**styles):
""" Adds a new sub-plot using the given data and plot style.

Parameters
==========
data : list(string), tuple(string)
The names of the data to be plotted in the ArrayDataSource. The
number of arguments determines how they are interpreted:

(index, bar_max)
filled or outline-only bar extending from index-axis to
**bar_max**

(index, bar_max1, bar_max2, bar_max3, ...)
filled or outline-only bar extending first from index-axis to
**bar_max1**, then another bar extending from **bar_max1**
to **bar_max2**, etc.

name : string
The name of the plot. If None, then a default one is created.

value_scale : string
The type of scale to use for the value axis. If not "linear",
then a log scale is used.

Styles
======
These are all optional keyword arguments.

color : List of strings, 3- or 4-tuples
The fill color of the bars; defaults to "auto".
outline_color : List of strings, 3- or 4-tuples
The color of the rectangular box forming the bars.

Returns
=======
[renderers] -> list of renderers created in response to this call.
"""
if len(data) == 0:
return
self.value_scale = value_scale

if name is None:
name = self._make_new_plot_name()
if origin is None:
origin = self.default_origin

# Create the datasources
if len(data) == 2:

index, bar_maxes = map(self._get_or_create_datasource, data)
self.index_range.add(index)
bar_maxes_data = bar_maxes.get_data()
# Accumulate data totals for stacking
prev = zeros_like(bar_maxes_data[0])
for i in range(len(bar_maxes_data)):
bar_maxes_data[i] = prev + bar_maxes_data[i]
prev = bar_maxes_data[i]

self.value_range.add(bar_maxes)

if self.index_scale == "linear":
imap = LinearMapper(range=self.index_range,
stretch_data=self.index_mapper.stretch_data)
else:
imap = LogMapper(range=self.index_range,
stretch_data=self.index_mapper.stretch_data)
if self.value_scale == "linear":
vmap = LinearMapper(range=self.value_range,
stretch_data=self.value_mapper.stretch_data)
else:
vmap = LogMapper(range=self.value_range,
stretch_data=self.value_mapper.stretch_data)

cls = self.renderer_map["stacked_bar"]
plot = cls(index = index,
bar_maxes = bar_maxes,
index_mapper = imap,
value_mapper = vmap,
orientation = self.orientation,
origin = self.origin,
**styles)
self.add(plot)
self.plots[name] = [plot]
return [plot]

def candle_plot(self, data, name=None, value_scale="linear", origin=None,
**styles):
Expand Down Expand Up @@ -908,7 +996,7 @@ def quiverplot(self, data, name=None, origin=None,
)
self.add(plot)
self.plots[name] = [plot]
return [plot]
return [plot]

def delplot(self, *names):
""" Removes the named sub-plots. """
Expand Down Expand Up @@ -1059,7 +1147,7 @@ def _data_update_handler(self, name, event):
if name in self.datasources:
source = self.datasources[name]
source.set_data(self.data.get_data(name))

def _plots_items_changed(self, event):
if self.legend:
self.legend.plots = self.plots
Expand Down Expand Up @@ -1185,5 +1273,3 @@ def _set_title_font(self, font):

def _get_title_font(self):
return self._title.font


Loading