Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor MAB and Strategy Classes with Cold Start Methods and Enhanced Validation #35

Merged
merged 1 commit into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/continuous_delivery.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ 3.8, 3.9 ]
python-version: [ "3.8", "3.9", "3.10" ]

steps:
- name: Checkout repository
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.8, 3.9]
python-version: [ "3.8", "3.9", "3.10" ]

steps:
- name: Checkout repository
Expand Down
8 changes: 3 additions & 5 deletions docs/tutorials/mab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"from rich import print\n",
"\n",
"from pybandits.model import Beta\n",
"from pybandits.smab import SmabBernoulli, create_smab_bernoulli_cold_start"
"from pybandits.smab import SmabBernoulli"
]
},
{
shaharbar1 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -73,8 +73,6 @@
"metadata": {},
"outputs": [],
"source": [
"n_objectives = 2\n",
"\n",
"mab = SmabBernoulli(\n",
" actions={\n",
" \"a1\": Beta(n_successes=1, n_failures=1),\n",
Expand Down Expand Up @@ -137,7 +135,7 @@
"id": "564914fd-73cc-4854-8ec7-548970f794a6",
"metadata": {},
"source": [
"You can initialize the bandit via the utility function `create_smab_bernoulli_mo_cc_cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
"You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
]
},
{
Expand All @@ -148,7 +146,7 @@
"outputs": [],
"source": [
"# generate a smab bernoulli in cold start settings\n",
"mab = create_smab_bernoulli_cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
"mab = SmabBernoulli.cold_start(action_ids=[\"a1\", \"a2\", \"a3\"])"
]
},
{
Expand Down
9 changes: 3 additions & 6 deletions docs/tutorials/smab_mo_cc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"from rich import print\n",
"\n",
"from pybandits.model import Beta, BetaMOCC\n",
"from pybandits.smab import SmabBernoulliMOCC, create_smab_bernoulli_mo_cc_cold_start"
"from pybandits.smab import SmabBernoulliMOCC"
]
},
{
shaharbar1 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -72,8 +72,6 @@
"metadata": {},
"outputs": [],
"source": [
"n_objectives = 2\n",
"\n",
"mab = SmabBernoulliMOCC(\n",
" actions={\n",
" \"a1\": BetaMOCC(counters=[Beta(n_successes=1, n_failures=1), Beta(n_successes=1, n_failures=1)], cost=30),\n",
Expand Down Expand Up @@ -153,7 +151,7 @@
"id": "564914fd-73cc-4854-8ec7-548970f794a6",
"metadata": {},
"source": [
"You can initialize the bandit via the utility function `create_smab_bernoulli_mo_cc_cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
"You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`."
]
},
{
Expand All @@ -165,10 +163,9 @@
"source": [
"# list of action IDs with their cost\n",
"action_ids_cost = {\"a1\": 30, \"a2\": 10, \"a3\": 20}\n",
"n_objectives = 2\n",
"\n",
"# generate a smab bernoulli in cold start settings\n",
"mab = create_smab_bernoulli_mo_cc_cold_start(action_ids_cost=action_ids_cost, n_objectives=n_objectives)"
"mab = SmabBernoulliMOCC.cold_start(action_ids_cost=action_ids_cost)"
]
},
{
Expand Down
238 changes: 12 additions & 226 deletions pybandits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,241 +21,27 @@
# SOFTWARE.


from abc import ABC, abstractmethod
from typing import Any, Dict, List, NewType, Optional, Set, Tuple, Union
from typing import Dict, List, NewType, Tuple, Union

import numpy as np
from pydantic import (
BaseModel,
NonNegativeInt,
confloat,
conint,
constr,
field_validator,
model_validator,
validate_call,
)
from pydantic import BaseModel, confloat, conint, constr

ActionId = NewType("ActionId", constr(min_length=1))
Float01 = NewType("Float_0_1", confloat(ge=0, le=1))
Probability = NewType("Probability", Float01)
Predictions = NewType("Predictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]]])
SmabPredictions = NewType("SmabPredictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]]])
CmabPredictions = NewType(
"CmabPredictions", Tuple[List[ActionId], List[Dict[ActionId, Probability]], List[Dict[ActionId, float]]]
)
Predictions = NewType("Predictions", Union[SmabPredictions, CmabPredictions])
BinaryReward = NewType("BinaryReward", conint(ge=0, le=1))
ActionRewardLikelihood = NewType(
"ActionRewardLikelihood",
Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
)
ACTION_IDS_PREFIX = "action_ids_"


class PyBanditsBaseModel(BaseModel, extra="forbid"):
"""
BaseModel of the PyBandits library.
"""


class Model(PyBanditsBaseModel, ABC):
"""
Class to model the prior distributions.
"""

@abstractmethod
def sample_proba(self) -> Probability:
"""
Sample the probability of getting a positive reward.
"""

@abstractmethod
def update(self, rewards: List[Any]):
"""
Update the model parameters.
"""


class Strategy(PyBanditsBaseModel, ABC):
"""
Strategy to select actions in multi-armed bandits.
"""

@abstractmethod
def select_action(self, p: Dict[ActionId, Probability], actions: Optional[Dict[ActionId, Model]]) -> ActionId:
"""
Select the action.
"""


class BaseMab(PyBanditsBaseModel, ABC):
"""
Multi-armed bandit superclass.

Parameters
----------
actions: Dict[ActionId, Model]
The list of possible actions, and their associated Model.
strategy: Strategy
The strategy used to select actions.
epsilon: Optional[Float01]
The probability of selecting a random action.
default_action: Optional[ActionId]
The default action to select with a probability of epsilon when using the epsilon-greedy approach.
If `default_action` is None, a random action from the action set will be selected with a probability of epsilon.
"""

actions: Dict[ActionId, Model]
strategy: Strategy
epsilon: Optional[Float01]
default_action: Optional[ActionId]

@field_validator("actions", mode="before")
@classmethod
def at_least_2_actions_are_defined(cls, v):
# validate that at least 2 actions are defined
if len(v) < 2:
raise AttributeError("At least 2 actions should be defined.")
# validate that all actions are of the same configuration
action_models = list(v.values())
first_action = action_models[0]
first_action_type = type(first_action)
if any(not isinstance(action, first_action_type) for action in action_models[1:]):
raise AttributeError("All actions should follow the same type.")

return v

@model_validator(mode="after")
def check_default_action(self):
if not self.epsilon and self.default_action:
raise AttributeError("A default action should only be defined when epsilon is defined.")
if self.default_action and self.default_action not in self.actions:
raise AttributeError("The default action should be defined in the actions.")
return self

def _get_valid_actions(self, forbidden_actions: Optional[Set[ActionId]]) -> Set[ActionId]:
"""
Given a set of forbidden action IDs, return a set of valid action IDs.

Parameters
----------
forbidden_actions: Optional[Set[ActionId]]
The set of forbidden action IDs.

Returns
-------
valid_actions: Set[ActionId]
The list of valid (i.e. not forbidden) action IDs.
"""
if forbidden_actions is None:
forbidden_actions = set()

if not all(a in self.actions.keys() for a in forbidden_actions):
raise ValueError("forbidden_actions contains invalid action IDs.")
valid_actions = set(self.actions.keys()) - forbidden_actions
if len(valid_actions) == 0:
raise ValueError("All actions are forbidden. You must allow at least 1 action.")
if self.default_action and self.default_action not in valid_actions:
raise ValueError("The default action is forbidden.")

return valid_actions

def _check_update_params(self, actions: List[ActionId], rewards: List[Union[NonNegativeInt, List[NonNegativeInt]]]):
"""
Verify that the given list of action IDs is a subset of the currently defined actions.

Parameters
----------
actions : List[ActionId]
The selected action for each sample.
rewards: List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""
invalid = set(actions) - set(self.actions.keys())
if invalid:
raise AttributeError(f"The following invalid action(s) were specified: {invalid}.")
if len(actions) != len(rewards):
raise AttributeError(f"Shape mismatch: actions and rewards should have the same length {len(actions)}.")

@abstractmethod
@validate_call
def update(self, actions: List[ActionId], rewards: List[Union[BinaryReward, List[BinaryReward]]], *args, **kwargs):
"""
Update the stochastic multi-armed bandit model.

actions: List[ActionId]
The selected action for each sample.
rewards: List[Union[BinaryReward, List[BinaryReward]]]
The reward for each sample.
"""

@abstractmethod
@validate_call
def predict(self, forbidden_actions: Optional[Set[ActionId]] = None):
"""
Predict actions.

Parameters
----------
forbidden_actions : Optional[Set[ActionId]], default=None
Set of forbidden actions. If specified, the model will discard the forbidden_actions and it will only
consider the remaining allowed_actions. By default, the model considers all actions as allowed_actions.
Note that: actions = allowed_actions U forbidden_actions.

Returns
-------
actions: List[ActionId] of shape (n_samples,)
The actions selected by the multi-armed bandit model.
probs: List[Dict[ActionId, float]] of shape (n_samples,)
The probabilities of getting a positive reward for each action.
"""

def get_state(self) -> (str, dict):
"""
Access the complete model internal state, enough to create an exact copy of the same model from it.
Returns
-------
model_class_name: str
The name of the class of the model.
model_state: dict
The internal state of the model (actions, scores, etc.).
"""
model_name = self.__class__.__name__
state: dict = self.dict()
return model_name, state

@validate_call
def _select_epsilon_greedy_action(
self,
p: Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]],
actions: Optional[Dict[ActionId, Model]] = None,
) -> ActionId:
"""
Wraps self.strategy.select_action function with epsilon-greedy strategy,
such that with probability epsilon a default_action is selected,
and with probability 1-epsilon the select_action function is triggered to choose action.
If no default_action is provided, a random action is selected.

Reference: Reinforcement Learning: An Introduction, Ch. 2 (Sutton and Burto, 2018)
https://web.stanford.edu/class/psych209/Readings/SuttonBartoIPRLBook2ndEd.pdf&ved=2ahUKEwjMy8WV9N2HAxVe0gIHHVjjG5sQFnoECEMQAQ&usg=AOvVaw3bKK-Y_1kf6XQVwR-UYrBY

Parameters
----------
p: Union[Dict[ActionId, float], Dict[ActionId, Probability], Dict[ActionId, List[Probability]]]
The dictionary or actions and their sampled probability of getting a positive reward.
For MO strategy, the sampled probability is a list with elements corresponding to the objectives.
actions: Optional[Dict[ActionId, Model]]
The dictionary of actions and their associated Model.

Returns
-------
selected_action: ActionId
The selected action.

Raises
------
KeyError
If self.default_action is not present as a key in the probabilities dictionary.
"""

if self.epsilon:
if self.default_action and self.default_action not in p.keys():
raise KeyError(f"Default action {self.default_action} not in actions.")
if np.random.binomial(1, self.epsilon):
selected_action = self.default_action if self.default_action else np.random.choice(list(p.keys()))
else:
selected_action = self.strategy.select_action(p=p, actions=actions)
else:
selected_action = self.strategy.select_action(p=p, actions=actions)
return selected_action
Loading
Loading