Skip to content

Commit

Permalink
feat: implement Particle ordering (#72)
Browse files Browse the repository at this point in the history
* feat: implement Particle.__gt__
* feat: implement Particle.name_root
* feat: implement Spin.__gt__
* feat: implement total ordering for Particle
* fix: allow comparison between Parity and None
* refactor: overwrite Particle.__eq__
* refactor: sort ParticleCollection.names by Particle ordering
  • Loading branch information
redeboer authored Jun 18, 2021
1 parent de71c75 commit c4cd302
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 13 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
"arange",
"asdict",
"asdot",
"astuple",
"cano",
"celltoolbar",
"codacy",
Expand Down
40 changes: 36 additions & 4 deletions src/qrules/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
import re
from collections import abc
from difflib import get_close_matches
from functools import total_ordering
from math import copysign
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
SupportsFloat,
Tuple,
Union,
Expand Down Expand Up @@ -50,6 +51,7 @@ def _to_float(value: SupportsFloat) -> float:
return float_value


@total_ordering
@attr.s(frozen=True, eq=False, hash=True)
class Spin:
"""Safe, immutable data container for spin **with projection**."""
Expand Down Expand Up @@ -89,6 +91,11 @@ def __eq__(self, other: object) -> bool:
def __float__(self) -> float:
return self.magnitude

def __gt__(self, other: Any) -> bool:
if isinstance(other, Spin):
return attr.astuple(self) > attr.astuple(other)
return self.magnitude > other

def __neg__(self) -> "Spin":
return Spin(self.magnitude, -self.projection)

Expand All @@ -112,7 +119,8 @@ def _to_spin(value: Union[Spin, Tuple[float, float]]) -> Spin:
return value


@attr.s(frozen=True, repr=True, kw_only=True)
@total_ordering
@attr.s(frozen=True, order=False, repr=True, kw_only=True)
class Particle: # pylint: disable=too-many-instance-attributes
"""Immutable container of data defining a physical particle.
Expand Down Expand Up @@ -194,6 +202,30 @@ def __attrs_post_init__(self) -> None:
")"
)

@property
def name_root(self) -> str:
name_root = self.name
name_root = re.sub(r"\(.+\)", "", name_root)
name_root = re.sub(r"[\*\+\-~\d']", "", name_root)
return name_root

def __gt__(self, other: Any) -> bool:
if isinstance(other, Particle):

def sorting_key(particle: Particle) -> tuple:
name_root = particle.name_root
return (
name_root[0].lower(),
name_root,
particle.mass,
particle.charge,
)

return sorting_key(self) > sorting_key(other)
raise NotImplementedError(
f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}"
)

def __neg__(self) -> "Particle":
return create_antiparticle(self)

Expand Down Expand Up @@ -386,8 +418,8 @@ def update(self, other: Iterable[Particle]) -> None:
self.add(particle)

@property
def names(self) -> Set[str]:
return set(self.__particles)
def names(self) -> List[str]:
return [p.name for p in sorted(self)]


def create_particle( # pylint: disable=too-many-arguments,too-many-locals
Expand Down
2 changes: 2 additions & 0 deletions src/qrules/quantum_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def __eq__(self, other: object) -> bool:
return self.value == other

def __gt__(self, other: Any) -> bool:
if other is None:
return True
return self.value > int(other)

def __int__(self) -> int:
Expand Down
6 changes: 3 additions & 3 deletions tests/channels/test_d0_to_ks_kp_km.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def test_script():
number_of_threads=1,
)
assert len(result.transitions) == 5
assert result.get_intermediate_particles().names == {
"a(0)(980)+",
assert result.get_intermediate_particles().names == [
"a(0)(980)-",
"a(0)(980)0",
"a(0)(980)+",
"a(2)(1320)-",
"phi(1020)",
}
]
7 changes: 4 additions & 3 deletions tests/channels/test_jpsi_to_gamma_pi0_pi0.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
(
[
"f(0)(980)",
"f(0)(1500)",
"f(2)(1270)",
"f(0)(1500)",
"f(2)(1950)",
"omega(782)",
],
Expand All @@ -37,8 +37,9 @@ def test_number_of_solutions(
formalism="helicity",
)
assert len(result.transitions) == number_of_solutions
assert result.get_intermediate_particles().names == set(
allowed_intermediate_particles
assert (
result.get_intermediate_particles().names
== allowed_intermediate_particles
)


Expand Down
88 changes: 86 additions & 2 deletions tests/unit/test_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,84 @@ def test_eq(self):
assert particle.name != different_labels.name
assert particle.pid != different_labels.pid

@pytest.mark.parametrize(
("name1", "name2"),
[
# by name
("pi0", "a(0)(980)-"),
# by mass
("pi+", "pi-"),
("pi-", "pi0"),
("pi+", "pi0"),
("K0", "K+"),
# by charge
("a(0)(980)+", "a(0)(980)-"),
("a(0)(980)+", "a(0)(980)0"),
("a(0)(980)0", "a(0)(980)-"),
],
)
def test_gt(self, name1, name2, particle_database: ParticleCollection):
pdg = particle_database
assert pdg[name1] > pdg[name2]

def test_name_root(self, particle_database: ParticleCollection):
name_roots = {p.name_root for p in particle_database}
assert name_roots == {
"a",
"B",
"b",
"chi",
"D",
"Delta",
"e",
"eta",
"f",
"g",
"gamma",
"h",
"J/psi",
"K",
"Lambda",
"mu",
"N",
"n",
"nu",
"Omega",
"omega",
"p",
"phi",
"pi",
"psi",
"rho",
"Sigma",
"tau",
"Upsilon",
"W",
"Xi",
"Y",
"Z",
}

def test_neg(self, particle_database: ParticleCollection):
pip = particle_database.find(211)
pim = particle_database.find(-211)
assert pip == -pim

def test_total_ordering(self, particle_database: ParticleCollection):
pdg = particle_database
assert [
particle.name
for particle in sorted(
pdg.filter(lambda p: p.name.startswith("f(0)"))
)
] == [
"f(0)(500)",
"f(0)(980)",
"f(0)(1370)",
"f(0)(1500)",
"f(0)(1710)",
]


class TestParticleCollection:
def test_init(self, particle_database: ParticleCollection):
Expand Down Expand Up @@ -200,10 +273,10 @@ def test_filter(self, particle_database: ParticleCollection):
and p.spin == 2
and p.strangeness == 1
)
assert filtered_result.names == {
assert filtered_result.names == [
"K(2)(1820)0",
"K(2)(1820)+",
}
]

def test_find(self, particle_database: ParticleCollection):
f2_1950 = particle_database.find(9050225)
Expand Down Expand Up @@ -280,6 +353,17 @@ def test_hash(self):
spin2,
}

@pytest.mark.parametrize(
("spin1", "spin2"),
[
(Spin(1, 0), Spin(0, 0)),
(Spin(1, 1), Spin(1, 0)),
(Spin(1, +1), Spin(1, -1)),
],
)
def test_gt(self, spin1: Spin, spin2: Spin):
assert spin1 > spin2

def test_neg(self):
isospin = Spin(1.5, -0.5)
flipped_spin = -isospin
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_quantum_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def test_init_and_eq(self):
parity = Parity(+1)
assert parity == +1
assert int(parity) == +1
assert parity > None

@typing.no_type_check # https://github.com/python/mypy/issues/4610
def test_comparison(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
class TestResult:
def test_get_intermediate_state_names(self, result: Result):
intermediate_particles = result.get_intermediate_particles()
assert intermediate_particles.names == {"f(0)(1500)", "f(0)(980)"}
assert intermediate_particles.names == ["f(0)(980)", "f(0)(1500)"]

0 comments on commit c4cd302

Please sign in to comment.