diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index a484bd27e..297b9a165 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -148,7 +148,7 @@ python-lsp-server==1.10.0 pytoolconfig==1.3.1 pyyaml==6.0.1 pyzmq==25.1.2 -qrules==0.9.8 +qrules==0.10.1 referencing==0.33.0 requests==2.31.0 rfc3339-validator==0.1.4 diff --git a/.constraints/py3.11.txt b/.constraints/py3.11.txt index 13d1daeea..d7bdd9d2c 100644 --- a/.constraints/py3.11.txt +++ b/.constraints/py3.11.txt @@ -147,7 +147,7 @@ python-lsp-server==1.10.0 pytoolconfig==1.3.1 pyyaml==6.0.1 pyzmq==25.1.2 -qrules==0.9.8 +qrules==0.10.1 referencing==0.33.0 requests==2.31.0 rfc3339-validator==0.1.4 diff --git a/.constraints/py3.12.txt b/.constraints/py3.12.txt index d6225401d..2d3518278 100644 --- a/.constraints/py3.12.txt +++ b/.constraints/py3.12.txt @@ -147,7 +147,7 @@ python-lsp-server==1.10.0 pytoolconfig==1.3.1 pyyaml==6.0.1 pyzmq==25.1.2 -qrules==0.9.8 +qrules==0.10.1 referencing==0.33.0 requests==2.31.0 rfc3339-validator==0.1.4 diff --git a/.constraints/py3.7.txt b/.constraints/py3.7.txt index d2f6d5491..9981c4ae7 100644 --- a/.constraints/py3.7.txt +++ b/.constraints/py3.7.txt @@ -153,7 +153,7 @@ pytoolconfig==1.3.0 pytz==2024.1 pyyaml==6.0.1 pyzmq==24.0.1 -qrules==0.9.8 +qrules==0.10.1 requests==2.31.0 rfc3339-validator==0.1.4 rfc3986-validator==0.1.1 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index 25d61ddd1..5fd5e6f22 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -153,7 +153,7 @@ pytoolconfig==1.3.1 pytz==2024.1 pyyaml==6.0.1 pyzmq==25.1.2 -qrules==0.9.8 +qrules==0.10.1 referencing==0.33.0 requests==2.31.0 rfc3339-validator==0.1.4 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index d56ad3f2c..40fb5b455 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -149,7 +149,7 @@ python-lsp-server==1.10.0 pytoolconfig==1.3.1 pyyaml==6.0.1 pyzmq==25.1.2 -qrules==0.9.8 +qrules==0.10.1 referencing==0.33.0 requests==2.31.0 rfc3339-validator==0.1.4 diff --git a/.github/workflows/ci-qrules-v0.9.yml b/.github/workflows/ci-qrules-v0.9.yml new file mode 100644 index 000000000..b416eba2b --- /dev/null +++ b/.github/workflows/ci-qrules-v0.9.yml @@ -0,0 +1,36 @@ +name: Test with QRules v0.9 + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + PYTHONHASHSEED: "0" + +on: + push: + branches: + - main + - epic/* + - "[0-9]+.[0-9]+.x" + pull_request: + branches: + - main + - epic/* + - "[0-9]+.[0-9]+.x" + workflow_dispatch: + +jobs: + pytest: + name: Run unit tests + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - uses: ComPWA/actions/pip-install@v1 + with: + additional-packages: tox + editable: "yes" + extras: test + python-version: "3.9" + specific-packages: qrules==0.9.* + - run: pytest -n auto diff --git a/docs/conf.py b/docs/conf.py index 4f9f9f781..a54b72725 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,7 +59,7 @@ "ReactionInfo": "qrules.transition.ReactionInfo", "Slider": ("obj", "symplot.Slider"), "State": "qrules.transition.State", - "StateTransition": "qrules.transition.StateTransition", + "StateTransition": "qrules.topology.Transition", "T": "TypeVar", "Topology": "qrules.topology.Topology", "WignerD": "sympy.physics.quantum.spin.WignerD", @@ -238,7 +238,7 @@ "numpy": (f"https://numpy.org/doc/{pin_minor('numpy')}", None), "pwa": ("https://pwa.readthedocs.io", None), "python": ("https://docs.python.org/3", None), - "qrules": (f"https://qrules.readthedocs.io/en/{pin('qrules')}", None), + "qrules": (f"https://qrules.readthedocs.io/{pin('qrules')}", None), "sympy": ("https://docs.sympy.org/latest", None), } linkcheck_anchors = False diff --git a/docs/usage/amplitude.ipynb b/docs/usage/amplitude.ipynb index 46c599eb4..00f4a15df 100644 --- a/docs/usage/amplitude.ipynb +++ b/docs/usage/amplitude.ipynb @@ -99,7 +99,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In {doc}`qrules:usage/reaction`, we used {func}`~qrules.generate_transitions` to create a list of allowed {class}`~qrules.transition.StateTransition`s for a specific decay channel:" + "In {doc}`qrules:usage/reaction`, we used {func}`~qrules.generate_transitions` to create a list of allowed {class}`~qrules.topology.Transition`s for a specific decay channel:" ] }, { diff --git a/docs/usage/dynamics/custom.ipynb b/docs/usage/dynamics/custom.ipynb index b10325601..4c676d46e 100644 --- a/docs/usage/dynamics/custom.ipynb +++ b/docs/usage/dynamics/custom.ipynb @@ -176,7 +176,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "A function that behaves like a {class}`.ResonanceDynamicsBuilder` should return a {class}`tuple` of some {class}`~sympy.core.expr.Expr` (which formulates your lineshape) and a {class}`dict` of {class}`~sympy.core.symbol.Symbol`s to some suggested initial values. This signature is required so the builder knows how to extract the correct symbol names and their suggested initial values from a {class}`~qrules.transition.StateTransition`." + "A function that behaves like a {class}`.ResonanceDynamicsBuilder` should return a {class}`tuple` of some {class}`~sympy.core.expr.Expr` (which formulates your lineshape) and a {class}`dict` of {class}`~sympy.core.symbol.Symbol`s to some suggested initial values. This signature is required so the builder knows how to extract the correct symbol names and their suggested initial values from a {class}`~qrules.topology.Transition`." ] }, { diff --git a/docs/usage/helicity/formalism.ipynb b/docs/usage/helicity/formalism.ipynb index 32468c981..e9003a404 100644 --- a/docs/usage/helicity/formalism.ipynb +++ b/docs/usage/helicity/formalism.ipynb @@ -332,9 +332,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "See {func}`.formulate_isobar_wigner_d` and {func}`.formulate_isobar_cg_coefficients` for how these Wigner-$D$ functions and Clebsch-Gordan coefficients are computed for each node on a {class}`~qrules.transition.StateTransition`.\n", + "See {func}`.formulate_isobar_wigner_d` and {func}`.formulate_isobar_cg_coefficients` for how these Wigner-$D$ functions and Clebsch-Gordan coefficients are computed for each node on a {class}`~qrules.topology.Transition`.\n", "\n", - "We can see this also from the original {class}`~qrules.transition.ReactionInfo` objects. Let's select only the {attr}`~qrules.transition.ReactionInfo.transitions` where the $a_1(1260)^+$ resonance has spin projection $-1$ (taken to be helicity $-1$ in the helicity formalism). We then see just one {class}`~qrules.transition.StateTransition` in the helicity basis and three transitions in the canonical basis:" + "We can see this also from the original {class}`~qrules.transition.ReactionInfo` objects. Let's select only the {attr}`~qrules.transition.ReactionInfo.transitions` where the $a_1(1260)^+$ resonance has spin projection $-1$ (taken to be helicity $-1$ in the helicity formalism). We then see just one {class}`~qrules.topology.Transition` in the helicity basis and three transitions in the canonical basis:" ] }, { diff --git a/pyproject.toml b/pyproject.toml index c38ed85a1..e8f38db92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,9 @@ classifiers = [ ] dependencies = [ "attrs >=20.1.0", # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen - "qrules ==0.9.*, >=0.9.6", # https://github.com/ComPWA/qrules/pull/145 + "qrules >=0.9.6", "sympy >=1.10", + 'importlib-metadata; python_version <"3.8.0"', 'singledispatchmethod; python_version <"3.8.0"', 'typing-extensions; python_version <"3.8.0"', ] @@ -208,6 +209,7 @@ reportPrivateUsage = false reportReturnType = false reportUnboundVariable = false reportUnknownArgumentType = false +reportUnknownLambdaType = false reportUnknownMemberType = false reportUnknownParameterType = false reportUnknownVariableType = false diff --git a/src/ampform/__init__.py b/src/ampform/__init__.py index 030e25d0b..466095cbc 100644 --- a/src/ampform/__init__.py +++ b/src/ampform/__init__.py @@ -1,7 +1,7 @@ """Build amplitude models with different PWA formalisms. AmpForm formalizes formalisms from :doc:`Partial Wave Analysis `. It provides -tools to convert `~qrules.transition.StateTransition` solutions that the `.qrules` +tools to convert `~qrules.topology.Transition` solutions that the `.qrules` package found into an `.HelicityModel`. The output `.HelicityModel` can then be used by external fitter packages to generate a data set (toy Monte Carlo) for this specific reaction process, or to optimize ('fit') its parameters so that they resemble the data diff --git a/src/ampform/_qrules.py b/src/ampform/_qrules.py new file mode 100644 index 000000000..a325d474b --- /dev/null +++ b/src/ampform/_qrules.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import sys +from functools import lru_cache + +if sys.version_info < (3, 8): + from importlib_metadata import version +else: + from importlib.metadata import version + + +@lru_cache(maxsize=1) +def get_qrules_version() -> tuple[int, ...]: + """Get the version of qrules as a tuple of integers. + + >>> get_qrules_version() >= (0, 10) + True + >>> import pytest + >>> from ampform._qrules import get_qrules_version + >>> if get_qrules_version() < (0, 10): + ... pytest.skip("Doctest only works for qrules>=0.10") + """ + v = version("qrules") + return tuple(int(i) for i in v.split(".") if i.strip().isdigit()) diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index 5adf236bf..05e46ae62 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -31,8 +31,14 @@ from attrs.validators import deep_iterable, instance_of, optional from qrules.combinatorics import perform_external_edge_identical_particle_combinatorics from qrules.particle import Particle -from qrules.transition import ReactionInfo, StateTransition +from qrules.transition import ( + InteractionProperties, + ReactionInfo, + State, + StateTransition, +) +from ampform._qrules import get_qrules_version from ampform.dynamics.builder import ( ResonanceDynamicsBuilder, TwoBodyKinematicVariableSet, @@ -75,6 +81,7 @@ from typing import override if TYPE_CHECKING: from IPython.lib.pretty import PrettyPrinter + from qrules.topology import MutableTransition _LOGGER = logging.getLogger(__name__) @@ -453,11 +460,9 @@ def __formulate_topology_amplitude( ) -> sp.Expr: sequential_expressions: list[sp.Expr] = [] for transition in transitions: - sequential_graphs = perform_external_edge_identical_particle_combinatorics( - transition.to_graph() - ) + sequential_graphs = _perform_combinatorics(transition) for graph in sequential_graphs: - first_transition = StateTransition.from_graph(graph) + first_transition = _freeze(graph) expression = self.__formulate_sequential_decay(first_transition) sequential_expressions.append(expression) @@ -561,6 +566,24 @@ def __generate_amplitude_prefactor( return None +def _perform_combinatorics( + transition: StateTransition, +) -> list[MutableTransition[State, InteractionProperties]]: + if get_qrules_version() < (0, 10): + return perform_external_edge_identical_particle_combinatorics( + transition.to_graph() # type: ignore[attr-defined] + ) + graph = transition.convert(lambda s: (s.particle, s.spin_projection)).unfreeze() + combinations = perform_external_edge_identical_particle_combinatorics(graph) + return [g.freeze().convert(lambda s: State(*s)).unfreeze() for g in combinations] + + +def _freeze(graph: MutableTransition[State, InteractionProperties]) -> StateTransition: + if get_qrules_version() < (0, 10): + return StateTransition.from_graph(graph) # type: ignore[attr-defined] + return graph.freeze() + + class CanonicalAmplitudeBuilder(HelicityAmplitudeBuilder): r"""Amplitude model generator for the canonical helicity formalism. @@ -656,7 +679,7 @@ def assign( # noqa: PLR6301 - `str`: Select transition nodes by the name of the `~.TwoBodyDecay.parent` `~qrules.particle.Particle`. - - `.TwoBodyDecay` or `tuple` of a `~qrules.transition.StateTransition` with a + - `.TwoBodyDecay` or `tuple` of a `~qrules.topology.Transition` with a node ID: set dynamics for one specific transition node. """ msg = ( diff --git a/src/ampform/helicity/align/axisangle.py b/src/ampform/helicity/align/axisangle.py index 409845f72..90cfd31ff 100644 --- a/src/ampform/helicity/align/axisangle.py +++ b/src/ampform/helicity/align/axisangle.py @@ -197,7 +197,7 @@ def formulate_wigner_rotation( :cite:`marangottoHelicityAmplitudesGeneric2020`, p.6, especially Eq.(36). Args: - transition: The `~qrules.transition.StateTransition` in which you + transition: The `~qrules.topology.Transition` in which you want to rotate one of the spin states. rotated_state_id: The state ID of a spin `~qrules.transition.State` that you want to rotate. diff --git a/src/ampform/helicity/align/dpd.py b/src/ampform/helicity/align/dpd.py index 7ca6a4bbe..1b93fc92c 100644 --- a/src/ampform/helicity/align/dpd.py +++ b/src/ampform/helicity/align/dpd.py @@ -14,9 +14,10 @@ from attrs import define, field from attrs.validators import in_ from qrules.topology import Topology -from qrules.transition import ReactionInfo, StateTransition, StateTransitionCollection +from qrules.transition import ReactionInfo, StateTransition from sympy.physics.quantum.spin import Rotation as Wigner +from ampform._qrules import get_qrules_version from ampform.helicity.align import SpinAlignment from ampform.helicity.decay import ( get_outer_state_ids, @@ -35,6 +36,11 @@ if TYPE_CHECKING: from sympy.physics.quantum.spin import WignerD +if get_qrules_version() < (0, 10): + from qrules.transition import ( # type: ignore[attr-defined] + StateTransitionCollection, + ) + @define class DalitzPlotDecomposition(SpinAlignment): @@ -112,33 +118,47 @@ def __call__( return Wigner.d(j, m, m_prime, zeta) -T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology) -"""Allowed types for :func:`relabel_edge_ids`.""" +if get_qrules_version() < (0, 10): + T = TypeVar("T", ReactionInfo, StateTransition, StateTransitionCollection, Topology) + """Allowed types for :func:`relabel_edge_ids`.""" +else: + T = TypeVar( # type: ignore[misc] # pyright: ignore[reportConstantRedefinition] + "T", ReactionInfo, StateTransition, Topology + ) + """Allowed types for :func:`relabel_edge_ids`.""" @singledispatch -def relabel_edge_ids(obj: T) -> T: +def relabel_edge_ids(obj: T) -> T: # type: ignore[reportInvalidTypeForm] msg = f"Cannot relabel edge IDs of a {type(obj).__name__}" raise NotImplementedError(msg) @relabel_edge_ids.register(ReactionInfo) def _(obj: ReactionInfo) -> ReactionInfo: # type: ignore[misc] - return ReactionInfo( # no attrs.evolve() in order to call __attrs_post_init__() - transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], + if get_qrules_version() < (0, 10): + return ReactionInfo( # type: ignore[call-arg] + transition_groups=[relabel_edge_ids(g) for g in obj.transition_groups], # type: ignore[attr-defined] + formalism=obj.formalism, + ) + return ReactionInfo( + # no attrs.evolve() in order to call __attrs_post_init__() + transitions=[relabel_edge_ids(g) for g in obj.transitions], formalism=obj.formalism, ) -@relabel_edge_ids.register(StateTransitionCollection) -def _(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc] - return StateTransitionCollection([ # no attrs.evolve() for __attrs_post_init__() - relabel_edge_ids(transition) for transition in obj.transitions - ]) +if get_qrules_version() < (0, 10): + + def __relabel_stc(obj: StateTransitionCollection) -> StateTransitionCollection: # type: ignore[misc] + return StateTransitionCollection([ + relabel_edge_ids(transition) for transition in obj.transitions + ]) + relabel_edge_ids.register(StateTransitionCollection)(__relabel_stc) -@relabel_edge_ids.register(StateTransition) -def _(obj: StateTransition) -> StateTransition: # type: ignore[misc] + +def __relabel_st(obj: StateTransition) -> StateTransition: # type: ignore[misc] mapping = __get_default_relabel_mapping() return attrs.evolve( obj, @@ -147,6 +167,14 @@ def _(obj: StateTransition) -> StateTransition: # type: ignore[misc] ) +if get_qrules_version() < (0, 10): + relabel_edge_ids.register(StateTransition)(__relabel_st) +else: + from qrules.topology import FrozenTransition + + relabel_edge_ids.register(FrozenTransition)(__relabel_st) + + @relabel_edge_ids.register(Topology) def _(obj: Topology) -> Topology: # type: ignore[misc] mapping = __get_default_relabel_mapping() diff --git a/src/ampform/helicity/decay.py b/src/ampform/helicity/decay.py index 9b10bfd33..f65e8cfdc 100644 --- a/src/ampform/helicity/decay.py +++ b/src/ampform/helicity/decay.py @@ -1,4 +1,4 @@ -"""Extract two-body decay info from a `~qrules.transition.StateTransition`.""" +"""Extract two-body decay info from a `~qrules.topology.Transition`.""" from __future__ import annotations @@ -8,16 +8,22 @@ from typing import TYPE_CHECKING, Iterable from attrs import frozen +from qrules.quantum_numbers import InteractionProperties from qrules.transition import ReactionInfo, State, StateTransition +from ampform._qrules import get_qrules_version + if TYPE_CHECKING: - from qrules.quantum_numbers import InteractionProperties from qrules.topology import Topology if sys.version_info < (3, 8): from typing_extensions import Literal else: from typing import Literal +if sys.version_info < (3, 10): + from typing_extensions import TypeGuard +else: + from typing import TypeGuard @frozen @@ -38,11 +44,11 @@ def from_transition(cls, transition: StateTransition, state_id: int) -> StateWit @frozen class TwoBodyDecay: - """Two-body sub-decay in a `~qrules.transition.StateTransition`. + """Two-body sub-decay in a `~qrules.topology.Transition`. This container class ensures that: - 1. a selected node in a `~qrules.transition.StateTransition` is indeed a 1-to-2 body + 1. a selected node in a `~qrules.topology.Transition` is indeed a 1-to-2 body decay 2. its two `.children` are sorted by whether they decay further or not (see @@ -104,12 +110,30 @@ def _(obj: TwoBodyDecay) -> TwoBodyDecay: def _(obj: tuple) -> TwoBodyDecay: if len(obj) == 2: # noqa: PLR2004 transition, node_id = obj - if isinstance(transition, StateTransition) and isinstance(node_id, int): - return TwoBodyDecay.from_transition(*obj) + if _is_qrules_state_transition(transition) and isinstance(node_id, int): + return TwoBodyDecay.from_transition(transition, node_id) msg = f"Cannot create a {TwoBodyDecay.__name__} from {obj}" raise NotImplementedError(msg) +def _is_qrules_state_transition(obj) -> TypeGuard[StateTransition]: + if get_qrules_version() >= (0, 10): + from qrules.topology import FrozenTransition # noqa: PLC0415 + + if isinstance(obj, FrozenTransition): + if any(not isinstance(s, State) for s in obj.states.values()): + return False + if any( + not isinstance(i, InteractionProperties) + for i in obj.interactions.values() + ): + return False + return True + if get_qrules_version() < (0, 10) and isinstance(obj, StateTransition): # type: ignore[misc] + return True + return False + + @lru_cache(maxsize=None) def is_opposite_helicity_state(topology: Topology, state_id: int) -> bool: """Determine if an edge is an "opposite helicity" state. @@ -327,8 +351,12 @@ def determine_attached_final_state(topology: Topology, state_id: int) -> list[in >>> from qrules.topology import create_isobar_topologies >>> topologies = create_isobar_topologies(5) - >>> determine_attached_final_state(topologies[0], state_id=5) + >>> determine_attached_final_state(topologies[3], state_id=5) [0, 3, 4] + >>> import pytest + >>> from ampform._qrules import get_qrules_version + >>> if get_qrules_version() < (0, 10): + ... pytest.skip("Doctest only works for qrules>=0.10") """ edge = topology.edges[state_id] if edge.ending_node_id is None: @@ -342,13 +370,20 @@ def get_outer_state_ids(obj: ReactionInfo | StateTransition) -> list[int]: raise NotImplementedError(msg) -@get_outer_state_ids.register(StateTransition) -def _(transition: StateTransition) -> list[int]: +def __convert_state_transition(transition: StateTransition) -> list[int]: outer_state_ids = list(transition.initial_states) outer_state_ids += sorted(transition.final_states) return outer_state_ids +if get_qrules_version() < (0, 10): + get_outer_state_ids.register(StateTransition)(__convert_state_transition) +else: + from qrules.topology import FrozenTransition + + get_outer_state_ids.register(FrozenTransition)(__convert_state_transition) + + @get_outer_state_ids.register(ReactionInfo) def _(reaction: ReactionInfo) -> list[int]: return get_outer_state_ids(reaction.transitions[0]) @@ -372,7 +407,7 @@ def group_by_spin_projection( ) -> list[list[StateTransition]]: """Match final and initial states in groups. - Each `~qrules.transition.StateTransition` corresponds to a specific state transition + Each `~qrules.topology.Transition` corresponds to a specific state transition amplitude. This function groups together transitions, which have the same initial and final state (including spin). This is needed to determine the coherency of the individual amplitude parts. diff --git a/src/ampform/helicity/naming.py b/src/ampform/helicity/naming.py index 3c9ab1fc4..f1f78de4e 100644 --- a/src/ampform/helicity/naming.py +++ b/src/ampform/helicity/naming.py @@ -41,10 +41,10 @@ def generate_amplitude_name( ) -> str: """Generates a unique name for the amplitude corresponding. - That is, corresponging to the given `~qrules.transition.StateTransition`. If + That is, corresponging to the given `~qrules.topology.Transition`. If ``node_id`` is given, it generates a unique name for the partial amplitude corresponding to the interaction node of the given - `~qrules.transition.StateTransition`. + `~qrules.topology.Transition`. """ @abstractmethod @@ -360,8 +360,9 @@ def get_boost_chain_suffix(topology: Topology, state_id: int) -> str: the internal decay topology. >>> from qrules.topology import create_isobar_topologies + >>> from ampform._qrules import get_qrules_version >>> topologies = create_isobar_topologies(5) - >>> topology = topologies[0] + >>> topology = topologies[0 if get_qrules_version() < (0, 10) else 3] >>> for i in topology.intermediate_edge_ids | topology.outgoing_edge_ids: ... suffix = get_boost_chain_suffix(topology, i) ... print(f"{i}: 'phi{suffix}'") @@ -373,7 +374,7 @@ def get_boost_chain_suffix(topology: Topology, state_id: int) -> str: 5: 'phi_034' 6: 'phi_12' 7: 'phi_34^034' - >>> topology = topologies[1] + >>> topology = topologies[1 if get_qrules_version() < (0, 10) else 2] >>> for i in topology.intermediate_edge_ids | topology.outgoing_edge_ids: ... suffix = get_boost_chain_suffix(topology, i) ... print(f"{i}: 'phi{suffix}'") diff --git a/src/ampform/kinematics/__init__.py b/src/ampform/kinematics/__init__.py index b16089d44..c36efbada 100644 --- a/src/ampform/kinematics/__init__.py +++ b/src/ampform/kinematics/__init__.py @@ -17,6 +17,7 @@ from qrules.topology import Topology from qrules.transition import ReactionInfo, StateTransition +from ampform._qrules import get_qrules_version from ampform.helicity.decay import assert_isobar_topology from ampform.kinematics.angles import compute_helicity_angles from ampform.kinematics.lorentz import ( @@ -121,6 +122,13 @@ def _(obj: Topology) -> Topology: return obj -@_get_topology.register(StateTransition) -def _(obj: StateTransition) -> Topology: +def __get_state_transition(obj: StateTransition) -> Topology: return obj.topology + + +if get_qrules_version() < (0, 10): + _get_topology.register(StateTransition)(__get_state_transition) +else: + from qrules.topology import FrozenTransition + + _get_topology.register(FrozenTransition)(__get_state_transition) diff --git a/src/ampform/kinematics/lorentz.py b/src/ampform/kinematics/lorentz.py index a03a4da1a..6d4a14deb 100644 --- a/src/ampform/kinematics/lorentz.py +++ b/src/ampform/kinematics/lorentz.py @@ -601,8 +601,10 @@ def get_invariant_mass_symbol(topology: Topology, state_id: int) -> sp.Symbol: state :math:`5` is :math:`m_{034}`, because :math:`p_5=p_0+p_3+p_4`: >>> from qrules.topology import create_isobar_topologies + >>> from ampform._qrules import get_qrules_version >>> topologies = create_isobar_topologies(5) - >>> get_invariant_mass_symbol(topologies[0], state_id=5) + >>> topology = topologies[0 if get_qrules_version() < (0, 10) else 3] + >>> get_invariant_mass_symbol(topology, state_id=5) m_034 Naturally, the 'invariant' mass label for a final state is just the mass of the diff --git a/tests/helicity/test_decay.py b/tests/helicity/test_decay.py index 857a4fd4c..4a06b3b4d 100644 --- a/tests/helicity/test_decay.py +++ b/tests/helicity/test_decay.py @@ -6,6 +6,7 @@ import pytest from qrules.topology import Topology, create_isobar_topologies +from ampform._qrules import get_qrules_version from ampform.helicity.decay import ( determine_attached_final_state, get_sibling_state_id, @@ -24,10 +25,10 @@ def test_determine_attached_final_state(): topology.outgoing_edge_ids ) # intermediate states - topology = topologies[0] + topology = topologies[0 if get_qrules_version() < (0, 10) else 1] assert determine_attached_final_state(topology, state_id=4) == [0, 1] assert determine_attached_final_state(topology, state_id=5) == [2, 3] - topology = topologies[1] + topology = topologies[1 if get_qrules_version() < (0, 10) else 0] assert determine_attached_final_state(topology, state_id=4) == [1, 2, 3] assert determine_attached_final_state(topology, state_id=5) == [2, 3] diff --git a/tests/kinematics/conftest.py b/tests/kinematics/conftest.py index 0db56a3eb..505b80382 100644 --- a/tests/kinematics/conftest.py +++ b/tests/kinematics/conftest.py @@ -5,6 +5,7 @@ import pytest from qrules.topology import Topology, create_isobar_topologies +from ampform._qrules import get_qrules_version from ampform.kinematics.lorentz import FourMomenta, create_four_momentum_symbols if TYPE_CHECKING: @@ -18,6 +19,6 @@ def topology_and_momentum_symbols( n = len(data_sample) assert n == 4 topologies = create_isobar_topologies(n) - topology = topologies[1] + topology = topologies[1 if get_qrules_version() < (0, 10) else 0] momentum_symbols = create_four_momentum_symbols(topology) return topology, momentum_symbols