diff --git a/src/qrules/conservation_rules.py b/src/qrules/conservation_rules.py index bb1cc356..b264f265 100644 --- a/src/qrules/conservation_rules.py +++ b/src/qrules/conservation_rules.py @@ -691,8 +691,8 @@ def spin_conservation( def spin_magnitude_conservation( - ingoing_spins: List[SpinEdgeInput], - outgoing_spins: List[SpinEdgeInput], + ingoing_spin_magnitudes: List[EdgeQN.spin_magnitude], + outgoing_spin_magnitudes: List[EdgeQN.spin_magnitude], interaction_qns: SpinMagnitudeNodeInput, ) -> bool: r"""Check for spin conservation. @@ -710,20 +710,20 @@ def spin_magnitude_conservation( # L and S can only be used if one side is a single state # and the other side contains of two states (isobar) # So do a full check if this is the case - if (len(ingoing_spins) == 1 and len(outgoing_spins) == 2) or ( - len(ingoing_spins) == 2 and len(outgoing_spins) == 1 + if (len(ingoing_spin_magnitudes) == 1 and len(outgoing_spin_magnitudes) == 2) or ( + len(ingoing_spin_magnitudes) == 2 and len(outgoing_spin_magnitudes) == 1 ): return _check_magnitude( - [x.spin_magnitude for x in ingoing_spins], - [x.spin_magnitude for x in outgoing_spins], + [float(x) for x in ingoing_spin_magnitudes], + [float(x) for x in outgoing_spin_magnitudes], interaction_qns, ) # otherwise don't use S and L and just check magnitude # are integral or non integral on both sides return ( - sum(float(x.spin_magnitude) for x in ingoing_spins).is_integer() # type: ignore[union-attr] - == sum(float(x.spin_magnitude) for x in outgoing_spins).is_integer() # type: ignore[union-attr] + sum(float(x) for x in ingoing_spin_magnitudes).is_integer() # type: ignore[union-attr] + == sum(float(x) for x in outgoing_spin_magnitudes).is_integer() # type: ignore[union-attr] ) diff --git a/tests/unit/conservation_rules/test_spin.py b/tests/unit/conservation_rules/test_spin.py index 3f370d3d..358ac1ad 100644 --- a/tests/unit/conservation_rules/test_spin.py +++ b/tests/unit/conservation_rules/test_spin.py @@ -11,8 +11,13 @@ spin_magnitude_conservation, ) from qrules.particle import Spin +from qrules.quantum_numbers import EdgeQuantumNumbers -_SpinRuleInputType = Tuple[List[SpinEdgeInput], List[SpinEdgeInput], SpinNodeInput] +_SpinRuleInputType = Tuple[ + List[EdgeQuantumNumbers.spin_magnitude], + List[EdgeQuantumNumbers.spin_magnitude], + SpinNodeInput, +] def __create_two_body_decay_spin_data(