Skip to content

Commit

Permalink
Fix BaseXYPlot map_screen (#802)
Browse files Browse the repository at this point in the history
* verify / adjust as needed input shape to map_screen. Also, return an empty array of correct shape instead of empty list in case of empty input

* fix use of map_screen in LinePlot to fix test failures

* make similar changes to subclasses

* add regression test

* similar updates for segment plot along with test

* flake8

* ErrorBarPlot is also a BaseXYPlot descendent

* fix hittest_tool.py with recent changes
  • Loading branch information
aaronayres35 authored Jul 7, 2021
1 parent eba69e1 commit 63f3bd9
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 9 deletions.
13 changes: 10 additions & 3 deletions chaco/base_xy_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
""" Defines the base class for XY plots.
"""
from math import sqrt
from numpy import around, array, isnan, transpose
from numpy import around, array, empty, isnan, transpose

# Enthought library imports
from enable.api import black_color_trait
Expand Down Expand Up @@ -353,9 +353,16 @@ def map_screen(self, data_array):
Implements the AbstractPlotRenderer interface.
"""
# data_array is Nx2 array
# ensure data_array is an N1 x ... Nk x 2 ndarray for some k >= 1
data_array = array(data_array)

if data_array.ndim == 1:
data_array = data_array.reshape(-1, 2)
if data_array.shape[-1] != 2:
raise ValueError("Input to map_screen must have shape (..., 2)")

if len(data_array) == 0:
return []
return empty(shape=(0, 2))

x_ary, y_ary = transpose(data_array)

Expand Down
2 changes: 1 addition & 1 deletion chaco/examples/demo/basic/hittest_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def normal_mouse_move(self, event):
else:
x, y = self.component.map_data((y, x))

x, y = self.line_plot.map_screen((x, y))
x, y = self.line_plot.map_screen((x, y))[0]
self.pt = self.line_plot.hittest((x, y), threshold=self.threshold)
self.request_redraw()

Expand Down
6 changes: 5 additions & 1 deletion chaco/plots/barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
array,
compress,
column_stack,
empty,
invert,
isnan,
transpose,
Expand Down Expand Up @@ -176,9 +177,12 @@ def map_screen(self, data_array):
Implements the AbstractPlotRenderer interface.
"""
# ensure data_array is an Nx2 ndarray
data_array = array(data_array)
data_array = data_array.reshape(-1,2)
# data_array is Nx2 array
if len(data_array) == 0:
return []
return empty(shape=(0,2))
x_ary, y_ary = transpose(data_array)
sx = self.index_mapper.map_screen(x_ary)
sy = self.value_mapper.map_screen(y_ary)
Expand Down
4 changes: 2 additions & 2 deletions chaco/plots/errorbar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Thanks for using Enthought open source!

# Major library imports
from numpy import column_stack, compress, invert, isnan, transpose
from numpy import column_stack, compress, empty, invert, isnan, transpose
import logging

# Enthought library imports
Expand Down Expand Up @@ -49,7 +49,7 @@ def map_screen(self, data_array):
or (y, xlow, xhigh) depending on self.orientation.
"""
if len(data_array) == 0:
return []
return empty(shape=(0, 2))
elif data_array.shape[1] == 2:
return LinePlot.map_screen(self, data_array)
else:
Expand Down
2 changes: 1 addition & 1 deletion chaco/plots/lineplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def hittest(self, screen_pt, threshold=7.0, return_distance=False):
# screen_pt is one of the points in the lineplot
data_pt = (self.index.get_data()[ndx], self.value.get_data()[ndx])
if return_distance:
scrn_pt = self.map_screen(data_pt)
scrn_pt = self.map_screen(data_pt)[0]
dist = sqrt(
(screen_pt[0] - scrn_pt[0]) ** 2
+ (screen_pt[1] - scrn_pt[1]) ** 2
Expand Down
3 changes: 2 additions & 1 deletion chaco/plots/scatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
array,
asarray,
column_stack,
empty,
isfinite,
isnan,
nanargmin,
Expand Down Expand Up @@ -299,7 +300,7 @@ def map_screen(self, data_array):
"""
# data_array is Nx2 array
if len(data_array) == 0:
return []
return empty(shape=(0,2))

data_array = asarray(data_array)
if len(data_array.shape) == 1:
Expand Down
21 changes: 21 additions & 0 deletions chaco/plots/segment_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,27 @@ def hittest(self, *args, **kwargs):
def map_index(self, *args, **kwargs):
raise NotImplementedError()

def map_screen(self, data_array):
"""Maps an Nx2x2 array of data points into screen space and returns it
as an array.
Implements the AbstractPlotRenderer interface.
"""
# ensure data_array is an Nx2x2 ndarray
data_array = np.asarray(data_array)
data_array = data_array.reshape(-1, 2, 2)

if len(data_array) == 0:
return np.empty(shape=(0, 2, 2))

x_ary, y_ary = np.transpose(data_array)

sx = self.index_mapper.map_screen(x_ary)
sy = self.value_mapper.map_screen(y_ary)
if self.orientation == "h":
return np.transpose(np.array((sx, sy)))
else:
return np.transpose(np.array((sy, sx)))

def _gather_points(self):
"""Collects the data points that are within the bounds of the plot and
caches them.
Expand Down
37 changes: 37 additions & 0 deletions chaco/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import unittest

import numpy as np
from numpy import alltrue, arange, array

from enable.api import ComponentEditor
Expand Down Expand Up @@ -95,6 +96,18 @@ def test_segment_plot_color_width(self):
actual = gc.bmp_array[:, :, :]
self.assertFalse(alltrue(actual == 255))

def test_segment_plot_map_screen(self):
x = arange(10)
y = arange(1, 11)
data = ArrayPlotData(x=x, y=y)
plot = Plot(data)
plot_renderer = plot.plot(("x", "y"), "segment")[0]

screen_point = plot_renderer.map_screen([(0, 1), (1, 2)])

self.assertEqual(type(screen_point), np.ndarray)
self.assertEqual(screen_point.shape, (1, 2, 2))

def test_text_plot(self):
x = arange(10)
y = arange(1, 11)
Expand All @@ -109,6 +122,30 @@ def test_text_plot(self):
actual = gc.bmp_array[:, :, :]
self.assertFalse(alltrue(actual == 255))

def check_map_screen(self, renderer):
arr = arange(10)
data = ArrayPlotData(x=arr, y=arr)
plot = Plot(data)
plot_renderer = plot.add_xy_plot(
'x', 'y', plot.renderer_map[renderer]
)[0]

screen_point = plot_renderer.map_screen((-1, 1))

self.assertEqual(type(screen_point), np.ndarray)
self.assertEqual(screen_point.shape, (1, 2))

screen_point = plot_renderer.map_screen([])

self.assertEqual(type(screen_point), np.ndarray)
self.assertEqual(screen_point.shape, (0, 2))

# serves as a regression test for enthought/chaco#272
def test_xy_plot_map_screen(self):
renderers = ["line", "scatter", "bar", "polygon"]
for renderer in renderers:
self.check_map_screen(renderer)


class EmptyLinePlot(HasTraits):
plot = Instance(Plot)
Expand Down

0 comments on commit 63f3bd9

Please sign in to comment.