From c4cd302b188e4643afaa911482dec5ef5bee6910 Mon Sep 17 00:00:00 2001 From: Remco de Boer Date: Fri, 18 Jun 2021 13:43:58 +0200 Subject: [PATCH] feat: implement Particle ordering (#72) * feat: implement Particle.__gt__ * feat: implement Particle.name_root * feat: implement Spin.__gt__ * feat: implement total ordering for Particle * fix: allow comparison between Parity and None * refactor: overwrite Particle.__eq__ * refactor: sort ParticleCollection.names by Particle ordering --- .cspell.json | 1 + src/qrules/particle.py | 40 ++++++++- src/qrules/quantum_numbers.py | 2 + tests/channels/test_d0_to_ks_kp_km.py | 6 +- tests/channels/test_jpsi_to_gamma_pi0_pi0.py | 7 +- tests/unit/test_particle.py | 88 +++++++++++++++++++- tests/unit/test_quantum_numbers.py | 1 + tests/unit/test_solving.py | 2 +- 8 files changed, 134 insertions(+), 13 deletions(-) diff --git a/.cspell.json b/.cspell.json index 60ca305e..4f6a23e5 100644 --- a/.cspell.json +++ b/.cspell.json @@ -111,6 +111,7 @@ "arange", "asdict", "asdot", + "astuple", "cano", "celltoolbar", "codacy", diff --git a/src/qrules/particle.py b/src/qrules/particle.py index 6b6a04f2..8ab68755 100644 --- a/src/qrules/particle.py +++ b/src/qrules/particle.py @@ -14,6 +14,7 @@ import re from collections import abc from difflib import get_close_matches +from functools import total_ordering from math import copysign from typing import ( Any, @@ -21,8 +22,8 @@ Dict, Iterable, Iterator, + List, Optional, - Set, SupportsFloat, Tuple, Union, @@ -50,6 +51,7 @@ def _to_float(value: SupportsFloat) -> float: return float_value +@total_ordering @attr.s(frozen=True, eq=False, hash=True) class Spin: """Safe, immutable data container for spin **with projection**.""" @@ -89,6 +91,11 @@ def __eq__(self, other: object) -> bool: def __float__(self) -> float: return self.magnitude + def __gt__(self, other: Any) -> bool: + if isinstance(other, Spin): + return attr.astuple(self) > attr.astuple(other) + return self.magnitude > other + def __neg__(self) -> "Spin": return Spin(self.magnitude, -self.projection) @@ -112,7 +119,8 @@ def _to_spin(value: Union[Spin, Tuple[float, float]]) -> Spin: return value -@attr.s(frozen=True, repr=True, kw_only=True) +@total_ordering +@attr.s(frozen=True, order=False, repr=True, kw_only=True) class Particle: # pylint: disable=too-many-instance-attributes """Immutable container of data defining a physical particle. @@ -194,6 +202,30 @@ def __attrs_post_init__(self) -> None: ")" ) + @property + def name_root(self) -> str: + name_root = self.name + name_root = re.sub(r"\(.+\)", "", name_root) + name_root = re.sub(r"[\*\+\-~\d']", "", name_root) + return name_root + + def __gt__(self, other: Any) -> bool: + if isinstance(other, Particle): + + def sorting_key(particle: Particle) -> tuple: + name_root = particle.name_root + return ( + name_root[0].lower(), + name_root, + particle.mass, + particle.charge, + ) + + return sorting_key(self) > sorting_key(other) + raise NotImplementedError( + f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) + def __neg__(self) -> "Particle": return create_antiparticle(self) @@ -386,8 +418,8 @@ def update(self, other: Iterable[Particle]) -> None: self.add(particle) @property - def names(self) -> Set[str]: - return set(self.__particles) + def names(self) -> List[str]: + return [p.name for p in sorted(self)] def create_particle( # pylint: disable=too-many-arguments,too-many-locals diff --git a/src/qrules/quantum_numbers.py b/src/qrules/quantum_numbers.py index 259909a4..29eddaca 100644 --- a/src/qrules/quantum_numbers.py +++ b/src/qrules/quantum_numbers.py @@ -34,6 +34,8 @@ def __eq__(self, other: object) -> bool: return self.value == other def __gt__(self, other: Any) -> bool: + if other is None: + return True return self.value > int(other) def __int__(self) -> int: diff --git a/tests/channels/test_d0_to_ks_kp_km.py b/tests/channels/test_d0_to_ks_kp_km.py index 10f5e2bf..10e133b2 100644 --- a/tests/channels/test_d0_to_ks_kp_km.py +++ b/tests/channels/test_d0_to_ks_kp_km.py @@ -13,10 +13,10 @@ def test_script(): number_of_threads=1, ) assert len(result.transitions) == 5 - assert result.get_intermediate_particles().names == { - "a(0)(980)+", + assert result.get_intermediate_particles().names == [ "a(0)(980)-", "a(0)(980)0", + "a(0)(980)+", "a(2)(1320)-", "phi(1020)", - } + ] diff --git a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py index 6217f6d8..c4d7fbc2 100644 --- a/tests/channels/test_jpsi_to_gamma_pi0_pi0.py +++ b/tests/channels/test_jpsi_to_gamma_pi0_pi0.py @@ -14,8 +14,8 @@ ( [ "f(0)(980)", - "f(0)(1500)", "f(2)(1270)", + "f(0)(1500)", "f(2)(1950)", "omega(782)", ], @@ -37,8 +37,9 @@ def test_number_of_solutions( formalism="helicity", ) assert len(result.transitions) == number_of_solutions - assert result.get_intermediate_particles().names == set( - allowed_intermediate_particles + assert ( + result.get_intermediate_particles().names + == allowed_intermediate_particles ) diff --git a/tests/unit/test_particle.py b/tests/unit/test_particle.py index f0d12943..66736ab9 100644 --- a/tests/unit/test_particle.py +++ b/tests/unit/test_particle.py @@ -98,11 +98,84 @@ def test_eq(self): assert particle.name != different_labels.name assert particle.pid != different_labels.pid + @pytest.mark.parametrize( + ("name1", "name2"), + [ + # by name + ("pi0", "a(0)(980)-"), + # by mass + ("pi+", "pi-"), + ("pi-", "pi0"), + ("pi+", "pi0"), + ("K0", "K+"), + # by charge + ("a(0)(980)+", "a(0)(980)-"), + ("a(0)(980)+", "a(0)(980)0"), + ("a(0)(980)0", "a(0)(980)-"), + ], + ) + def test_gt(self, name1, name2, particle_database: ParticleCollection): + pdg = particle_database + assert pdg[name1] > pdg[name2] + + def test_name_root(self, particle_database: ParticleCollection): + name_roots = {p.name_root for p in particle_database} + assert name_roots == { + "a", + "B", + "b", + "chi", + "D", + "Delta", + "e", + "eta", + "f", + "g", + "gamma", + "h", + "J/psi", + "K", + "Lambda", + "mu", + "N", + "n", + "nu", + "Omega", + "omega", + "p", + "phi", + "pi", + "psi", + "rho", + "Sigma", + "tau", + "Upsilon", + "W", + "Xi", + "Y", + "Z", + } + def test_neg(self, particle_database: ParticleCollection): pip = particle_database.find(211) pim = particle_database.find(-211) assert pip == -pim + def test_total_ordering(self, particle_database: ParticleCollection): + pdg = particle_database + assert [ + particle.name + for particle in sorted( + pdg.filter(lambda p: p.name.startswith("f(0)")) + ) + ] == [ + "f(0)(500)", + "f(0)(980)", + "f(0)(1370)", + "f(0)(1500)", + "f(0)(1710)", + ] + class TestParticleCollection: def test_init(self, particle_database: ParticleCollection): @@ -200,10 +273,10 @@ def test_filter(self, particle_database: ParticleCollection): and p.spin == 2 and p.strangeness == 1 ) - assert filtered_result.names == { + assert filtered_result.names == [ "K(2)(1820)0", "K(2)(1820)+", - } + ] def test_find(self, particle_database: ParticleCollection): f2_1950 = particle_database.find(9050225) @@ -280,6 +353,17 @@ def test_hash(self): spin2, } + @pytest.mark.parametrize( + ("spin1", "spin2"), + [ + (Spin(1, 0), Spin(0, 0)), + (Spin(1, 1), Spin(1, 0)), + (Spin(1, +1), Spin(1, -1)), + ], + ) + def test_gt(self, spin1: Spin, spin2: Spin): + assert spin1 > spin2 + def test_neg(self): isospin = Spin(1.5, -0.5) flipped_spin = -isospin diff --git a/tests/unit/test_quantum_numbers.py b/tests/unit/test_quantum_numbers.py index f6f58001..8a896cd5 100644 --- a/tests/unit/test_quantum_numbers.py +++ b/tests/unit/test_quantum_numbers.py @@ -12,6 +12,7 @@ def test_init_and_eq(self): parity = Parity(+1) assert parity == +1 assert int(parity) == +1 + assert parity > None @typing.no_type_check # https://github.com/python/mypy/issues/4610 def test_comparison(self): diff --git a/tests/unit/test_solving.py b/tests/unit/test_solving.py index 59bc098e..4ec05c8c 100644 --- a/tests/unit/test_solving.py +++ b/tests/unit/test_solving.py @@ -5,4 +5,4 @@ class TestResult: def test_get_intermediate_state_names(self, result: Result): intermediate_particles = result.get_intermediate_particles() - assert intermediate_particles.names == {"f(0)(1500)", "f(0)(980)"} + assert intermediate_particles.names == ["f(0)(980)", "f(0)(1500)"]