Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: remove aliases from state #63

Merged
merged 8 commits into from
Dec 1, 2023
190 changes: 6 additions & 184 deletions src/autora/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,47 +212,6 @@ class State:
We can define aliases which can transform between different potential field
names.

>>> @dataclass(frozen=True)
... class FieldAliasState(State):
... things: List[str] = field(
... default_factory=list,
... metadata={"delta": "extend",
... "aliases": {"thing": lambda m: [m]}}
... )

In the "normal" case, the Delta object is expected to include a list of data in the
correct format which is used to extend the object:
>>> FieldAliasState(things=["0"]) + Delta(things=["1", "2"])
FieldAliasState(things=['0', '1', '2'])

However, say the standard return from a step in AER is a single `thing`, rather than a
sequence of them:
>>> FieldAliasState(things=["0"]) + Delta(thing="1")
FieldAliasState(things=['0', '1'])


If a cycle function relies on the existence of the `s.thing` as a property of your state
`s`, rather than accessing `s.things[-1]`, then you could additionally define a `property`:

>>> class FieldAliasStateWithProperty(FieldAliasState): # inherit from FieldAliasState
... @property
... def thing(self):
... return self.things[-1]

Now you can access both `s.things` and `s.thing` as required by your code. The State only
shows `things` in the string representation...
>>> u = FieldAliasStateWithProperty(things=["0"]) + Delta(thing="1")
>>> u
FieldAliasStateWithProperty(things=['0', '1'])

... and exposes `things` as an attribute:
>>> u.things
['0', '1']

... but also exposes `thing`, always returning the last value.
>>> u.thing
'1'

"""

def __add__(self, other: Union[Delta, Mapping]):
Expand Down Expand Up @@ -322,12 +281,7 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> from dataclasses import field, dataclass, fields
>>> @dataclass
... class Example:
... a: int = field() # base case
... b: List[int] = field(metadata={"aliases": {"ba": lambda b: [b]}}) # Single alias
... c: List[int] = field(metadata={"aliases": {
... "ca": lambda x: x, # pass the value unchanged
... "cb": lambda x: [x] # wrap the value in a list
... }}) # Multiple alias
... a: int = field()

For a field with no aliases, we retrieve values with the base name:
>>> f_a = fields(Example)[0]
Expand All @@ -342,104 +296,15 @@ def _get_value(f, other: Union[Delta, Mapping]):
>>> _get_value(f_a, Delta(b=2, a=1))
(1, 'a')

For fields with an alias, we retrieve values with the base name:
>>> f_b = fields(Example)[1]
>>> _get_value(f_b, Delta(b=[2]))
([2], 'b')

... or for the alias name, transformed by the alias lambda function:
>>> _get_value(f_b, Delta(ba=21))
([21], 'ba')

We preferentially get the base name, and then any aliases:
>>> _get_value(f_b, Delta(b=2, ba=21))
(2, 'b')

... , regardless of their order in the `Delta` object:
>>> _get_value(f_b, Delta(ba=21, b=2))
(2, 'b')

Other names are ignored:
>>> _get_value(f_b, Delta(a=1))
(None, None)

and the order of other names is unimportant:
>>> _get_value(f_b, Delta(a=1, b=2))
(2, 'b')

For fields with multiple aliases, we retrieve values with the base name:
>>> f_c = fields(Example)[2]
>>> _get_value(f_c, Delta(c=[3]))
([3], 'c')

... for any alias:
>>> _get_value(f_c, Delta(ca=31))
(31, 'ca')

... transformed by the alias lambda function :
>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')

... and ignoring any other names:
>>> print(_get_value(f_c, Delta(a=1)))
(None, None)

... preferentially in the order base name, 1st alias, 2nd alias, ... nth alias:
>>> _get_value(f_c, Delta(c=3, ca=31, cb=32))
(3, 'c')

>>> _get_value(f_c, Delta(ca=31, cb=32))
(31, 'ca')

>>> _get_value(f_c, Delta(cb=32))
([32], 'cb')

>>> print(_get_value(f_c, Delta()))
(None, None)

This works with dict objects:
>>> _get_value(f_a, dict(a=13))
(13, 'a')

... with multiple keys:
>>> _get_value(f_b, dict(a=13, b=24, c=35))
(24, 'b')

... and with aliases:
>>> _get_value(f_b, dict(ba=222))
([222], 'ba')

This works with UserDicts:
>>> class MyDelta(UserDict):
... pass

>>> _get_value(f_a, MyDelta(a=14))
(14, 'a')

... with multiple keys:
>>> _get_value(f_b, MyDelta(a=1, b=4, c=9))
(4, 'b')

... and with aliases:
>>> _get_value(f_b, MyDelta(ba=234))
([234], 'ba')

"""

