Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix BaseXYPlot map_screen #802

Merged
merged 8 commits into from
Jul 7, 2021
Merged
13 changes: 10 additions & 3 deletions chaco/base_xy_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
""" Defines the base class for XY plots.
"""
from math import sqrt
from numpy import around, array, isnan, transpose
from numpy import around, array, empty, isnan, transpose

# Enthought library imports
from enable.api import black_color_trait
Expand Down Expand Up @@ -353,9 +353,16 @@ def map_screen(self, data_array):

Implements the AbstractPlotRenderer interface.
"""
# data_array is Nx2 array
# ensure data_array is an N1 x ... Nk x 2 ndarray for some k >= 1
data_array = array(data_array)

if data_array.ndim == 1:
data_array = data_array.reshape(-1, 2)
if data_array.shape[-1] != 2:
raise ValueError("Input to map_screen must have shape (..., 2)")

if len(data_array) == 0:
return []
return empty(shape=(0, 2))

x_ary, y_ary = transpose(data_array)

Expand Down
2 changes: 1 addition & 1 deletion chaco/examples/demo/basic/hittest_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def normal_mouse_move(self, event):
else:
x, y = self.component.map_data((y, x))

x, y = self.line_plot.map_screen((x, y))
x, y = self.line_plot.map_screen((x, y))[0]
self.pt = self.line_plot.hittest((x, y), threshold=self.threshold)
self.request_redraw()

Expand Down
6 changes: 5 additions & 1 deletion chaco/plots/barplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
array,
compress,
column_stack,
empty,
invert,
isnan,
transpose,
Expand Down Expand Up @@ -176,9 +177,12 @@ def map_screen(self, data_array):

Implements the AbstractPlotRenderer interface.
"""
# ensure data_array is an Nx2 ndarray
data_array = array(data_array)
data_array = data_array.reshape(-1,2)
# data_array is Nx2 array
if len(data_array) == 0:
return []
return empty(shape=(0,2))
x_ary, y_ary = transpose(data_array)
sx = self.index_mapper.map_screen(x_ary)
sy = self.value_mapper.map_screen(y_ary)
Expand Down
4 changes: 2 additions & 2 deletions chaco/plots/errorbar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Thanks for using Enthought open source!

# Major library imports
from numpy import column_stack, compress, invert, isnan, transpose
from numpy import column_stack, compress, empty, invert, isnan, transpose
import logging

# Enthought library imports
Expand Down Expand Up @@ -49,7 +49,7 @@ def map_screen(self, data_array):
or (y, xlow, xhigh) depending on self.orientation.
"""
if len(data_array) == 0:
return []
return empty(shape=(0, 2))
elif data_array.shape[1] == 2:
return LinePlot.map_screen(self, data_array)
else:
Expand Down
2 changes: 1 addition & 1 deletion chaco/plots/lineplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def hittest(self, screen_pt, threshold=7.0, return_distance=False):
# screen_pt is one of the points in the lineplot
data_pt = (self.index.get_data()[ndx], self.value.get_data()[ndx])
if return_distance:
scrn_pt = self.map_screen(data_pt)
scrn_pt = self.map_screen(data_pt)[0]
dist = sqrt(
(screen_pt[0] - scrn_pt[0]) ** 2
+ (screen_pt[1] - scrn_pt[1]) ** 2
Expand Down
3 changes: 2 additions & 1 deletion chaco/plots/scatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
array,
asarray,
column_stack,
empty,
isfinite,
isnan,
nanargmin,
Expand Down Expand Up @@ -299,7 +300,7 @@ def map_screen(self, data_array):
"""
# data_array is Nx2 array
if len(data_array) == 0:
return []
return empty(shape=(0,2))

data_array = asarray(data_array)
if len(data_array.shape) == 1:
Expand Down
21 changes: 21 additions & 0 deletions chaco/plots/segment_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,27 @@ def hittest(self, *args, **kwargs):
def map_index(self, *args, **kwargs):
raise NotImplementedError()

def map_screen(self, data_array):
"""Maps an Nx2x2 array of data points into screen space and returns it
as an array.
Implements the AbstractPlotRenderer interface.
"""
# ensure data_array is an Nx2x2 ndarray
data_array = np.asarray(data_array)
data_array = data_array.reshape(-1, 2, 2)

if len(data_array) == 0:
return np.empty(shape=(0, 2, 2))

x_ary, y_ary = np.transpose(data_array)

sx = self.index_mapper.map_screen(x_ary)
sy = self.value_mapper.map_screen(y_ary)
if self.orientation == "h":
return np.transpose(np.array((sx, sy)))
else:
return np.transpose(np.array((sy, sx)))

def _gather_points(self):
"""Collects the data points that are within the bounds of the plot and
caches them.
Expand Down
37 changes: 37 additions & 0 deletions chaco/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import unittest

import numpy as np
from numpy import alltrue, arange, array

from enable.api import ComponentEditor
Expand Down Expand Up @@ -95,6 +96,18 @@ def test_segment_plot_color_width(self):
actual = gc.bmp_array[:, :, :]
self.assertFalse(alltrue(actual == 255))

def test_segment_plot_map_screen(self):
x = arange(10)
y = arange(1, 11)
data = ArrayPlotData(x=x, y=y)
plot = Plot(data)
plot_renderer = plot.plot(("x", "y"), "segment")[0]

screen_point = plot_renderer.map_screen([(0, 1), (1, 2)])

self.assertEqual(type(screen_point), np.ndarray)
self.assertEqual(screen_point.shape, (1, 2, 2))

def test_text_plot(self):
x = arange(10)
y = arange(1, 11)
Expand All @@ -109,6 +122,30 @@ def test_text_plot(self):
actual = gc.bmp_array[:, :, :]
self.assertFalse(alltrue(actual == 255))

def check_map_screen(self, renderer):
arr = arange(10)
data = ArrayPlotData(x=arr, y=arr)
plot = Plot(data)
plot_renderer = plot.add_xy_plot(
'x', 'y', plot.renderer_map[renderer]
)[0]

screen_point = plot_renderer.map_screen((-1, 1))

self.assertEqual(type(screen_point), np.ndarray)
self.assertEqual(screen_point.shape, (1, 2))

screen_point = plot_renderer.map_screen([])

self.assertEqual(type(screen_point), np.ndarray)
self.assertEqual(screen_point.shape, (0, 2))

# serves as a regression test for enthought/chaco#272
def test_xy_plot_map_screen(self):
renderers = ["line", "scatter", "bar", "polygon"]
for renderer in renderers:
self.check_map_screen(renderer)


class EmptyLinePlot(HasTraits):
plot = Instance(Plot)
Expand Down