Skip to content

Commit

Permalink
Added startWithIntrinsic and returnWfDev flags.
Browse files Browse the repository at this point in the history
  • Loading branch information
jfcrenshaw committed Jan 22, 2024
1 parent 0fb47fc commit 4828543
Show file tree
Hide file tree
Showing 7 changed files with 551 additions and 149 deletions.
1 change: 0 additions & 1 deletion policy/estimation/tie.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ centerTol: 10.0e-9
centerBinary: True
convergeTol: 10.0e-9
maskKwargs: null
saveHistory: False
4 changes: 4 additions & 0 deletions policy/estimation/wfEstimator.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ algoName: tie
algoConfig: null
instConfig: policy/instruments/LsstCam.yaml
jmax: 22
startWithIntrinsic: true
returnWfDev: false
return4Up: true
units: um
saveHistory: false
113 changes: 79 additions & 34 deletions python/lsst/ts/wep/estimation/tie.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,6 @@ class TieAlgorithm(WfAlgorithm):
Dictionary of mask keyword arguments to pass to mask creation.
To see possibilities, see the docstring for
lsst.ts.wep.imageMapper.ImageMapper.createPupilMask().
saveHistory : bool, optional
Whether to save the algorithm history in the self.history attribute.
If True, then self.history contains information about the most recent
time the algorithm was run.
"""

def __init__(
Expand All @@ -97,7 +93,6 @@ def __init__(
centerBinary: Optional[bool] = None,
convergeTol: Optional[float] = None,
maskKwargs: Optional[dict] = None,
saveHistory: Optional[bool] = None,
) -> None:
super().__init__(
configFile=configFile,
Expand All @@ -110,7 +105,6 @@ def __init__(
centerBinary=centerBinary,
convergeTol=convergeTol,
maskKwargs=maskKwargs,
saveHistory=saveHistory,
)

@property
Expand Down Expand Up @@ -374,23 +368,38 @@ def history(self) -> dict:
"""The algorithm history.
The history is a dictionary saving the intermediate products from
each iteration of the TIE solver. The first iteration is saved as
history[0].
The entry for each iteration is itself a dictionary containing
each iteration of the TIE solver.
The initial products before the iteration begins are stored
in history[0], which is a dictionary with the keys:
- "intraInit" - the initial intrafocal image
- "extraInit" - the initial extrafocal image
- "zkStartIntra" - the starting intrafocal Zernikes
- "zkStartExtra" - the starting extrafocal Zernikes
- "zkStartMean" - the mean of the starting Zernikes. Note these
Zernikes are added to zkBest to estimate the
full OPD.
Each iteration of the solver is then stored under indices >= 1.
The entry for each iteration is also a dictionary, containing
the following keys:
- "intraComp" - the compensated intrafocal image
- "extraComp" - the compensated extrafocal image
- "I0" - the estimate of the beam intensity on the pupil
- "dIdz" - estimate of z-derivative of intensity across the pupil
- "zkComp" - the Zernikes used for image compensation
- "zkCompIntra" - Zernikes for compensating the intrafocal image
- "zkCompExtra" - Zernikes for compensating the extrafocal image
- "zkResid" - the estimated residual Zernikes
- "zkBest" - the best estimate of the Zernikes after this iteration
- "zkBest" - the best cumulative estimate the wavefront residual.
- "zkSum" - the sum of zkBest and zkStartMean from history[0].
This is the best estimate of the OPD at the end of
this iteration.
- "converged" - flag indicating if Zernike estimation has converged
- "caustic" - flag indicating if a caustic has been hit
Note the units for all Zernikes are in meters, and the z-derivative
in dIdz is also in meters.
in dIdz is also in meters. Furthermore, all Zernikes start with Noll
index 4.
"""
return super().history

Expand Down Expand Up @@ -472,12 +481,15 @@ def _fftSolve(
# TODO: Implement the fft solver
raise NotImplementedError("The fft solver is not yet implemented.")

def estimateZk(
def _estimateZk(
self,
I1: Image,
I2: Image, # type: ignore[override]
jmax: int = 22,
instrument: Instrument = Instrument(),
zkStartI1: np.ndarray,
zkStartI2: np.ndarray,
jmax: int,
instrument: Instrument,
saveHistory: bool,
) -> np.ndarray:
"""Return the wavefront Zernike coefficients in meters.
Expand All @@ -487,26 +499,36 @@ def estimateZk(
An Image object containing an intra- or extra-focal donut image.
I2 : Image
A second image, on the opposite side of focus from I1.
jmax : int, optional
zkStartI1 : np.ndarray
The starting Zernikes for I1.
zkStartI2 : np.ndarray or None
The starting Zernikes for I2.
jmax : int
The maximum Zernike Noll index to estimate.
(the default is 22)
instrument : Instrument, optional
instrument : Instrument
The Instrument object associated with the DonutStamps.
(the default is the default Instrument)
saveHistory : bool
Whether to save the algorithm history in the self.history
attribute. If True, then self.history contains information
about the most recent time the algorithm was run.
Returns
-------
np.ndarray
Zernike coefficients (for Noll indices >= 4) estimated from
the images, in meters.
Raises
------
RuntimeError
If the solver is not supported
"""
# Validate the inputs
# Make sure we have been provided with two images
if I1 is None or I2 is None:
raise ValueError(
"TIEAlgorithm requires a pair of intrafocal and extrafocal "
"donuts to estimate Zernikes. Please provide both I1 and I2."
)
self._validateInputs(I1, I2, jmax, instrument)

# Create the ImageMapper for centering and image compensation
imageMapper = ImageMapper(
Expand All @@ -515,19 +537,35 @@ def estimateZk(
opticalModel=self.opticalModel,
)

# Get the initial intrafocal and extrafocal stamps
# Re-assign I1/I2 to intra/extra
intra = I1.copy() if I1.defocalType == DefocalType.Intra else I2.copy()
zkStartIntra = (
zkStartI1.copy()
if I1.defocalType == DefocalType.Intra
else zkStartI2.copy()
)
extra = I1.copy() if I1.defocalType == DefocalType.Extra else I2.copy()
zkStartExtra = (
zkStartI1.copy()
if I1.defocalType == DefocalType.Extra
else zkStartI2.copy()
)

# Calculate the mean starting Zernikes
zkStartMean = np.mean([zkStartIntra, zkStartExtra], axis=0)

# Initialize the variables for intra and extrafocal pupil masks
intraPupilMask = None
extraPupilMask = None

if self.saveHistory:
# Save the initial images in the history
if saveHistory:
# Save the initial images and intrinsic Zernikes in the history
self._history[0] = {
"intraInit": intra.image.copy(),
"extraInit": extra.image.copy(),
"zkStartIntra": zkStartIntra.copy(),
"zkStartExtra": zkStartExtra.copy(),
"zkStartMean": zkStartMean.copy(),
}

# Initialize Zernike arrays at zero
Expand Down Expand Up @@ -562,27 +600,27 @@ def estimateZk(
zkCenter = zkComp.copy()
intraCent = imageMapper.centerOnProjection(
intra,
zkCenter,
zkCenter + zkStartIntra,
binary=self.centerBinary,
**self.maskKwargs,
)
extraCent = imageMapper.centerOnProjection(
extra,
zkCenter,
zkCenter + zkStartExtra,
binary=self.centerBinary,
**self.maskKwargs,
)

# Compensate images using the Zernikes
intraComp = imageMapper.mapImageToPupil(
intraCent,
zkComp,
zkComp + zkStartIntra,
mask=intraPupilMask,
**self.maskKwargs,
)
extraComp = imageMapper.mapImageToPupil(
extraCent,
zkComp,
zkComp + zkStartExtra,
mask=extraPupilMask,
**self.maskKwargs,
)
Expand Down Expand Up @@ -625,6 +663,8 @@ def estimateZk(
zkResid = self._expSolve(I0, dIdz, jmax, instrument)
elif self.solver == "fft":
zkResid = self._fftSolve(I0, dIdz, jmax, instrument)
else:
raise RuntimeError(f"Solver {self.solver} is unsupported")

# Check for convergence
# (1) The max absolute difference with the previous iteration
Expand All @@ -635,12 +675,15 @@ def estimateZk(
np.max(np.abs(newBest - zkBest)) < self.convergeTol
)

# Set the new best estimate
# Set the new best cumulative estimate of the residuals
zkBest = newBest

# Add the starting Zernikes to the best residuals
zkSum = zkBest + zkStartMean

# Time to wrap up this iteration!
# Should we save intermediate products in the algorithm history?
if self.saveHistory:
if saveHistory:
# Save the images and Zernikes from this iteration
self._history[i + 1] = {
"recenter": bool(recenter),
Expand All @@ -651,20 +694,22 @@ def estimateZk(
"mask": mask.copy(), # type: ignore
"I0": I0.copy(),
"dIdz": dIdz.copy(),
"zkComp": zkComp.copy(),
"zkCompIntra": zkComp + zkStartIntra,
"zkCompExtra": zkComp + zkStartExtra,
"zkResid": zkResid.copy(),
"zkBest": zkBest.copy(),
"zkSum": zkSum.copy(),
"converged": bool(converged),
"caustic": bool(caustic),
}

# If we are using the FFT solver, save the inner loop as well
if self.solver == "fft":
# TODO: After implementing fft, add inner loop here
self._history[i]["innerLoop"] = None
self._history[i + 1]["innerLoop"] = None

# If we've hit a caustic or converged, we will stop early
if caustic or converged:
break

return zkBest
return zkSum
Loading

0 comments on commit 4828543

Please sign in to comment.