Skip to content

Commit

Permalink
Merge pull request #2456 from jfoster17/support-regions-in-imageviewer
Browse files Browse the repository at this point in the history
Support regions in imageviewer
  • Loading branch information
astrofrog authored Nov 10, 2023
2 parents 4de2cfd + f87062c commit 881345f
Show file tree
Hide file tree
Showing 4 changed files with 452 additions and 5 deletions.
16 changes: 13 additions & 3 deletions glue/viewers/image/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from glue.core.subset import roi_to_subset_state
from glue.core.coordinates import Coordinates, LegacyCoordinates
from glue.core.coordinate_helpers import dependent_axes
from glue.core.data_region import RegionData

from glue.viewers.scatter.layer_artist import ScatterLayerArtist
from glue.viewers.scatter.layer_artist import ScatterLayerArtist, ScatterRegionLayerArtist
from glue.viewers.image.layer_artist import ImageLayerArtist, ImageSubsetLayerArtist
from glue.viewers.image.compat import update_image_viewer_state

Expand Down Expand Up @@ -172,15 +173,24 @@ def _scatter_artist(self, axes, state, layer=None, layer_state=None):
raise Exception("Can only add a scatter plot overlay once an image is present")
return ScatterLayerArtist(axes, state, layer=layer, layer_state=None)

def _region_artist(self, axes, state, layer=None, layer_state=None):
if len(self._layer_artist_container) == 0:
raise Exception("Can only add a region plot overlay once an image is present")
return ScatterRegionLayerArtist(axes, state, layer=layer, layer_state=None)

def get_data_layer_artist(self, layer=None, layer_state=None):
if layer.ndim == 1:
if isinstance(layer, RegionData):
cls = self._region_artist
elif layer.ndim == 1:
cls = self._scatter_artist
else:
cls = ImageLayerArtist
return self.get_layer_artist(cls, layer=layer, layer_state=layer_state)

def get_subset_layer_artist(self, layer=None, layer_state=None):
if layer.ndim == 1:
if isinstance(layer.data, RegionData):
cls = self._region_artist
elif layer.ndim == 1:
cls = self._scatter_artist
else:
cls = ImageSubsetLayerArtist
Expand Down
178 changes: 177 additions & 1 deletion glue/viewers/scatter/layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@

from glue.config import stretches
from glue.utils import defer_draw, ensure_numerical, datetime64_to_mpl
from glue.viewers.scatter.state import ScatterLayerState
from glue.viewers.scatter.state import ScatterLayerState, ScatterRegionLayerState
from glue.viewers.scatter.python_export import python_export_scatter_layer
from glue.viewers.scatter.plot_polygons import (UpdateableRegionCollection,
get_geometry_type,
_sanitize_geoms,
_PolygonPatch,
transform_shapely)
from glue.viewers.matplotlib.layer_artist import MatplotlibLayerArtist
from glue.core.exceptions import IncompatibleAttribute
from glue.core.data import BaseData

from matplotlib.lines import Line2D
from shapely.ops import transform

# We keep the following so that scripts exported with previous versions of glue
# continue to work, as they imported STRETCHES from here.
Expand Down Expand Up @@ -594,3 +601,172 @@ def _use_plot_artist(self):
res = self.state.cmap_mode == 'Fixed' and self.state.size_mode == 'Fixed'
return res and (not hasattr(self._viewer_state, 'plot_mode') or
not self._viewer_state.plot_mode == 'polar')


class ScatterRegionLayerArtist(MatplotlibLayerArtist):

_layer_state_cls = ScatterRegionLayerState
# _python_exporter = python_export_scatter_layer # TODO: Update this to work with regions

def __init__(self, axes, viewer_state, layer_state=None, layer=None):

super().__init__(axes, viewer_state,
layer_state=layer_state, layer=layer)
self._viewer_state.add_global_callback(self._update_scatter_region)
self.state.add_global_callback(self._update_scatter_region)
self._set_axes(axes)

def _set_axes(self, axes):
self.axes = axes
self.region_collection = UpdateableRegionCollection([])
self.axes.add_collection(self.region_collection)
# This is a little unnecessary, but keeps code more parallel
self.mpl_artists = [self.region_collection]

@defer_draw
def _update_data(self):
# Layer artist has been cleared already
if len(self.mpl_artists) == 0:
return

if self.layer is not None:
if isinstance(self.layer, BaseData):
data = self.layer
else:
data = self.layer.data
region_att = data.extended_component_id

