Skip to content

Commit

Permalink
Refactor MAB and Strategy Classes with Cold Start Methods and Enhance…
Browse files Browse the repository at this point in the history
…d Validation

 Change log:
 1. Moved Strategy, Model, and MAB to strategy.py, model.py, and to the new mab.py. base.py is now only for definitions and abstract PyBanditsBaseModel. The abstract MAB now allows for all childs to either accept strategy instance as parameter, or to get the strategy parameters and instantiate correspondingly.
 2. The from_state functionality is now directly inherited by all MABs from BaseMab.
 3. Replaced all cold_start methods in cmab.py and smab.py with cold_start stemming from BaseMab. Correspondingly, updated test cases to use the new cold_start_instantiate methods.
 4. Introduced numerize_field and get_expected_value_from_state methods in the Strategy class to handle default values and state extraction. Added field_validator for exploit_p in BestActionIdentification and subsidy_factor in CostControlBandit to ensure proper default handling and validation.
 5. Merged common functionality into a new CostControlStrategy abstract class, which is now inherited by CostControlBandit and MultiObjectiveCostControlBandit. Simplified the select_action methods by using helper methods like _evaluate_and_select and _reduce.
 6. Plugged get_pareto_front into a new MultiObjectiveStrategy abstract class, which is now inherited by MultiObjectiveBandit and MultiObjectiveCostControlBandit.
 7. In model.py. Removed the redundant BaseBetaMO and BaseBayesianLogisticRegression. Added cold_start_instantiate method to BetaMO and BayesianLogisticRegression models.
 8. Added extract_argument_names_from_function under utils.py to allow extract function parameter names by handle.
 9. Changed test_base.py into test_mab.py.
 10. Updated deprecated linter settings in pyproject.toml.
 11. Added test_smab_mo_cc_update test on test_smab.py.
 12. Changed version to 1.0.0 on pyproject.toml.
  • Loading branch information
Shahar-Bar committed Sep 4, 2024
1 parent fcd0896 commit 5bcb9b5
Show file tree
Hide file tree
Showing 16 changed files with 1,106 additions and 1,260 deletions.
70 changes: 35 additions & 35 deletions docs/tutorials/mab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"from rich import print\n",
"\n",
"from pybandits.model import Beta\n",
"from pybandits.smab import SmabBernoulli, create_smab_bernoulli_cold_start"
"from pybandits.smab import SmabBernoulli"
]
},
{
Expand Down Expand Up @@ -104,14 +104,14 @@
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n",
" \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m}\u001b[0m,\n",
" \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[1m)\u001b[0m\n"
"\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n",
" \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n",
" \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\n",
" \u001B[1m}\u001B[0m,\n",
" \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n",
"\u001B[1m)\u001B[0m\n"
]
},
"metadata": {},
Expand All @@ -137,7 +137,7 @@
"id": "564914fd-73cc-4854-8ec7-548970f794a6",
"metadata": {},
"source": [
"You can initialize the bandit via the utility function `create_smab_bernoulli_mo_cc_cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
"You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
]
},
{
Expand All @@ -148,7 +148,7 @@
"outputs": [],
"source": [
"# generate a smab bernoulli in cold start settings\n",
"mab = create_smab_bernoulli_cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
"mab = SmabBernoulli.cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
]
},
{
Expand All @@ -171,14 +171,14 @@
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n",
" \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m}\u001b[0m,\n",
" \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[1m)\u001b[0m\n"
"\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n",
" \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n",
" \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m\n",
" \u001B[1m}\u001B[0m,\n",
" \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n",
"\u001B[1m)\u001B[0m\n"
]
},
"metadata": {},
Expand Down Expand Up @@ -424,14 +424,14 @@
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n",
" \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m1\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m2\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m3\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m3\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m}\u001b[0m,\n",
" \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[1m)\u001b[0m\n"
"\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n",
" \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n",
" \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m1\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m2\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m1\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m3\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m3\u001B[0m\u001B[1m)\u001B[0m\n",
" \u001B[1m}\u001B[0m,\n",
" \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n",
"\u001B[1m)\u001B[0m\n"
]
},
"metadata": {},
Expand Down Expand Up @@ -496,14 +496,14 @@
"</pre>\n"
],
"text/plain": [
"\u001b[1;35mSmabBernoulli\u001b[0m\u001b[1m(\u001b[0m\n",
" \u001b[33mactions\u001b[0m=\u001b[1m{\u001b[0m\n",
" \u001b[32m'a1'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m337\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m369\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a3'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m4448\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m4315\u001b[0m\u001b[1m)\u001b[0m,\n",
" \u001b[32m'a2'\u001b[0m: \u001b[1;35mBeta\u001b[0m\u001b[1m(\u001b[0m\u001b[33mn_successes\u001b[0m=\u001b[1;36m246\u001b[0m, \u001b[33mn_failures\u001b[0m=\u001b[1;36m296\u001b[0m\u001b[1m)\u001b[0m\n",
" \u001b[1m}\u001b[0m,\n",
" \u001b[33mstrategy\u001b[0m=\u001b[1;35mClassicBandit\u001b[0m\u001b[1m(\u001b[0m\u001b[1m)\u001b[0m\n",
"\u001b[1m)\u001b[0m\n"
"\u001B[1;35mSmabBernoulli\u001B[0m\u001B[1m(\u001B[0m\n",
" \u001B[33mactions\u001B[0m=\u001B[1m{\u001B[0m\n",
" \u001B[32m'a1'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m337\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m369\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a3'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m4448\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m4315\u001B[0m\u001B[1m)\u001B[0m,\n",
" \u001B[32m'a2'\u001B[0m: \u001B[1;35mBeta\u001B[0m\u001B[1m(\u001B[0m\u001B[33mn_successes\u001B[0m=\u001B[1;36m246\u001B[0m, \u001B[33mn_failures\u001B[0m=\u001B[1;36m296\u001B[0m\u001B[1m)\u001B[0m\n",
" \u001B[1m}\u001B[0m,\n",
" \u001B[33mstrategy\u001B[0m=\u001B[1;35mClassicBandit\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m\n",
"\u001B[1m)\u001B[0m\n"
]
},
"metadata": {},
Expand Down
Loading

0 comments on commit 5bcb9b5

Please sign in to comment.