Skip to content

Commit

Permalink
WIP: add PipelineGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Feb 28, 2023
1 parent b604813 commit b77c85d
Show file tree
Hide file tree
Showing 12 changed files with 1,467 additions and 58 deletions.
6 changes: 5 additions & 1 deletion python/lsst/pipe/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import connectionTypes, pipelineIR
from . import connectionTypes, pipeline_graph, pipelineIR
from ._dataset_handle import *
from ._instrument import *
from ._status import *
Expand All @@ -10,6 +10,10 @@
from .graph import *
from .graphBuilder import *
from .pipeline import *

# We import the main PipelineGraph types and the module (above), but we don't
# lift all symbols to package scope.
from .pipeline_graph import MutablePipelineGraph, ResolvedPipelineGraph
from .pipelineTask import *
from .struct import *
from .task import *
Expand Down
15 changes: 1 addition & 14 deletions python/lsst/pipe/base/pipeTools.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
# Imports for other modules --
# -----------------------------
from .connections import iterConnections
from .pipeline_graph import DuplicateOutputError, PipelineDataCycleError

if TYPE_CHECKING:
from .pipeline import Pipeline, TaskDef
Expand All @@ -57,20 +58,6 @@ class MissingTaskFactoryError(Exception):
pass


class DuplicateOutputError(Exception):
"""Exception raised when Pipeline has more than one task for the same
output.
"""

pass


class PipelineDataCycleError(Exception):
"""Exception raised when Pipeline has data dependency cycle."""

pass


def isPipelineOrdered(
pipeline: Union[Pipeline, Iterable[TaskDef]], taskFactory: Optional[TaskFactory] = None
) -> bool:
Expand Down
96 changes: 55 additions & 41 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,12 @@
from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name

from . import pipelineIR, pipeTools
from ._task_metadata import TaskMetadata
from . import pipeline_graph, pipelineIR
from .config import PipelineTaskConfig
from .configOverrides import ConfigOverrides
from .connections import iterConnections
from .connections import PipelineTaskConnections, iterConnections
from .connectionTypes import Input
from .pipelineTask import PipelineTask
from .task import _TASK_METADATA_TYPE

if TYPE_CHECKING: # Imports needed only for type annotations; may be circular.
from lsst.obs.base import Instrument
Expand Down Expand Up @@ -134,6 +132,7 @@ class TaskDef:
Task label, usually a short string unique in a pipeline. If not
provided, ``taskClass`` must be, and ``taskClass._DefaultName`` will
be used.
TODO
"""

def __init__(
Expand All @@ -142,6 +141,7 @@ def __init__(
config: Optional[PipelineTaskConfig] = None,
taskClass: Optional[Type[PipelineTask]] = None,
label: Optional[str] = None,
connections: PipelineTaskConnections | None = None,
):
if taskName is None:
if taskClass is None:
Expand All @@ -158,29 +158,33 @@ def __init__(
raise ValueError("`taskClass` must be provided if `label` is not.")
label = taskClass._DefaultName
self.taskName = taskName
try:
config.validate()
except Exception:
_LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
raise
config.freeze()
if connections is None:
# If we don't have connections yet, assume the config hasn't been
# validated yet.
try:
config.validate()
except Exception:
_LOG.error("Configuration validation failed for task %s (%s)", label, taskName)
raise
config.freeze()
connections = config.connections.ConnectionsClass(config=config)
self.config = config
self.taskClass = taskClass
self.label = label
self.connections = config.connections.ConnectionsClass(config=config)
self.connections = connections

@property
def configDatasetName(self) -> str:
"""Name of a dataset type for configuration of this task (`str`)"""
return self.label + "_config"
return pipeline_graph.WriteEdge.CONFIG_INIT_OUTPUT_TEMPLATE.format(label=self.label)

@property
def metadataDatasetName(self) -> Optional[str]:
"""Name of a dataset type for metadata of this task, `None` if
metadata is not to be saved (`str`)
"""
if self.config.saveMetadata:
return self.makeMetadataDatasetName(self.label)
return self.makeMetadataDatasetName(label=self.label)
else:
return None

Expand All @@ -198,15 +202,15 @@ def makeMetadataDatasetName(cls, label: str) -> str:
name : `str`
Name of the task's metadata dataset type.
"""
return f"{label}_metadata"
return pipeline_graph.WriteEdge.METADATA_OUTPUT_TEMPLATE.format(label=label)

@property
def logOutputDatasetName(self) -> Optional[str]:
"""Name of a dataset type for log output from this task, `None` if
logs are not to be saved (`str`)
"""
if cast(PipelineTaskConfig, self.config).saveLogOutput:
return self.label + "_log"
return pipeline_graph.WriteEdge.LOG_OUTPUT_TEMPLATE.format(label=self.label)
else:
return None

Expand Down Expand Up @@ -731,6 +735,24 @@ def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""
self._pipelineIR.write_to_uri(uri)

