Skip to content

Commit

Permalink
Make PipelineGraph importing lazy and move all state to instances.
Browse files Browse the repository at this point in the history
  • Loading branch information
TallJimbo committed Mar 7, 2023
1 parent d1610f3 commit 5b67ba0
Show file tree
Hide file tree
Showing 6 changed files with 613 additions and 426 deletions.
208 changes: 113 additions & 95 deletions python/lsst/pipe/base/pipeline_graph/_abcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@
import itertools
from abc import ABC, abstractmethod
from collections.abc import Iterable, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, TypeVar, cast

import networkx
from lsst.daf.butler import DatasetRef, DatasetType, DimensionUniverse, Registry
from lsst.utils.classes import immutable

from ._exceptions import ConnectionTypeConsistencyError, IncompatibleDatasetTypeError
from ._exceptions import ConnectionTypeConsistencyError

if TYPE_CHECKING:
from ..connectionTypes import BaseConnection
Expand All @@ -56,6 +57,7 @@ class NodeType(enum.Enum):
DATASET_TYPE = 1


@immutable
class NodeKey(NamedTuple):
"""A special key type for nodes in networkx graphs.
Expand Down Expand Up @@ -86,44 +88,53 @@ def __str__(self) -> str:
return self.name


@immutable
class Node(ABC):
"""Base class for nodes in a pipeline graph."""
"""Base class for nodes in a pipeline graph.
Parameters
----------
key : `NodeKey`
The key for this node in networkx graphs.
"""

def __init__(self, key: NodeKey):
self.key = key

key: NodeKey
"""The key for this node in networkx graphs."""

@abstractmethod
def _resolve(self, state: dict[str, Any], xgraph: networkx.DiGraph, registry: Registry) -> None:
def _resolved(self, xgraph: networkx.DiGraph, registry: Registry) -> Node:
"""Resolve any dataset type and dimension names in this graph.
Parameters
----------
state : `dict`
The state dictionary that networkx associates with this node. This
`Node` instance will be the value associated with the "instance"
key, and on return that instance should be replaced with a resolved
version of the node (if ``self`` is not already resolved).
graph : `networkx.DiGraph`
xgraph : `networkx.DiGraph`
Directed bipartite graph representing the full pipeline. Should
not be modified.
registry : `lsst.daf.butler.Registry`
Registry that provides dimension and dataset type information.
Notes
-----
This should do nothing if the node is already resolved.
Returns
-------
node : `Node`
Resolved version of this node. May be self if the node is already
resolved.
"""
raise NotImplementedError()

@abstractmethod
def _unresolve(self, state: dict[str, Any]) -> None:
def _unresolved(self) -> Node:
"""Revert this node to a form that just holds names for dataset types
and dimensions, allowing `_reresolve` to have an effect if called
again.
Notes
-----
This should do nothing if the node is already unresolved.
Returns
-------
node : `Node`
Resolved version of this node. May be self if the node is already
resolved.
"""
raise NotImplementedError()

Expand All @@ -138,29 +149,54 @@ def _serialize(self) -> dict[str, Any]:
raise NotImplementedError()


@immutable
class Edge(ABC):
"""Base class for edges in a pipeline graph.
This represents the link between a task node and an input or output dataset
type. Task-only and dataset-type-only views of the fully graph do not have
type. Task-only and dataset-type-only views of the full graph do not have
stateful edges.
Notes
-----
A ``state`` dictionary is passed in to all private methods here because an
`Edge` instance actually holds some of its logical state "outside itself"
in the same networkx dictionary that also holds the `Edge` instance. This
keeps information that is really a property of the graph-wide dataset type
out of the public edge object interfaces, while still letting us check it
(and provide good error messages) here.
Parameters
----------
task_key : `NodeKey`
Key for the task node this edge is connected to.
dataset_type_key : `NodeKey`
Key for the dataset type node this edge is connected to.
storage_class_name : `str`
Name of the dataset type's storage class as seen by the task.
connection_name : `str`
Internal name for the connection as seen by the task.
connection : `BaseConnection`
Post-configuration object to draw dataset type information from.
"""

def __init__(
self,
*,
task_key: NodeKey,
dataset_type_key: NodeKey,
storage_class_name: str,
connection_name: str,
):
self.task_key = task_key
self.dataset_type_key = dataset_type_key
self.storage_class_name = storage_class_name
self._connection_name = connection_name

task_key: NodeKey
"""Task part of the key for this edge in networkx graphs."""

dataset_type_key: NodeKey
"""Task part of the key for this edge in networkx graphs."""

storage_class_name: str
"""Storage class expected by this task.
If `component` is not `None`, this is the component storage class, not the
parent storage class.
"""

@property
def task_label(self) -> str:
"""Label of the task."""
Expand All @@ -186,6 +222,18 @@ def key(self) -> tuple[NodeKey, NodeKey]:
"""
raise NotImplementedError()