try:
# These must be special attributes that are linked to a region_att
if ((not data.linked_to_center_comp(self._viewer_state.x_att)) and
(not data.linked_to_center_comp(self._viewer_state.x_att_world))):
raise IncompatibleAttribute
x = ensure_numerical(self.layer[self._viewer_state.x_att].ravel())
except (IncompatibleAttribute, IndexError):
# The following includes a call to self.clear()
self.disable_invalid_attributes(self._viewer_state.x_att)
return
else:
self.enable()

try:
# These must be special attributes that are linked to a region_att
if ((not data.linked_to_center_comp(self._viewer_state.y_att)) and
(not data.linked_to_center_comp(self._viewer_state.y_att_world))):
raise IncompatibleAttribute
y = ensure_numerical(self.layer[self._viewer_state.y_att].ravel())
except (IncompatibleAttribute, IndexError):
# The following includes a call to self.clear()
self.disable_invalid_attributes(self._viewer_state.y_att)
return
else:
self.enable()

regions = self.layer[region_att]

# If we are using world coordinates (i.e. the regions are specified in world coordinates)
# we need to transform the geometries of the regions into pixel coordinates for display
# Note that this calls a custom version of the transform function from shapely
# to accomodate glue WCS objects
if self._viewer_state._display_world:
# First, convert to world coordinates
tfunc = data.get_transform_to_cids([self._viewer_state.x_att_world, self._viewer_state.y_att_world])
regions = np.array([transform(tfunc, g) for g in regions])

# Then convert to pixels for display
world2pix = self._viewer_state.reference_data.coords.world_to_pixel_values
regions = np.array([transform_shapely(world2pix, g) for g in regions])
else:
tfunc = data.get_transform_to_cids([self._viewer_state.x_att, self._viewer_state.y_att])
regions = np.array([transform(tfunc, g) for g in regions])

# decompose GeometryCollections
geoms, multiindex = _sanitize_geoms(regions, prefix="Geom")
self.multiindex_geometry = multiindex

geom_types = get_geometry_type(geoms)
poly_idx = np.asarray((geom_types == "Polygon") | (geom_types == "MultiPolygon"))
polys = geoms[poly_idx]

# decompose MultiPolygons
geoms, multiindex = _sanitize_geoms(polys, prefix="Multi")
self.region_collection.patches = [_PolygonPatch(poly) for poly in geoms]

self.geoms = geoms
self.multiindex = multiindex

@defer_draw
def _update_visual_attributes(self, changed, force=False):

if not self.enabled:
return

if self.state.cmap_mode == 'Fixed':
if force or 'color' in changed or 'cmap_mode' in changed or 'fill' in changed:
self.region_collection.set_array(None)
if self.state.fill:
self.region_collection.set_facecolors(self.state.color)
self.region_collection.set_edgecolors('none')
else:
self.region_collection.set_facecolors('none')
self.region_collection.set_edgecolors(self.state.color)
elif force or any(prop in changed for prop in CMAP_PROPERTIES) or 'fill' in changed:
self.region_collection.set_edgecolors(None)
self.region_collection.set_facecolors(None)
c = ensure_numerical(self.layer[self.state.cmap_att].ravel())
c_values = np.take(c, self.multiindex_geometry, axis=0) # Decompose Geoms
c_values = np.take(c_values, self.multiindex, axis=0) # Decompose MultiPolys
set_mpl_artist_cmap(self.region_collection, c_values, self.state)
if self.state.fill:
self.region_collection.set_edgecolors('none')
else:
self.region_collection.set_facecolors('none')

for artist in [self.region_collection]:

if artist is None:
continue

if force or 'alpha' in changed:
artist.set_alpha(self.state.alpha)

if force or 'zorder' in changed:
artist.set_zorder(self.state.zorder)

if force or 'visible' in changed:
artist.set_visible(self.state.visible)

self.redraw()

@defer_draw
def _update_scatter_region(self, force=False, **kwargs):

if (self._viewer_state.x_att is None or
self._viewer_state.y_att is None or
self.state.layer is None):
return

# NOTE: we need to evaluate this even if force=True so that the cache
# of updated properties is up to date after this method has been called.
changed = self.pop_changed_properties()

full_sphere = getattr(self._viewer_state, 'using_full_sphere', False)
change_from_limits = full_sphere and len(changed & LIMIT_PROPERTIES) > 0
if force or change_from_limits or len(changed & DATA_PROPERTIES) > 0:
self._update_data()
force = True

if force or len(changed & VISUAL_PROPERTIES) > 0:
self._update_visual_attributes(changed, force=force)

@defer_draw
def update(self):
self._update_scatter_region(force=True)
self.redraw()

