Skip to content

Commit

Permalink
BRAIN-15566 - Access model state with get_state() function
Browse files Browse the repository at this point in the history
 ### Changes

* Implemented get_state() function in BaseMab abstract class
  • Loading branch information
adarmiento committed Nov 1, 2023
1 parent 5c48947 commit a2d4d07
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 3 deletions.
14 changes: 14 additions & 0 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,17 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None):
probs: List[Dict[ActionId, float]] of shape (n_samples,)
The probabilities of getting a positive reward for each action.
"""

def get_state(self) -> (str, dict):
"""
Access the complete model internal state, enough to create an exact copy of the same model from it.
Returns
-------
model_class_name: str
The name of the class of the model.
model_state: dict
The internal state of the model (actions, scores, etc.).
"""
model_name = self.__class__.__name__
state: dict = self.dict()
return model_name, state
5 changes: 5 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def update(
def predict():
pass

def get_state(self) -> (str, dict):
model_name = self.__class__.__name__
state: dict = {"actions": self.actions}
return model_name, state


def test_base_mab_raise_on_less_than_2_actions():
with pytest.raises(ValidationError):
Expand Down
102 changes: 101 additions & 1 deletion tests/test_cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import json

import numpy as np
import pandas as pd
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from pydantic import ValidationError
from pydantic import NonNegativeFloat, ValidationError

from pybandits.base import Float01
from pybandits.cmab import (
CmabBernoulli,
CmabBernoulliBAI,
Expand All @@ -37,6 +40,7 @@
)
from pybandits.model import (
BayesianLogisticRegression,
BayesianLogisticRegressionCC,
StudentT,
create_bayesian_logistic_regression_cc_cold_start,
create_bayesian_logistic_regression_cold_start,
Expand All @@ -46,6 +50,7 @@
ClassicBandit,
CostControlBandit,
)
from tests.test_utils import is_serializable

########################################################################################################################

Expand Down Expand Up @@ -312,6 +317,29 @@ def run_predict(mab):
run_predict(mab=mab)


@settings(deadline=500)
@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=2, max_value=100))
def test_cmab_get_state(mu, sigma, n_features):
actions: dict = {
"a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features),
}

cmab = CmabBernoulli(actions=actions)
expected_state = json.loads(
json.dumps(
{"actions": actions, "strategy": {}, "predict_with_proba": False, "predict_actions_randomly": False},
default=dict,
)
)

class_name, cmab_state = cmab.get_state()
assert class_name == "CmabBernoulli"
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -451,6 +479,39 @@ def test_cmab_bai_update(n_samples=100, n_features=3):
assert not mab.predict_actions_randomly


@settings(deadline=500)
@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=2, max_value=100),
st.floats(min_value=0, max_value=1),
)
def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01):
actions: dict = {
"a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features),
}

cmab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p)
expected_state = json.loads(
json.dumps(
{
"actions": actions,
"strategy": {"exploit_p": exploit_p},
"predict_with_proba": False,
"predict_actions_randomly": False,
},
default=dict,
)
)

class_name, cmab_state = cmab.get_state()
assert class_name == "CmabBernoulliBAI"
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -597,3 +658,42 @@ def test_cmab_cc_update(n_samples=100, n_features=3):
]
)
assert not mab.predict_actions_randomly


@settings(deadline=500)
@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=2, max_value=100),
st.floats(min_value=0),
st.floats(min_value=0),
st.floats(min_value=0, max_value=1),
)
def test_cmab_cc_get_state(
mu, sigma, n_features, cost_1: NonNegativeFloat, cost_2: NonNegativeFloat, subsidy_factor: Float01
):
actions: dict = {
"a1": BayesianLogisticRegressionCC(
alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()], cost=cost_1
),
"a2": create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=cost_2),
}

cmab = CmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor)
expected_state = json.loads(
json.dumps(
{
"actions": actions,
"strategy": {"subsidy_factor": subsidy_factor},
"predict_with_proba": True,
"predict_actions_randomly": False,
},
default=dict,
)
)

class_name, cmab_state = cmab.get_state()
assert class_name == "CmabBernoulliCC"
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"
123 changes: 121 additions & 2 deletions tests/test_smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import pytest
from hypothesis import given
from hypothesis import strategies as st
from pydantic import ValidationError
from pydantic import NonNegativeFloat, ValidationError

from pybandits.base import BinaryReward
from pybandits.base import BinaryReward, Float01
from pybandits.model import Beta, BetaCC, BetaMO, BetaMOCC
from pybandits.smab import (
SmabBernoulli,
Expand All @@ -48,6 +48,7 @@
MultiObjectiveBandit,
MultiObjectiveCostControlBandit,
)
from tests.test_utils import is_serializable

########################################################################################################################

Expand Down Expand Up @@ -200,6 +201,19 @@ def test_smab_accepts_only_valid_actions(s):
SmabBernoulli(actions={s: Beta(), s + "_": Beta()})


@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=1))
def test_smab_get_state(a, b, c, d):
actions = {"action1": Beta(n_successes=a, n_failures=b), "action2": Beta(n_successes=c, n_failures=d)}
smab = SmabBernoulli(actions=actions)

expected_state = {"actions": actions, "strategy": {}}
smab_state = smab.get_state()

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulli"
assert smab_state == expected_state


########################################################################################################################


Expand Down Expand Up @@ -265,6 +279,25 @@ def test_smabbai_with_betacc():
)


@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.floats(min_value=0, max_value=1),
)
def test_smab_bai_get_state(a, b, c, d, exploit_p: Float01):
actions = {"action1": Beta(n_successes=a, n_failures=b), "action2": Beta(n_successes=c, n_failures=d)}
smab = SmabBernoulliBAI(actions=actions, exploit_p=exploit_p)
expected_state = {"actions": actions, "strategy": {"exploit_p": exploit_p}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliBAI"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -327,6 +360,30 @@ def test_smabcc_update():
s.update(actions=["a1", "a1"], rewards=[1, 0])


@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.floats(min_value=0),
st.floats(min_value=0),
st.floats(min_value=0, max_value=1),
)
def test_smab_cc_get_state(a, b, c, d, cost1: NonNegativeFloat, cost2: NonNegativeFloat, subsidy_factor: Float01):
actions = {
"action1": BetaCC(n_successes=a, n_failures=b, cost=cost1),
"action2": BetaCC(n_successes=c, n_failures=d, cost=cost2),
}
smab = SmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor)
expected_state = {"actions": actions, "strategy": {"subsidy_factor": subsidy_factor}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliCC"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -414,6 +471,36 @@ def test_smab_mo_update():
mab.update(actions=["a1", "a1"], rewards=[[1, 0, 1], [1, 1, 0]])


@given(st.lists(st.integers(min_value=1), min_size=6, max_size=6))
def test_smab_mo_get_state(a_list):
a, b, c, d, e, f = a_list

actions = {
"a1": BetaMO(
counters=[
Beta(n_successes=a, n_failures=b),
Beta(n_successes=c, n_failures=d),
Beta(n_successes=e, n_failures=f),
]
),
"a2": BetaMO(
counters=[
Beta(n_successes=d, n_failures=a),
Beta(n_successes=e, n_failures=b),
Beta(n_successes=f, n_failures=c),
]
),
}
smab = SmabBernoulliMO(actions=actions)
expected_state = {"actions": actions, "strategy": {}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliMO"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -498,3 +585,35 @@ def test_smab_mo_cc_predict():
forbidden = ["a1", "a3"]
with pytest.raises(ValueError):
s.predict(n_samples=n_samples, forbidden_actions=forbidden)


@given(st.lists(st.integers(min_value=1), min_size=8, max_size=8))
def test_smab_mocc_get_state(a_list):
a, b, c, d, e, f, g, h = a_list

actions = {
"a1": BetaMOCC(
counters=[
Beta(n_successes=a, n_failures=b),
Beta(n_successes=c, n_failures=d),
Beta(n_successes=e, n_failures=f),
],
cost=g,
),
"a2": BetaMOCC(
counters=[
Beta(n_successes=d, n_failures=a),
Beta(n_successes=e, n_failures=b),
Beta(n_successes=f, n_failures=c),
],
cost=h,
),
}
smab = SmabBernoulliMOCC(actions=actions)
expected_state = {"actions": actions, "strategy": {}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliMOCC"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import json


def is_serializable(something) -> bool:
try:
json.dumps(something)
return True
except Exception:
return False

0 comments on commit a2d4d07

Please sign in to comment.