def __eq__(self, other: object) -> bool:
try:
return self.key == cast(Edge, other).key
except AttributeError:
return NotImplemented

def __hash__(self) -> int:
return hash(self.key)

def __repr__(self) -> str:
return f"{self.key[0]} -> {self.key[1]}"

@property
def dataset_type_name(self) -> str:
"""Dataset type name seen by the task.
Expand Down Expand Up @@ -214,29 +262,19 @@ def adapt_dataset_ref(self, ref: DatasetRef) -> DatasetRef:
def _from_connection(
cls,
task_label: str,
connection_name: str,
connection: BaseConnection,
edge_data: list[tuple[NodeKey, NodeKey, dict[str, Any]]],
*,
is_init: bool,
is_prerequisite: bool = False,
) -> Edge:
"""Construct an `Edge` instance from a `BaseConnection` object.
Parameters
----------
task_label : `str`
Label of the task.
connection_name : `str`
Internal name for the connection as seen by the task,.
connection : `BaseConnection`
Post-configuration object to draw dataset type information from.
edge_data : `list`
List of networkx edge data 3-tuples to append to. The first two
items are the edge `key`, and the last is the state dictionary,
which should have the returned `Edge` instance as the value of its
``instance``key.
is_init : `bool`
Whether this is an init-input or init-output edge.
is_prerequisite : `bool`
Whether this is a prerequisite input edge.
Returns
-------
Expand All @@ -248,39 +286,41 @@ def _from_connection(
@abstractmethod
def _check_dataset_type(
self,
state: dict[str, Any],
xgraph: networkx.DiGraph,
dataset_type_node: DatasetTypeNode,
is_init: bool,
is_prerequisite: bool,
) -> None:
"""Check the a potential graph-wide definition of a dataset type for
consistency with this edge.
Parameters
-----------
state : `dict`
The networkx dictionary that holds this edge's state in the graph.
The value associated with the "instance" key is ``self``.
xgraph : `networkx.DiGraph`
Directed bipartite graph representing the full pipeline.
dataset_type_node : `DatasetTypeNode`
Dataset type node to be checked.
is_init : `bool`
Whether this is an init-input or init-output edge.
is_prerequisite : `bool`
Whether this is a prerequisite input edge.
Raises
------
ConnectionTypeConsistencyError
Raised if the dataset type node's ``is_init`` or ``is_prequisite``
flags are inconsistent with this edge.
Raised if the dataset type node's ``is_init`` or
``is_prerequisite`` flags are inconsistent with this edge.
IncompatibleDatasetTypeError
Raised if the dataset type itself is incompatible with this edge.
"""
if state["is_init"] != dataset_type_node.is_init:
if is_init != dataset_type_node.is_init:
referencing_tasks = list(
itertools.chain(
xgraph.predecessors(dataset_type_node.name),
xgraph.successors(dataset_type_node.name),
)
)
if state["is_init"]:
if is_init:
raise ConnectionTypeConsistencyError(
f"{dataset_type_node.name!r} is an init dataset in task {self.task_label!r}, "
f"but a run dataset in task(s) {referencing_tasks}."
Expand All @@ -290,9 +330,9 @@ def _check_dataset_type(
f"{dataset_type_node.name!r} is a run dataset in task {self.task_label!r}, "
f"but an init dataset in task(s) {referencing_tasks}."
)
if state["is_prerequisite"] != dataset_type_node.is_prerequisite:
if is_prerequisite != dataset_type_node.is_prerequisite:
referencing_tasks = list(xgraph.successors(dataset_type_node.name))
if state["is_prerequisite"]:
if is_prerequisite:
raise ConnectionTypeConsistencyError(
f"Dataset type {dataset_type_node.name!r} is a prerequisite input in "
f"task {self.task_label!r}, but it was not a prerequisite to "
Expand All @@ -304,51 +344,11 @@ def _check_dataset_type(
f"task {self.task_label!r}, but it was a prerequisite to "
f"{referencing_tasks}."
)
connection: BaseConnection = state["connection"]
if connection.isCalibration != dataset_type_node.is_calibration:
referencing_tasks = list(
itertools.chain(
xgraph.predecessors(dataset_type_node.name),
xgraph.successors(dataset_type_node.name),
)
)
if connection.isCalibration:
raise IncompatibleDatasetTypeError(
f"Dataset type {dataset_type_node.name!r} is a calibration in "
f"task {self.task_label}, but it was not in task(s) {referencing_tasks}."
)
else:
raise IncompatibleDatasetTypeError(
f"Dataset type {dataset_type_node.name!r} is not a calibration in "
f"task {self.task_label}, but it was in task(s) {referencing_tasks}."
)

@abstractmethod
def _make_dataset_type_state(self, state: dict[str, Any]) -> dict[str, Any]:
"""Make a networkx state dictionary for the dataset type node that is
one side of this edge.
Parameters
----------
state : `dict`
The networkx dictionary that holds this edge's state in the graph.
The value associated with the "instance" key is ``self``.
Returns
-------
node_state : `dict`
The networkx dictionary that holds the node's state in the graph.
Must have the following keys:
- instance: a `DatasetTypeNode` instance
- bipartite: integer set to ``DatasetTypeKey.node_type.value``
"""
raise NotImplementedError()

@abstractmethod
def _resolve_dataset_type(
self,
state: dict[str, Any],
connection: BaseConnection,
current: DatasetType | None,
is_initial_query_constraint: bool,
universe: DimensionUniverse,
Expand All @@ -358,9 +358,9 @@ def _resolve_dataset_type(
Parameters
----------
state : `dict`
The networkx dictionary that holds this edge's state in the graph.
The value associated with the "instance" key is ``self``.
connection : `.BaseConnection`
Object provided by the task to describe this edge, or `None` if the
edge was added by the framework.
current : `lsst.daf.butler.DatasetType` or `None`
The current graph-wide `DatasetType`, or `None`. This will always
be the registry's definition of the parent dataset type, if one
Expand Down Expand Up @@ -394,15 +394,14 @@ def _resolve_dataset_type(
"""
raise NotImplementedError()

@abstractmethod
def _serialize(self) -> dict[str, Any]:
"""Serialize the content of this edge into a dictionary of built-in
objects suitable for JSON conversion.
This should not include the edge's parent dataset type and task label,
as it is always serialized in a context that already identifies those.
"""
raise NotImplementedError()
return {"storage_class_name": self.storage_class_name, "connection_name": self._connection_name}


_G = TypeVar("_G", bound=networkx.DiGraph, covariant=True)
Expand Down Expand Up @@ -475,10 +474,12 @@ def __init__(self, parent_xgraph: networkx.DiGraph) -> None:

@abstractmethod
def _make_node_key(self, arg: str) -> NodeKey:
"""Make a `NodeKey` instance from the given string key."""
raise NotImplementedError()

@abstractmethod
def _contains_node_key(self, key: NodeKey) -> bool:
"""Test whether a `NodeKey` belongs to this view."""
raise NotImplementedError()

def __contains__(self, key: object) -> bool:
Expand Down Expand Up @@ -509,11 +510,28 @@ def __str__(self) -> str:
return f"{{{', '.join(iter(self))}}}"

def _reorder(self, parent_keys: Sequence[NodeKey]) -> None:
"""Set this view's iteration order according to the given iterable of
parent keys.
Parameters
----------
parent_keys : `~collections.abc.Sequence` [ `NodeKey` ]
Superset of the keys in this view, in the new order.
"""
self._keys = self._make_keys(parent_keys)

def _reset(self) -> None:
# Docstring inherited.
super()._reset()
self._keys = None

def _make_keys(self, parent_keys: Iterable[NodeKey]) -> list[str]:
"""Make a sequence of keys for this view from an iterable of parent
keys.
Parameters
----------
parent_keys : `~collections.abc.Iterable` [ `NodeKey` ]
Superset of the keys in this view.
"""
return [str(k) for k in parent_keys if self._contains_node_key(k)]
Loading

0 comments on commit 5b67ba0

Please sign in to comment.