diff --git a/src/autora/state.py b/src/autora/state.py index 75e11838..f959f92f 100644 --- a/src/autora/state.py +++ b/src/autora/state.py @@ -32,7 +32,6 @@ _logger = logging.getLogger(__name__) - T = TypeVar("T") C = TypeVar("C", covariant=True) @@ -356,6 +355,38 @@ def _get_value(f, other: Union[Delta, Mapping]): return value, used_key +def _get_field_names_and_properties(s: State): + """ + Get a list of field names and their aliases from a State object + + Args: + s: a State object + + Returns: a list of field names and their aliases on `s` + + Examples: + >>> from dataclasses import field + >>> @dataclass(frozen=True) + ... class SomeState(State): + ... l: List = field(default_factory=list) + ... m: List = field(default_factory=list) + ... @property + ... def both(self): + ... return self.l + self.m + >>> _get_field_names_and_properties(SomeState()) + ['both', 'l', 'm'] + """ + result = _get_field_names_and_aliases(s) + property_names = [ + attr + for attr in dir(s) + if isinstance(getattr(type(s), attr, None), property) + and attr not in dir(object) + and attr not in result + ] + return property_names + result + + def _get_field_names_and_aliases(s: State): """ Get a list of field names and their aliases from a State object @@ -692,19 +723,21 @@ def inputs_from_state(f, input_mapping: Dict = {}): ) @wraps(f) - def _f(state_: S, /, **kwargs) -> S: + def _f(state_: State, /, **kwargs) -> State: # Get the parameters needed which are available from the state_. # All others must be provided as kwargs or default values on f. assert is_dataclass(state_) or isinstance(state_, UserDict) if is_dataclass(state_): - from_state = parameters_.intersection({i.name for i in fields(state_)}) + from_state = parameters_.intersection( + _get_field_names_and_properties(state_) + ) arguments_from_state = {k: getattr(state_, k) for k in from_state} from_state_input_mapping = { - reversed_mapping.get(field.name, field.name): getattr( - state_, field.name + reversed_mapping.get(field_name, field_name): getattr( + state_, field_name ) - for field in fields(state_) - if reversed_mapping.get(field.name, field.name) in parameters_ + for field_name in _get_field_names_and_properties(state_) + if reversed_mapping.get(field_name, field_name) in parameters_ } arguments_from_state.update(from_state_input_mapping) elif isinstance(state_, UserDict): @@ -1219,6 +1252,36 @@ class StandardState(State): >>> (s + dm1 + dm2).models [DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)] + We can use properties X, y, iv_names and dv_names as 'getters' ... + >>> x_v = Variable('x') + >>> y_v = Variable('y') + >>> variables = VariableCollection(independent_variables=[x_v], dependent_variables=[y_v]) + >>> e_data = pd.DataFrame({'x': [1, 2, 3], 'y': [2, 4, 6]}) + >>> s = StandardState(variables=variables, experiment_data=e_data) + >>> @inputs_from_state + ... def show_X(X): + ... return X + >>> show_X(s) + x + 0 1 + 1 2 + 2 3 + + ... but nothing happens if we use them as `setters`: + >>> @on_state + ... def add_to_X(X): + ... res = X.copy() + ... res['x'] += 1 + ... return Delta(X=res) + >>> s = add_to_X(s) + >>> s.X + x + 0 1 + 1 2 + 2 3 + + + """ variables: Optional[VariableCollection] = field(