diff --git a/docs/usage/dynamics/k-matrix.ipynb b/docs/usage/dynamics/k-matrix.ipynb index b401194e2..24dd7bd89 100644 --- a/docs/usage/dynamics/k-matrix.ipynb +++ b/docs/usage/dynamics/k-matrix.ipynb @@ -642,9 +642,9 @@ "outputs": [], "source": [ "# reformulate terms\n", - "denominator, nominator = k_matrix.args\n", - "term1 = nominator.args[0] * denominator\n", - "term2 = nominator.args[1] * denominator\n", + "*rest, denominator, nominator = k_matrix.args\n", + "term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n", + "term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n", "k_matrix = term1 + term2\n", "k_matrix" ] @@ -934,9 +934,9 @@ " sp.sqrt(rho): 1,\n", " sp.conjugate(sp.sqrt(rho)): 1,\n", "})\n", - "denominator, nominator = rel_k_matrix_2r.args\n", - "term1 = nominator.args[0] * denominator\n", - "term2 = nominator.args[1] * denominator\n", + "*rest, denominator, nominator = rel_k_matrix_2r.args\n", + "term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n", + "term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n", "rel_k_matrix_2r = term1 + term2\n", "rel_k_matrix_2r" ] @@ -1081,9 +1081,9 @@ }, "outputs": [], "source": [ - "denominator, nominator = f_vector.args\n", - "term1 = nominator.args[0] * denominator\n", - "term2 = nominator.args[1] * denominator\n", + "*rest, denominator, nominator = f_vector.args\n", + "term1 = nominator.args[0] * denominator * sp.Mul(*rest)\n", + "term2 = nominator.args[1] * denominator * sp.Mul(*rest)\n", "f_vector = term1 + term2\n", "f_vector" ] diff --git a/tests/dynamics/test_kmatrix.py b/tests/dynamics/test_kmatrix.py index 0d1f1ddf2..bd37c3baa 100644 --- a/tests/dynamics/test_kmatrix.py +++ b/tests/dynamics/test_kmatrix.py @@ -1,16 +1,13 @@ from __future__ import annotations import re -from typing import TYPE_CHECKING import pytest +import sympy as sp from ampform.dynamics.kmatrix import NonRelativisticKMatrix from symplot import rename_symbols, substitute_indexed_symbols -if TYPE_CHECKING: - import sympy as sp - class TestNonRelativisticKMatrix: @pytest.mark.parametrize( @@ -35,9 +32,9 @@ def test_interference_single_channel(self): expr = substitute_indexed_symbols(expr) expr = _remove_residue_constants(expr) expr = _rename_widths(expr) - denominator, nominator = expr.args - term1 = nominator.args[0] * denominator - term2 = nominator.args[1] * denominator + *rest, denominator, nominator = expr.args + term1 = nominator.args[0] * denominator * sp.Mul(*rest) + term2 = nominator.args[1] * denominator * sp.Mul(*rest) assert str(term1 / term2) == R"m1*w1*(m2**2 - s)/(m2*w2*(m1**2 - s))"