Skip to content

Commit

Permalink
Merge pull request #2453 from astrofrog/custom-stretches-instances
Browse files Browse the repository at this point in the history
Make stretches be customizable on a layer by layer basis
  • Loading branch information
astrofrog authored Nov 30, 2023
2 parents 9caaa9c + 83e84d6 commit 127837a
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 30 deletions.
8 changes: 4 additions & 4 deletions glue/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,10 +836,10 @@ def __iter__(self):
from astropy.visualization import (LinearStretch, SqrtStretch, AsinhStretch,
LogStretch)
stretches = StretchRegistry()
stretches.add('linear', LinearStretch(), display='Linear')
stretches.add('sqrt', SqrtStretch(), display='Square Root')
stretches.add('arcsinh', AsinhStretch(), display='Arcsinh')
stretches.add('log', LogStretch(), display='Logarithmic')
stretches.add('linear', LinearStretch, display='Linear')
stretches.add('sqrt', SqrtStretch, display='Square Root')
stretches.add('arcsinh', AsinhStretch, display='Arcsinh')
stretches.add('log', LogStretch, display='Logarithmic')

# Backward-compatibility
qglue_parser = cli_parser
Expand Down
13 changes: 12 additions & 1 deletion glue/viewers/common/layer_artist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from echo import keep_in_sync, CallbackProperty
from echo import keep_in_sync, CallbackProperty, CallbackDict
from glue.core.layer_artist import LayerArtistBase
from glue.viewers.common.state import LayerState
from glue.core.message import LayerArtistVisibilityMessage
Expand Down Expand Up @@ -78,4 +78,15 @@ def pop_changed_properties(self):
self._last_viewer_state.update(self._viewer_state.as_dict())
self._last_layer_state.update(self.state.as_dict())

# If any of the items are CallbackDict, we make a copy otherwise both
# the 'last' and new values will remain the same.

for key, value in self._last_viewer_state.items():
if isinstance(value, CallbackDict):
self._last_viewer_state[key] = dict(value)

for key, value in self._last_layer_state.items():
if isinstance(value, CallbackDict):
self._last_layer_state[key] = dict(value)

return changed
47 changes: 47 additions & 0 deletions glue/viewers/common/stretch_state_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from glue.config import stretches
from glue.viewers.matplotlib.state import (
DeferredDrawDictCallbackProperty as DDDCProperty,
DeferredDrawSelectionCallbackProperty as DDSCProperty,
)

__all__ = ["StretchStateMixin"]


class StretchStateMixin:
stretch = DDSCProperty(
docstring="The stretch used to render the layer, "
"which should be one of ``linear``, "
"``sqrt``, ``log``, or ``arcsinh``"
)
stretch_parameters = DDDCProperty(
docstring="Keyword arguments to pass to the stretch"
)

_stretch_set_up = False

def setup_stretch_callback(self):
type(self).stretch.set_choices(self, list(stretches.members))
type(self).stretch.set_display_func(self, stretches.display_func)
self._reset_stretch()
self.add_callback("stretch", self._reset_stretch)
self.add_callback("stretch_parameters", self._sync_stretch_parameters)
self._stretch_set_up = True

@property
def stretch_object(self):
if not self._stretch_set_up:
raise Exception("setup_stretch_callback has not been called")
return self._stretch_object

def _sync_stretch_parameters(self, *args):
for key, value in self.stretch_parameters.items():
if hasattr(self._stretch_object, key):
setattr(self._stretch_object, key, value)
else:
raise ValueError(
f"Stretch object {self._stretch_object.__class__.__name__} has no attribute {key}"
)

def _reset_stretch(self, *args):
self._stretch_object = stretches.members[self.stretch]()
self.stretch_parameters.clear()
56 changes: 56 additions & 0 deletions glue/viewers/common/tests/test_stretch_state_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

from astropy.visualization import LinearStretch, LogStretch

from glue.core.state_objects import State
from glue.viewers.common.stretch_state_mixin import StretchStateMixin


class ExampleStateWithStretch(State, StretchStateMixin):
pass


def test_not_set_up():
state = ExampleStateWithStretch()
with pytest.raises(Exception, match="setup_stretch_callback has not been called"):
state.stretch_object


class TestStretchStateMixin:
def setup_method(self, method):
self.state = ExampleStateWithStretch()
self.state.setup_stretch_callback()

def test_defaults(self):
assert self.state.stretch == "linear"
assert len(self.state.stretch_parameters) == 0
assert isinstance(self.state.stretch_object, LinearStretch)

def test_change_stretch(self):
self.state.stretch = "log"
assert self.state.stretch == "log"
assert len(self.state.stretch_parameters) == 0
assert isinstance(self.state.stretch_object, LogStretch)

def test_invalid_parameter(self):
with pytest.raises(
ValueError, match="Stretch object LinearStretch has no attribute foo"
):
self.state.stretch_parameters["foo"] = 1

def test_set_parameter(self):
self.state.stretch = "log"

assert self.state.stretch_object.exp == 1000

