diff --git a/doc/content.rst b/doc/content.rst
index 14503f43..26678212 100644
--- a/doc/content.rst
+++ b/doc/content.rst
@@ -104,6 +104,7 @@ This module calculates the wavefront error by solving the TIE.
* **CentroidDefault**: Default centroid class.
* **CentroidRandomWalk**: CentroidDefault child class to get the centroid of donut by the random walk model.
* **CentroidOtsu**: CentroidDefault child class to get the centroid of donut by the Otsu's method.
+* **CentroidConvolveTemplate**: CentroidDefault child class to get the centroids of one or more donuts in an image by convolution with a template donut.
* **BaseCwfsTestCase**: Base class for CWFS tests.
.. _lsst.ts.wep-modules_wep_deblend:
diff --git a/doc/uml/cwfsClass.uml b/doc/uml/cwfsClass.uml
index 20b4c2b8..7d7ddb84 100644
--- a/doc/uml/cwfsClass.uml
+++ b/doc/uml/cwfsClass.uml
@@ -5,8 +5,11 @@ Algorithm -- CompensableImage
CompensableImage ..> Instrument
CentroidDefault <|-- CentroidRandomWalk
CentroidDefault <|-- CentroidOtsu
+CentroidDefault <|-- CentroidConvolveTemplate
CentroidFindFactory ..> CentroidRandomWalk
CentroidFindFactory ..> CentroidOtsu
+CentroidFindFactory ..> CentroidConvolveTemplate
+CentroidConvolveTemplate *-- CentroidRandomWalk
Image ..> CentroidFindFactory
Image *-- CentroidDefault
BaseCwfsTestCase ..> CompensableImage
diff --git a/doc/versionHistory.rst b/doc/versionHistory.rst
index 99cbfa7d..3c39b820 100644
--- a/doc/versionHistory.rst
+++ b/doc/versionHistory.rst
@@ -6,6 +6,14 @@
Version History
##################
+.. _lsst.ts.wep-1.5.0:
+
+-------------
+1.5.0
+-------------
+
+* Add ``CentroidConvolveTemplate`` as a new centroid finding method.
+
.. _lsst.ts.wep-1.4.9:
-------------
diff --git a/policy/default.yaml b/policy/default.yaml
index e956a954..02c2513d 100644
--- a/policy/default.yaml
+++ b/policy/default.yaml
@@ -47,7 +47,7 @@ defocalDistInMm: 1.5
# Donut image size in pixel (default value at 1.5 mm)
donutImgSizeInPixel: 160
-# Centroid find algorithm. It can be "randomWalk" or "otsu"
+# Centroid find algorithm. It can be "randomWalk", "otsu", or "convolveTemplate"
centroidFindAlgo: randomWalk
# Camera mapper for the data butler to use
diff --git a/python/lsst/ts/wep/Utility.py b/python/lsst/ts/wep/Utility.py
index 9ce6b57c..143d0e17 100644
--- a/python/lsst/ts/wep/Utility.py
+++ b/python/lsst/ts/wep/Utility.py
@@ -62,6 +62,7 @@ class ImageType(IntEnum):
class CentroidFindType(IntEnum):
RandomWalk = 1
Otsu = auto()
+ ConvolveTemplate = auto()
class DeblendDonutType(IntEnum):
@@ -359,7 +360,7 @@ def getCentroidFindType(centroidFindType):
Parameters
----------
centroidFindType : str
- Centroid find algorithm to use (randomWalk or otsu).
+ Centroid find algorithm to use (randomWalk, otsu, or convolveTemplate).
Returns
-------
@@ -376,6 +377,8 @@ def getCentroidFindType(centroidFindType):
return CentroidFindType.RandomWalk
elif centroidFindType == "otsu":
return CentroidFindType.Otsu
+ elif centroidFindType == "convolveTemplate":
+ return CentroidFindType.ConvolveTemplate
else:
raise ValueError("The %s is not supported." % centroidFindType)
diff --git a/python/lsst/ts/wep/cwfs/CentroidConvolveTemplate.py b/python/lsst/ts/wep/cwfs/CentroidConvolveTemplate.py
new file mode 100644
index 00000000..5df4e1ea
--- /dev/null
+++ b/python/lsst/ts/wep/cwfs/CentroidConvolveTemplate.py
@@ -0,0 +1,206 @@
+# This file is part of ts_wep.
+#
+# Developed for the LSST Telescope and Site Systems.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import numpy as np
+from copy import copy
+from lsst.ts.wep.cwfs.CentroidDefault import CentroidDefault
+from lsst.ts.wep.cwfs.CentroidRandomWalk import CentroidRandomWalk
+from scipy.signal import correlate
+from sklearn.cluster import KMeans
+
+
+class CentroidConvolveTemplate(CentroidDefault):
+ def __init__(self):
+ """CentroidDefault child class to get the centroid of donut by
+ convolution with a template donut image."""
+
+ super(CentroidConvolveTemplate, self).__init__()
+ self._centRandomWalk = CentroidRandomWalk()
+
+ def getImgBinary(self, imgDonut):
+ """Get the binary image.
+
+ Parameters
+ ----------
+ imgDonut : numpy.ndarray
+ Donut image to do the analysis.
+
+ Returns
+ -------
+ numpy.ndarray [int]
+ Binary image of donut.
+ """
+
+ return self._centRandomWalk.getImgBinary(imgDonut)
+
+ def getCenterAndR(self, imgDonut, templateDonut=None, peakThreshold=0.95):
+ """Get the centroid data and effective weighting radius.
+
+ Parameters
+ ----------
+ imgDonut : numpy.ndarray
+ Donut image.
+ templateDonut : None or numpy.ndarray, optional
+ Template image for a single donut. If set to None
+ then the image will be convolved with itself. (The Default is None)
+ peakThreshold : float, optional
+ This value is a specifies a number between 0 and 1 that is
+ the fraction of the highest pixel value in the convolved image.
+ The code then sets all pixels with a value below this to 0 before
+ running the K-means algorithm to find peaks that represent possible
+ donut locations. (The default is 0.95)
+
+ Returns
+ -------
+ float
+ Centroid x.
+ float
+ Centroid y.
+ float
+ Effective weighting radius.
+ """
+
+ imgBinary = self.getImgBinary(imgDonut)
+
+ if templateDonut is None:
+ templateBinary = copy(imgBinary)
+ else:
+ templateBinary = self.getImgBinary(templateDonut)
+
+ return self.getCenterAndRfromImgBinary(
+ imgBinary, templateBinary=templateBinary, peakThreshold=peakThreshold,
+ )
+
+ def getCenterAndRfromImgBinary(
+ self, imgBinary, templateBinary=None, peakThreshold=0.95
+ ):
+ """Get the centroid data and effective weighting radius.
+
+ Parameters
+ ----------
+ imgBinary : numpy.ndarray
+ Binary image of donut.
+ templateBinary : None or numpy.ndarray, optional
+ Binary image of template for a single donut. If set to None
+ then the image will be convolved with itself. (The Default is None)
+ peakThreshold : float, optional
+ This value is a specifies a number between 0 and 1 that is
+ the fraction of the highest pixel value in the convolved image.
+ The code then sets all pixels with a value below this to 0 before
+ running the K-means algorithm to find peaks that represent possible
+ donut locations. (The default is 0.95)
+
+ Returns
+ -------
+ float
+ Centroid x.
+ float
+ Centroid y.
+ float
+ Effective weighting radius.
+ """
+
+ x, y, radius = self.getCenterAndRfromTemplateConv(
+ imgBinary,
+ templateImgBinary=templateBinary,
+ nDonuts=1,
+ peakThreshold=peakThreshold,
+ )
+
+ return x[0], y[0], radius
+
+ def getCenterAndRfromTemplateConv(
+ self, imageBinary, templateImgBinary=None, nDonuts=1, peakThreshold=0.95
+ ):
+ """
+ Get the centers of the donuts by convolving a binary template image
+ with the binary image of the donut or donuts.
+
+ Peaks will appear as bright spots in the convolved image. Since we
+ use binary images the brightness of the stars does not matter and
+ the peaks of any stars in the image should have about the same
+ brightness if the template is correct.
+
+ Parameters
+ ----------
+ imageBinary: numpy.ndarray
+ Binary image of postage stamp.
+ templateImgBinary: None or numpy.ndarray, optional
+ Binary image of template donut. If set to None then the image
+ is convolved with itself. (The default is None)
+ nDonuts: int, optional
+ Number of donuts there should be in the binary image. Needs to
+ be >= 1. (The default is 1)
+ peakThreshold: float, optional
+ This value is a specifies a number between 0 and 1 that is
+ the fraction of the highest pixel value in the convolved image.
+ The code then sets all pixels with a value below this to 0 before
+ running the K-means algorithm to find peaks that represent possible
+ donut locations. (The default is 0.95)
+
+ Returns
+ -------
+ list
+ X pixel coordinates for donut centroid.
+ list
+ Y pixel coordinates for donut centroid.
+ float
+ Effective weighting radius calculated using the template image.
+ """
+
+ if templateImgBinary is None:
+ templateImgBinary = copy(imageBinary)
+
+ nDonutsAssertStr = "nDonuts must be an integer >= 1"
+ assert (nDonuts >= 1) & (type(nDonuts) is int), nDonutsAssertStr
+
+ # We set the mode to be "same" because we need to return the same
+ # size image to the code.
+ tempConvolve = correlate(imageBinary, templateImgBinary, mode="same")
+
+ # Then we rank the pixel values keeping only those above
+ # some fraction of the highest value.
+ rankedConvolve = np.argsort(tempConvolve.flatten())[::-1]
+ cutoff = len(
+ np.where(tempConvolve.flatten() > peakThreshold * np.max(tempConvolve))[0]
+ )
+ rankedConvolveCutoff = rankedConvolve[:cutoff]
+ nx, ny = np.unravel_index(rankedConvolveCutoff, np.shape(imageBinary))
+
+ # Then to find peaks in the image we use K-Means with the
+ # specified number of donuts
+ kmeans = KMeans(n_clusters=nDonuts)
+ labels = kmeans.fit_predict(np.array([nx, ny]).T)
+
+ # Then in each cluster we take the brightest pixel as the centroid
+ centX = []
+ centY = []
+ for labelNum in range(nDonuts):
+ nxLabel, nyLabel = np.unravel_index(
+ rankedConvolveCutoff[labels == labelNum][0], np.shape(imageBinary)
+ )
+ centX.append(nxLabel)
+ centY.append(nyLabel)
+
+ # Get the radius of the donut from the template image
+ radius = np.sqrt(np.sum(templateImgBinary) / np.pi)
+
+ return centX, centY, radius
diff --git a/python/lsst/ts/wep/cwfs/CentroidDefault.py b/python/lsst/ts/wep/cwfs/CentroidDefault.py
index 1757476f..9493a4ee 100644
--- a/python/lsst/ts/wep/cwfs/CentroidDefault.py
+++ b/python/lsst/ts/wep/cwfs/CentroidDefault.py
@@ -26,13 +26,15 @@
class CentroidDefault(object):
"""Default Centroid class."""
- def getCenterAndR(self, imgDonut):
+ def getCenterAndR(self, imgDonut, **kwargs):
"""Get the centroid data and effective weighting radius.
Parameters
----------
imgDonut : numpy.ndarray
Donut image.
+ **kwargs : dict[str, any]
+ Dictionary of input argument: new value for that input argument.
Returns
-------
@@ -48,7 +50,7 @@ def getCenterAndR(self, imgDonut):
return self.getCenterAndRfromImgBinary(imgBinary)
- def getCenterAndRfromImgBinary(self, imgBinary):
+ def getCenterAndRfromImgBinary(self, imgBinary, **kwargs):
"""Get the centroid data and effective weighting radius from the binary
image.
@@ -56,6 +58,8 @@ def getCenterAndRfromImgBinary(self, imgBinary):
----------
imgBinary : numpy.ndarray [int]
Binary image of donut.
+ **kwargs : dict[str, any]
+ Dictionary of input argument: new value for that input argument.
Returns
-------
diff --git a/python/lsst/ts/wep/cwfs/CentroidFindFactory.py b/python/lsst/ts/wep/cwfs/CentroidFindFactory.py
index af83fc23..520c5c98 100644
--- a/python/lsst/ts/wep/cwfs/CentroidFindFactory.py
+++ b/python/lsst/ts/wep/cwfs/CentroidFindFactory.py
@@ -22,6 +22,7 @@
from lsst.ts.wep.Utility import CentroidFindType
from lsst.ts.wep.cwfs.CentroidRandomWalk import CentroidRandomWalk
from lsst.ts.wep.cwfs.CentroidOtsu import CentroidOtsu
+from lsst.ts.wep.cwfs.CentroidConvolveTemplate import CentroidConvolveTemplate
class CentroidFindFactory(object):
@@ -39,7 +40,7 @@ def createCentroidFind(centroidFindType):
Returns
-------
- CentroidRandomWalk, CentroidOtsu
+ Child class of centroidDefault
Centroid find object.
Raises
@@ -52,5 +53,7 @@ def createCentroidFind(centroidFindType):
return CentroidRandomWalk()
elif centroidFindType == CentroidFindType.Otsu:
return CentroidOtsu()
+ elif centroidFindType == CentroidFindType.ConvolveTemplate:
+ return CentroidConvolveTemplate()
else:
raise ValueError("The %s is not supported." % centroidFindType)
diff --git a/tests/cwfs/test_centroidConvolveTemplate.py b/tests/cwfs/test_centroidConvolveTemplate.py
new file mode 100644
index 00000000..46996268
--- /dev/null
+++ b/tests/cwfs/test_centroidConvolveTemplate.py
@@ -0,0 +1,192 @@
+# This file is part of ts_wep.
+#
+# Developed for the LSST Telescope and Site Systems.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+import numpy as np
+
+from lsst.ts.wep.cwfs.CentroidConvolveTemplate import CentroidConvolveTemplate
+
+
+class TestCentroidConvolveTemplate(unittest.TestCase):
+ """Test the CentroidConvolveTemplate class."""
+
+ def setUp(self):
+
+ self.centroidConv = CentroidConvolveTemplate()
+
+ def _createData(self, radiusInner, radiusOuter, imageSize, addNoise=False):
+
+ # Create two images. One with a single donut and one with two donuts.
+ singleDonut = np.zeros((imageSize, imageSize))
+ doubleDonut = np.zeros((imageSize, imageSize))
+
+ for x in range(imageSize):
+ for y in range(imageSize):
+ # For single donut put the donut at the center of the image
+ if (
+ np.sqrt((imageSize / 2 - x) ** 2 + (imageSize / 2 - y) ** 2)
+ <= radiusOuter
+ ):
+ singleDonut[x, y] += 1
+ if (
+ np.sqrt((imageSize / 2 - x) ** 2 + (imageSize / 2 - y) ** 2)
+ <= radiusInner
+ ):
+ singleDonut[x, y] -= 1
+ # For double donut put the two donuts along same line
+ # halfway down the image and provide 10 pixels between
+ # image edge and outer edge of donut on either side of image
+ if (
+ np.sqrt(((radiusOuter + 10) - x) ** 2 + (imageSize / 2 - y) ** 2)
+ <= radiusOuter
+ ):
+ doubleDonut[x, y] += 1
+ if (
+ np.sqrt(((radiusOuter + 10) - x) ** 2 + (imageSize / 2 - y) ** 2)
+ <= radiusInner
+ ):
+ doubleDonut[x, y] -= 1
+ if (
+ np.sqrt(
+ (imageSize - (radiusOuter + 10) - x) ** 2
+ + (imageSize / 2 - y) ** 2
+ )
+ <= radiusOuter
+ ):
+ doubleDonut[x, y] += 1
+ if (
+ np.sqrt(
+ (imageSize - (radiusOuter + 10) - x) ** 2
+ + (imageSize / 2 - y) ** 2
+ )
+ <= radiusInner
+ ):
+ doubleDonut[x, y] -= 1
+ # Make binary image
+ doubleDonut[doubleDonut > 0.5] = 1
+
+ if addNoise is True:
+ # Add noise so the images are not binary
+ randState = np.random.RandomState(42)
+ singleDonut += randState.normal(scale=0.01, size=np.shape(singleDonut))
+ doubleDonut += randState.normal(scale=0.01, size=np.shape(doubleDonut))
+
+ eff_radius = np.sqrt(radiusOuter ** 2 - radiusInner ** 2)
+
+ return singleDonut, doubleDonut, eff_radius
+
+ def testGetImgBinary(self):
+
+ singleDonut, doubleDonut, eff_radius = self._createData(
+ 20, 40, 160, addNoise=False
+ )
+
+ noisySingle, noisyDouble, eff_radius = self._createData(
+ 20, 40, 160, addNoise=True
+ )
+
+ binarySingle = self.centroidConv.getImgBinary(noisySingle)
+
+ np.testing.assert_array_equal(singleDonut, binarySingle)
+
+ def testGetCenterAndRWithoutTemplate(self):
+
+ singleDonut, doubleDonut, eff_radius = self._createData(
+ 20, 40, 160, addNoise=True
+ )
+
+ # Test recovery with defaults
+ centX, centY, rad = self.centroidConv.getCenterAndR(singleDonut)
+
+ self.assertEqual(centX, 80.0)
+ self.assertEqual(centY, 80.0)
+ self.assertAlmostEqual(rad, eff_radius, delta=0.1)
+
+ def testGetCenterAndRWithTemplate(self):
+
+ singleDonut, doubleDonut, eff_radius = self._createData(
+ 20, 40, 160, addNoise=True
+ )
+
+ # Test recovery with defaults
+ centX, centY, rad = self.centroidConv.getCenterAndR(
+ singleDonut, templateDonut=singleDonut
+ )
+
+ self.assertEqual(centX, 80.0)
+ self.assertEqual(centY, 80.0)
+ self.assertAlmostEqual(rad, eff_radius, delta=0.1)
+
+ def testGetCenterAndRFromImgBinary(self):
+
+ singleDonut, doubleDonut, eff_radius = self._createData(20, 40, 160)
+
+ # Test recovery with defaults
+ centX, centY, rad = self.centroidConv.getCenterAndRfromImgBinary(singleDonut)
+
+ self.assertEqual(centX, 80.0)
+ self.assertEqual(centY, 80.0)
+ self.assertAlmostEqual(rad, eff_radius, delta=0.1)
+
+ def testNDonutsAssertion(self):
+
+ singleDonut, doubleDonut, eff_radius = self._createData(20, 40, 160)
+
+ nDonutsAssertMsg = "nDonuts must be an integer >= 1"
+ with self.assertRaises(AssertionError, msg=nDonutsAssertMsg):
+ cX, cY, rad = self.centroidConv.getCenterAndRfromTemplateConv(
+ singleDonut, nDonuts=0
+ )
+
+ with self.assertRaises(AssertionError, msg=nDonutsAssertMsg):
+ cX, cY, rad = self.centroidConv.getCenterAndRfromTemplateConv(
+ singleDonut, nDonuts=-1
+ )
+
+ with self.assertRaises(AssertionError, msg=nDonutsAssertMsg):
+ cX, cY, rad = self.centroidConv.getCenterAndRfromTemplateConv(
+ singleDonut, nDonuts=1.5
+ )
+
+ def testGetCenterAndRFromTemplateConv(self):
+
+ singleDonut, doubleDonut, eff_radius = self._createData(20, 40, 160)
+
+ # Test recovery of single donut
+ singleCX, singleCY, rad = self.centroidConv.getCenterAndRfromTemplateConv(
+ singleDonut
+ )
+ self.assertEqual(singleCX, [80.0])
+ self.assertEqual(singleCY, [80.0])
+ self.assertAlmostEqual(rad, eff_radius, delta=0.1)
+
+ # Test recovery of two donuts at once
+ doubleCX, doubleCY, rad = self.centroidConv.getCenterAndRfromTemplateConv(
+ doubleDonut, templateImgBinary=singleDonut, nDonuts=2
+ )
+ self.assertCountEqual(doubleCX, [50.0, 110.0])
+ self.assertEqual(doubleCY, [80.0, 80.0])
+ self.assertAlmostEqual(rad, eff_radius, delta=0.1)
+
+
+if __name__ == "__main__":
+
+ unittest.main()
diff --git a/tests/cwfs/test_centroidFindFactory.py b/tests/cwfs/test_centroidFindFactory.py
index a9ccfb6b..199156f4 100644
--- a/tests/cwfs/test_centroidFindFactory.py
+++ b/tests/cwfs/test_centroidFindFactory.py
@@ -25,6 +25,7 @@
from lsst.ts.wep.cwfs.CentroidFindFactory import CentroidFindFactory
from lsst.ts.wep.cwfs.CentroidRandomWalk import CentroidRandomWalk
from lsst.ts.wep.cwfs.CentroidOtsu import CentroidOtsu
+from lsst.ts.wep.cwfs.CentroidConvolveTemplate import CentroidConvolveTemplate
class TestCentroidFindFactory(unittest.TestCase):
@@ -42,6 +43,13 @@ def testCreateCentroidFindOtsu(self):
centroidFind = CentroidFindFactory.createCentroidFind(CentroidFindType.Otsu)
self.assertTrue(isinstance(centroidFind, CentroidOtsu))
+ def testCreateCentroidFindConvolveTemplate(self):
+
+ centroidFind = CentroidFindFactory.createCentroidFind(
+ CentroidFindType.ConvolveTemplate
+ )
+ self.assertTrue(isinstance(centroidFind, CentroidConvolveTemplate))
+
def testCreateCentroidFindWrongType(self):
self.assertRaises(