Skip to content

Commit

Permalink
BRAIN-15566 - Access model state with get_state() function
Browse files Browse the repository at this point in the history
 ### Changes

* Implemented get_state() function in BaseMab abstract class
  • Loading branch information
adarmiento committed Oct 24, 2023
1 parent 5c48947 commit 61144ed
Show file tree
Hide file tree
Showing 5 changed files with 242 additions and 10 deletions.
30 changes: 22 additions & 8 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ def _get_valid_actions(self, forbidden_actions: Optional[Set[ActionId]]) -> Set[
return valid_actions

def _check_update_params(
self,
actions: List[ActionId],
rewards: List[Union[NonNegativeInt, List[NonNegativeInt]]],
self,
actions: List[ActionId],
rewards: List[Union[NonNegativeInt, List[NonNegativeInt]]],
):
"""
Verify that the given list of action IDs is a subset of the currently defined actions.
Expand All @@ -152,11 +152,11 @@ def _check_update_params(
@abstractmethod
@validate_arguments
def update(
self,
actions: List[ActionId],
rewards: List[Union[BinaryReward, List[BinaryReward]]],
*args,
**kwargs,
self,
actions: List[ActionId],
rewards: List[Union[BinaryReward, List[BinaryReward]]],
*args,
**kwargs,
):
"""
Update the stochastic multi-armed bandit model.
Expand Down Expand Up @@ -187,3 +187,17 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None):
probs: List[Dict[ActionId, float]] of shape (n_samples,)
The probabilities of getting a positive reward for each action.
"""

def get_state(self) -> (str, dict):
"""
Access the complete model internal state, enough to create an exact copy of the same model from it.
Returns
-------
model_class_name: str
The name of the class of the model.
model_state: dict
The internal state of the model (actions, scores, etc.).
"""
model_name = self.__class__.__name__
state: dict = self.dict()
return model_name, state
5 changes: 5 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def update(
def predict():
pass

def get_state(self) -> (str, dict):
model_name = self.__class__.__name__
state: dict = {"actions": self.actions}
return model_name, state


def test_base_mab_raise_on_less_than_2_actions():
with pytest.raises(ValidationError):
Expand Down
88 changes: 88 additions & 0 deletions tests/test_cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import json
from typing import Optional
import numpy as np
import pandas as pd
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from pydantic import ValidationError

from pybandits.base import Float01
from pybandits.cmab import (
CmabBernoulli,
CmabBernoulliBAI,
Expand All @@ -40,12 +43,14 @@
StudentT,
create_bayesian_logistic_regression_cc_cold_start,
create_bayesian_logistic_regression_cold_start,
BayesianLogisticRegressionCC
)
from pybandits.strategy import (
BestActionIdentification,
ClassicBandit,
CostControlBandit,
)
from tests.test_utils import is_serializable

########################################################################################################################

Expand Down Expand Up @@ -311,6 +316,33 @@ def run_predict(mab):
assert mab != create_cmab_bernoulli_cold_start(action_ids=["a1", "a2", "a3", "a4", "a5"], n_features=n_features)
run_predict(mab=mab)

@pytest.mark.parametrize("action_dict, action_ids", [
({"a1": BayesianLogisticRegression(alpha=StudentT(mu=1, sigma=2), betas=[StudentT(), StudentT(), StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=3)}, None),
(None, {"a0", "a1", "a2"})])
def test_cmab_get_state(action_dict: Optional[dict], action_ids: Optional[set]):
if action_dict:
cmab = CmabBernoulli(actions=action_dict)
expected_state = json.loads(json.dumps({
"actions": action_dict,
"strategy": {},
'predict_with_proba': False,
'predict_actions_randomly': False}, default=dict))
else:
cmab = create_cmab_bernoulli_cold_start(action_ids=action_ids, n_features=3)
expected_state = json.loads(json.dumps({
"actions": {action_id: create_bayesian_logistic_regression_cold_start(n_betas=3)
for action_id in action_ids},
"strategy": {},
'predict_with_proba': False,
'predict_actions_randomly': True}, default=dict))

cmab_state = cmab.get_state()
assert cmab_state[0] == "CmabBernoulli"
assert cmab_state[1] == expected_state

assert is_serializable(cmab_state[1]), "Internal state is not serializable"


########################################################################################################################

Expand Down Expand Up @@ -450,6 +482,33 @@ def test_cmab_bai_update(n_samples=100, n_features=3):
)
assert not mab.predict_actions_randomly

@pytest.mark.parametrize("action_dict, action_ids, exploit_p", [
({"a1": BayesianLogisticRegression(alpha=StudentT(mu=1, sigma=2), betas=[StudentT(), StudentT(), StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=3)}, None, 0.5),
(None, {"a0", "a1", "a2"}, 0.8)])
def test_cmab_bai_get_state(action_dict: Optional[dict], action_ids: Optional[set], exploit_p: Float01):
if action_dict:
cmab = CmabBernoulliBAI(actions=action_dict, exploit_p=exploit_p)
expected_state = json.loads(json.dumps({
"actions": action_dict,
"strategy": {"exploit_p": exploit_p},
'predict_with_proba': False,
'predict_actions_randomly': False}, default=dict))
else:
cmab = create_cmab_bernoulli_bai_cold_start(action_ids=action_ids, n_features=3, exploit_p=exploit_p)
expected_state = json.loads(json.dumps({
"actions": {action_id: create_bayesian_logistic_regression_cold_start(n_betas=3)
for action_id in action_ids},
"strategy": {"exploit_p": exploit_p},
'predict_with_proba': False,
'predict_actions_randomly': True}, default=dict))

cmab_state = cmab.get_state()
assert cmab_state[0] == "CmabBernoulliBAI"
assert cmab_state[1] == expected_state

assert is_serializable(cmab_state[1]), "Internal state is not serializable"


########################################################################################################################

Expand Down Expand Up @@ -597,3 +656,32 @@ def test_cmab_cc_update(n_samples=100, n_features=3):
]
)
assert not mab.predict_actions_randomly


@pytest.mark.parametrize("action_dict, action_ids_cost, subsidy_factor", [
({"a1": BayesianLogisticRegressionCC(alpha=StudentT(mu=1, sigma=2), betas=[StudentT(), StudentT()], cost=0.1),
"a2": create_bayesian_logistic_regression_cc_cold_start(n_betas=2, cost=0.2)}, None, 0.3),
(None, {"a0": 0.1, "a1": 0.2, "a2": 0.3}, 0.5)])
def test_cmab_cc_get_state(action_dict: Optional[dict], action_ids_cost: Optional[dict], subsidy_factor: Float01):
if action_dict:
cmab = CmabBernoulliCC(actions=action_dict, subsidy_factor=subsidy_factor)
expected_state = json.loads(json.dumps({
"actions": action_dict,
"strategy": {"subsidy_factor": subsidy_factor},
'predict_with_proba': True,
'predict_actions_randomly': False}, default=dict))
else:
cmab = create_cmab_bernoulli_cc_cold_start(action_ids_cost=action_ids_cost, n_features=2,
subsidy_factor=subsidy_factor)
expected_state = json.loads(json.dumps({
"actions": {action_id: create_bayesian_logistic_regression_cc_cold_start(n_betas=2, cost=action_cost)
for action_id, action_cost in action_ids_cost.items()},
"strategy": {"subsidy_factor": subsidy_factor},
'predict_with_proba': True,
'predict_actions_randomly': True}, default=dict))

cmab_state = cmab.get_state()
assert cmab_state[0] == "CmabBernoulliCC"
assert cmab_state[1] == expected_state

assert is_serializable(cmab_state[1]), "Internal state is not serializable"
120 changes: 118 additions & 2 deletions tests/test_smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
# SOFTWARE.

from copy import deepcopy
from typing import List
from typing import List, Optional, Tuple

import pytest
from hypothesis import given
from hypothesis import strategies as st
from pydantic import ValidationError

from pybandits.base import BinaryReward
from pybandits.base import BinaryReward, Float01
from pybandits.model import Beta, BetaCC, BetaMO, BetaMOCC
from pybandits.smab import (
SmabBernoulli,
Expand All @@ -48,6 +48,7 @@
MultiObjectiveBandit,
MultiObjectiveCostControlBandit,
)
from tests.test_utils import is_serializable

########################################################################################################################

Expand Down Expand Up @@ -200,6 +201,26 @@ def test_smab_accepts_only_valid_actions(s):
SmabBernoulli(actions={s: Beta(), s + "_": Beta()})


@pytest.mark.parametrize("action_dict, action_ids", [
({"a0": Beta(), "a1": Beta(), "a2": Beta()}, None),
({"a0": Beta(), "a1": Beta(n_successes=5, n_failures=5), "a2": Beta(n_successes=10, n_failures=1),
"a3": Beta(n_successes=10, n_failures=5), "a4": Beta(n_successes=100, n_failures=4), "a5": Beta()}, None),
(None, {"a0", "a1", "a2"})])
def test_smab_get_state(action_dict: Optional[dict], action_ids: Optional[set]):
if action_dict:
smab = SmabBernoulli(actions=action_dict)
expected_state = {"actions": action_dict, "strategy": {}}
else:
smab = create_smab_bernoulli_cold_start(action_ids=action_ids)
expected_state = {"actions": {action_id: Beta() for action_id in action_ids}, "strategy": {}}

smab_state = smab.get_state()
assert smab_state[0] == "SmabBernoulli"
assert smab_state[1] == expected_state

assert is_serializable(smab_state[1]), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -265,6 +286,27 @@ def test_smabbai_with_betacc():
)


@pytest.mark.parametrize("action_dict, action_ids, exploit_p", [
({"a0": Beta(), "a1": Beta(), "a2": Beta()}, None, 0.3),
({"a0": Beta(), "a1": Beta(n_successes=5, n_failures=5), "a2": Beta(n_successes=10, n_failures=1),
"a3": Beta(n_successes=10, n_failures=5), "a4": Beta(n_successes=100, n_failures=4), "a5": Beta()}, None, 0.8),
(None, {"a0", "a1", "a2"}, 0.5)])
def test_smab_bai_get_state(action_dict: Optional[dict], action_ids: Optional[set], exploit_p: Float01):
if action_dict:
smab = SmabBernoulliBAI(actions=action_dict, exploit_p=exploit_p)
expected_state = {"actions": action_dict, "strategy": {"exploit_p": exploit_p}}
else:
smab = create_smab_bernoulli_bai_cold_start(action_ids=action_ids, exploit_p=exploit_p)
expected_state = {"actions": {action_id: Beta() for action_id in action_ids},
"strategy": {"exploit_p": exploit_p}}

smab_state = smab.get_state()
assert smab_state[0] == "SmabBernoulliBAI"
assert smab_state[1] == expected_state

assert is_serializable(smab_state[1]), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -327,6 +369,28 @@ def test_smabcc_update():
s.update(actions=["a1", "a1"], rewards=[1, 0])


@pytest.mark.parametrize("action_dict, action_ids_cost, subsidy_factor", [
({"a0": BetaCC(cost=0.1), "a1": BetaCC(cost=0.2), "a2": BetaCC(cost=0.3)}, None, 0.3),
({"a0": BetaCC(cost=0.1), "a1": BetaCC(n_successes=5, n_failures=5, cost=0.2), "a2":
BetaCC(n_successes=10, n_failures=1, cost=0.3), "a3": BetaCC(n_successes=10, n_failures=5, cost=0.4),
"a4": BetaCC(n_successes=100, n_failures=4, cost=0.5), "a5": BetaCC(cost=0.6)}, None, 0.8),
(None, {"a0": 0.1, "a1": 0.2, "a2": 0.3}, 0.5)])
def test_smab_cc_get_state(action_dict: Optional[dict], action_ids_cost: Optional[dict], subsidy_factor: Float01):
if action_dict:
smab = SmabBernoulliCC(actions=action_dict, subsidy_factor=subsidy_factor)
expected_state = {"actions": action_dict, "strategy": {"subsidy_factor": subsidy_factor}}
else:
smab = create_smab_bernoulli_cc_cold_start(action_ids_cost=action_ids_cost, subsidy_factor=subsidy_factor)
expected_state = {"actions": {k: BetaCC(cost=v) for k, v in action_ids_cost.items()},
"strategy": {"subsidy_factor": subsidy_factor}}

smab_state = smab.get_state()
assert smab_state[0] == "SmabBernoulliCC"
assert smab_state[1] == expected_state

assert is_serializable(smab_state[1]), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -414,6 +478,31 @@ def test_smab_mo_update():
mab.update(actions=["a1", "a1"], rewards=[[1, 0, 1], [1, 1, 0]])


@pytest.mark.parametrize("action_dict, action_ids_with_obj", [
({"a0": BetaMO(counters=[Beta(), Beta()]), "a1": BetaMO(counters=[Beta(), Beta()]),
"a2": BetaMO(counters=[Beta(), Beta()])}, None),
({"a0": BetaMO(counters=[Beta(), Beta()]),
"a1": BetaMO(counters=[Beta(n_successes=2, n_failures=3), Beta(n_successes=4, n_failures=5)]),
"a2": BetaMO(counters=[Beta(n_successes=6, n_failures=7), Beta(n_successes=8, n_failures=9)])}, None),
(None, ({"a0", "a1", "a2"}, 2))
])
def test_smab_mo_get_state(action_dict: Optional[dict], action_ids_with_obj: Optional[Tuple[set, int]]):
if action_dict:
smab = SmabBernoulliMO(actions=action_dict)
expected_state = {"actions": action_dict, "strategy": {}}
else:
smab = create_smab_bernoulli_mo_cold_start(action_ids=action_ids_with_obj[0],
n_objectives=action_ids_with_obj[1])
expected_state = {"actions": {action_id: BetaMO(counters=[Beta()] * action_ids_with_obj[1]) for action_id in
action_ids_with_obj[0]}, "strategy": {}}

smab_state = smab.get_state()
assert smab_state[0] == "SmabBernoulliMO"
assert smab_state[1] == expected_state

assert is_serializable(smab_state[1]), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -498,3 +587,30 @@ def test_smab_mo_cc_predict():
forbidden = ["a1", "a3"]
with pytest.raises(ValueError):
s.predict(n_samples=n_samples, forbidden_actions=forbidden)


@pytest.mark.parametrize("action_dict, action_ids_cost_with_obj", [
({"a0": BetaMOCC(counters=[Beta(), Beta()], cost=0.1), "a1": BetaMOCC(counters=[Beta(), Beta()], cost=0.2),
"a2": BetaMOCC(counters=[Beta(), Beta()], cost=0.3)}, None),
({"a0": BetaMOCC(counters=[Beta(), Beta()], cost=0.1),
"a1": BetaMOCC(counters=[Beta(n_successes=2, n_failures=3), Beta(n_successes=4, n_failures=5)], cost=0.2),
"a2": BetaMOCC(counters=[Beta(n_successes=6, n_failures=7), Beta(n_successes=8, n_failures=9)], cost=0.3)}, None),
(None, ({"a0": 0.1, "a1": 0.2, "a2": 0.3}, 2))
])
def test_smab_mo_cc_get_state(action_dict: Optional[dict],
action_ids_cost_with_obj: Optional[Tuple[dict, int]]):
if action_dict:
smab = SmabBernoulliMOCC(actions=action_dict)
expected_state = {"actions": action_dict, "strategy": {}}
else:
smab = create_smab_bernoulli_mo_cc_cold_start(action_ids_cost=action_ids_cost_with_obj[0],
n_objectives=action_ids_cost_with_obj[1])
expected_state = {"actions": {action_id: BetaMOCC(counters=[Beta()] * action_ids_cost_with_obj[1],
cost=action_ids_cost_with_obj[0][action_id]) for action_id in
action_ids_cost_with_obj[0]}, "strategy": {}}

smab_state = smab.get_state()
assert smab_state[0] == "SmabBernoulliMOCC"
assert smab_state[1] == expected_state

assert is_serializable(smab_state[1]), "Internal state is not serializable"
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import json


def is_serializable(something) -> bool:
try:
json.dumps(something)
return True
except Exception:
return False

0 comments on commit 61144ed

Please sign in to comment.