Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-7847: Add mid-level drivers for measurement algorithms #1020

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
379 changes: 379 additions & 0 deletions python/lsst/pipe/tasks/measurementDriver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,379 @@
# This file is part of pipe_tasks.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

__all__ = ["MeasurementDriverConfig", "MeasurementDriverTask"]

import logging

import lsst.afw.image as afwImage
import lsst.afw.table as afwTable
import lsst.meas.algorithms as measAlgorithms
import lsst.meas.base as measBase
import lsst.meas.deblender as measDeblender
import lsst.meas.extensions.scarlet as scarlet
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
import numpy as np

logging.basicConfig(level=logging.INFO)


class MeasurementDriverConfig(pexConfig.Config):
"""Configuration parameters for `MeasurementDriverTask`."""

# To generate catalog ids consistently across subtasks.
id_generator = measBase.DetectorVisitIdGeneratorConfig.make_field()

detection = pexConfig.ConfigurableField(
target=measAlgorithms.SourceDetectionTask,
doc="Task to detect sources to return in the output catalog.",
)

deblender = pexConfig.ChoiceField[str](
doc="The deblender to use.",
default="meas_deblender",
allowed={"meas_deblender": "Deblend using meas_deblender", "scarlet": "Deblend using scarlet"},
)
Comment on lines +50 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that this option is necessary. You already set the deblender in deblend, so having a config option that could accidentally be misaligned with the target deblend seems like an unnecessary option that could lead to user error.

This will also allow you to remove the other methods implemented in this class below.


deblend = pexConfig.ConfigurableField(
target=measDeblender.SourceDeblendTask, doc="Split blended sources into their components."
)

measurement = pexConfig.ConfigurableField(
target=measBase.SingleFrameMeasurementTask,
doc="Task to measure sources to return in the output catalog.",
)

def __setattr__(self, key, value):
"""Intercept attribute setting to trigger setDefaults when relevant
fields change.
"""
super().__setattr__(key, value)

# This is to ensure the deblend target is set correctly whenever the
# deblender is changed. This is required because `setDefaults` is not
# automatically invoked during reconfiguration.
if key == "deblender":
self.setDefaults()

def validate(self):
super().validate()

# Ensure the deblend target aligns with the selected deblender.
if self.deblender == "scarlet":
assert self.deblend.target == scarlet.ScarletDeblendTask
elif self.deblender == "meas_deblender":
assert self.deblend.target == measDeblender.SourceDeblendTask
elif self.deblender is not None:
raise ValueError(f"Invalid deblender value: {self.deblender}")

def setDefaults(self):
super().setDefaults()
if self.deblender == "scarlet":
self.deblend.retarget(scarlet.ScarletDeblendTask)
elif self.deblender == "meas_deblender":
self.deblend.retarget(measDeblender.SourceDeblendTask)
Comment on lines +88 to +93
Copy link
Contributor Author

@enourbakhsh enourbakhsh Jan 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to rename this method to setDeblender() and just call it in the setDefaults() method? This way our __setattr__ method becomes:

def __setattr__(self, key, value):
    super().__setattr__(key, value)
    if key == "deblender":
        self.setDeblender()

It just makes more sense to me if I add a setDeblender() method because we might want to put other things in setDefaults() that don’t need to be set every time the deblender is changed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my previous comment



class MeasurementDriverTask(pipeBase.Task):
"""A mid-level driver for running detection, deblending (optional), and
measurement algorithms in one go.

This driver simplifies the process of applying a small set of measurement
algorithms to images by abstracting away schema and table boilerplate. It
is particularly suited for simple use cases, such as processing images
without neighbor-noise-replacement or extensive configuration.

Designed to streamline the measurement framework, this class integrates
detection, deblending (if enabled), and measurement into a single workflow.

Parameters
----------
schema : `~lsst.afw.table.Schema`
Schema used to create the output `~lsst.afw.table.SourceCatalog`,
modified in place with fields that will be written by this task.
**kwargs : `dict`
Additional kwargs to pass to lsst.pipe.base.Task.__init__()

Examples
--------
Here is an example of how to use this class to run detection, deblending,
and measurement on a given exposure:
>>> from lsst.pipe.tasks.measurementDriver import MeasurementDriverTask
>>> import lsst.meas.extensions.shapeHSM # To register its plugins
>>> config = MeasurementDriverTask().ConfigClass()
>>> config.detection.thresholdValue = 5.5
>>> config.deblender = "meas_deblender"
>>> config.deblend.tinyFootprintSize = 3
>>> config.measurement.plugins.names |= [
... "base_SdssCentroid",
... "base_SdssShape",
... "ext_shapeHSM_HsmSourceMoments",
... ]
>>> config.measurement.slots.psfFlux = None
>>> config.measurement.doReplaceWithNoise = False
>>> exposure = butler.get("deepCoadd", dataId=...)
>>> driver = MeasurementDriverTask(config=config)
>>> catalog = driver.run(exposure)
>>> catalog.writeFits("meas_catalog.fits")
"""