def to_graph(self) -> pipeline_graph.MutablePipelineGraph:
graph = pipeline_graph.MutablePipelineGraph()
for label in self._pipelineIR.tasks:
self._add_task_to_graph(label, graph)
if self._pipelineIR.contracts is not None:
label_to_config = {x.label: x.config for x in graph.tasks.values()}
for contract in self._pipelineIR.contracts:
# execute this in its own line so it can raise a good error
# message if there was problems with the eval
success = eval(contract.contract, None, label_to_config)
if not success:
extra_info = f": {contract.msg}" if contract.msg is not None else ""
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)
graph.sort()
return graph

def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
"""Returns a generator of TaskDefs which can be used to create quantum
graphs.
Expand All @@ -747,31 +769,12 @@ def toExpandedPipeline(self) -> Generator[TaskDef, None, None]:
If a dataId is supplied in a config block. This is in place for
future use
"""
taskDefs = []
for label in self._pipelineIR.tasks:
taskDefs.append(self._buildTaskDef(label))
yield from self.to_graph()._iter_task_defs()

# lets evaluate the contracts
if self._pipelineIR.contracts is not None:
label_to_config = {x.label: x.config for x in taskDefs}
for contract in self._pipelineIR.contracts:
# execute this in its own line so it can raise a good error
# message if there was problems with the eval
success = eval(contract.contract, None, label_to_config)
if not success:
extra_info = f": {contract.msg}" if contract.msg is not None else ""
raise pipelineIR.ContractError(
f"Contract(s) '{contract.contract}' were not satisfied{extra_info}"
)

taskDefs = sorted(taskDefs, key=lambda x: x.label)
yield from pipeTools.orderPipeline(taskDefs)

def _buildTaskDef(self, label: str) -> TaskDef:
def _add_task_to_graph(self, label: str, graph: pipeline_graph.MutablePipelineGraph) -> None:
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
raise NameError(f"Label {label} does not appear in this pipeline")
taskClass: Type[PipelineTask] = doImportType(taskIR.klass)
taskName = get_full_type_name(taskClass)
config = taskClass.ConfigClass()
overrides = ConfigOverrides()
if self._pipelineIR.instrument is not None:
Expand All @@ -793,13 +796,16 @@ def _buildTaskDef(self, label: str) -> TaskDef:
for key, value in configIR.rest.items():
overrides.addValueOverride(key, value)
overrides.applyTo(config)
return TaskDef(taskName=taskName, config=config, taskClass=taskClass, label=label)
graph.add_task(label, taskClass, config)

def __iter__(self) -> Generator[TaskDef, None, None]:
return self.toExpandedPipeline()

def __getitem__(self, item: str) -> TaskDef:
return self._buildTaskDef(item)
graph = pipeline_graph.MutablePipelineGraph()
self._add_task_to_graph(item, graph)
(result,) = graph._iter_task_defs()
return result

def __len__(self) -> int:
return len(self._pipelineIR.tasks)
Expand Down Expand Up @@ -1071,7 +1077,7 @@ def makeDatasetTypesSet(
DatasetType(
taskDef.configDatasetName,
registry.dimensions.empty,
storageClass="Config",
storageClass=pipeline_graph.WriteEdge.CONFIG_INIT_OUTPUT_STORAGE_CLASS,
)
)
initOutputs.freeze()
Expand All @@ -1089,15 +1095,23 @@ def makeDatasetTypesSet(
current = registry.getDatasetType(taskDef.metadataDatasetName)
except KeyError:
# No previous definition so use the default.
storageClass = "TaskMetadata" if _TASK_METADATA_TYPE is TaskMetadata else "PropertySet"
storageClass = pipeline_graph.WriteEdge.METADATA_OUTPUT_STORAGE_CLASS
else:
storageClass = current.storageClass.name

outputs.update({DatasetType(taskDef.metadataDatasetName, dimensions, storageClass)})
if taskDef.logOutputDatasetName is not None:
# Log output dimensions correspond to a task quantum.
dimensions = registry.dimensions.extract(taskDef.connections.dimensions)
outputs.update({DatasetType(taskDef.logOutputDatasetName, dimensions, "ButlerLogRecords")})
outputs.update(
{
DatasetType(
taskDef.logOutputDatasetName,
dimensions,
pipeline_graph.WriteEdge.LOG_OUTPUT_STORAGE_CLASS,
)
}
)

outputs.freeze()

Expand Down
4 changes: 2 additions & 2 deletions python/lsst/pipe/base/pipelineIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import re
import warnings
from collections import Counter
from collections.abc import Iterable as abcIterable
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import Any, Dict, Generator, Hashable, List, Literal, MutableMapping, Optional, Set, Union

Expand Down Expand Up @@ -162,7 +162,7 @@ def from_primitives(label: str, value: Union[List[str], dict]) -> LabeledSubset:
"If a labeled subset is specified as a mapping, it must contain the key 'subset'"
)
description = value.pop("description", None)
elif isinstance(value, abcIterable):
elif isinstance(value, Iterable):
subset = value
description = None
else:
Expand Down
29 changes: 29 additions & 0 deletions python/lsst/pipe/base/pipeline_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations

from ._abcs import *
from ._dataset_types import *
from ._edges import *
from ._exceptions import *
from ._pipeline_graph import *
from ._task_subsets import *
from ._tasks import *
Loading

0 comments on commit b77c85d

Please sign in to comment.