diff --git a/pybandits/cmab.py b/pybandits/cmab.py index da3fccb..1268a11 100644 --- a/pybandits/cmab.py +++ b/pybandits/cmab.py @@ -215,6 +215,10 @@ class CmabBernoulli(BaseCmabBernoulli): def __init__(self, actions: Dict[ActionId, BaseBayesianLogisticRegression]): super().__init__(actions=actions, strategy=ClassicBandit()) + @classmethod + def from_state(cls, state: dict) -> "CmabBernoulli": + return cls(actions=state["actions"]) + @validate_arguments(config=dict(arbitrary_types_allowed=True)) def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]): super().update(context=context, actions=actions, rewards=rewards) @@ -249,6 +253,10 @@ def __init__(self, actions: Dict[ActionId, BayesianLogisticRegression], exploit_ strategy = BestActionIdentification() if exploit_p is None else BestActionIdentification(exploit_p=exploit_p) super().__init__(actions=actions, strategy=strategy) + @classmethod + def from_state(cls, state: dict) -> "CmabBernoulliBAI": + return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None)) + @validate_arguments(config=dict(arbitrary_types_allowed=True)) def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]): super().update(context=context, actions=actions, rewards=rewards) @@ -292,6 +300,10 @@ def __init__(self, actions: Dict[ActionId, BayesianLogisticRegressionCC], subsid strategy = CostControlBandit() if subsidy_factor is None else CostControlBandit(subsidy_factor=subsidy_factor) super().__init__(actions=actions, strategy=strategy) + @classmethod + def from_state(cls, state: dict) -> "CmabBernoulliCC": + return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None)) + @validate_arguments(config=dict(arbitrary_types_allowed=True)) def update(self, context: ArrayLike, actions: List[ActionId], rewards: List[BinaryReward]): super().update(context=context, actions=actions, rewards=rewards) diff --git a/pybandits/smab.py b/pybandits/smab.py index 8f1c669..b72a319 100644 --- a/pybandits/smab.py +++ b/pybandits/smab.py @@ -147,6 +147,10 @@ class SmabBernoulli(BaseSmabBernoulli): def __init__(self, actions: Dict[ActionId, Beta]): super().__init__(actions=actions, strategy=ClassicBandit()) + @classmethod + def from_state(cls, state: dict) -> "SmabBernoulli": + return cls(actions=state["actions"]) + @validate_arguments def update(self, actions: List[ActionId], rewards: List[BinaryReward]): super().update(actions=actions, rewards=rewards) @@ -174,6 +178,10 @@ def __init__(self, actions: Dict[ActionId, Beta], exploit_p: Optional[Float01] = strategy = BestActionIdentification() if exploit_p is None else BestActionIdentification(exploit_p=exploit_p) super().__init__(actions=actions, strategy=strategy) + @classmethod + def from_state(cls, state: dict) -> "SmabBernoulliBAI": + return cls(actions=state["actions"], exploit_p=state["strategy"].get("exploit_p", None)) + @validate_arguments def update(self, actions: List[ActionId], rewards: List[BinaryReward]): super().update(actions=actions, rewards=rewards) @@ -209,6 +217,10 @@ def __init__(self, actions: Dict[ActionId, BetaCC], subsidy_factor: Optional[Flo strategy = CostControlBandit() if subsidy_factor is None else CostControlBandit(subsidy_factor=subsidy_factor) super().__init__(actions=actions, strategy=strategy) + @classmethod + def from_state(cls, state: dict) -> "SmabBernoulliCC": + return cls(actions=state["actions"], subsidy_factor=state["strategy"].get("subsidy_factor", None)) + @validate_arguments def update(self, actions: List[ActionId], rewards: List[BinaryReward]): super().update(actions=actions, rewards=rewards) @@ -269,6 +281,10 @@ class SmabBernoulliMO(BaseSmabBernoulliMO): def __init__(self, actions: Dict[ActionId, Beta]): super().__init__(actions=actions, strategy=MultiObjectiveBandit()) + @classmethod + def from_state(cls, state: dict) -> "SmabBernoulliMO": + return cls(actions=state["actions"]) + class SmabBernoulliMOCC(BaseSmabBernoulliMO): """ @@ -292,6 +308,10 @@ class SmabBernoulliMOCC(BaseSmabBernoulliMO): def __init__(self, actions: Dict[ActionId, Beta]): super().__init__(actions=actions, strategy=MultiObjectiveCostControlBandit()) + @classmethod + def from_state(cls, state: dict) -> "SmabBernoulliMOCC": + return cls(actions=state["actions"]) + @validate_arguments def create_smab_bernoulli_cold_start(action_ids: Set[ActionId]) -> SmabBernoulli: diff --git a/tests/test_cmab.py b/tests/test_cmab.py index 4620901..d2d25f4 100644 --- a/tests/test_cmab.py +++ b/tests/test_cmab.py @@ -340,6 +340,53 @@ def test_cmab_get_state(mu, sigma, n_features): assert is_serializable(cmab_state), "Internal state is not serializable" +@settings(deadline=500) +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "alpha": st.fixed_dictionaries( + { + "mu": st.floats(min_value=-100, max_value=100), + "nu": st.floats(min_value=0, max_value=100), + "sigma": st.floats(min_value=0, max_value=100), + } + ), + "betas": st.lists( + st.fixed_dictionaries( + { + "mu": st.floats(min_value=-100, max_value=100), + "nu": st.floats(min_value=0, max_value=100), + "sigma": st.floats(min_value=0, max_value=100), + } + ), + min_size=3, + max_size=3, + ), + }, + ), + min_size=2, + ), + "strategy": st.fixed_dictionaries({}), + } + ) +) +def test_cmab_from_state(state): + cmab = CmabBernoulli.from_state(state) + assert isinstance(cmab, CmabBernoulli) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + + # Ensure get_state and from_state compatibility + new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1]) + assert new_cmab == cmab + + ######################################################################################################################## @@ -512,6 +559,62 @@ def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01): assert is_serializable(cmab_state), "Internal state is not serializable" +@settings(deadline=500) +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "alpha": st.fixed_dictionaries( + { + "mu": st.floats(min_value=-100, max_value=100), + "nu": st.floats(min_value=0, max_value=100), + "sigma": st.floats(min_value=0, max_value=100), + } + ), + "betas": st.lists( + st.fixed_dictionaries( + { + "mu": st.floats(min_value=-100, max_value=100), + "nu": st.floats(min_value=0, max_value=100), + "sigma": st.floats(min_value=0, max_value=100), + } + ), + min_size=3, + max_size=3, + ), + }, + ), + min_size=2, + ), + "strategy": st.one_of( + st.just({}), + st.just({"exploit_p": None}), + st.builds(lambda x: {"exploit_p": x}, st.floats(min_value=0, max_value=1)), + ), + } + ) +) +def test_cmab_bai_from_state(state): + cmab = CmabBernoulliBAI.from_state(state) + assert isinstance(cmab, CmabBernoulliBAI) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + expected_exploit_p = ( + state["strategy"].get("exploit_p", 0.5) if state["strategy"].get("exploit_p") is not None else 0.5 + ) # Covers both not existing and existing + None + actual_exploit_p = cmab.strategy.exploit_p + assert expected_exploit_p == actual_exploit_p + + # Ensure get_state and from_state compatibility + new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1]) + assert new_cmab == cmab + + ######################################################################################################################## @@ -697,3 +800,60 @@ def test_cmab_cc_get_state( assert cmab_state == expected_state assert is_serializable(cmab_state), "Internal state is not serializable" + + +@settings(deadline=500) +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "alpha": st.fixed_dictionaries( + { + "mu": st.floats(min_value=-100, max_value=100), + "nu": st.floats(min_value=0, max_value=100), + "sigma": st.floats(min_value=0, max_value=100), + } + ), + "betas": st.lists( + st.fixed_dictionaries( + { + "mu": st.floats(min_value=-100, max_value=100), + "nu": st.floats(min_value=0, max_value=100), + "sigma": st.floats(min_value=0, max_value=100), + } + ), + min_size=3, + max_size=3, + ), + "cost": st.floats(min_value=0), + }, + ), + min_size=2, + ), + "strategy": st.one_of( + st.just({}), + st.just({"subsidy_factor": None}), + st.builds(lambda x: {"subsidy_factor": x}, st.floats(min_value=0, max_value=1)), + ), + } + ) +) +def test_cmab_cc_from_state(state): + cmab = CmabBernoulliCC.from_state(state) + assert isinstance(cmab, CmabBernoulliCC) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(cmab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + expected_subsidy_factor = ( + state["strategy"].get("subsidy_factor", 0.5) if state["strategy"].get("subsidy_factor") is not None else 0.5 + ) # Covers both not existing and existing + None + actual_subsidy_factor = cmab.strategy.subsidy_factor + assert expected_subsidy_factor == actual_subsidy_factor + + # Ensure get_state and from_state compatibility + new_cmab = globals()[cmab.get_state()[0]].from_state(state=cmab.get_state()[1]) + assert new_cmab == cmab diff --git a/tests/test_smab.py b/tests/test_smab.py index b136790..52347de 100644 --- a/tests/test_smab.py +++ b/tests/test_smab.py @@ -19,7 +19,7 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. - +import json from copy import deepcopy from typing import List @@ -207,13 +207,42 @@ def test_smab_get_state(a, b, c, d): smab = SmabBernoulli(actions=actions) expected_state = {"actions": actions, "strategy": {}} - smab_state = smab.get_state() class_name, smab_state = smab.get_state() assert class_name == "SmabBernoulli" assert smab_state == expected_state +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "n_successes": st.integers(min_value=1, max_value=100), + "n_failures": st.integers(min_value=1, max_value=100), + }, + ), + min_size=2, + ), + "strategy": st.fixed_dictionaries({}), + } + ) +) +def test_smab_from_state(state): + smab = SmabBernoulli.from_state(state) + assert isinstance(smab, SmabBernoulli) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(smab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + + # Ensure get_state and from_state compatibility + new_smab = globals()[smab.get_state()[0]].from_state(state=smab.get_state()[1]) + assert new_smab == smab + + ######################################################################################################################## @@ -298,6 +327,45 @@ def test_smab_bai_get_state(a, b, c, d, exploit_p: Float01): assert is_serializable(smab_state), "Internal state is not serializable" +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "n_successes": st.integers(min_value=1, max_value=100), + "n_failures": st.integers(min_value=1, max_value=100), + }, + ), + min_size=2, + ), + "strategy": st.one_of( + st.just({}), + st.just({"exploit_p": None}), + st.builds(lambda x: {"exploit_p": x}, st.floats(min_value=0, max_value=1)), + ), + } + ) +) +def test_smab_bai_from_state(state): + smab = SmabBernoulliBAI.from_state(state) + assert isinstance(smab, SmabBernoulliBAI) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(smab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + expected_exploit_p = ( + state["strategy"].get("exploit_p", 0.5) if state["strategy"].get("exploit_p") is not None else 0.5 + ) # Covers both not existing and existing + None + actual_exploit_p = smab.strategy.exploit_p + assert expected_exploit_p == actual_exploit_p + + # Ensure get_state and from_state compatibility + new_smab = globals()[smab.get_state()[0]].from_state(state=smab.get_state()[1]) + assert new_smab == smab + + ######################################################################################################################## @@ -384,6 +452,46 @@ def test_smab_cc_get_state(a, b, c, d, cost1: NonNegativeFloat, cost2: NonNegati assert is_serializable(smab_state), "Internal state is not serializable" +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "n_successes": st.integers(min_value=1, max_value=100), + "n_failures": st.integers(min_value=1, max_value=100), + "cost": st.floats(min_value=0), + }, + ), + min_size=2, + ), + "strategy": st.one_of( + st.just({}), + st.just({"subsidy_factor": None}), + st.builds(lambda x: {"subsidy_factor": x}, st.floats(min_value=0, max_value=1)), + ), + } + ) +) +def test_smab_cc_from_state(state): + smab = SmabBernoulliCC.from_state(state) + assert isinstance(smab, SmabBernoulliCC) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(smab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + expected_subsidy_factor = ( + state["strategy"].get("subsidy_factor", 0.5) if state["strategy"].get("subsidy_factor") is not None else 0.5 + ) # Covers both not existing and existing + None + actual_subsidy_factor = smab.strategy.subsidy_factor + assert expected_subsidy_factor == actual_subsidy_factor + + # Ensure get_state and from_state compatibility + new_smab = globals()[smab.get_state()[0]].from_state(state=smab.get_state()[1]) + assert new_smab == smab + + ######################################################################################################################## @@ -501,6 +609,44 @@ def test_smab_mo_get_state(a_list): assert is_serializable(smab_state), "Internal state is not serializable" +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "counters": st.lists( + st.fixed_dictionaries( + { + "n_successes": st.integers(min_value=1, max_value=100), + "n_failures": st.integers(min_value=1, max_value=100), + }, + ), + min_size=3, + max_size=3, + ) + } + ), + min_size=2, + ), + "strategy": st.fixed_dictionaries({}), + } + ) +) +def test_smab_mo_from_state(state): + smab = SmabBernoulliMO.from_state(state) + assert isinstance(smab, SmabBernoulliMO) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(smab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + + # Ensure get_state and from_state compatibility + new_smab = globals()[smab.get_state()[0]].from_state(state=smab.get_state()[1]) + assert new_smab == smab + + ######################################################################################################################## @@ -617,3 +763,42 @@ def test_smab_mocc_get_state(a_list): assert smab_state == expected_state assert is_serializable(smab_state), "Internal state is not serializable" + + +@given( + state=st.fixed_dictionaries( + { + "actions": st.dictionaries( + keys=st.text(min_size=1, max_size=10), + values=st.fixed_dictionaries( + { + "counters": st.lists( + st.fixed_dictionaries( + { + "n_successes": st.integers(min_value=1, max_value=100), + "n_failures": st.integers(min_value=1, max_value=100), + }, + ), + min_size=3, + max_size=3, + ), + "cost": st.floats(min_value=0), + } + ), + min_size=2, + ), + "strategy": st.fixed_dictionaries({}), + } + ) +) +def test_smab_mo_cc_from_state(state): + smab = SmabBernoulliMOCC.from_state(state) + assert isinstance(smab, SmabBernoulliMOCC) + + expected_actions = state["actions"] + actual_actions = json.loads(json.dumps(smab.actions, default=dict)) # Normalize the dict + assert expected_actions == actual_actions + + # Ensure get_state and from_state compatibility + new_smab = globals()[smab.get_state()[0]].from_state(state=smab.get_state()[1]) + assert new_smab == smab