Skip to content

Commit

Permalink
test: parametrize fixtures (#86)
Browse files Browse the repository at this point in the history
* test: add type hint pytest Config
* test: improve use of Input class
* test: improve qrules imports
* test: parametrize fixture for HelicityModel
* test: parametrize fixture for ReactionInfo
* test: rename count_parameters to n_parameters
  • Loading branch information
redeboer authored Jun 21, 2021
1 parent 0e0d3be commit d28271f
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 104 deletions.
44 changes: 12 additions & 32 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# pylint: disable=redefined-outer-name
import logging
from typing import Dict
from typing import Dict, Tuple

import numpy as np
import pytest
import qrules
from _pytest.config import Config
from _pytest.fixtures import SubRequest
from qrules import ParticleCollection, ReactionInfo, load_default_particles

from ampform import get_builder
Expand All @@ -21,53 +23,31 @@ def particle_database() -> ParticleCollection:


@pytest.fixture(scope="session")
def output_dir(pytestconfig) -> str:
def output_dir(pytestconfig: Config) -> str:
return f"{pytestconfig.rootpath}/tests/output/"


@pytest.fixture(scope="session")
def jpsi_to_gamma_pi_pi_canonical_solutions() -> ReactionInfo:
return qrules.generate_transitions(
initial_state=[("J/psi(1S)", [-1, 1])],
final_state=["gamma", "pi0", "pi0"],
allowed_intermediate_particles=["f(0)(980)", "f(0)(1500)"],
allowed_interaction_types="strong only",
formalism="canonical-helicity",
)


@pytest.fixture(scope="session")
def jpsi_to_gamma_pi_pi_helicity_solutions() -> ReactionInfo:
@pytest.fixture(scope="session", params=["canonical-helicity", "helicity"])
def reaction(request: SubRequest) -> ReactionInfo:
formalism: str = request.param
return qrules.generate_transitions(
initial_state=[("J/psi(1S)", [-1, 1])],
final_state=["gamma", "pi0", "pi0"],
allowed_intermediate_particles=["f(0)(980)", "f(0)(1500)"],
allowed_interaction_types="strong only",
formalism="helicity",
allowed_interaction_types="strong",
formalism=formalism,
)


@pytest.fixture(scope="session")
def jpsi_to_gamma_pi_pi_canonical_amplitude_model(
jpsi_to_gamma_pi_pi_canonical_solutions: ReactionInfo,
) -> HelicityModel:
return __create_model(jpsi_to_gamma_pi_pi_canonical_solutions)


@pytest.fixture(scope="session")
def jpsi_to_gamma_pi_pi_helicity_amplitude_model(
jpsi_to_gamma_pi_pi_helicity_solutions: ReactionInfo,
) -> HelicityModel:
return __create_model(jpsi_to_gamma_pi_pi_helicity_solutions)


def __create_model(reaction: ReactionInfo) -> HelicityModel:
def amplitude_model(reaction: ReactionInfo) -> Tuple[str, HelicityModel]:
model_builder = get_builder(reaction)
for name in reaction.get_intermediate_particles().names:
model_builder.set_dynamics(
name, create_relativistic_breit_wigner_with_ff
)
return model_builder.generate()
model = model_builder.generate()
return reaction.formalism, model


# https://github.com/ComPWA/tensorwaves/blob/3d0ec44/tests/physics/helicity_formalism/test_helicity_angles.py#L61-L98
Expand Down
7 changes: 4 additions & 3 deletions tests/test_angular_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import qrules
import sympy as sp
from qrules import ParticleCollection
from qrules.particle import Particle

from ampform import get_builder

Expand Down Expand Up @@ -57,13 +58,13 @@ def normalize(
class TestEpemToDmD0Pip:
@pytest.fixture(scope="class")
def sympy_model(self, particle_database: ParticleCollection) -> sp.Expr:
epem = qrules.particle.Particle(
epem = Particle(
name="EpEm",
pid=12345678,
mass=4.36,
spin=1.0,
parity=qrules.particle.Parity(-1),
c_parity=qrules.particle.Parity(-1),
parity=-1,
c_parity=-1,
)
particles = ParticleCollection(particle_database)
particles.add(epem)
Expand Down
36 changes: 13 additions & 23 deletions tests/test_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,27 @@
# pylint: disable=no-self-use, too-many-arguments
from typing import Tuple

import numpy as np
import pytest
import qrules
import sympy as sp
from qrules import ParticleCollection
from sympy import preorder_traversal

from ampform.dynamics import ComplexSqrt
from ampform.helicity import HelicityModel


@pytest.mark.parametrize(
("formalism", "n_amplitudes", "n_parameters"),
[
("canonical", 16, 10),
("helicity", 8, 8),
],
)
def test_generate(
formalism: str,
n_amplitudes: int,
n_parameters: int,
jpsi_to_gamma_pi_pi_canonical_amplitude_model: HelicityModel,
jpsi_to_gamma_pi_pi_helicity_amplitude_model: HelicityModel,
particle_database: qrules.ParticleCollection,
amplitude_model: Tuple[str, HelicityModel],
particle_database: ParticleCollection,
):
if formalism == "canonical":
model = jpsi_to_gamma_pi_pi_canonical_amplitude_model
elif formalism == "helicity":
model = jpsi_to_gamma_pi_pi_helicity_amplitude_model
formalism, model = amplitude_model
if formalism == "canonical-helicity":
n_amplitudes = 16
n_parameters = 10
else:
raise NotImplementedError
n_amplitudes = 8
n_parameters = 8
assert len(model.parameter_defaults) == n_parameters
assert len(model.components) == 4 + n_amplitudes
assert len(model.expression.free_symbols) == 7 + n_parameters
Expand Down Expand Up @@ -81,14 +73,12 @@ def test_generate(
expression = sp.piecewise_fold(expression)
assert isinstance(expression, sp.Add)
a1, a2 = tuple(map(str, expression.args))
if formalism == "canonical":
if formalism == "canonical-helicity":
assert a1 == "0.08/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)"
assert a2 == "0.23/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)"
elif formalism == "helicity":
else:
assert a1 == "0.17/(-m**2 - 0.17*I*sqrt(m**2 - 0.07)/Abs(m) + 2.27)"
assert a2 == "0.06/(-m**2 - 0.06*I*sqrt(m**2 - 0.07)/Abs(m) + 0.98)"
else:
raise NotImplementedError


def round_nested(expression: sp.Expr, n_decimals: int) -> sp.Expr:
Expand Down
34 changes: 10 additions & 24 deletions tests/test_helicity.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,18 @@
import pytest
import sympy as sp
from qrules import ReactionInfo
from sympy import cos, sin, sqrt

from ampform import get_builder


@pytest.mark.parametrize(
("formalism", "n_amplitudes", "n_parameters"),
[
("canonical", 16, 4),
("helicity", 8, 2),
],
)
def test_generate(
formalism: str,
n_amplitudes: int,
n_parameters: int,
jpsi_to_gamma_pi_pi_canonical_solutions: ReactionInfo,
jpsi_to_gamma_pi_pi_helicity_solutions: ReactionInfo,
):
if formalism == "canonical":
reaction = jpsi_to_gamma_pi_pi_canonical_solutions
elif formalism == "helicity":
reaction = jpsi_to_gamma_pi_pi_helicity_solutions
def test_generate(reaction: ReactionInfo):
if reaction.formalism == "canonical-helicity":
n_amplitudes = 16
n_parameters = 4
else:
raise NotImplementedError
n_amplitudes = 8
n_parameters = 2

model = get_builder(reaction).generate()
assert len(model.parameter_defaults) == n_parameters
assert len(model.components) == 4 + n_amplitudes
Expand All @@ -39,15 +26,14 @@ def test_generate(
theta = sp.Symbol("theta", real=True)
no_dynamics = no_dynamics.subs({existing_theta: theta})
no_dynamics = no_dynamics.trigsimp()
if formalism == "canonical":

if reaction.formalism == "canonical-helicity":
assert (
no_dynamics
== 0.8 * sqrt(10) * cos(theta) ** 2
+ 4.4 * cos(theta) ** 2
+ 0.8 * sqrt(10)
+ 4.4
)
elif formalism == "helicity":
assert no_dynamics == 8.0 - 4.0 * sin(theta) ** 2
else:
raise NotImplementedError
assert no_dynamics == 8.0 - 4.0 * sin(theta) ** 2
41 changes: 19 additions & 22 deletions tests/test_parity_prefactor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import NamedTuple

import pytest
import qrules
from qrules import StateTransitionManager

from ampform import get_builder

Expand All @@ -14,44 +14,41 @@ class Input(NamedTuple):


@pytest.mark.parametrize(
("test_input", "parameter_count"),
("test_input", "n_parameters"),
[
(
Input(
[("Lambda(c)+", [0.5])],
["p", "K-", "pi+"],
["Lambda(1405)"],
[],
initial_state=[("Lambda(c)+", [0.5])],
final_state=["p", "K-", "pi+"],
intermediate_states=["Lambda(1405)"],
final_state_grouping=[],
),
2,
),
(
Input(
[("Lambda(c)+", [0.5])],
["p", "K-", "pi+"],
["Delta(1232)++"],
[],
initial_state=[("Lambda(c)+", [0.5])],
final_state=["p", "K-", "pi+"],
intermediate_states=["Delta(1232)++"],
final_state_grouping=[],
),
2,
),
(
Input(
[("Lambda(c)+", [0.5])],
["p", "K-", "pi+"],
["K*(892)0"],
[],
initial_state=[("Lambda(c)+", [0.5])],
final_state=["p", "K-", "pi+"],
intermediate_states=["K*(892)0"],
final_state_grouping=[],
),
4,
),
],
)
def test_parity_amplitude_coupling(
test_input: Input,
parameter_count: int,
) -> None:
stm = qrules.StateTransitionManager(
test_input.initial_state,
test_input.final_state,
def test_parity_amplitude_coupling(test_input: Input, n_parameters: int):
stm = StateTransitionManager(
initial_state=test_input.initial_state,
final_state=test_input.final_state,
allowed_intermediate_particles=test_input.intermediate_states,
number_of_threads=1,
)
Expand All @@ -60,4 +57,4 @@ def test_parity_amplitude_coupling(

model_builder = get_builder(reaction)
amplitude_model = model_builder.generate()
assert len(amplitude_model.parameter_defaults) == parameter_count
assert len(amplitude_model.parameter_defaults) == n_parameters

0 comments on commit d28271f

Please sign in to comment.