Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BRAIN-15567 - Implement from_state() factory method for all models #26

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pybandits/cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ class CmabBernoulli(BaseCmabBernoulli):
def __init__(self, actions: Dict[ActionId, BaseBayesianLogisticRegression]):
super().__init__(actions=actions, strategy=ClassicBandit())

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulli":
return cls(actions=state["actions"])

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(context=context, actions=actions, rewards=rewards)
Expand Down Expand Up @@ -249,6 +253,10 @@ def __init__(self, actions: Dict[ActionId, BayesianLogisticRegression], exploit_
strategy = BestActionIdentification() if exploit_p is None else BestActionIdentification(exploit_p=exploit_p)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulliBAI":
return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None))

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(context=context, actions=actions, rewards=rewards)
Expand Down Expand Up @@ -292,6 +300,10 @@ def __init__(self, actions: Dict[ActionId, BayesianLogisticRegressionCC], subsid
strategy = CostControlBandit() if subsidy_factor is None else CostControlBandit(subsidy_factor=subsidy_factor)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "CmabBernoulliCC":
return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None))

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(context=context, actions=actions, rewards=rewards)
Expand Down
20 changes: 20 additions & 0 deletions pybandits/smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ class SmabBernoulli(BaseSmabBernoulli):
def __init__(self, actions: Dict[ActionId, Beta]):
super().__init__(actions=actions, strategy=ClassicBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulli":
return cls(actions=state["actions"])

@validate_arguments
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(actions=actions, rewards=rewards)
Expand Down Expand Up @@ -174,6 +178,10 @@ def __init__(self, actions: Dict[ActionId, Beta], exploit_p: Optional[Float01] =
strategy = BestActionIdentification() if exploit_p is None else BestActionIdentification(exploit_p=exploit_p)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliBAI":
return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None))

@validate_arguments
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(actions=actions, rewards=rewards)
Expand Down Expand Up @@ -209,6 +217,10 @@ def __init__(self, actions: Dict[ActionId, BetaCC], subsidy_factor: Optional[Flo
strategy = CostControlBandit() if subsidy_factor is None else CostControlBandit(subsidy_factor=subsidy_factor)
super().__init__(actions=actions, strategy=strategy)

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliCC":
return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None))

@validate_arguments
def update(self, actions: List[ActionId], rewards: List[BinaryReward]):
super().update(actions=actions, rewards=rewards)
Expand Down Expand Up @@ -269,6 +281,10 @@ class SmabBernoulliMO(BaseSmabBernoulliMO):
def __init__(self, actions: Dict[ActionId, Beta]):
super().__init__(actions=actions, strategy=MultiObjectiveBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMO":
return cls(actions=state["actions"])


class SmabBernoulliMOCC(BaseSmabBernoulliMO):
"""
Expand All @@ -292,6 +308,10 @@ class SmabBernoulliMOCC(BaseSmabBernoulliMO):
def __init__(self, actions: Dict[ActionId, Beta]):
super().__init__(actions=actions, strategy=MultiObjectiveCostControlBandit())

@classmethod
def from_state(cls, state: dict) -> "SmabBernoulliMOCC":
return cls(actions=state["actions"])


@validate_arguments
def create_smab_bernoulli_cold_start(action_ids: Set[ActionId]) -> SmabBernoulli:
Expand Down
160 changes: 160 additions & 0 deletions tests/test_cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,53 @@ def test_cmab_get_state(mu, sigma, n_features):
assert is_serializable(cmab_state), "Internal state is not serializable"


@settings(deadline=500)
@given(
state=st.fixed_dictionaries(
{
"actions": st.dictionaries(
keys=st.text(min_size=1, max_size=10),
values=st.fixed_dictionaries(
{
"alpha": st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
"betas": st.lists(
st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
min_size=3,
max_size=3,
),
},
),
min_size=2,
),
"strategy": st.fixed_dictionaries({}),
}
)
)
def test_cmab_from_state(state):
cmab = CmabBernoulli.from_state(state)
assert isinstance(cmab, CmabBernoulli)

expected_actions = state["actions"]
actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict
assert expected_actions == actual_actions

# Ensure get_state and from_state compatibility
new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1])
assert new_cmab == cmab


########################################################################################################################


Expand Down Expand Up @@ -512,6 +559,62 @@ def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01):
assert is_serializable(cmab_state), "Internal state is not serializable"


@settings(deadline=500)
@given(
state=st.fixed_dictionaries(
{
"actions": st.dictionaries(
keys=st.text(min_size=1, max_size=10),
values=st.fixed_dictionaries(
{
"alpha": st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
"betas": st.lists(
st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
min_size=3,
max_size=3,
),
},
),
min_size=2,
),
"strategy": st.one_of(
st.just({}),
st.just({"exploit_p": None}),
st.builds(lambda x: {"exploit_p": x}, st.floats(min_value=0, max_value=1)),
),
}
)
)
def test_cmab_bai_from_state(state):
cmab = CmabBernoulliBAI.from_state(state)
assert isinstance(cmab, CmabBernoulliBAI)

expected_actions = state["actions"]
actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict
assert expected_actions == actual_actions
expected_exploit_p = (
state["strategy"].get("exploit_p", 0.5) if state["strategy"].get("exploit_p") is not None else 0.5
) # Covers both not existing and existing + None
actual_exploit_p = cmab.strategy.exploit_p
assert expected_exploit_p == actual_exploit_p

# Ensure get_state and from_state compatibility
new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1])
assert new_cmab == cmab


########################################################################################################################


Expand Down Expand Up @@ -697,3 +800,60 @@ def test_cmab_cc_get_state(
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"


@settings(deadline=500)
@given(
state=st.fixed_dictionaries(
{
"actions": st.dictionaries(
keys=st.text(min_size=1, max_size=10),
values=st.fixed_dictionaries(
{
"alpha": st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
"betas": st.lists(
st.fixed_dictionaries(
{
"mu": st.floats(min_value=-100, max_value=100),
"nu": st.floats(min_value=0, max_value=100),
"sigma": st.floats(min_value=0, max_value=100),
}
),
min_size=3,
max_size=3,
),
"cost": st.floats(min_value=0),
},
),
min_size=2,
),
"strategy": st.one_of(
st.just({}),
st.just({"subsidy_factor": None}),
st.builds(lambda x: {"subsidy_factor": x}, st.floats(min_value=0, max_value=1)),
),
}
)
)
def test_cmab_cc_from_state(state):
cmab = CmabBernoulliCC.from_state(state)
assert isinstance(cmab, CmabBernoulliCC)

expected_actions = state["actions"]
actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict
assert expected_actions == actual_actions
expected_subsidy_factor = (
state["strategy"].get("subsidy_factor", 0.5) if state["strategy"].get("subsidy_factor") is not None else 0.5
) # Covers both not existing and existing + None
actual_subsidy_factor = cmab.strategy.subsidy_factor
assert expected_subsidy_factor == actual_subsidy_factor

# Ensure get_state and from_state compatibility
new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1])
assert new_cmab == cmab
Loading
Loading