Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: simplify numpy code of matrix expressions #232

Merged
merged 21 commits into from
Feb 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
e6af65b
fix: run qrules single-threaded in tests
redeboer Feb 7, 2022
5994c9a
test: check lambdify src code RotationZMatrix
redeboer Feb 7, 2022
dd8cd20
test: add runtime tests RotationZMatrix
redeboer Feb 7, 2022
4a0a0d4
test: perform same tests on RotationYMatrix
redeboer Feb 7, 2022
de462d3
refactor: wrap zeros length in NumPyPrintable
redeboer Feb 7, 2022
52418dd
fix: render _numpycode() output as python block
redeboer Feb 7, 2022
757a2a5
refactor: remove angle property from Rotation classes
redeboer Feb 7, 2022
e491a7a
refactor: remove beta property from BoostZMatrix
redeboer Feb 7, 2022
d13f820
docs: automatically widen numpycode code-blocks
redeboer Feb 7, 2022
ed443ee
refactor: extract _BoostZMatrixImplementation
redeboer Feb 7, 2022
10f7d40
refactor!: add n_events argument to Matrix expr classes"
redeboer Feb 7, 2022
634a824
feat: implmenent cse_all_symbols()
redeboer Feb 7, 2022
0768b95
fix: make cse_all_symbols() suitable for lambdify()
redeboer Feb 7, 2022
fc07dd6
docs: extend codeautolink_global_preface
redeboer Feb 7, 2022
09a159e
refactor: extend number of args in Matrix implementation classes
redeboer Feb 7, 2022
88ee324
fix: render n_events as len(arg) in API
redeboer Feb 7, 2022
7ddc864
test: check numpycode with cse in nested expression
redeboer Feb 7, 2022
c1616cd
refactor: remove cse_all_symbols()
redeboer Feb 7, 2022
a9abb53
fix: remove redundant transpose import
redeboer Feb 7, 2022
f867408
docs: link to BoostZMatrix from RotationZMatrix
redeboer Feb 7, 2022
d47adbb
refactor: make n_events argument optional
redeboer Feb 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@
"kmatrix",
"kutschke",
"kwargs",
"lambdifygenerated",
"lambdifying",
"linestyle",
"linewidth",
"linkcheck",
Expand Down
109 changes: 91 additions & 18 deletions docs/_extend_docstrings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import sympy as sp
from sympy.printing.numpy import NumPyPrinter

from ampform.kinematics import FourMomentumSymbol
from ampform.kinematics import FourMomentumSymbol, _ArraySize
from ampform.sympy import NumPyPrintable
from ampform.sympy._array_expressions import ArrayMultiplication

logging.getLogger().setLevel(logging.ERROR)

Expand Down Expand Up @@ -60,8 +62,8 @@ def extend_BlattWeisskopfSquared() -> None:
def extend_BoostZMatrix() -> None:
from ampform.kinematics import BoostZMatrix

beta = sp.Symbol("beta")
expr = BoostZMatrix(beta)
beta, n_events = sp.symbols("beta n")
expr = BoostZMatrix(beta, n_events)
_append_to_docstring(
BoostZMatrix,
f"""\n
Expand All @@ -81,7 +83,45 @@ def extend_BoostZMatrix() -> None:
""",
)
b = sp.Symbol("b")
_append_code_rendering(BoostZMatrix(b))
_append_code_rendering(
BoostZMatrix(b).doit(),
use_cse=True,
docstring_class=BoostZMatrix,
)

from ampform.kinematics import RotationYMatrix, RotationZMatrix