key = f.name
aliases = f.metadata.get("aliases", {})

value, used_key = None, None

if key in other.keys():
value = other[key]
used_key = key
elif aliases: # ... is not an empty dict
for alias_key, wrapping_function in aliases.items():
if alias_key in other:
value = wrapping_function(other[alias_key])
used_key = alias_key
break # we only evaluate the first match

return value, used_key

Expand All @@ -462,23 +327,8 @@ def _get_field_names_and_aliases(s: State):
>>> _get_field_names_and_aliases(SomeState())
['l', 'm']

>>> @dataclass(frozen=True)
... class SomeStateWithAliases(State):
... l: List = field(default_factory=list, metadata={"aliases": {"l1": None, "l2": None}})
... m: List = field(default_factory=list, metadata={"aliases": {"m1": None}})
>>> _get_field_names_and_aliases(SomeStateWithAliases())
['l', 'l1', 'l2', 'm', 'm1']

"""
result = []

for f in fields(s):
name = f.name
result.append(name)

aliases = f.metadata.get("aliases", {})
result.extend(aliases)

result = [f.name for f in fields(s)]
return result


Expand Down Expand Up @@ -1322,27 +1172,6 @@ class StandardState(State):
>>> (s + dm1 + dm2).models
[DummyClassifier(constant=1), DummyClassifier(constant=2), DummyClassifier(constant=3)]

The last model is available under the `model` property:
>>> (s + dm1 + dm2).model
DummyClassifier(constant=3)

If there is no model, `None` is returned:
>>> print(s.model)
None

`models` can also be updated using a Delta with a single `model`:
>>> dm3 = Delta(model=DummyClassifier(constant=4))
>>> (s + dm1 + dm3).model
DummyClassifier(constant=4)

As before, the `models` list is extended:
>>> (s + dm1 + dm3).models
[DummyClassifier(constant=1), DummyClassifier(constant=4)]

No coercion or validation occurs with `models` or `model`:
>>> (s + dm1 + Delta(model="not a model")).models
[DummyClassifier(constant=1), 'not a model']

"""

variables: Optional[VariableCollection] = field(
Expand All @@ -1356,17 +1185,9 @@ class StandardState(State):
)
models: List[BaseEstimator] = field(
default_factory=list,
metadata={"delta": "extend", "aliases": {"model": lambda model: [model]}},
metadata={"delta": "extend"},
)

@property
def model(self):
"""Alias for the last model in the `models`."""
try:
return self.models[-1]
except IndexError:
return None


X = TypeVar("X")
Y = TypeVar("Y")
Expand Down Expand Up @@ -1396,8 +1217,9 @@ def estimator_on_state(estimator: BaseEstimator) -> StateFunction:
... experiment_data=pd.DataFrame({"x": [1,2,3], "y":[3,6,9]})
... )

Run the function, which fits the model and adds the result to the `StandardState`
>>> state_fn(s).model.coef_
Run the function, which fits the model and adds the result to the `StandardState` as the
last entry in the .models list.
>>> state_fn(s).models[-1].coef_
array([[3.]])

"""
Expand Down
6 changes: 3 additions & 3 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def validate_model(state: Optional[StandardState]):
assert state.experiment_data is not None
assert len(state.experiment_data) == 100

assert state.model is not None
assert np.allclose(state.model.coef_, [[2.0]])
assert np.allclose(state.model.intercept_, [[0.5]])
assert state.models[-1] is not None
assert np.allclose(state.models[-1].coef_, [[2.0]])
assert np.allclose(state.models[-1].intercept_, [[0.5]])


@given(example_workflow_library_module)
Expand Down