ConfigClass = MeasurementDriverConfig
_DefaultName = "measurementDriver"

def __init__(self, schema=None, **kwargs):
super().__init__(**kwargs)

if schema is None:
# Create a minimal schema that will be extended by tasks.
self.schema = afwTable.SourceTable.makeMinimalSchema()
else:
self.schema = schema

# Add coordinate error fields to the schema (this is to avoid errors
# such as: "Field with name 'coord_raErr' not found with type 'F'").
afwTable.CoordKey.addErrorFields(self.schema)

self.subtasks = ["detection", "deblend", "measurement"]

def make_subtasks(self):
"""Create subtasks based on the current configuration."""
for name in self.subtasks:
self.makeSubtask(name, schema=self.schema)

def run(
self,
image,
bands=None,
band=None,
mask=None,
variance=None,
psf=None,
wcs=None,
photo_calib=None,
id_generator=None,
):
Comment on lines +162 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend using type hints in your function signature, then you can remove the types from the docstring of Parameters. This prevents bitrotting and will also help people using an IDE with a mypy-like linter.

"""Run detection, optional deblending, and measurement on a given
image.

Parameters
----------
image: `~lsst.afw.image.Exposure` or `~lsst.afw.image.MaskedImage` or
`~lsst.afw.image.Image` or `np.ndarray` or
`~lsst.afw.image.MultibandExposure` or
`list` of `~lsst.afw.image.Exposure`
The image on which to detect, deblend and measure sources. If
provided as a multiband exposure, or a list of `Exposure` objects,
it can be taken advantage of by the 'scarlet' deblender. When using
a list of `Exposure` objects, the ``bands`` parameter must also be
provided.
bands: `str` or `list` of `str`, optional
The bands of the input image. Required if ``image`` is provided as
a list of `Exposure` objects. Example: ["g", "r", "i", "z", "y"]
or "grizy".
band: `str`, optional
The target band of the image to use for detection and measurement.
Required when ``image`` is provided as a `MultibandExposure`, or a
list of `Exposure` objects.
mask: `~lsst.afw.image.Mask`, optional
The mask for the input image. Only used if ``image`` is provided
as an afw `Image` or a numpy `ndarray`.
variance: `~lsst.afw.image.Image`, optional
The variance image for the input image. Only used if ``image`` is
provided as an afw `Image` or a numpy `ndarray`.
psf: `~lsst.afw.detection.Psf`, optional
The PSF model for the input image. Will be ignored if ``image`` is
provided as an `Exposure`, `MultibandExposure`, or a list of
`Exposure` objects.
wcs: `~lsst.afw.image.Wcs`, optional
The World Coordinate System (WCS) model for the input image. Will
be ignored if ``image`` is provided as an `Exposure`,
`MultibandExposure`, or a list of `Exposure` objects.
photo_calib : `~lsst.afw.image.PhotoCalib`, optional
Photometric calibration model for the input image. Will be ignored
if ``image`` is provided as an `Exposure`, `MultibandExposure`, or
a list of `Exposure` objects.
Comment on lines +179 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this API is confusing. Instead of having three optional sets of parameters in run, one for Exposures, and one for Images (which requires a lot of additional parameters that are ignored if the image is an Exposure), and one for MultibandExposures, I recommend making run take an Exposure and create a new method, runFromImage that takes the parameters required if an image is passed and creates an exposure from those images to call the normal run. I think that this will make the API easier for users and simplify the codebase. See my other comments about single-band vs multiband

id_generator : `~lsst.meas.base.IdGenerator`, optional
Object that generates source IDs and provides random seeds.

Returns
-------
catalog : `~lsst.afw.table.SourceCatalog`
The source catalog with all requested measurements.
"""

# Only make the `deblend` subtask if it is enabled.
if self.config.deblender is None:
self.subtasks.remove("deblend")

# Validate the configuration before running the task.
self.config.validate()

# This guarantees the `run` method picks up the current subtask config.
self.make_subtasks()
# N.B. subtasks must be created here to handle reconfigurations, such
# as retargeting the `deblend` subtask, because the `makeSubtask`
# method locks in its config just before creating the subtask. If the
# subtask was already made in __init__ using the initial config, it
# cannot be retargeted now because retargeting happens at the config
# level, not the subtask level.

if id_generator is None:
id_generator = measBase.IdGenerator()

