Skip to content

Commit

Permalink
feat: make wrappers work on added properties
Browse files Browse the repository at this point in the history
  • Loading branch information
younesStrittmatter committed Aug 9, 2024
1 parent 312ed85 commit c3f88ba
Showing 1 changed file with 56 additions and 7 deletions.
63 changes: 56 additions & 7 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

_logger = logging.getLogger(__name__)


T = TypeVar("T")
C = TypeVar("C", covariant=True)

Expand Down Expand Up @@ -309,6 +308,38 @@ def _get_value(f, other: Union[Delta, Mapping]):
return value, used_key


def _get_field_names_and_properties(s: State):
"""
Get a list of field names and their aliases from a State object
Args:
s: a State object
Returns: a list of field names and their aliases on `s`
Examples:
>>> from dataclasses import field
>>> @dataclass(frozen=True)
... class SomeState(State):
... l: List = field(default_factory=list)
... m: List = field(default_factory=list)
... @property
... def both(self):
... return self.l + self.m
>>> _get_field_names_and_properties(SomeState())
['both', 'l', 'm']
"""
result = _get_field_names_and_aliases(s)
property_names = [
attr
for attr in dir(s)
if isinstance(getattr(type(s), attr, None), property)
and attr not in dir(object)
and attr not in result
]
return property_names + result


def _get_field_names_and_aliases(s: State):
"""
Get a list of field names and their aliases from a State object
Expand Down Expand Up @@ -645,19 +676,21 @@ def inputs_from_state(f, input_mapping: Dict = {}):
)

@wraps(f)
def _f(state_: S, /, **kwargs) -> S:
def _f(state_: State, /, **kwargs) -> State:
# Get the parameters needed which are available from the state_.
# All others must be provided as kwargs or default values on f.
assert is_dataclass(state_) or isinstance(state_, UserDict)
if is_dataclass(state_):
from_state = parameters_.intersection({i.name for i in fields(state_)})
from_state = parameters_.intersection(
_get_field_names_and_properties(state_)
)
arguments_from_state = {k: getattr(state_, k) for k in from_state}
from_state_input_mapping = {
reversed_mapping.get(field.name, field.name): getattr(
state_, field.name
reversed_mapping.get(field_name, field_name): getattr(
state_, field_name
)
for field in fields(state_)
if reversed_mapping.get(field.name, field.name) in parameters_
for field_name in _get_field_names_and_properties(state_)
if reversed_mapping.get(field_name, field_name) in parameters_
}
arguments_from_state.update(from_state_input_mapping)
elif isinstance(state_, UserDict):
Expand Down Expand Up @@ -1172,6 +1205,22 @@ class StandardState(State):
>>> (s + dm1 + dm2).models
[DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)]
>>> x_v = Variable('x')
>>> y_v = Variable('y')
>>> variables = VariableCollection(independent_variables=[x_v], dependent_variables=[y_v])
>>> e_data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 6]})
>>> s = StandardState(variables=variables, experiment_data=e_data)
>>> @inputs_from_state
... def add_six(X):
... return X
>>> add_six(s)
x
0 1
1 2
2 3
"""

variables: Optional[VariableCollection] = field(
Expand Down

0 comments on commit c3f88ba

Please sign in to comment.