From e7c332762fc79b0e04fdfc7915537cf0c9c69fc7 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 13 Oct 2023 15:40:17 +0200 Subject: [PATCH] ENH: adapt implementation to QRules v0.10 --- src/ampform/helicity/__init__.py | 31 +++++++++++++++--- src/ampform/helicity/align/dpd.py | 52 +++++++++++++++++++++++------- src/ampform/helicity/decay.py | 48 +++++++++++++++++++++++---- src/ampform/kinematics/__init__.py | 12 +++++-- 4 files changed, 118 insertions(+), 25 deletions(-) diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index 47809a32c..c32c12142 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -30,8 +30,14 @@ from attrs.validators import deep_iterable, instance_of, optional from qrules.combinatorics import perform_external_edge_identical_particle_combinatorics from qrules.particle import Particle -from qrules.transition import ReactionInfo, StateTransition +from qrules.transition import ( + InteractionProperties, + ReactionInfo, + State, + StateTransition, +) +from ampform._qrules import get_qrules_version from ampform.dynamics.builder import ( ResonanceDynamicsBuilder, TwoBodyKinematicVariableSet, @@ -70,6 +76,7 @@ if TYPE_CHECKING: from IPython.lib.pretty import PrettyPrinter + from qrules.topology import MutableTransition _LOGGER = logging.getLogger(__name__) @@ -450,11 +457,9 @@ def __formulate_topology_amplitude( ) -> sp.Expr: sequential_expressions: list[sp.Expr] = [] for transition in transitions: - sequential_graphs = perform_external_edge_identical_particle_combinatorics( - transition.to_graph() - ) + sequential_graphs = _perform_combinatorics(transition) for graph in sequential_graphs: - first_transition = StateTransition.from_graph(graph) + first_transition = _freeze(graph) expression = self.__formulate_sequential_decay(first_transition) sequential_expressions.append(expression) @@ -558,6 +563,22 @@ def __generate_amplitude_prefactor( return None +def _perform_combinatorics( + transition: StateTransition, +) -> list[MutableTransition[State, InteractionProperties]]: + if get_qrules_version() < (0, 10): + return perform_external_edge_identical_particle_combinatorics( + transition.to_graph() # type: ignore[attr-defined] + ) + return perform_external_edge_identical_particle_combinatorics(transition.unfreeze()) + + +def _freeze(graph: MutableTransition[State, InteractionProperties]) -> StateTransition: + if get_qrules_version() < (0, 10): + return StateTransition.from_graph(graph) # type: ignore[attr-defined] + return graph.freeze() + + class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder): r"""Amplitude model generator for the canonical helicity formalism. diff --git a/src/ampform/helicity/align/dpd.py b/src/ampform/helicity/align/dpd.py index 79138bb03..09a5598eb 100644 --- a/src/ampform/helicity/align/dpd.py +++ b/src/ampform/helicity/align/dpd.py @@ -13,9 +13,10 @@ from attrs import define, field from attrs.validators import in_ from qrules.topology import Topology -from qrules.transition import ReactionInfo, StateTransition, StateTransitionCollection +from qrules.transition import ReactionInfo, StateTransition from sympy.physics.quantum.spin import Rotation as Wigner +from ampform._qrules import get_qrules_version from ampform.helicity.align import SpinAlignment from ampform.helicity.decay import ( get_outer_state_ids, @@ -34,6 +35,11 @@ if TYPE_CHECKING: from sympy.physics.quantum.spin import WignerD +if get_qrules_version() < (0, 10): + from qrules.transition import ( # type: ignore[attr-defined] + StateTransitionCollection, + ) + @define class DalitzPlotDecomposition(SpinAlignment): @@ -109,8 +115,14 @@ def __call__( return Wigner.d(j, m, m_prime, zeta) -T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology) -"""Allowed types for :func:`relabel_edge_ids`.""" +if get_qrules_version() < (0, 10): + T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology) + """Allowed types for :func:`relabel_edge_ids`.""" +else: + T = TypeVar( # type: ignore[misc] # pyright: ignore[reportConstantRedefinition] + "T", ReactionInfo, StateTransition, Topology + ) + """Allowed types for :func:`relabel_edge_ids`.""" @singledispatch @@ -121,21 +133,29 @@ def relabel_edge_ids(obj: T) -> T: @relabel_edge_ids.register(ReactionInfo) def _(obj: ReactionInfo) -> ReactionInfo: # type: ignore[misc] - return ReactionInfo( # no attrs.evolve() in order to call __attrs_post_init__() - transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], + if get_qrules_version() < (0, 10): + return ReactionInfo( # type: ignore[call-arg] + transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], # type: ignore[attr-defined] + formalism=obj.formalism, + ) + return ReactionInfo( + # no attrs.evolve() in order to call __attrs_post_init__() + transitions=[relabel_edge_ids(g) for g in obj.transitions], formalism=obj.formalism, ) -@relabel_edge_ids.register(StateTransitionCollection) -def _(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc] - return StateTransitionCollection( # no attrs.evolve() for __attrs_post_init__() - [relabel_edge_ids(transition) for transition in obj.transitions] - ) +if get_qrules_version() < (0, 10): + def __relabel_stc(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc] + return StateTransitionCollection( + [relabel_edge_ids(transition) for transition in obj.transitions] + ) -@relabel_edge_ids.register(StateTransition) -def _(obj: StateTransition) -> StateTransition: # type: ignore[misc] + relabel_edge_ids.register(StateTransitionCollection)(__relabel_stc) + + +def __relabel_st(obj: StateTransition) -> StateTransition: # type: ignore[misc] mapping = __get_default_relabel_mapping() return attrs.evolve( obj, @@ -144,6 +164,14 @@ def _(obj: StateTransition) -> StateTransition: # type: ignore[misc] ) +if get_qrules_version() < (0, 10): + relabel_edge_ids.register(StateTransition)(__relabel_st) +else: + from qrules.topology import FrozenTransition + + relabel_edge_ids.register(FrozenTransition)(__relabel_st) + + @relabel_edge_ids.register(Topology) def _(obj: Topology) -> Topology: # type: ignore[misc] mapping = __get_default_relabel_mapping() diff --git a/src/ampform/helicity/decay.py b/src/ampform/helicity/decay.py index c2e715583..a2e738b79 100644 --- a/src/ampform/helicity/decay.py +++ b/src/ampform/helicity/decay.py @@ -7,16 +7,22 @@ from typing import TYPE_CHECKING, Iterable from attrs import frozen +from qrules.quantum_numbers import InteractionProperties from qrules.transition import ReactionInfo, State, StateTransition +from ampform._qrules import get_qrules_version + if TYPE_CHECKING: - from qrules.quantum_numbers import InteractionProperties from qrules.topology import Topology if sys.version_info < (3, 8): from typing_extensions import Literal else: from typing import Literal +if sys.version_info < (3, 10): + from typing_extensions import TypeGuard +else: + from typing import TypeGuard @frozen @@ -103,12 +109,30 @@ def _(obj: TwoBodyDecay) -> TwoBodyDecay: def _(obj: tuple) -> TwoBodyDecay: if len(obj) == 2: # noqa: PLR2004 transition, node_id = obj - if isinstance(transition, StateTransition) and isinstance(node_id, int): - return TwoBodyDecay.from_transition(*obj) + if _is_qrules_state_transition(transition) and isinstance(node_id, int): + return TwoBodyDecay.from_transition(transition, node_id) msg = f"Cannot create a {TwoBodyDecay.__name__} from {obj}" raise NotImplementedError(msg) +def _is_qrules_state_transition(obj) -> TypeGuard[StateTransition]: + if get_qrules_version() >= (0, 10): + from qrules.topology import FrozenTransition + + if isinstance(obj, FrozenTransition): + if any(not isinstance(s, State) for s in obj.states.values()): + return False + if any( + not isinstance(i, InteractionProperties) + for i in obj.interactions.values() + ): + return False + return True + if get_qrules_version() < (0, 10) and isinstance(obj, StateTransition): # type: ignore[misc] + return True + return False + + @lru_cache(maxsize=None) def is_opposite_helicity_state(topology: Topology, state_id: int) -> bool: """Determine if an edge is an "opposite helicity" state. @@ -328,8 +352,13 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in >>> from qrules.topology import create_isobar_topologies >>> topologies = create_isobar_topologies(5) - >>> determine_attached_final_state(topologies[0], state_id=5) + >>> determine_attached_final_state(topologies[3], state_id=5) [0, 3, 4] + >>> import pytest + >>> from ampform._qrules import get_qrules_version + >>> if get_qrules_version() < (0, 10): + ... pytest.skip('Doctest only works for qrules>=0.10') + ... """ edge = topology.edges[state_id] if edge.ending_node_id is None: @@ -343,13 +372,20 @@ def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]: raise NotImplementedError(msg) -@get_outer_state_ids.register(StateTransition) -def _(transition: StateTransition) -> list[int]: +def __convert_state_transition(transition: StateTransition) -> list[int]: outer_state_ids = list(transition.initial_states) outer_state_ids += sorted(transition.final_states) return outer_state_ids +if get_qrules_version() < (0, 10): + get_outer_state_ids.register(StateTransition)(__convert_state_transition) +else: + from qrules.topology import FrozenTransition + + get_outer_state_ids.register(FrozenTransition)(__convert_state_transition) + + @get_outer_state_ids.register(ReactionInfo) def _(reaction: ReactionInfo) -> list[int]: return get_outer_state_ids(reaction.transitions[0]) diff --git a/src/ampform/kinematics/__init__.py b/src/ampform/kinematics/__init__.py index 3ae4cddcd..cac5816f7 100644 --- a/src/ampform/kinematics/__init__.py +++ b/src/ampform/kinematics/__init__.py @@ -16,6 +16,7 @@ from qrules.topology import Topology from qrules.transition import ReactionInfo, StateTransition +from ampform._qrules import get_qrules_version from ampform.helicity.decay import assert_isobar_topology from ampform.kinematics.angles import compute_helicity_angles from ampform.kinematics.lorentz import ( @@ -120,6 +121,13 @@ def _(obj: Topology) -> Topology: return obj -@_get_topology.register(StateTransition) -def _(obj: StateTransition) -> Topology: +def __get_state_transition(obj: StateTransition) -> Topology: return obj.topology + + +if get_qrules_version() < (0, 10): + _get_topology.register(StateTransition)(__get_state_transition) +else: + from qrules.topology import FrozenTransition + + _get_topology.register(FrozenTransition)(__get_state_transition)