# Setting the stretch parameter 'exp' is synced with the stretch object attribute
self.state.stretch_parameters["exp"] = 200
assert self.state.stretch_object.exp == 200

# Changing stretch resets the stretch parameter dictionary
self.state.stretch = "linear"
assert len(self.state.stretch_parameters) == 0

# And there is no memory of previous parameters
self.state.stretch = "log"
assert self.state.stretch_object.exp == 1000
5 changes: 4 additions & 1 deletion glue/viewers/image/composite_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def __call__(self, bounds=None):

interval = ManualInterval(*layer['clim'])
contrast_bias = ContrastBiasStretch(layer['contrast'], layer['bias'])
stretch = stretches.members[layer['stretch']]
if isinstance(layer['stretch'], str):
stretch = stretches.members[layer['stretch']]()
else:
stretch = layer['stretch']

if callable(layer['array']):
array = layer['array'](bounds=bounds)
Expand Down
4 changes: 2 additions & 2 deletions glue/viewers/image/layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _update_visual_attributes(self):
contrast=self.state.contrast,
bias=self.state.bias,
alpha=self.state.alpha,
stretch=self.state.stretch)
stretch=self.state.stretch_object)

self.composite_image.invalidate_cache()

Expand All @@ -193,7 +193,7 @@ def _update_image(self, force=False, **kwargs):
if force or any(prop in changed for prop in ('v_min', 'v_max', 'contrast',
'bias', 'alpha', 'color_mode',
'cmap', 'color', 'zorder',
'visible', 'stretch')):
'visible', 'stretch', 'stretch_parameters')):
self._update_visual_attributes()

@defer_draw
Expand Down
11 changes: 4 additions & 7 deletions glue/viewers/image/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict

