Skip to content

Commit

Permalink
Refactor MAB and Strategy Classes with Cold Start Methods and Enhance…
Browse files Browse the repository at this point in the history
…d Validation (#63)

Refactor MAB and Strategy Classes with Cold Start Methods and Enhanced Validation (#35)

### Changes
 * Moved Strategy, Model, and MAB to strategy.py, model.py, and to the new mab.py. base.py is now only for definitions and abstract PyBanditsBaseModel. The abstract MAB now allows for all childs to either accept strategy instance as parameter, or to get the strategy parameters and instantiate correspondingly.
 * The from_state functionality is now directly inherited by all MABs from BaseMab.
 * Replaced all cold_start methods in cmab.py and smab.py with cold_start stemming from BaseMab. Correspondingly, updated test cases to use the new cold_start_instantiate methods.
 * Introduced numerize_field and get_expected_value_from_state methods in the Strategy class to handle default values and state extraction. Added field_validator for exploit_p in BestActionIdentification and subsidy_factor in CostControlBandit to ensure proper default handling and validation.
 * Merged common functionality into a new CostControlStrategy abstract class, which is now inherited by CostControlBandit and MultiObjectiveCostControlBandit. Simplified the select_action methods by using helper methods like _evaluate_and_select and _reduce.
 * Plugged get_pareto_front into a new MultiObjectiveStrategy abstract class, which is now inherited by MultiObjectiveBandit and MultiObjectiveCostControlBandit.
 * In model.py. Removed the redundant BaseBetaMO and BaseBayesianLogisticRegression. Added cold_start_instantiate method to BetaMO and BayesianLogisticRegression models.
 * Added extract_argument_names_from_function under utils.py to allow extract function parameter names by handle.
 * Changed test_base.py into test_mab.py.
 * Updated deprecated linter settings in pyproject.toml.
 * Added test_smab_mo_cc_update test on test_smab.py.
 * Changed version to 1.0.0 on pyproject.toml.
  • Loading branch information
shaharbar1 authored Sep 26, 2024
1 parent 9c15f78 commit e90f4bc
Show file tree
Hide file tree
Showing 17 changed files with 1,161 additions and 1,326 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/continuous_delivery.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ 3.8, 3.9 ]
python-version: [ "3.8", "3.9", "3.10" ]

steps:
- name: Checkout repository
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
python-version: [ "3.8", "3.9", "3.10" ]

steps:
- name: Checkout repository
Expand Down
8 changes: 3 additions & 5 deletions docs/tutorials/mab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"from rich import print\n",
"\n",
"from pybandits.model import Beta\n",
"from pybandits.smab import SmabBernoulli, create_smab_bernoulli_cold_start"
"from pybandits.smab import SmabBernoulli"
]
},
{
Expand Down Expand Up @@ -73,8 +73,6 @@
"metadata": {},
"outputs": [],
"source": [
"n_objectives = 2\n",
"\n",
"mab = SmabBernoulli(\n",
" actions={\n",
" \"a1\": Beta(n_successes=1, n_failures=1),\n",
Expand Down Expand Up @@ -137,7 +135,7 @@
"id": "564914fd-73cc-4854-8ec7-548970f794a6",
"metadata": {},
"source": [
"You can initialize the bandit via the utility function `create_smab_bernoulli_mo_cc_cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
"You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
]
},
{
Expand All @@ -148,7 +146,7 @@
"outputs": [],
"source": [
"# generate a smab bernoulli in cold start settings\n",
"mab = create_smab_bernoulli_cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
"mab = SmabBernoulli.cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
]
},
{
Expand Down
9 changes: 3 additions & 6 deletions docs/tutorials/smab_mo_cc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"from rich import print\n",
"\n",
"from pybandits.model import Beta, BetaMOCC\n",
"from pybandits.smab import SmabBernoulliMOCC, create_smab_bernoulli_mo_cc_cold_start"
"from pybandits.smab import SmabBernoulliMOCC"
]
},
{
Expand Down Expand Up @@ -72,8 +72,6 @@
"metadata": {},
"outputs": [],
"source": [
"n_objectives = 2\n",
"\n",
"mab = SmabBernoulliMOCC(\n",
" actions={\n",
" \"a1\": BetaMOCC(counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=30),\n",
Expand Down Expand Up @@ -153,7 +151,7 @@
"id": "564914fd-73cc-4854-8ec7-548970f794a6",
"metadata": {},
"source": [
"You can initialize the bandit via the utility function `create_smab_bernoulli_mo_cc_cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
"You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
]
},
{
Expand All @@ -165,10 +163,9 @@
"source": [
"# list of action IDs with their cost\n",
"action_ids_cost = {\"a1\": 30, \"a2\": 10, \"a3\": 20}\n",
"n_objectives = 2\n",
"\n",
"# generate a smab bernoulli in cold start settings\n",
"mab = create_smab_bernoulli_mo_cc_cold_start(action_ids_cost=action_ids_cost, n_objectives=n_objectives)"
"mab = SmabBernoulliMOCC.cold_start(action_ids_cost=action_ids_cost)"
]
},
{
Expand Down
238 changes: 12 additions & 226 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,241 +21,27 @@
# SOFTWARE.


from abc import ABC, abstractmethod
from typing import Any, Dict, List, NewType, Optional, Set, Tuple, Union
from typing import Dict, List, NewType, Tuple, Union

import numpy as np
from pydantic import (
BaseModel,
NonNegativeInt,
confloat,
conint,
constr,
field_validator,
model_validator,
validate_call,
)
from pydantic import BaseModel, confloat, conint, constr

ActionId = NewType("ActionId", constr(min_length=1))
Float01 = NewType("Float_0_1", confloat(ge=0, le=1))
Probability = NewType("Probability", Float01)
Predictions = NewType("Predictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]]])
SmabPredictions = NewType("SmabPredictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]]])
CmabPredictions = NewType(
"CmabPredictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]], List[Dict[ActionId, float]]]
)
Predictions = NewType("Predictions", Union[SmabPredictions, CmabPredictions])
BinaryReward = NewType("BinaryReward", conint(ge=0, le=1))
ActionRewardLikelihood = NewType(
"ActionRewardLikelihood",
Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
)
ACTION_IDS_PREFIX = "action_ids_"


class PyBanditsBaseModel(BaseModel, extra="forbid"):
"""
BaseModel of the PyBandits library.
"""


class Model(PyBanditsBaseModel, ABC):
"""
Class to model the prior distributions.
"""

@abstractmethod
def sample_proba(self) -> Probability:
"""
Sample the probability of getting a positive reward.
"""

@abstractmethod
def update(self, rewards: List[Any]):
"""
Update the model parameters.
"""


class Strategy(PyBanditsBaseModel, ABC):
"""
Strategy to select actions in multi-armed bandits.
"""

@abstractmethod
def select_action(self, p: Dict[ActionId, Probability], actions: Optional[Dict[ActionId, Model]]) -> ActionId:
"""
Select the action.
"""


class BaseMab(PyBanditsBaseModel, ABC):
"""
Multi-armed bandit superclass.
Parameters
----------
actions: Dict[ActionId, Model]
The list of possible actions, and their associated Model.
strategy: Strategy
The strategy used to select actions.
epsilon: Optional[Float01]
The probability of selecting a random action.
default_action: Optional[ActionId]
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
"""

actions: Dict[ActionId, Model]
strategy: Strategy
epsilon: Optional[Float01]
default_action: Optional[ActionId]

@field_validator("actions", mode="before")
@classmethod
def at_least_2_actions_are_defined(cls, v):
# validate that at least 2 actions are defined
if len(v) < 2:
raise AttributeError("At least 2 actions should be defined.")
# validate that all actions are of the same configuration
action_models = list(v.values())
first_action = action_models[0]
first_action_type = type(first_action)
if any(not isinstance(action, first_action_type) for action in action_models[1:]):
raise AttributeError("All actions should follow the same type.")

return v

@model_validator(mode="after")
def check_default_action(self):
if not self.epsilon and self.default_action:
raise AttributeError("A default action should only be defined when epsilon is defined.")
if self.default_action and self.default_action not in self.actions:
raise AttributeError("The default action should be defined in the actions.")
return self

def _get_valid_actions(self, forbidden_actions: Optional[Set[ActionId]]) -> Set[ActionId]:
"""
Given a set of forbidden action IDs, return a set of valid action IDs.
Parameters
----------
forbidden_actions: Optional[Set[ActionId]]
The set of forbidden action IDs.
Returns
-------
valid_actions: Set[ActionId]
The list of valid (i.e. not forbidden) action IDs.
"""
if forbidden_actions is None:
forbidden_actions = set()

if not all(a in self.actions.keys() for a in forbidden_actions):
raise ValueError("forbidden_actions contains invalid action IDs.")
valid_actions = set(self.actions.keys()) - forbidden_actions
if len(valid_actions) == 0:
raise ValueError("All actions are forbidden. You must allow at least 1 action.")
if self.default_action and self.default_action not in valid_actions:
raise ValueError("The default action is forbidden.")

return valid_actions

def _check_update_params(self, actions: List[ActionId], rewards: List[Union[NonNegativeInt, List[NonNegativeInt]]]):
"""
Verify that the given list of action IDs is a subset of the currently defined actions.
Parameters
----------
actions : List[ActionId]
The selected action for each sample.
rewards: List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""
invalid = set(actions) - set(self.actions.keys())
if invalid:
raise AttributeError(f"The following invalid action(s) were specified: {invalid}.")
if len(actions) != len(rewards):
raise AttributeError(f"Shape mismatch: actions and rewards should have the same length {len(actions)}.")

@abstractmethod
@validate_call
def update(self, actions: List[ActionId], rewards: List[Union[BinaryReward, List[BinaryReward]]], *args, **kwargs):
"""
Update the stochastic multi-armed bandit model.
actions: List[ActionId]
The selected action for each sample.
rewards: List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""

@abstractmethod
@validate_call
def predict(self, forbidden_actions: Optional[Set[ActionId]] = None):
"""
Predict actions.
Parameters
----------
forbidden_actions : Optional[Set[ActionId]], default=None
Set of forbidden actions. If specified, the model will discard the forbidden_actions and it will only
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.
Returns
-------
actions: List[ActionId] of shape (n_samples,)
The actions selected by the multi-armed bandit model.
probs: List[Dict[ActionId, float]] of shape (n_samples,)
The probabilities of getting a positive reward for each action.
"""

def get_state(self) -> (str, dict):
"""
Access the complete model internal state, enough to create an exact copy of the same model from it.
Returns
-------
model_class_name: str
The name of the class of the model.
model_state: dict
The internal state of the model (actions, scores, etc.).
"""
model_name = self.__class__.__name__
state: dict = self.dict()
return model_name, state

@validate_call
def _select_epsilon_greedy_action(
self,
p: Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
actions: Optional[Dict[ActionId, Model]] = None,
) -> ActionId:
"""
Wraps self.strategy.select_action function with epsilon-greedy strategy,
such that with probability epsilon a default_action is selected,
and with probability 1-epsilon the select_action function is triggered to choose action.
If no default_action is provided, a random action is selected.
Reference: Reinforcement Learning: An Introduction, Ch. 2 (Sutton and Burto, 2018)
https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&ved=2ahUKEwjMy8WV9N2HAxVe0gIHHVjjG5sQFnoECEMQAQ&usg=AOvVaw3bKK-Y_1kf6XQVwR-UYrBY
Parameters
----------
p: Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]]
The dictionary or actions and their sampled probability of getting a positive reward.
For MO strategy, the sampled probability is a list with elements corresponding to the objectives.
actions: Optional[Dict[ActionId, Model]]
The dictionary of actions and their associated Model.
Returns
-------
selected_action: ActionId
The selected action.
Raises
------
KeyError
If self.default_action is not present as a key in the probabilities dictionary.
"""

if self.epsilon:
if self.default_action and self.default_action not in p.keys():
raise KeyError(f"Default action {self.default_action} not in actions.")
if np.random.binomial(1, self.epsilon):
selected_action = self.default_action if self.default_action else np.random.choice(list(p.keys()))
else:
selected_action = self.strategy.select_action(p=p, actions=actions)
else:
selected_action = self.strategy.select_action(p=p, actions=actions)
return selected_action
Loading

0 comments on commit e90f4bc

Please sign in to comment.