Skip to content

Commit

Permalink
Refactor struct passing
Browse files Browse the repository at this point in the history
  • Loading branch information
Gerenjie committed Dec 17, 2024
1 parent c88f82c commit 4ff53f7
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 37 deletions.
24 changes: 18 additions & 6 deletions python/lsst/ap/association/diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ class DiaPipelineConnections(
storageClass="DataFrame",
dimensions=("instrument", "visit", "detector"),
)
unassociatedSsObjects = connTypes.Output(
doc="Expected locations of an ssObject with no source",
name="ssUnassociatedObjects",
storageClass="ArrowAstropy",
dimensions=("instrument", "visit", "detector"),
)

diaForcedSources = connTypes.Output(
doc="Optional output storing the forced sources computed at the diaObject positions.",
name="{fakesType}{coaddName}Diff_diaForcedSrc",
Expand All @@ -156,6 +163,7 @@ def __init__(self, *, config=None):
self.inputs.remove("solarSystemObjectTable")
if (not config.doWriteAssociatedSources) or (not config.doSolarSystemAssociation):
self.outputs.remove("associatedSsSources")
self.outputs.remove("unassociatedSsObjects")

def adjustQuantum(self, inputs, outputs, label, dataId):
"""Override to make adjustments to `lsst.daf.butler.DatasetRef` objects
Expand Down Expand Up @@ -473,9 +481,10 @@ def run(self,
diaObjects = preloadedDiaObjects

# Associate DiaSources with DiaObjects
associatedDiaSources, newDiaObjects, associatedSsSources = self.associateDiaSources(
diaSourceTable, solarSystemObjectTable, diffIm, diaObjects
)

(
associatedDiaSources, newDiaObjects, associatedSsSources, unassociatedSsObjects
) = self.associateDiaSources(diaSourceTable, solarSystemObjectTable, diffIm, diaObjects)

# Merge associated diaSources
mergedDiaSourceHistory, mergedDiaObjects, updatedDiaObjectIds = self.mergeAssociatedCatalogs(
Expand Down Expand Up @@ -556,7 +565,8 @@ def run(self,
associatedDiaSources=associatedDiaSources,
diaForcedSources=diaForcedSources,
diaObjects=diaCalResult.diaObjectCat,
associatedSsSources=associatedSsSources
associatedSsSources=associatedSsSources,
unassociatedSsObjects=unassociatedSsObjects
)

def createNewDiaObjects(self, unAssocDiaSources):
Expand Down Expand Up @@ -637,13 +647,15 @@ def associateDiaSources(self, diaSourceTable, solarSystemObjectTable, diffIm, di
toAssociate.append(ssoAssocResult.ssoAssocDiaSources)
nTotalSsObjects = ssoAssocResult.nTotalSsObjects
nAssociatedSsObjects = ssoAssocResult.nAssociatedSsObjects
associatedSsSources = ssoAssocResult.ssSourceData
associatedSsSources = ssoAssocResult.associatedSsSources
unassociatedSsObjects = ssoAssocResult.unassociatedSsObjects
else:
# Create new DiaObjects from unassociated diaSources.
createResults = self.createNewDiaObjects(assocResults.unAssocDiaSources)
nTotalSsObjects = 0
nAssociatedSsObjects = 0
associatedSsSources = None
unassociatedSsObjects = None
if len(assocResults.matchedDiaSources) > 0:
toAssociate.append(assocResults.matchedDiaSources)
toAssociate.append(createResults.diaSources)
Expand All @@ -659,7 +671,7 @@ def associateDiaSources(self, diaSourceTable, solarSystemObjectTable, diffIm, di
assocResults.nUnassociatedDiaObjects,
createResults.nNewDiaObjects,
)
return (associatedDiaSources, createResults.newDiaObjects, associatedSsSources)
return (associatedDiaSources, createResults.newDiaObjects, associatedSsSources, unassociatedSsObjects)

@timeMethod
def mergeAssociatedCatalogs(self, preloadedDiaSources, associatedDiaSources, diaObjects, newDiaObjects,
Expand Down
30 changes: 4 additions & 26 deletions python/lsst/ap/association/ssSingleFrameAssociation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SsSingleFrameAssociationConnections(
storageClass="ArrowAstropy",
dimensions=("instrument", "visit", "detector"),
)
unassociatedObjects = connTypes.Output(
unassociatedSsObjects = connTypes.Output(
doc="Expected locations of an ssObject with no source",
name="ssSingleFrameUnassociatedObjects",
storageClass="ArrowAstropy",
Expand Down Expand Up @@ -150,32 +150,10 @@ def run(self,
"""
if solarSystemObjectTable is None:
raise pipeBase.NoWorkFound("No ephemerides to associate. Skipping ssSingleFrameAssociation.")
else:
# Associate DiaSources with DiaObjects
associatedSsSources, unassociatedObjects = self.associateSources(sourceTable,
solarSystemObjectTable, exposure)
return pipeBase.Struct(associatedSsSources=associatedSsSources,
unassociatedObjects=unassociatedObjects)

@timeMethod
def associateSources(self, sourceTable, solarSystemObjectTable, exposure):
"""Associate single-image sources with ssObjects.
Parameters
----------
sourceTable : `pandas.DataFrame`
Newly detected sources.
solarSystemObjectTable : `pandas.DataFrame`
Preloaded Solar System objects expected to be visible in the image.
Returns
-------
associatedSsSources : `astropy.table.Table`
Table of new ssSources after association.
"""
# Associate DiaSources with DiaObjects
sourceTable = sourceTable.asAstropy()
sourceTable['ra'] = sourceTable['coord_ra'].to(deg).value
sourceTable['dec'] = sourceTable['coord_dec'].to(deg).value
ssoAssocResult = self.solarSystemAssociator.run(sourceTable.to_pandas(),
solarSystemObjectTable, exposure)
return ssoAssocResult.ssSourceData, ssoAssocResult.unAssocSsObjects
return self.solarSystemAssociator.run(sourceTable.to_pandas(),
solarSystemObjectTable, exposure)
10 changes: 6 additions & 4 deletions python/lsst/ap/association/ssoAssociation.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def run(self, diaSourceCatalog, solarSystemObjects, exposure):
associated.append(False)

self.log.info("Successfully associated %d / %d SolarSystemObjects.", nFound, nSolarSystemObjects)
self.metadata['nAssociatedSsObjects'] = nFound
self.metadata['nExpectedSsObjects'] = nSolarSystemObjects
maskedObjects['associated'] = associated
assocMask = diaSourceCatalog["ssObjectId"] != 0
ssSourceData = pd.DataFrame(ssSourceData, columns=["ssObjectId", "obs_position_x", "obs_position_y",
Expand All @@ -183,8 +185,8 @@ def run(self, diaSourceCatalog, solarSystemObjects, exposure):
unAssocDiaSources=diaSourceCatalog[~assocMask].reset_index(drop=True),
nTotalSsObjects=nSolarSystemObjects,
nAssociatedSsObjects=nFound,
ssSourceData=Table.from_pandas(ssSourceData),
unAssocSsObjects=maskedObjects[~maskedObjects['associated']])
associatedSsSources=Table.from_pandas(ssSourceData),
unassociatedSsObjects=maskedObjects[~maskedObjects['associated']])

def _maskToCcdRegion(self, solarSystemObjects, exposure, marginArcsec):
"""Mask the input SolarSystemObjects to only those in the exposure
Expand Down Expand Up @@ -254,7 +256,7 @@ def _return_empty(self, diaSourceCatalog, emptySolarSystemObjects):
unAssocDiaSources=diaSourceCatalog,
nTotalSsObjects=0,
nAssociatedSsObjects=0,
ssSourceData=Table(names=["ssObjectId", "ra", "dec", "obs_position", "obj_position",
associatedSsSources=Table(names=["ssObjectId", "ra", "dec", "obs_position", "obj_position",
"residual_ras", "residual_decs"]),
unAssocSsObjects=Table(names=emptySolarSystemObjects.columns)
unassociatedSsObjects=Table(names=emptySolarSystemObjects.columns)
)
3 changes: 2 additions & 1 deletion tests/test_diaPipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def solarSystemAssociator_run(unAssocDiaSources, solarSystemObjectTable, diffIm)
nAssociatedSsObjects=30,
ssoAssocDiaSources=_makeMockDataFrame(),
unAssocDiaSources=_makeMockDataFrame(),
ssSourceData=_makeMockDataFrame())
associatedSsSources=_makeMockDataFrame(),
unassociatedSsObjects=_makeMockDataFrame())

def associator_run(table, diaObjects):
return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3,
Expand Down

0 comments on commit 4ff53f7

Please sign in to comment.