from glue.core import BaseData
from glue.config import colormaps, stretches
from glue.config import colormaps
from glue.viewers.matplotlib.state import (MatplotlibDataViewerState,
MatplotlibLayerState,
DeferredDrawCallbackProperty as DDCProperty,
Expand All @@ -12,6 +12,7 @@
from echo import delay_callback
from glue.core.data_combo_helper import ManualDataComboHelper, ComponentIDComboHelper
from glue.core.exceptions import IncompatibleDataException
from glue.viewers.common.stretch_state_mixin import StretchStateMixin

__all__ = ['ImageViewerState', 'ImageLayerState', 'ImageSubsetLayerState', 'AggregateSlice']

Expand Down Expand Up @@ -481,7 +482,7 @@ def slice_to_bound(slc, size):
return image


class ImageLayerState(BaseImageLayerState):
class ImageLayerState(BaseImageLayerState, StretchStateMixin):
"""
A state class that includes all the attributes for data layers in an image plot.
"""
Expand All @@ -495,9 +496,6 @@ class ImageLayerState(BaseImageLayerState):
bias = DDCProperty(0.5, docstring='A constant value that is added to the '
'layer before rendering')
cmap = DDCProperty(docstring='The colormap used to render the layer')
stretch = DDSCProperty(docstring='The stretch used to render the layer, '
'which should be one of ``linear``, '
'``sqrt``, ``log``, or ``arcsinh``')
global_sync = DDCProperty(False, docstring='Whether the color and transparency '
'should be synced with the global '
'color and transparency for the data')
Expand Down Expand Up @@ -525,8 +523,7 @@ def __init__(self, layer=None, viewer_state=None, **kwargs):
ImageLayerState.percentile.set_choices(self, [100, 99.5, 99, 95, 90, 'Custom'])
ImageLayerState.percentile.set_display_func(self, percentile_display.get)

ImageLayerState.stretch.set_choices(self, list(stretches.members))
ImageLayerState.stretch.set_display_func(self, stretches.display_func)
self.setup_stretch_callback()

self.add_callback('global_sync', self._update_syncing)
self.add_callback('layer', self._update_attribute)
Expand Down
16 changes: 15 additions & 1 deletion glue/viewers/matplotlib/state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from echo import CallbackProperty, SelectionCallbackProperty, keep_in_sync, delay_callback
from echo import (CallbackProperty,
SelectionCallbackProperty,
DictCallbackProperty,
keep_in_sync, delay_callback)

from matplotlib.colors import to_rgba

Expand Down Expand Up @@ -35,6 +38,17 @@ def notify(self, *args, **kwargs):
super(DeferredDrawSelectionCallbackProperty, self).notify(*args, **kwargs)


class DeferredDrawDictCallbackProperty(DictCallbackProperty):
"""
A callback property where drawing is deferred until
after notify has called all callback functions.
"""

@defer_draw
def notify(self, *args, **kwargs):
super(DeferredDrawDictCallbackProperty, self).notify(*args, **kwargs)


VALID_WEIGHTS = ['light', 'normal', 'medium', 'semibold', 'bold', 'heavy', 'black']


Expand Down
8 changes: 4 additions & 4 deletions glue/viewers/scatter/layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@

# We keep the following so that scripts exported with previous versions of glue
# continue to work, as they imported STRETCHES from here.
STRETCHES = stretches.members
STRETCHES = {key: value() for key, value in stretches.members.items()}

CMAP_PROPERTIES = set(['cmap_mode', 'cmap_att', 'cmap_vmin', 'cmap_vmax', 'cmap'])
MARKER_PROPERTIES = set(['size_mode', 'size_att', 'size_vmin', 'size_vmax', 'size_scaling', 'size', 'fill'])
LINE_PROPERTIES = set(['linewidth', 'linestyle'])
DENSITY_PROPERTIES = set(['dpi', 'stretch', 'density_contrast'])
DENSITY_PROPERTIES = set(['dpi', 'stretch', 'stretch_parameters', 'density_contrast'])
VISUAL_PROPERTIES = (CMAP_PROPERTIES | MARKER_PROPERTIES | DENSITY_PROPERTIES |
LINE_PROPERTIES | set(['color', 'alpha', 'zorder', 'visible']))

Expand Down Expand Up @@ -371,8 +371,8 @@ def _update_visual_attributes(self, changed, force=False):
c = ensure_numerical(self.layer[self.state.cmap_att].ravel())
set_mpl_artist_cmap(self.density_artist, c, self.state)

if force or 'stretch' in changed:
self.density_artist.set_norm(ImageNormalize(stretch=stretches.members[self.state.stretch]))
if force or 'stretch' in changed or 'stretch_parameters' in changed:
self.density_artist.set_norm(ImageNormalize(stretch=self.state.stretch_object))

if force or 'dpi' in changed:
self.density_artist.set_dpi(self._viewer_state.dpi)
Expand Down
5 changes: 2 additions & 3 deletions glue/viewers/scatter/python_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def python_export_scatter_layer(layer, *args):
if layer.state.density_map:

imports += ["from mpl_scatter_density import ScatterDensityArtist"]
imports += ["from glue.config import stretches"]
imports += ["from glue.viewers.scatter.layer_artist import DensityMapLimits"]
imports += ["from glue.viewers.scatter.layer_artist import DensityMapLimits, STRETCHES"]
imports += ["from astropy.visualization import ImageNormalize"]

script += "density_limits = DensityMapLimits()\n"
Expand All @@ -92,7 +91,7 @@ def python_export_scatter_layer(layer, *args):
options['color'] = layer.state.color
options['vmin'] = code('density_limits.min')
options['vmax'] = code('density_limits.max')
options['norm'] = code("ImageNormalize(stretch=stretches.members['{0}'])".format(layer.state.stretch))
options['norm'] = code("ImageNormalize(stretch=STRETCHES['{0}'])".format(layer.state.stretch))
else:
options['c'] = code("layer_data['{0}']".format(layer.state.cmap_att.label))
options['cmap'] = code("plt.cm.{0}".format(layer.state.cmap.name))
Expand Down
11 changes: 4 additions & 7 deletions glue/viewers/scatter/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from glue.core import BaseData, Subset

from glue.config import colormaps, stretches
from glue.config import colormaps
from glue.viewers.matplotlib.state import (MatplotlibDataViewerState,
MatplotlibLayerState,
DeferredDrawCallbackProperty as DDCProperty,
Expand All @@ -13,6 +13,7 @@
from echo import keep_in_sync, delay_callback
from glue.core.data_combo_helper import ComponentIDComboHelper, ComboHelper
from glue.core.exceptions import IncompatibleAttribute
from glue.viewers.common.stretch_state_mixin import StretchStateMixin

from matplotlib.projections import get_projection_names

Expand Down Expand Up @@ -204,7 +205,7 @@ def display_func_slow(x):
return x


class ScatterLayerState(MatplotlibLayerState):
class ScatterLayerState(MatplotlibLayerState, StretchStateMixin):
"""
A state class that includes all the attributes for layers in a scatter plot.
"""
Expand Down Expand Up @@ -235,9 +236,6 @@ class ScatterLayerState(MatplotlibLayerState):
# Density map

density_map = DDCProperty(False, docstring="Whether to show the points as a density map")
stretch = DDSCProperty(default='log', docstring='The stretch used to render the layer, '
'which should be one of ``linear``, '
'``sqrt``, ``log``, or ``arcsinh``')
density_contrast = DDCProperty(1, docstring="The dynamic range of the density map")

# Note that we keep the dpi in the viewer state since we want it to always
Expand Down Expand Up @@ -330,8 +328,7 @@ def __init__(self, viewer_state=None, layer=None, **kwargs):
ScatterLayerState.vector_origin.set_choices(self, ['tail', 'middle', 'tip'])
ScatterLayerState.vector_origin.set_display_func(self, vector_origin_display.get)

ScatterLayerState.stretch.set_choices(self, ['linear', 'sqrt', 'arcsinh', 'log'])
ScatterLayerState.stretch.set_display_func(self, stretches.display_func)
self.setup_stretch_callback()

if self.viewer_state is not None:
self.viewer_state.add_callback('x_att', self._on_xy_change, priority=10000)
Expand Down

0 comments on commit 127837a

Please sign in to comment.