if isinstance(image, afwImage.MultibandExposure) or isinstance(image, list):
if self.config.deblender != "scarlet":
self.log.debug(
"Supplied a multiband exposure, or a list of exposures, while the deblender is set to "
f"'{self.config.deblender}'. A single exposure corresponding to target `band` will be "
"used for everything."
)
if band is None:
raise ValueError(
"The target `band` must be provided when using multiband exposures or a list of "
"exposures."
)
if isinstance(image, list):
if not all(isinstance(im, afwImage.Exposure) for im in image):
raise ValueError("All elements in the `image` list must be `Exposure` objects.")
if bands is None:
raise ValueError(
"The `bands` parameter must be provided if `image` is a list of `Exposure` objects."
)
if not isinstance(bands, (str, list)) or (
isinstance(bands, list) and not all(isinstance(b, str) for b in bands)
):
raise TypeError(
"The `bands` parameter must be a string or a list of strings if provided."
)
if len(bands) != len(image):
raise ValueError(
"The number of bands must match the number of `Exposure` objects in the list."
)
else:
if band is None:
band = "N/A" # Just a placeholder for single-band deblending
else:
self.log.warn("The target `band` is not required when the input image is not multiband.")
if bands is not None:
self.log.warn(
"The `bands` parameter will be ignored because the input image is not multiband."
)

if self.config.deblender == "scarlet":
if not isinstance(image, (afwImage.MultibandExposure, list, afwImage.Exposure)):
raise ValueError(
"The `image` parameter must be a `MultibandExposure`, a list of `Exposure` "
"objects, or a single `Exposure` when the deblender is set to 'scarlet'."
)
if isinstance(image, afwImage.Exposure):
# N.B. scarlet is designed to leverage multiband information to
# differentiate overlapping sources based on their spectral and
# spatial profiles. However, it can also run on a single band
# and still give better results than 'meas_deblender'.
self.log.debug(
"Supplied a single-band exposure, while the deblender is set to 'scarlet'."
"Make sure it was intended."
)

# Start with some image conversions if needed.
if isinstance(image, np.ndarray):
image = afwImage.makeImageFromArray(image)
if isinstance(mask, np.ndarray):
mask = afwImage.makeMaskFromArray(mask)
if isinstance(variance, np.ndarray):
variance = afwImage.makeImageFromArray(variance)
if isinstance(image, afwImage.Image):
image = afwImage.makeMaskedImage(image, mask, variance)

# Avoid type checker errors by being explicit from here on.
exposure: afwImage.Exposure

# Make sure we have an `Exposure` object to work with (potentially
# along with a `MultiBandExposure` for scarlet deblending).
if isinstance(image, afwImage.Exposure):
exposure = image
elif isinstance(image, afwImage.MaskedImage):
exposure = afwImage.makeExposure(image, wcs)
if psf is not None:
exposure.setPsf(psf)
if photo_calib is not None:
exposure.setPhotoCalib(photo_calib)
elif isinstance(image, list):
# Construct a multiband exposure for scarlet deblending.
exposures = afwImage.MultibandExposure.fromExposures(bands, image)
# Select the exposure of the desired band, which will be used for
# detection and measurement.
exposure = exposures[band]
elif isinstance(image, afwImage.MultibandExposure):
exposures = image
exposure = exposures[band]
else:
raise TypeError(f"Unsupported image type: {type(image)}")

# Create a source table into which detections will be placed.
table = afwTable.SourceTable.make(self.schema, id_generator.make_table_id_factory())

# Detect sources and get a source catalog.
self.log.info(f"Running detection on a {exposure.width}x{exposure.height} pixel image")
detections = self.detection.run(table, exposure)
catalog = detections.sources

# Deblend sources into their components and update the catalog.
if self.config.deblender is None:
self.log.info("Deblending is disabled; skipping deblending")
else:
self.log.info(
f"Running deblending via '{self.config.deblender}' on {len(catalog)} detection footprints"
)
if self.config.deblender == "meas_deblender":
self.deblend.run(exposure=exposure, sources=catalog)
elif self.config.deblender == "scarlet":
if not isinstance(image, (afwImage.MultibandExposure, list)):
# We need to have a multiband exposure to satisfy scarlet
# function's signature, even when using a single band.
exposures = afwImage.MultibandExposure.fromExposures([band], [exposure])
catalog, model_data = self.deblend.run(mExposure=exposures, mergedSources=catalog)
# The footprints need to be updated for the subsequent
# measurement.
scarlet.io.updateCatalogFootprints(
modelData=model_data,
catalog=catalog,
band=band,
imageForRedistribution=exposure,
removeScarletData=True,
updateFluxColumns=True,
)

# The deblender may not produce a contiguous catalog; ensure contiguity
# for the subsequent task.
if not catalog.isContiguous():
self.log.info("Catalog is not contiguous; making it contiguous")
catalog = catalog.copy(deep=True)

# Measure requested quantities on sources.
self.measurement.run(catalog, exposure)
self.log.info(
f"Measured {len(catalog)} sources and stored them in the output "
f"catalog containing {catalog.schema.getFieldCount()} fields"
)

return catalog
Loading