Skip to content

Commit

Permalink
BRAIN-15567 - Implement from_state() factory method for all models (#26)
Browse files Browse the repository at this point in the history
### Changes

* Smab and Cmabs expose a factoy method to be created with a state object generated by a `get_state()` of an instance of the same class.
  • Loading branch information
adarmiento authored Nov 13, 2023
1 parent b0a9555 commit 54e504e
Show file tree
Hide file tree
Showing 4 changed files with 379 additions and 2 deletions.
12 changes: 12 additions & 0 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ class CmabBernoulli(BaseCmabBernoulli):
def __init__(self, actions: Dict[ActionId, BaseBayesianLogisticRegression]):
super().__init__(actions=actions, strategy=ClassicBandit())

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulli":
return cls(actions=state["actions"])

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(context=context, actions=actions, rewards=rewards)
Expand Down Expand Up @@ -249,6 +253,10 @@ def __init__(self, actions: Dict[ActionId, BayesianLogisticRegression], exploit_
strategy = BestActionIdentification() if exploit_p is None else BestActionIdentification(exploit_p=exploit_p)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulliBAI":
return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None))

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(context=context, actions=actions, rewards=rewards)
Expand Down Expand Up @@ -292,6 +300,10 @@ def __init__(self, actions: Dict[ActionId, BayesianLogisticRegressionCC], subsid
strategy = CostControlBandit() if subsidy_factor is None else CostControlBandit(subsidy_factor=subsidy_factor)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulliCC":
return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None))

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(context=context, actions=actions, rewards=rewards)
Expand Down
20 changes: 20 additions & 0 deletions pybandits/smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ class SmabBernoulli(BaseSmabBernoulli):
def __init__(self, actions: Dict[ActionId, Beta]):
super().__init__(actions=actions, strategy=ClassicBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulli":
return cls(actions=state["actions"])

@validate_arguments
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(actions=actions, rewards=rewards)
Expand Down Expand Up @@ -174,6 +178,10 @@ def __init__(self, actions: Dict[ActionId, Beta], exploit_p: Optional[Float01] =
strategy = BestActionIdentification() if exploit_p is None else BestActionIdentification(exploit_p=exploit_p)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliBAI":
return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None))

@validate_arguments
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(actions=actions, rewards=rewards)
Expand Down Expand Up @@ -209,6 +217,10 @@ def __init__(self, actions: Dict[ActionId, BetaCC], subsidy_factor: Optional[Flo
strategy = CostControlBandit() if subsidy_factor is None else CostControlBandit(subsidy_factor=subsidy_factor)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliCC":
return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None))

@validate_arguments
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(actions=actions, rewards=rewards)
Expand Down Expand Up @@ -269,6 +281,10 @@ class SmabBernoulliMO(BaseSmabBernoulliMO):
def __init__(self, actions: Dict[ActionId, Beta]):
super().__init__(actions=actions, strategy=MultiObjectiveBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMO":
return cls(actions=state["actions"])


class SmabBernoulliMOCC(BaseSmabBernoulliMO):
"""
Expand All @@ -292,6 +308,10 @@ class SmabBernoulliMOCC(BaseSmabBernoulliMO):
def __init__(self, actions: Dict[ActionId, Beta]):
super().__init__(actions=actions, strategy=MultiObjectiveCostControlBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMOCC":
return cls(actions=state["actions"])


@validate_arguments
def create_smab_bernoulli_cold_start(action_ids: Set[ActionId]) -> SmabBernoulli:
Expand Down
160 changes: 160 additions & 0 deletions tests/test_cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,53 @@ def test_cmab_get_state(mu, sigma, n_features):
assert is_serializable(cmab_state), "Internal state is not serializable"


@settings(deadline=500)
@given(
state=st.fixed_dictionaries(
{
"actions": st.dictionaries(
keys=st.text(min_size=1, max_size=10),
values=st.fixed_dictionaries(
{
"alpha": st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
"betas": st.lists(
st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
min_size=3,
max_size=3,
),
},
),
min_size=2,
),
"strategy": st.fixed_dictionaries({}),
}
)
)
def test_cmab_from_state(state):
cmab = CmabBernoulli.from_state(state)
assert isinstance(cmab, CmabBernoulli)

expected_actions = state["actions"]
actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict
assert expected_actions == actual_actions

# Ensure get_state and from_state compatibility
new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1])
assert new_cmab == cmab


########################################################################################################################


Expand Down Expand Up @@ -512,6 +559,62 @@ def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01):
assert is_serializable(cmab_state), "Internal state is not serializable"


@settings(deadline=500)
@given(
state=st.fixed_dictionaries(
{
"actions": st.dictionaries(
keys=st.text(min_size=1, max_size=10),
values=st.fixed_dictionaries(
{
"alpha": st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
"betas": st.lists(
st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
min_size=3,
max_size=3,
),
},
),
min_size=2,
),
"strategy": st.one_of(
st.just({}),
st.just({"exploit_p": None}),
st.builds(lambda x: {"exploit_p": x}, st.floats(min_value=0, max_value=1)),
),
}
)
)
def test_cmab_bai_from_state(state):
cmab = CmabBernoulliBAI.from_state(state)
assert isinstance(cmab, CmabBernoulliBAI)

expected_actions = state["actions"]
actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict
assert expected_actions == actual_actions
expected_exploit_p = (
state["strategy"].get("exploit_p", 0.5) if state["strategy"].get("exploit_p") is not None else 0.5
) # Covers both not existing and existing + None
actual_exploit_p = cmab.strategy.exploit_p
assert expected_exploit_p == actual_exploit_p

# Ensure get_state and from_state compatibility
new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1])
assert new_cmab == cmab


########################################################################################################################


Expand Down Expand Up @@ -697,3 +800,60 @@ def test_cmab_cc_get_state(
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"


@settings(deadline=500)
@given(
state=st.fixed_dictionaries(
{
"actions": st.dictionaries(
keys=st.text(min_size=1, max_size=10),
values=st.fixed_dictionaries(
{
"alpha": st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
"betas": st.lists(
st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
min_size=3,
max_size=3,
),
"cost": st.floats(min_value=0),
},
),
min_size=2,
),
"strategy": st.one_of(
st.just({}),
st.just({"subsidy_factor": None}),
st.builds(lambda x: {"subsidy_factor": x}, st.floats(min_value=0, max_value=1)),
),
}
)
)
def test_cmab_cc_from_state(state):
cmab = CmabBernoulliCC.from_state(state)
assert isinstance(cmab, CmabBernoulliCC)

expected_actions = state["actions"]
actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict
assert expected_actions == actual_actions
expected_subsidy_factor = (
state["strategy"].get("subsidy_factor", 0.5) if state["strategy"].get("subsidy_factor") is not None else 0.5
) # Covers both not existing and existing + None
actual_subsidy_factor = cmab.strategy.subsidy_factor
assert expected_subsidy_factor == actual_subsidy_factor

# Ensure get_state and from_state compatibility
new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1])
assert new_cmab == cmab
Loading

0 comments on commit 54e504e

Please sign in to comment.