@defer_draw
def _on_components_changed(self, components_changed):
for limit_helper in [self.state.cmap_lim_helper]:
if limit_helper.attribute in components_changed:
limit_helper.update_values('attribute')
self.redraw()
157 changes: 157 additions & 0 deletions glue/viewers/scatter/plot_polygons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
Some lightly edited code from geopandas.plotting.py for efficiently
plotting multiple polygons in matplotlib.
"""
# Copyright (c) 2013-2022, GeoPandas developers.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
# * Neither the name of GeoPandas nor the names of its contributors may
# be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import numpy as np
import shapely
from matplotlib.collections import PatchCollection
from shapely.errors import GeometryTypeError


class UpdateableRegionCollection(PatchCollection):
"""
Allow paths in PatchCollection to be modified after creation.
"""

def __init__(self, patches, *args, **kwargs):
self.patches = patches
PatchCollection.__init__(self, patches, *args, **kwargs)

def get_paths(self):
self.set_paths(self.patches)
return self._paths


def _sanitize_geoms(geoms, prefix="Multi"):
"""
Returns Series like geoms and index, except that any Multi geometries
are split into their components and indices are repeated for all component
in the same Multi geometry. At the same time, empty or missing geometries are
filtered out. Maintains 1:1 matching of geometry to value.
Prefix specifies type of geometry to be flatten. 'Multi' for MultiPoint and similar,
"Geom" for GeometryCollection.
Returns
-------
components : list of geometry
component_index : index array
indices are repeated for all components in the same Multi geometry
"""
# TODO(shapely) look into simplifying this with
# shapely.get_parts(geoms, return_index=True) from shapely 2.0
components, component_index = [], []

geom_types = get_geometry_type(geoms).astype("str")

if (
not np.char.startswith(geom_types, prefix).any()
# and not geoms.is_empty.any()
# and not geoms.isna().any()
):
return geoms, np.arange(len(geoms))

for ix, (geom, geom_type) in enumerate(zip(geoms, geom_types)):
if geom is not None and geom_type.startswith(prefix):
for poly in geom.geoms:
components.append(poly)
component_index.append(ix)
elif geom is None:
continue
else:
components.append(geom)
component_index.append(ix)

return components, np.array(component_index)


def get_geometry_type(data):
_names = {
"MISSING": None,
"NAG": None,
"POINT": "Point",
"LINESTRING": "LineString",
"LINEARRING": "LinearRing",
"POLYGON": "Polygon",
"MULTIPOINT": "MultiPoint",
"MULTILINESTRING": "MultiLineString",
"MULTIPOLYGON": "MultiPolygon",
"GEOMETRYCOLLECTION": "GeometryCollection",
}

type_mapping = {p.value: _names[p.name] for p in shapely.GeometryType}
geometry_type_ids = list(type_mapping.keys())
geometry_type_values = np.array(list(type_mapping.values()), dtype=object)
res = shapely.get_type_id(data)
return geometry_type_values[np.searchsorted(geometry_type_ids, res)]


def transform_shapely(func, geom):
"""
A simplified/modified version of shapely.ops.transform where the func
call signature is tuned for the coordinate transform functions
coming from glue.
"""
if geom.is_empty:
return geom
if geom.geom_type in ("Point", "LineString", "LinearRing", "Polygon"):
if geom.geom_type in ("Point", "LineString", "LinearRing"):
return type(geom)(func(geom.coords))
elif geom.geom_type == "Polygon":
shell = type(geom.exterior)(func(geom.exterior.coords))
holes = list(
type(ring)(func(ring.coords))
for ring in geom.interiors
)
return type(geom)(shell, holes)

elif geom.geom_type.startswith("Multi") or geom.geom_type == "GeometryCollection":
return type(geom)([transform_shapely(func, part) for part in geom.geoms])
else:
raise GeometryTypeError(f"Type {geom.geom_type!r} not recognized")


def _PolygonPatch(polygon, **kwargs):
"""Constructs a matplotlib patch from a Polygon geometry
The `kwargs` are those supported by the matplotlib.patches.PathPatch class
constructor. Returns an instance of matplotlib.patches.PathPatch.
Example (using Shapely Point and a matplotlib axes)::
b = shapely.geometry.Point(0, 0).buffer(1.0)
patch = _PolygonPatch(b, fc='blue', ec='blue', alpha=0.5)
ax.add_patch(patch)
GeoPandas originally relied on the descartes package by Sean Gillies
(BSD license, https://pypi.org/project/descartes) for PolygonPatch, but
this dependency was removed in favor of the below matplotlib code.
"""
from matplotlib.patches import PathPatch
from matplotlib.path import Path

path = Path.make_compound_path(
Path(np.asarray(polygon.exterior.coords)[:, :2]),
*[Path(np.asarray(ring.coords)[:, :2]) for ring in polygon.interiors],
)
return PathPatch(path, **kwargs)
Loading

0 comments on commit 881345f

Please sign in to comment.