From 06b404b1ff2cc75393f25153ed392a943903c7ce Mon Sep 17 00:00:00 2001 From: "Travis N. Vaught" Date: Mon, 12 Sep 2011 23:50:52 -0500 Subject: [PATCH 1/2] - stacked bar example first checkin -- very nascent. --- examples/demo/basic/stacked_bar.py | 95 ++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 examples/demo/basic/stacked_bar.py diff --git a/examples/demo/basic/stacked_bar.py b/examples/demo/basic/stacked_bar.py new file mode 100644 index 000000000..1d4ada03d --- /dev/null +++ b/examples/demo/basic/stacked_bar.py @@ -0,0 +1,95 @@ +""" +Demonstrates stacked bar plots. + + - Left-drag pans the plot. + - Mousewheel up and down zooms the plot in and out. + - Pressing "z" brings up the Zoom Box, and you can click-drag a rectangular + region to zoom. If you use a sequence of zoom boxes, pressing alt-left-arrow + and alt-right-arrow moves you forwards and backwards through the "zoom + history". +""" + +# Major library imports +from numpy import abs, arange, cumprod, ones, random, vstack + +# Enthought library imports +from enable.api import Component, ComponentEditor +from traits.api import HasTraits, Instance +from traitsui.api import Item, Group, View + +# Chaco imports +from chaco.api import ArrayPlotData, Plot +from chaco.tools.api import PanTool, ZoomTool + +#=============================================================================== +# # Create the Chaco plot. +#=============================================================================== + +def _create_data(numpoints): + index = arange(numpoints) + + returns1 = random.random(numpoints)/4.0 + returns2 = random.random(numpoints)/4.0 + returns3 = random.random(numpoints)/4.0 + returns4 = 1.0 - (returns1 + returns2 + returns3) + vals = vstack((returns1, returns2, returns3, returns4)) + #vals.sort(0) + return index, vals + +def _create_plot_component(): + + # Create some data + index, vals = _create_data(20) + + # Create a plot data object and give it this data + pd = ArrayPlotData(index = index, + values = vals) + + # Create the plot + plot = Plot(pd) + plot.stacked_bar_plot(("index", "values"), + color = ["red", "yellow", "green", "blue"], + outline_color = "lightgray",) + + # Tweak some of the plot properties + plot.title = "Stacked Bar Plot" + plot.line_width = 0.5 + plot.padding = 50 + + # Attach some tools to the plot + plot.tools.append(PanTool(plot, constrain_key="shift")) + zoom = ZoomTool(component=plot, tool_mode="box", always_on=False) + plot.overlays.append(zoom) + + return plot + +#=============================================================================== +# Attributes to use for the plot view. +size = (650, 650) +title = "Stacked Bar Plot" +bg_color="transparent" + +#=============================================================================== +# # Demo class that is used by the demo.py application. +#=============================================================================== +class Demo(HasTraits): + plot = Instance(Component) + + traits_view = View( + Group( + Item('plot', editor=ComponentEditor(size=size, + bgcolor=bg_color), + show_label=False), + orientation = "vertical"), + resizable=True, title=title + ) + + def _plot_default(self): + return _create_plot_component() + +demo = Demo() + +if __name__ == "__main__": + demo.configure_traits() + +#--EOF--- From 79c1a44ef4b3c9589cad5ccb85ffc6524bea84cf Mon Sep 17 00:00:00 2001 From: "Travis N. Vaught" Date: Tue, 13 Sep 2011 00:00:10 -0500 Subject: [PATCH 2/2] First cut at stacked_bar_plot renderer. Still missing some things like: - axis labels from text lists - screen-space bar widths - better color designation (currently a list of colors, or ["auto"]) - I'm sure there's a lot more... --- chaco/base_stacked_bar_plot.py | 134 ++++++++++++++++++++++++++++ chaco/plot.py | 93 +++++++++++++++++++- chaco/stacked_bar_plot.py | 155 +++++++++++++++++++++++++++++++++ 3 files changed, 380 insertions(+), 2 deletions(-) create mode 100644 chaco/base_stacked_bar_plot.py create mode 100644 chaco/stacked_bar_plot.py diff --git a/chaco/base_stacked_bar_plot.py b/chaco/base_stacked_bar_plot.py new file mode 100644 index 000000000..285522101 --- /dev/null +++ b/chaco/base_stacked_bar_plot.py @@ -0,0 +1,134 @@ + +from __future__ import with_statement +import warnings + +# Major library imports +from numpy import array, column_stack, zeros_like + +# Enthought library imports +from enable.api import ColorTrait +from traits.api import Any, Bool, Float, Int, List, Property, Trait + +# Chaco imports +from base_xy_plot import BaseXYPlot + +def Alias(name): + return Property(lambda obj: getattr(obj, name), + lambda obj, val: setattr(obj, name, val)) + +class BaseStackedBarPlot(BaseXYPlot): + """ Represents the base class for candle- and bar-type plots that are + multi-valued at each index point, and optionally have an extent in the + index dimension. + + Implements the rendering logic and centralizes a lot of the visual + attributes for these sorts of plots. The gather and culling and + clipping of data is up to individual subclasses. + """ + + #------------------------------------------------------------------------ + # Appearance traits + #------------------------------------------------------------------------ + + # The fill color of the marker. + # TODO: this is a hack...need to see how to handle auto colors correctly + color = List(Any) + + # The fill color of the bar + fill_color = Alias("color") + + # The color of the rectangular box forming the bar. + outline_color = ColorTrait("black") + + # The thickness, in pixels, of the outline to draw around the bar. If + # this is 0, no outline is drawn. + line_width = Float(1.0) + + + # List of colors to cycle through when auto-coloring is requested. Picked + # and ordered to be red-green color-blind friendly, though should not + # be an issue for blue-yellow. + auto_colors = List(["green", "lightgreen", "blue", "lightblue", "red", + "pink", "darkgray", "silver"]) + + + #------------------------------------------------------------------------ + # Private traits + #------------------------------------------------------------------------ + + # currently used color -- basically using ColorTrait magic for color + # designations. + _current_color = ColorTrait("lightgray") + + # Override the base class definition of this because we store a list of + # arrays and not a single array. + _cached_data_pts = List() + + #------------------------------------------------------------------------ + # BaseXYPlot interface + #------------------------------------------------------------------------ + + def get_screen_points(self): + # Override the BaseXYPlot implementation so that this is just + # a pass-through, in case anyone calls it. + pass + + #------------------------------------------------------------------------ + # Protected methods (subclasses should be able to use these directly + # or wrap them) + #------------------------------------------------------------------------ + + def _render(self, gc, left, right, bar_maxes): + stack = column_stack + + bottom = zeros_like(left) + bottom = self.value_mapper.map_screen(bottom) + + with gc: + widths = right - left + + if len(self.color) list of renderers created in response to this call. + """ + if len(data) == 0: + return + self.value_scale = value_scale + + if name is None: + name = self._make_new_plot_name() + if origin is None: + origin = self.default_origin + + # Create the datasources + if len(data) == 2: + + index, bar_maxes = map(self._get_or_create_datasource, data) + self.index_range.add(index) + bar_maxes_data = bar_maxes.get_data() + # Accumulate data totals for stacking + prev = zeros_like(bar_maxes_data[0]) + for i in range(len(bar_maxes_data)): + bar_maxes_data[i] = prev + bar_maxes_data[i] + prev = bar_maxes_data[i] + + self.value_range.add(bar_maxes) + + if self.index_scale == "linear": + imap = LinearMapper(range=self.index_range, + stretch_data=self.index_mapper.stretch_data) + else: + imap = LogMapper(range=self.index_range, + stretch_data=self.index_mapper.stretch_data) + if self.value_scale == "linear": + vmap = LinearMapper(range=self.value_range, + stretch_data=self.value_mapper.stretch_data) + else: + vmap = LogMapper(range=self.value_range, + stretch_data=self.value_mapper.stretch_data) + + cls = self.renderer_map["stacked_bar"] + plot = cls(index = index, + bar_maxes = bar_maxes, + index_mapper = imap, + value_mapper = vmap, + orientation = self.orientation, + origin = self.origin, + **styles) + self.add(plot) + self.plots[name] = [plot] + return [plot] + def candle_plot(self, data, name=None, value_scale="linear", origin=None, **styles): """ Adds a new sub-plot using the given data and plot style. diff --git a/chaco/stacked_bar_plot.py b/chaco/stacked_bar_plot.py new file mode 100644 index 000000000..c39ec43a9 --- /dev/null +++ b/chaco/stacked_bar_plot.py @@ -0,0 +1,155 @@ + +from __future__ import with_statement + +# Major library imports +from numpy import array, compress, concatenate, searchsorted + +# Enthought library imports +from traits.api import Instance, List, Property + +# Chaco imports +from abstract_data_source import AbstractDataSource +from base_stacked_bar_plot import BaseStackedBarPlot + +def broaden(mask): + """ Takes a 1D boolean mask array and returns a copy with all the non-zero + runs widened by 1. + """ + if len(mask) < 2: + return mask + # Note: the order in which these operations are performed is important. + # Modifying newmask in-place with the |= operator only works for if + # newmask[:-1] is the L-value. + newmask = concatenate(([False], mask[1:] | mask[:-1])) + newmask[:-1] |= mask[1:] + return newmask + + +class StackedBarPlot(BaseStackedBarPlot): + """ A plot consisting of a filled bar(s) + + The values in the **index** datasource indicate the centers of the bins; + the widths of the bins are *not* specified in data space, and are + determined by the minimum space between adjacent index values. + """ + + #------------------------------------------------------------------------ + # Data-related traits + #------------------------------------------------------------------------ + + # The "upper" extent of the "bar(s)" + bar_maxes = Instance(AbstractDataSource) + + value = Property + + def map_data(self, screen_pt, all_values=True): + """ Maps a screen space point into the "index" space of the plot. + + Overrides the BaseXYPlot implementation, and always returns an + array of (index, value) tuples. + """ + x, y = screen_pt + if self.orientation == 'v': + x, y = y, x + return array((self.index_mapper.map_data(x), + self.value_mapper.map_data(y))) + + def map_index(self, screen_pt, threshold=0.0, outside_returns_none=True, + index_only = True): + if not index_only: + raise NotImplementedError("Bar Plots only support index_only map_index()") + if len(screen_pt) == 0: + return None + + # Find the closest index point using numpy + index_data = self.index.get_data() + if len(index_data) == 0: + return None + + target_data = self.index_mapper.map_data(screen_pt[0]) + + index = searchsorted(index_data, [target_data])[0] + if index == len(index_data): + index -= 1 + # Bracket index and map those points to screen space, then + # compute the distance + if index > 0: + lower = index_data[index-1] + upper = index_data[index] + screen_low, screen_high = self.index_mapper.map_screen(array([lower, upper])) + # Find the closest index + low_dist = abs(screen_pt[0] - screen_low) + high_dist = abs(screen_pt[0] - screen_high) + if low_dist < high_dist: + index = index - 1 + dist = low_dist + else: + dist = high_dist + # Determine if we need to check the threshold + if threshold > 0 and dist >= threshold: + return None + else: + return index + else: + screen = self.index_mapper.map_screen(index_data[0]) + if threshold > 0 and abs(screen - screen_pt[0]) >= threshold: + return None + else: + return index + + def _gather_points(self): + index = self.index.get_data() + mask = broaden(self.index_range.mask_data(index)) + + if not mask.any(): + self._cached_data_pts = [] + self._cache_valid = True + return + + data_pts = [compress(mask, index)] + + for v in self.bar_maxes.get_data(): + if v is None or len(v) == 0: + data_pts.append(None) + else: + data_pts.append(compress(mask, v)) + + self._cached_data_pts = data_pts + self._cache_valid = True + + def _draw_plot(self, gc, view_bounds=None, mode="normal"): + self._gather_points() + + if len(self._cached_data_pts) == 0: + return + + index = self.index_mapper.map_screen(self._cached_data_pts[0]) + if len(index) == 0: + return + + bar_maxes = [] + for v in self._cached_data_pts[1:]: + if v is None: + bar_maxes.append(None) + else: + bar_maxes.append(self.value_mapper.map_screen(v)) + + # Compute lefts and rights from self.index, which represents bin + # centers. + if len(index) == 1: + width = 5.0 + else: + width = (index[1:] - index[:-1]).min() / 2.5 + left = index - width + right = index + width + + with gc: + gc.clip_to_rect(self.x, self.y, self.width, self.height) + self._render(gc, left, right, bar_maxes) + + def _get_value(self): + # TODO: Check where this is used to make sure it returns what we want. + if self.bar_maxes is not None: + return self.bar_maxes + +