From 63f3bd92434257a48206bbd1459cbd7d3a97ec38 Mon Sep 17 00:00:00 2001 From: aaronayres35 <36972686+aaronayres35@users.noreply.github.com> Date: Wed, 7 Jul 2021 06:33:28 -0700 Subject: [PATCH] Fix BaseXYPlot map_screen (#802) * verify / adjust as needed input shape to map_screen. Also, return an empty array of correct shape instead of empty list in case of empty input * fix use of map_screen in LinePlot to fix test failures * make similar changes to subclasses * add regression test * similar updates for segment plot along with test * flake8 * ErrorBarPlot is also a BaseXYPlot descendent * fix hittest_tool.py with recent changes --- chaco/base_xy_plot.py | 13 ++++++-- chaco/examples/demo/basic/hittest_tool.py | 2 +- chaco/plots/barplot.py | 6 +++- chaco/plots/errorbar_plot.py | 4 +-- chaco/plots/lineplot.py | 2 +- chaco/plots/scatterplot.py | 3 +- chaco/plots/segment_plot.py | 21 +++++++++++++ chaco/tests/test_plot.py | 37 +++++++++++++++++++++++ 8 files changed, 79 insertions(+), 9 deletions(-) diff --git a/chaco/base_xy_plot.py b/chaco/base_xy_plot.py index 509f47fa8..8a4e03b04 100644 --- a/chaco/base_xy_plot.py +++ b/chaco/base_xy_plot.py @@ -11,7 +11,7 @@ """ Defines the base class for XY plots. """ from math import sqrt -from numpy import around, array, isnan, transpose +from numpy import around, array, empty, isnan, transpose # Enthought library imports from enable.api import black_color_trait @@ -353,9 +353,16 @@ def map_screen(self, data_array): Implements the AbstractPlotRenderer interface. """ - # data_array is Nx2 array + # ensure data_array is an N1 x ... Nk x 2 ndarray for some k >= 1 + data_array = array(data_array) + + if data_array.ndim == 1: + data_array = data_array.reshape(-1, 2) + if data_array.shape[-1] != 2: + raise ValueError("Input to map_screen must have shape (..., 2)") + if len(data_array) == 0: - return [] + return empty(shape=(0, 2)) x_ary, y_ary = transpose(data_array) diff --git a/chaco/examples/demo/basic/hittest_tool.py b/chaco/examples/demo/basic/hittest_tool.py index 4f45db75c..c1d549646 100644 --- a/chaco/examples/demo/basic/hittest_tool.py +++ b/chaco/examples/demo/basic/hittest_tool.py @@ -46,7 +46,7 @@ def normal_mouse_move(self, event): else: x, y = self.component.map_data((y, x)) - x, y = self.line_plot.map_screen((x, y)) + x, y = self.line_plot.map_screen((x, y))[0] self.pt = self.line_plot.hittest((x, y), threshold=self.threshold) self.request_redraw() diff --git a/chaco/plots/barplot.py b/chaco/plots/barplot.py index 54f97c1c7..4e8915654 100644 --- a/chaco/plots/barplot.py +++ b/chaco/plots/barplot.py @@ -16,6 +16,7 @@ array, compress, column_stack, + empty, invert, isnan, transpose, @@ -176,9 +177,12 @@ def map_screen(self, data_array): Implements the AbstractPlotRenderer interface. """ + # ensure data_array is an Nx2 ndarray + data_array = array(data_array) + data_array = data_array.reshape(-1,2) # data_array is Nx2 array if len(data_array) == 0: - return [] + return empty(shape=(0,2)) x_ary, y_ary = transpose(data_array) sx = self.index_mapper.map_screen(x_ary) sy = self.value_mapper.map_screen(y_ary) diff --git a/chaco/plots/errorbar_plot.py b/chaco/plots/errorbar_plot.py index 7e2e9e26b..d837ad723 100644 --- a/chaco/plots/errorbar_plot.py +++ b/chaco/plots/errorbar_plot.py @@ -9,7 +9,7 @@ # Thanks for using Enthought open source! # Major library imports -from numpy import column_stack, compress, invert, isnan, transpose +from numpy import column_stack, compress, empty, invert, isnan, transpose import logging # Enthought library imports @@ -49,7 +49,7 @@ def map_screen(self, data_array): or (y, xlow, xhigh) depending on self.orientation. """ if len(data_array) == 0: - return [] + return empty(shape=(0, 2)) elif data_array.shape[1] == 2: return LinePlot.map_screen(self, data_array) else: diff --git a/chaco/plots/lineplot.py b/chaco/plots/lineplot.py index 9faa2f565..e26bfdc09 100644 --- a/chaco/plots/lineplot.py +++ b/chaco/plots/lineplot.py @@ -122,7 +122,7 @@ def hittest(self, screen_pt, threshold=7.0, return_distance=False): # screen_pt is one of the points in the lineplot data_pt = (self.index.get_data()[ndx], self.value.get_data()[ndx]) if return_distance: - scrn_pt = self.map_screen(data_pt) + scrn_pt = self.map_screen(data_pt)[0] dist = sqrt( (screen_pt[0] - scrn_pt[0]) ** 2 + (screen_pt[1] - scrn_pt[1]) ** 2 diff --git a/chaco/plots/scatterplot.py b/chaco/plots/scatterplot.py index 7c91e6468..b5c10dc76 100644 --- a/chaco/plots/scatterplot.py +++ b/chaco/plots/scatterplot.py @@ -21,6 +21,7 @@ array, asarray, column_stack, + empty, isfinite, isnan, nanargmin, @@ -299,7 +300,7 @@ def map_screen(self, data_array): """ # data_array is Nx2 array if len(data_array) == 0: - return [] + return empty(shape=(0,2)) data_array = asarray(data_array) if len(data_array.shape) == 1: diff --git a/chaco/plots/segment_plot.py b/chaco/plots/segment_plot.py index 4d11bd46b..11e600f90 100644 --- a/chaco/plots/segment_plot.py +++ b/chaco/plots/segment_plot.py @@ -142,6 +142,27 @@ def hittest(self, *args, **kwargs): def map_index(self, *args, **kwargs): raise NotImplementedError() + def map_screen(self, data_array): + """Maps an Nx2x2 array of data points into screen space and returns it + as an array. + Implements the AbstractPlotRenderer interface. + """ + # ensure data_array is an Nx2x2 ndarray + data_array = np.asarray(data_array) + data_array = data_array.reshape(-1, 2, 2) + + if len(data_array) == 0: + return np.empty(shape=(0, 2, 2)) + + x_ary, y_ary = np.transpose(data_array) + + sx = self.index_mapper.map_screen(x_ary) + sy = self.value_mapper.map_screen(y_ary) + if self.orientation == "h": + return np.transpose(np.array((sx, sy))) + else: + return np.transpose(np.array((sy, sx))) + def _gather_points(self): """Collects the data points that are within the bounds of the plot and caches them. diff --git a/chaco/tests/test_plot.py b/chaco/tests/test_plot.py index c14015221..9e458d5da 100644 --- a/chaco/tests/test_plot.py +++ b/chaco/tests/test_plot.py @@ -10,6 +10,7 @@ import unittest +import numpy as np from numpy import alltrue, arange, array from enable.api import ComponentEditor @@ -95,6 +96,18 @@ def test_segment_plot_color_width(self): actual = gc.bmp_array[:, :, :] self.assertFalse(alltrue(actual == 255)) + def test_segment_plot_map_screen(self): + x = arange(10) + y = arange(1, 11) + data = ArrayPlotData(x=x, y=y) + plot = Plot(data) + plot_renderer = plot.plot(("x", "y"), "segment")[0] + + screen_point = plot_renderer.map_screen([(0, 1), (1, 2)]) + + self.assertEqual(type(screen_point), np.ndarray) + self.assertEqual(screen_point.shape, (1, 2, 2)) + def test_text_plot(self): x = arange(10) y = arange(1, 11) @@ -109,6 +122,30 @@ def test_text_plot(self): actual = gc.bmp_array[:, :, :] self.assertFalse(alltrue(actual == 255)) + def check_map_screen(self, renderer): + arr = arange(10) + data = ArrayPlotData(x=arr, y=arr) + plot = Plot(data) + plot_renderer = plot.add_xy_plot( + 'x', 'y', plot.renderer_map[renderer] + )[0] + + screen_point = plot_renderer.map_screen((-1, 1)) + + self.assertEqual(type(screen_point), np.ndarray) + self.assertEqual(screen_point.shape, (1, 2)) + + screen_point = plot_renderer.map_screen([]) + + self.assertEqual(type(screen_point), np.ndarray) + self.assertEqual(screen_point.shape, (0, 2)) + + # serves as a regression test for enthought/chaco#272 + def test_xy_plot_map_screen(self): + renderers = ["line", "scatter", "bar", "polygon"] + for renderer in renderers: + self.check_map_screen(renderer) + class EmptyLinePlot(HasTraits): plot = Instance(Plot)