Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BRAIN-15566 - Access model state with get_state() function #25

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,17 @@ def predict(self, forbidden_actions: Optional[Set[ActionId]] = None):
probs: List[Dict[ActionId, float]] of shape (n_samples,)
The probabilities of getting a positive reward for each action.
"""

def get_state(self) -> (str, dict):
"""
Access the complete model internal state, enough to create an exact copy of the same model from it.
Returns
-------
model_class_name: str
The name of the class of the model.
model_state: dict
The internal state of the model (actions, scores, etc.).
"""
model_name = self.__class__.__name__
state: dict = self.dict()
adarmiento marked this conversation as resolved.
Show resolved Hide resolved
return model_name, state
5 changes: 5 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ def update(
def predict():
pass

def get_state(self) -> (str, dict):
model_name = self.__class__.__name__
state: dict = {"actions": self.actions}
return model_name, state


def test_base_mab_raise_on_less_than_2_actions():
with pytest.raises(ValidationError):
Expand Down
102 changes: 101 additions & 1 deletion tests/test_cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import json

import numpy as np
import pandas as pd
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
from pydantic import ValidationError
from pydantic import NonNegativeFloat, ValidationError

from pybandits.base import Float01
from pybandits.cmab import (
CmabBernoulli,
CmabBernoulliBAI,
Expand All @@ -37,6 +40,7 @@
)
from pybandits.model import (
BayesianLogisticRegression,
BayesianLogisticRegressionCC,
StudentT,
create_bayesian_logistic_regression_cc_cold_start,
create_bayesian_logistic_regression_cold_start,
Expand All @@ -46,6 +50,7 @@
ClassicBandit,
CostControlBandit,
)
from tests.test_utils import is_serializable

########################################################################################################################

Expand Down Expand Up @@ -312,6 +317,29 @@ def run_predict(mab):
run_predict(mab=mab)


@settings(deadline=500)
@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=2, max_value=100))
def test_cmab_get_state(mu, sigma, n_features):
actions: dict = {
"a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features),
}

cmab = CmabBernoulli(actions=actions)
expected_state = json.loads(
json.dumps(
{"actions": actions, "strategy": {}, "predict_with_proba": False, "predict_actions_randomly": False},
default=dict,
)
)

class_name, cmab_state = cmab.get_state()
assert class_name == "CmabBernoulli"
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -451,6 +479,39 @@ def test_cmab_bai_update(n_samples=100, n_features=3):
assert not mab.predict_actions_randomly


@settings(deadline=500)
@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=2, max_value=100),
st.floats(min_value=0, max_value=1),
)
def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01):
actions: dict = {
"a1": BayesianLogisticRegression(alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()]),
"a2": create_bayesian_logistic_regression_cold_start(n_betas=n_features),
}

cmab = CmabBernoulliBAI(actions=actions, exploit_p=exploit_p)
expected_state = json.loads(
json.dumps(
{
"actions": actions,
"strategy": {"exploit_p": exploit_p},
"predict_with_proba": False,
"predict_actions_randomly": False,
},
default=dict,
)
)

class_name, cmab_state = cmab.get_state()
assert class_name == "CmabBernoulliBAI"
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -597,3 +658,42 @@ def test_cmab_cc_update(n_samples=100, n_features=3):
]
)
assert not mab.predict_actions_randomly


@settings(deadline=500)
@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=2, max_value=100),
st.floats(min_value=0),
st.floats(min_value=0),
st.floats(min_value=0, max_value=1),
)
def test_cmab_cc_get_state(
mu, sigma, n_features, cost_1: NonNegativeFloat, cost_2: NonNegativeFloat, subsidy_factor: Float01
):
actions: dict = {
"a1": BayesianLogisticRegressionCC(
alpha=StudentT(mu=mu, sigma=sigma), betas=n_features * [StudentT()], cost=cost_1
),
"a2": create_bayesian_logistic_regression_cc_cold_start(n_betas=n_features, cost=cost_2),
}

cmab = CmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor)
expected_state = json.loads(
json.dumps(
{
"actions": actions,
"strategy": {"subsidy_factor": subsidy_factor},
"predict_with_proba": True,
"predict_actions_randomly": False,
},
default=dict,
)
)

class_name, cmab_state = cmab.get_state()
assert class_name == "CmabBernoulliCC"
assert cmab_state == expected_state

assert is_serializable(cmab_state), "Internal state is not serializable"
123 changes: 121 additions & 2 deletions tests/test_smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import pytest
from hypothesis import given
from hypothesis import strategies as st
from pydantic import ValidationError
from pydantic import NonNegativeFloat, ValidationError

from pybandits.base import BinaryReward
from pybandits.base import BinaryReward, Float01
from pybandits.model import Beta, BetaCC, BetaMO, BetaMOCC
from pybandits.smab import (
SmabBernoulli,
Expand All @@ -48,6 +48,7 @@
MultiObjectiveBandit,
MultiObjectiveCostControlBandit,
)
from tests.test_utils import is_serializable

########################################################################################################################

Expand Down Expand Up @@ -200,6 +201,19 @@ def test_smab_accepts_only_valid_actions(s):
SmabBernoulli(actions={s: Beta(), s + "_": Beta()})


@given(st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=1), st.integers(min_value=1))
def test_smab_get_state(a, b, c, d):
actions = {"action1": Beta(n_successes=a, n_failures=b), "action2": Beta(n_successes=c, n_failures=d)}
smab = SmabBernoulli(actions=actions)

expected_state = {"actions": actions, "strategy": {}}
smab_state = smab.get_state()

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulli"
assert smab_state == expected_state


########################################################################################################################


Expand Down Expand Up @@ -265,6 +279,25 @@ def test_smabbai_with_betacc():
)


@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.floats(min_value=0, max_value=1),
)
def test_smab_bai_get_state(a, b, c, d, exploit_p: Float01):
actions = {"action1": Beta(n_successes=a, n_failures=b), "action2": Beta(n_successes=c, n_failures=d)}
smab = SmabBernoulliBAI(actions=actions, exploit_p=exploit_p)
expected_state = {"actions": actions, "strategy": {"exploit_p": exploit_p}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliBAI"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -327,6 +360,30 @@ def test_smabcc_update():
s.update(actions=["a1", "a1"], rewards=[1, 0])


@given(
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.integers(min_value=1),
st.floats(min_value=0),
st.floats(min_value=0),
st.floats(min_value=0, max_value=1),
)
def test_smab_cc_get_state(a, b, c, d, cost1: NonNegativeFloat, cost2: NonNegativeFloat, subsidy_factor: Float01):
actions = {
"action1": BetaCC(n_successes=a, n_failures=b, cost=cost1),
"action2": BetaCC(n_successes=c, n_failures=d, cost=cost2),
}
smab = SmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor)
expected_state = {"actions": actions, "strategy": {"subsidy_factor": subsidy_factor}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliCC"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -414,6 +471,36 @@ def test_smab_mo_update():
mab.update(actions=["a1", "a1"], rewards=[[1, 0, 1], [1, 1, 0]])


@given(st.lists(st.integers(min_value=1), min_size=6, max_size=6))
def test_smab_mo_get_state(a_list):
a, b, c, d, e, f = a_list

actions = {
"a1": BetaMO(
counters=[
Beta(n_successes=a, n_failures=b),
Beta(n_successes=c, n_failures=d),
Beta(n_successes=e, n_failures=f),
]
),
"a2": BetaMO(
counters=[
Beta(n_successes=d, n_failures=a),
Beta(n_successes=e, n_failures=b),
Beta(n_successes=f, n_failures=c),
]
),
}
smab = SmabBernoulliMO(actions=actions)
expected_state = {"actions": actions, "strategy": {}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliMO"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"


########################################################################################################################


Expand Down Expand Up @@ -498,3 +585,35 @@ def test_smab_mo_cc_predict():
forbidden = ["a1", "a3"]
with pytest.raises(ValueError):
s.predict(n_samples=n_samples, forbidden_actions=forbidden)


@given(st.lists(st.integers(min_value=1), min_size=8, max_size=8))
def test_smab_mocc_get_state(a_list):
a, b, c, d, e, f, g, h = a_list

actions = {
"a1": BetaMOCC(
counters=[
Beta(n_successes=a, n_failures=b),
Beta(n_successes=c, n_failures=d),
Beta(n_successes=e, n_failures=f),
],
cost=g,
),
"a2": BetaMOCC(
counters=[
Beta(n_successes=d, n_failures=a),
Beta(n_successes=e, n_failures=b),
Beta(n_successes=f, n_failures=c),
],
cost=h,
),
}
smab = SmabBernoulliMOCC(actions=actions)
expected_state = {"actions": actions, "strategy": {}}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliMOCC"
assert smab_state == expected_state

assert is_serializable(smab_state), "Internal state is not serializable"
9 changes: 9 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import json


def is_serializable(something) -> bool:
try:
json.dumps(something)
return True
except Exception:
return False
Loading