Skip to content

Commit

Permalink
BREAK: require n_events argument in lorentz classes (#381)
Browse files Browse the repository at this point in the history
* FEAT: make `ArraySize` publically available
  • Loading branch information
redeboer authored Dec 21, 2023
1 parent e341327 commit 2722118
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 54 deletions.
12 changes: 6 additions & 6 deletions docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from sympy.printing.numpy import NumPyPrinter

from ampform.io import aslatex
from ampform.kinematics.lorentz import FourMomentumSymbol, _ArraySize
from ampform.kinematics.lorentz import ArraySize, FourMomentumSymbol
from ampform.sympy._array_expressions import ArrayMultiplication

if sys.version_info < (3, 8):
Expand Down Expand Up @@ -127,7 +127,7 @@ def extend_BoostZMatrix() -> None:
)
b = sp.Symbol("b")
_append_code_rendering(
BoostZMatrix(b).doit(),
BoostZMatrix(b, n_events=ArraySize(b)).doit(),
use_cse=True,
docstring_class=BoostZMatrix,
)
Expand All @@ -147,9 +147,9 @@ def extend_BoostZMatrix() -> None:
)
p, beta, phi, theta = sp.symbols("p beta phi theta")
multiplication = ArrayMultiplication(
BoostZMatrix(beta, n_events=_ArraySize(p)),
RotationYMatrix(theta, n_events=_ArraySize(p)),
RotationZMatrix(phi, n_events=_ArraySize(p)),
BoostZMatrix(beta, n_events=ArraySize(p)),
RotationYMatrix(theta, n_events=ArraySize(p)),
RotationZMatrix(phi, n_events=ArraySize(p)),
p,
)
_append_to_docstring(
Expand Down Expand Up @@ -460,7 +460,7 @@ def extend_RotationZMatrix() -> None:
)
a = sp.Symbol("a")
_append_code_rendering(
RotationZMatrix(a).doit(),
RotationZMatrix(a, n_events=ArraySize(a)).doit(),
use_cse=True,
docstring_class=RotationZMatrix,
)
Expand Down
3 changes: 2 additions & 1 deletion docs/usage/kinematics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
"source": [
"from ampform.kinematics.lorentz import (\n",
" ArrayMultiplication,\n",
" ArraySize,\n",
" BoostZMatrix,\n",
" Energy,\n",
" FourMomentumSymbol,\n",
Expand All @@ -117,7 +118,7 @@
"p = FourMomentumSymbol(\"p\", shape=[])\n",
"q = FourMomentumSymbol(\"q\", shape=[])\n",
"beta = three_momentum_norm(p) / Energy(p)\n",
"Bz = BoostZMatrix(beta)\n",
"Bz = BoostZMatrix(beta, n_events=ArraySize(beta))\n",
"Bz_expr = ArrayMultiplication(Bz, q)\n",
"Bz_expr"
]
Expand Down
6 changes: 3 additions & 3 deletions src/ampform/kinematics/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from ampform.helicity.naming import get_helicity_angle_symbols, get_helicity_suffix
from ampform.kinematics.lorentz import (
ArraySize,
BoostMatrix,
BoostZMatrix,
Energy,
Expand All @@ -23,7 +24,6 @@
NegativeMomentum,
RotationYMatrix,
RotationZMatrix,
_ArraySize,
compute_boost_chain,
three_momentum_norm,
)
Expand Down Expand Up @@ -183,9 +183,9 @@ def __recursive_helicity_angles(
return __recursive_helicity_angles(four_momenta, initial_state_edge.ending_node_id)


def _get_number_of_events(four_momenta: Mapping[int, sp.Expr]) -> _ArraySize:
def _get_number_of_events(four_momenta: Mapping[int, sp.Expr]) -> ArraySize:
sorted_momentum_symbols = sorted(four_momenta.values(), key=str)
return _ArraySize(sorted_momentum_symbols[0])
return ArraySize(sorted_momentum_symbols[0])


def compute_wigner_angles(
Expand Down
24 changes: 7 additions & 17 deletions src/ampform/kinematics/lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,7 @@ class BoostZMatrix(UnevaluatedExpression):
:math:`n\times4\times4`. Defaults to the `len` of :code:`beta`.
"""

def __new__(
cls, beta: sp.Basic, n_events: sp.Expr | None = None, **kwargs
) -> BoostZMatrix:
if n_events is None:
n_events = _ArraySize(beta)
def __new__(cls, beta: sp.Basic, n_events: sp.Basic, **kwargs) -> BoostZMatrix:
return create_expression(cls, beta, n_events, **kwargs)

def as_explicit(self) -> sp.MutableDenseMatrix:
Expand Down Expand Up @@ -484,11 +480,7 @@ class RotationYMatrix(UnevaluatedExpression):
:math:`n\times4\times4`. Defaults to the `len` of :code:`angle`.
"""

def __new__(
cls, angle: sp.Basic, n_events: sp.Expr | None = None, **hints
) -> RotationYMatrix:
if n_events is None:
n_events = _ArraySize(angle)
def __new__(cls, angle: sp.Basic, n_events: sp.Basic, **hints) -> RotationYMatrix:
return create_expression(cls, angle, n_events, **hints)

def as_explicit(self) -> sp.MutableDenseMatrix:
Expand Down Expand Up @@ -555,11 +547,7 @@ class RotationZMatrix(UnevaluatedExpression):
:math:`n\times4\times4`. Defaults to the `len` of :code:`angle`.
"""

def __new__(
cls, angle: sp.Basic, n_events: sp.Expr | None = None, **hints
) -> RotationZMatrix:
if n_events is None:
n_events = _ArraySize(angle)
def __new__(cls, angle: sp.Basic, n_events: sp.Basic, **hints) -> RotationZMatrix:
return create_expression(cls, angle, n_events, **hints)

def as_explicit(self) -> sp.MutableDenseMatrix:
Expand Down Expand Up @@ -636,8 +624,10 @@ def _numpycode(self, printer: NumPyPrinter, *args) -> str:
return f"zeros({shape})"


class _ArraySize(NumPyPrintable):
def __new__(cls, array: sp.Basic, **kwargs) -> _ArraySize:
class ArraySize(NumPyPrintable):
"""Symbolic expression for getting the size of a numerical array."""

def __new__(cls, array: sp.Basic, **kwargs) -> ArraySize:
return create_expression(cls, array, **kwargs)

def _numpycode(self, printer: NumPyPrinter, *args) -> str:
Expand Down
52 changes: 25 additions & 27 deletions tests/kinematics/test_lorentz.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sympy.printing.numpy import NumPyPrinter

from ampform.kinematics.lorentz import (
ArraySize,
BoostMatrix,
BoostZMatrix,
Energy,
Expand All @@ -24,7 +25,6 @@
RotationYMatrix,
RotationZMatrix,
ThreeMomentum,
_ArraySize,
_OnesArray,
_ZerosArray,
compute_boost_chain,
Expand Down Expand Up @@ -58,7 +58,7 @@ def test_boost_in_z_direction_reduces_to_z_boost(self):
])

beta = three_momentum_norm(p) / Energy(p)
z_expr = BoostZMatrix(beta)
z_expr = BoostZMatrix(beta, n_events=ArraySize(p))
z_func = sp.lambdify(p, z_expr.doit(), cse=True)
z_matrix = z_func(p_array)[0]
assert pytest.approx(matrix) == z_matrix
Expand Down Expand Up @@ -104,7 +104,7 @@ def test_boosting_back_gives_original_momentum(
class TestBoostZMatrix:
def test_boost_into_own_rest_frame_gives_mass(self):
p = FourMomentumSymbol("p", shape=[])
n_events = _ArraySize(p)
n_events = ArraySize(p)
beta = three_momentum_norm(p) / Energy(p)
expr = BoostZMatrix(beta, n_events)
func = sp.lambdify(p, expr.doit(), cse=True)
Expand All @@ -122,9 +122,9 @@ def test_boost_into_own_rest_frame_gives_mass(self):
def test_numpycode_cse_in_expression_tree(self):
p, beta, phi, theta = sp.symbols("p beta phi theta")
expr = ArrayMultiplication(
BoostZMatrix(beta, n_events=_ArraySize(p)),
RotationYMatrix(theta, n_events=_ArraySize(p)),
RotationZMatrix(phi, n_events=_ArraySize(p)),
BoostZMatrix(beta, n_events=ArraySize(p)),
RotationYMatrix(theta, n_events=ArraySize(p)),
RotationZMatrix(phi, n_events=ArraySize(p)),
p,
)
func = sp.lambdify([], expr.doit(), cse=True)
Expand Down Expand Up @@ -246,27 +246,26 @@ def test_same_as_inverse(self, data_sample: dict[int, np.ndarray]):
class TestRotationYMatrix:
@pytest.fixture(scope="session")
def rotation_expr(self):
angle, n_events = sp.symbols("a n")
return RotationYMatrix(angle, n_events)
angle = sp.Symbol("a")
return RotationYMatrix(angle, n_events=ArraySize(angle))

@pytest.fixture(scope="session")
def rotation_func(self, rotation_expr):
def rotation_func(self, rotation_expr: RotationYMatrix):
angle = sp.Symbol("a")
rotation_expr = rotation_expr.doit()
rotation_expr = rotation_expr.subs(sp.Symbol("n"), _ArraySize(angle))
return sp.lambdify(angle, rotation_expr, cse=True)
return sp.lambdify(angle, rotation_expr.doit(), cse=True)

def test_numpycode_cse(self, rotation_expr: RotationYMatrix):
func = sp.lambdify([], rotation_expr.doit(), cse=True)
src = inspect.getsource(func)
expected_src = """
def _lambdifygenerated():
x0 = len(a)
return (array(
[
[ones(n), zeros(n), zeros(n), zeros(n)],
[zeros(n), cos(a), zeros(n), sin(a)],
[zeros(n), zeros(n), ones(n), zeros(n)],
[zeros(n), -sin(a), zeros(n), cos(a)],
[ones(x0), zeros(x0), zeros(x0), zeros(x0)],
[zeros(x0), cos(a), zeros(x0), sin(a)],
[zeros(x0), zeros(x0), ones(x0), zeros(x0)],
[zeros(x0), -sin(a), zeros(x0), cos(a)],
]
).transpose((2, 0, 1)))
"""
Expand All @@ -285,27 +284,26 @@ def test_rotation_over_pi_flips_xz(self, rotation_func):
class TestRotationZMatrix:
@pytest.fixture(scope="session")
def rotation_expr(self):
angle, n_events = sp.symbols("a n")
return RotationZMatrix(angle, n_events)
angle = sp.Symbol("a")
return RotationZMatrix(angle, n_events=ArraySize(angle))

@pytest.fixture(scope="session")
def rotation_func(self, rotation_expr):
def rotation_func(self, rotation_expr: RotationZMatrix):
angle = sp.Symbol("a")
rotation_expr = rotation_expr.doit()
rotation_expr = rotation_expr.subs(sp.Symbol("n"), _ArraySize(angle))
return sp.lambdify(angle, rotation_expr, cse=True)
return sp.lambdify(angle, rotation_expr.doit(), cse=True)

def test_numpycode_cse(self, rotation_expr: RotationZMatrix):
func = sp.lambdify([], rotation_expr.doit(), cse=True)
src = inspect.getsource(func)
expected_src = """
def _lambdifygenerated():
x0 = len(a)
return (array(
[
[ones(n), zeros(n), zeros(n), zeros(n)],
[zeros(n), cos(a), -sin(a), zeros(n)],
[zeros(n), sin(a), cos(a), zeros(n)],
[zeros(n), zeros(n), zeros(n), ones(n)],
[ones(x0), zeros(x0), zeros(x0), zeros(x0)],
[zeros(x0), cos(a), -sin(a), zeros(x0)],
[zeros(x0), sin(a), cos(a), zeros(x0)],
[zeros(x0), zeros(x0), zeros(x0), ones(x0)],
]
).transpose((2, 0, 1)))
"""
Expand All @@ -331,7 +329,7 @@ def test_rotation_latex_repr_is_identical_with_doit(rotation):
@pytest.mark.parametrize("rotation", [RotationYMatrix, RotationZMatrix])
def test_rotation_over_multiple_two_pi_is_identity(rotation):
angle = sp.Symbol("a")
expr = rotation(angle)
expr = rotation(angle, n_events=ArraySize(angle))
func = sp.lambdify(angle, expr.doit(), cse=True)
angle_array = np.arange(-2, 4, 1) * 2 * np.pi
rotation_matrices = func(angle_array)
Expand Down

0 comments on commit 2722118

Please sign in to comment.