Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: set i/o types for function implementations #522

Merged
merged 2 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions benchmarks/ampform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from ampform.helicity import HelicityModel
from qrules.combinatorics import StateDefinition

from tensorwaves.function import ParametrizedBackendFunction
from tensorwaves.interface import (
DataSample,
FitResult,
Function,
ParameterValue,
ParametrizedFunction,
)
Expand Down Expand Up @@ -55,7 +57,7 @@ def formulate_amplitude_model(

def create_function(
model: HelicityModel, backend: str, max_complexity: int | None = None
) -> ParametrizedFunction:
) -> ParametrizedBackendFunction:
return create_parametrized_function(
expression=model.expression.doit(),
parameters=model.parameter_defaults,
Expand All @@ -66,7 +68,7 @@ def create_function(

def generate_data(
model: HelicityModel,
function: ParametrizedFunction,
function: Function[DataSample, np.ndarray],
data_sample_size: int,
phsp_sample_size: int,
backend: str,
Expand Down Expand Up @@ -103,7 +105,7 @@ def generate_data(
def fit(
data: DataSample,
phsp: DataSample,
function: ParametrizedFunction,
function: ParametrizedFunction[DataSample, np.ndarray],
initial_parameters: Mapping[str, ParameterValue],
backend: str,
) -> FitResult:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _generate_domain(

def _generate_data(
size: int,
function: Function,
function: Function[DataSample, np.ndarray],
rng: np.random.Generator,
bunch_size: int = 10_000,
) -> DataSample:
Expand Down
2 changes: 1 addition & 1 deletion src/tensorwaves/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class IntensityDistributionGenerator(DataGenerator):
def __init__(
self,
domain_generator: DataGenerator,
function: Function,
function: Function[DataSample, np.ndarray],
domain_transformer: DataTransformer | None = None,
bunch_size: int = 50_000,
) -> None:
Expand Down
7 changes: 5 additions & 2 deletions src/tensorwaves/data/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ._attrs import to_tuple

if TYPE_CHECKING: # pragma: no cover
import numpy as np
import sympy as sp


Expand Down Expand Up @@ -55,7 +56,9 @@ def __call__(self, data: DataSample) -> DataSample:
class SympyDataTransformer(DataTransformer):
"""Implementation of a `.DataTransformer`."""

def __init__(self, functions: Mapping[str, Function]) -> None:
def __init__(
self, functions: Mapping[str, Function[DataSample, np.ndarray]]
) -> None:
if any(not isinstance(f, Function) for f in functions.values()):
msg = (
f"Not all values in the mapping are an instance of {Function.__name__}"
Expand All @@ -64,7 +67,7 @@ def __init__(self, functions: Mapping[str, Function]) -> None:
self.__functions = dict(functions)

@property
def functions(self) -> dict[str, Function]:
def functions(self) -> dict[str, Function[DataSample, np.ndarray]]:
"""Read-only access to the internal mapping of functions."""
return dict(self.__functions)

Expand Down
6 changes: 3 additions & 3 deletions src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def create_cached_function(
backend: str,
free_parameters: Iterable[sp.Symbol],
use_cse: bool = True,
) -> tuple[ParametrizedFunction, DataTransformer]:
) -> tuple[ParametrizedFunction[DataSample, np.ndarray], DataTransformer]:
"""Create a function and data transformer for cached computations.

Once it is known which parameters in an expression are to be optimized, this
Expand Down Expand Up @@ -118,7 +118,7 @@ class ChiSquared(Estimator):

def __init__( # noqa: PLR0913
self,
function: ParametrizedFunction,
function: ParametrizedFunction[DataSample, np.ndarray],
domain: DataSample,
observed_values: np.ndarray,
weights: np.ndarray | None = None,
Expand Down Expand Up @@ -185,7 +185,7 @@ class UnbinnedNLL(Estimator):

def __init__( # noqa: PLR0913
self,
function: ParametrizedFunction,
function: ParametrizedFunction[DataSample, np.ndarray],
data: DataSample,
phsp: DataSample,
phsp_volume: float = 1.0,
Expand Down
10 changes: 4 additions & 6 deletions src/tensorwaves/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Callable, Iterable, Mapping
from typing import Callable, Iterable, Mapping

import attrs
import numpy as np
from attrs import field, frozen

from tensorwaves.interface import (
Expand All @@ -15,9 +16,6 @@
ParametrizedFunction,
)

if TYPE_CHECKING:
import numpy as np


def _all_str(
_: PositionalArgumentFunction, __: attrs.Attribute, value: Iterable[str]
Expand Down Expand Up @@ -66,7 +64,7 @@ def _to_tuple(argument_order: Iterable[str]) -> tuple[str, ...]:


@frozen
class PositionalArgumentFunction(Function):
class PositionalArgumentFunction(Function[DataSample, np.ndarray]):
"""Wrapper around a function with positional arguments.

This class provides a :meth:`~.Function.__call__` that can take a `.DataSample` for
Expand All @@ -90,7 +88,7 @@ def __call__(self, data: DataSample) -> np.ndarray:
return self.function(*args)


class ParametrizedBackendFunction(ParametrizedFunction):
class ParametrizedBackendFunction(ParametrizedFunction[DataSample, np.ndarray]):
"""Implements `.ParametrizedFunction` for a specific computational back-end.

.. seealso:: :func:`.create_parametrized_function`
Expand Down
2 changes: 1 addition & 1 deletion src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __call__(self, data: InputType) -> OutputType: ...
"""Allowed types for parameter values."""


class ParametrizedFunction(Function[DataSample, np.ndarray]):
class ParametrizedFunction(Function[InputType, OutputType]):
"""Interface of a callable function.

A `ParametrizedFunction` identifies certain variables in a mathematical expression
Expand Down
2 changes: 1 addition & 1 deletion tests/optimizer/test_fit_simple_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def generate_domain(
def generate_data(
size: int,
boundaries: dict[str, tuple[float, float]],
function: Function,
function: Function[DataSample, np.ndarray],
rng: np.random.Generator,
bunch_size: int = 10_000,
) -> DataSample:
Expand Down
Loading