_append_to_docstring(
BoostZMatrix,
"""
Note that this code was generated with :func:`sympy.lambdify
<sympy.utilities.lambdify.lambdify>` with :code:`cse=True`. The repetition
of :func:`numpy.ones` is still bothersome, but these sub-nodes is also
extracted by :func:`sympy.cse <sympy.simplify.cse_main.cse>` if the
expression is nested further down in an :doc:`expression tree
<sympy:tutorial/manipulation>`, for instance when boosting a
`.FourMomentumSymbol` :math:`p` in the :math:`z`-direction:
""",
)
p, beta, phi, theta = sp.symbols("p beta phi theta")
expr = ArrayMultiplication(
BoostZMatrix(beta, n_events=_ArraySize(p)),
RotationYMatrix(theta, n_events=_ArraySize(p)),
RotationZMatrix(phi, n_events=_ArraySize(p)),
p,
)
_append_to_docstring(
BoostZMatrix,
f"""\n
.. math:: {sp.latex(expr)}
:label: boost-in-z-direction

which in :mod:`numpy` code becomes:
""",
)
_append_code_rendering(
expr.doit(), use_cse=True, docstring_class=BoostZMatrix
)


def extend_BreakupMomentumSquared() -> None:
Expand Down Expand Up @@ -235,8 +275,8 @@ def extend_Phi() -> None:
def extend_RotationYMatrix() -> None:
from ampform.kinematics import RotationYMatrix

angle = sp.Symbol("alpha")
expr = RotationYMatrix(angle)
angle, n_events = sp.symbols("alpha n")
expr = RotationYMatrix(angle, n_events)
_append_to_docstring(
RotationYMatrix,
f"""\n
Expand All @@ -254,8 +294,8 @@ def extend_RotationYMatrix() -> None:
def extend_RotationZMatrix() -> None:
from ampform.kinematics import RotationZMatrix

angle = sp.Symbol("alpha")
expr = RotationZMatrix(angle)
angle, n_events = sp.symbols("alpha n")
expr = RotationZMatrix(angle, n_events)
_append_to_docstring(
RotationZMatrix,
f"""\n
Expand All @@ -276,7 +316,17 @@ def extend_RotationZMatrix() -> None:
""",
)
a = sp.Symbol("a")
_append_code_rendering(RotationZMatrix(a))
_append_code_rendering(
RotationZMatrix(a).doit(),
use_cse=True,
docstring_class=RotationZMatrix,
)
_append_to_docstring(
RotationZMatrix,
"""
See also the note that comes with Equation :eq:`boost-in-z-direction`.
""",
)


def extend_Theta() -> None:
Expand Down Expand Up @@ -423,19 +473,42 @@ def extend_relativistic_breit_wigner_with_ff() -> None:
)


def _append_code_rendering(expr: sp.Expr) -> None:
def _append_code_rendering(
expr: NumPyPrintable,
use_cse: bool = False,
docstring_class: Optional[type] = None,
) -> None:
printer = NumPyPrinter()
numpy_code = expr._numpycode(printer)
if use_cse:
args = sorted(expr.free_symbols, key=str)
func = sp.lambdify(args, expr, cse=True, printer=printer)
numpy_code = inspect.getsource(func)
else:
numpy_code = expr._numpycode(printer)
import_statements = __print_imports(printer)
_append_to_docstring(
type(expr),
f"""\n
.. code::

if docstring_class is None:
docstring_class = type(expr)
numpy_code = textwrap.dedent(numpy_code)
numpy_code = textwrap.indent(numpy_code, prefix=8 * " ").strip()
options = ""
if (
max(__get_text_width(import_statements), __get_text_width(numpy_code))
> 90
):
options += ":class: full-width\n"
appended_text = f"""\n
.. code-block:: python
{options}
{import_statements}
{numpy_code}
""",
)
"""
_append_to_docstring(docstring_class, appended_text)


def __get_text_width(text: str) -> int:
lines = text.split("\n")
widths = map(len, lines)
return max(widths)


def _append_latex_doit_definition(
Expand Down
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def fetch_logo(url: str, output_path: str) -> None:
autodoc_typehints_format = "short"
codeautolink_concat_default = True
codeautolink_global_preface = """
import numpy
import numpy as np
import sympy as sp
from IPython.display import display
"""
AUTODOC_INSERT_SIGNATURE_LINEBREAKS = False
Expand Down
Loading