Skip to content

Commit

Permalink
UTs Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Shahar-Bar committed Jul 28, 2024
1 parent 3911584 commit e694209
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
9 changes: 6 additions & 3 deletions tests/test_cmab.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,10 @@ def test_cmab_get_state(mu, sigma, n_features):
json.dumps(
{
"actions": actions,
"strategy": {},
"strategy": {
"epsilon": None,
"default_action": None,
},
"predict_with_proba": False,
"predict_actions_randomly": False,
"epsilon": None,
Expand Down Expand Up @@ -551,7 +554,7 @@ def test_cmab_bai_get_state(mu, sigma, n_features, exploit_p: Float01):
json.dumps(
{
"actions": actions,
"strategy": {"exploit_p": exploit_p},
"strategy": {"epsilon": None, "default_action": None, "exploit_p": exploit_p},
"predict_with_proba": False,
"predict_actions_randomly": False,
"epsilon": None,
Expand Down Expand Up @@ -796,7 +799,7 @@ def test_cmab_cc_get_state(
json.dumps(
{
"actions": actions,
"strategy": {"subsidy_factor": subsidy_factor},
"strategy": {"epsilon": None, "default_action": None, "subsidy_factor": subsidy_factor},
"predict_with_proba": True,
"predict_actions_randomly": False,
"epsilon": None,
Expand Down
39 changes: 34 additions & 5 deletions tests/test_smab.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ def test_smab_get_state(a, b, c, d):

expected_state = {
"actions": actions,
"strategy": {},
"strategy": {
"epsilon": None,
"default_action": None,
},
"epsilon": None,
"default_action": None,
}
Expand Down Expand Up @@ -323,7 +326,14 @@ def test_smabbai_with_betacc():
def test_smab_bai_get_state(a, b, c, d, exploit_p: Float01):
actions = {"action1": Beta(n_successes=a, n_failures=b), "action2": Beta(n_successes=c, n_failures=d)}
smab = SmabBernoulliBAI(actions=actions, exploit_p=exploit_p)
expected_state = {"actions": actions, "strategy": {"exploit_p": exploit_p}}
expected_state = {
"actions": actions,
"strategy": {
"exploit_p": exploit_p,
"epsilon": None,
"default_action": None,
},
}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliBAI"
Expand Down Expand Up @@ -448,7 +458,14 @@ def test_smab_cc_get_state(a, b, c, d, cost1: NonNegativeFloat, cost2: NonNegati
"action2": BetaCC(n_successes=c, n_failures=d, cost=cost2),
}
smab = SmabBernoulliCC(actions=actions, subsidy_factor=subsidy_factor)
expected_state = {"actions": actions, "strategy": {"subsidy_factor": subsidy_factor}}
expected_state = {
"actions": actions,
"strategy": {
"subsidy_factor": subsidy_factor,
"epsilon": None,
"default_action": None,
},
}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliCC"
Expand Down Expand Up @@ -605,7 +622,13 @@ def test_smab_mo_get_state(a_list):
),
}
smab = SmabBernoulliMO(actions=actions)
expected_state = {"actions": actions, "strategy": {}}
expected_state = {
"actions": actions,
"strategy": {
"epsilon": None,
"default_action": None,
},
}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliMO"
Expand Down Expand Up @@ -761,7 +784,13 @@ def test_smab_mocc_get_state(a_list):
),
}
smab = SmabBernoulliMOCC(actions=actions)
expected_state = {"actions": actions, "strategy": {}}
expected_state = {
"actions": actions,
"strategy": {
"epsilon": None,
"default_action": None,
},
}

class_name, smab_state = smab.get_state()
assert class_name == "SmabBernoulliMOCC"
Expand Down

0 comments on commit e694209

Please sign in to comment.