Skip to content

Commit

Permalink
Added a class for parametric distributions, defined a tabular distrib…
Browse files Browse the repository at this point in the history
…ution and

updated the policy distribution algorithm to use it.

PiperOrigin-RevId: 440070085
Change-Id: Ib88c20cc855aa8764c074ffa8cdc9bd892013a87
  • Loading branch information
sgirgin authored and lanctot committed Apr 7, 2022
1 parent 7a56efc commit 6e36388
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 139 deletions.
229 changes: 95 additions & 134 deletions open_spiel/python/mfg/algorithms/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Computes the distribution of a policy."""
import collections

from typing import Dict, List, Tuple
from typing import List, Tuple
from open_spiel.python import policy as policy_module
from open_spiel.python.mfg import distribution as distribution_module
from open_spiel.python.mfg import tabular_distribution
from open_spiel.python.mfg.tabular_distribution import DistributionDict
import pyspiel


Expand All @@ -28,116 +28,20 @@ def type_from_states(states):
return types[0]


def state_to_str(state):
# TODO(author15): Consider switching to
# state.mean_field_population(). For now, this does not matter in
# practice since games don't have different observation strings for
# different player IDs.
return state.observation_string(pyspiel.PlayerId.DEFAULT_PLAYER_ID)


def forward_actions(
current_states: List[pyspiel.State], distribution: Dict[str, float],
actions_and_probs_fn) -> Tuple[List[pyspiel.State], Dict[str, float]]:
"""Applies one action to each current state.
Args:
current_states: The states to apply actions on.
distribution: Current distribution.
actions_and_probs_fn: Function that maps one state to the corresponding list
of (action, proba). For decision nodes, this should be the policy, and for
chance nodes, this should be chance outcomes.
Returns:
A pair:
- new_states: List of new states after applying one action on
each input state.
- new_distribution: Probabilities for each of these states.
"""
new_states = []
new_distribution = collections.defaultdict(float)
for state in current_states:
state_str = state_to_str(state)
for action, prob in actions_and_probs_fn(state):
new_state = state.child(action)
new_state_str = state_to_str(new_state)
if new_state_str not in new_distribution:
new_states.append(new_state)
new_distribution[new_state_str] += prob * distribution[state_str]
return new_states, new_distribution


def one_forward_step(current_states: List[pyspiel.State],
distribution: Dict[str, float],
policy: policy_module.Policy):
"""Performs one step of the forward equation.
Namely, this takes as input a list of current state, the current
distribution, and performs one step of the forward equation, using
actions coming from the policy or from the chance node
probabilities, or propagating the distribution to the MFG nodes.
Args:
current_states: The states to perform the forward step on. All states are
assumed to be of the same type.
distribution: Current distribution.
policy: Policy that will be used if states
Returns:
A pair:
- new_states: List of new states after applying one step of the
forward equation (either performing one action or doing one
distribution update).
- new_distribution: Probabilities for each of these states.
"""
state_types = type_from_states(current_states)
if state_types == pyspiel.StateType.CHANCE:
return forward_actions(current_states, distribution,
lambda state: state.chance_outcomes())

if state_types == pyspiel.StateType.MEAN_FIELD:
new_states = []
new_distribution = {}
for state in current_states:
dist = [
# We need to default to 0, since the support requested by
# the state in `state.distribution_support()` might have
# states that we might not have reached yet. A probability
# of 0. should be given for them.
distribution.get(str_state, 0.)
for str_state in state.distribution_support()
]
new_state = state.clone()
new_state.update_distribution(dist)
new_state_str = state_to_str(new_state)
if new_state_str not in new_distribution:
new_states.append(new_state)
new_distribution[new_state_str] = 0.0
new_distribution[new_state_str] += distribution.get(
state_to_str(state), 0)
return new_states, new_distribution

if state_types == pyspiel.StateType.DECISION:
return forward_actions(
current_states, distribution,
lambda state: policy.action_probabilities(state).items())

raise ValueError(
f"Unpexpected state_stypes: {state_types}, states: {current_states}")


def check_distribution_sum(distribution: Dict[str, float], expected_sum: int):
def _check_distribution_sum(distribution: DistributionDict, expected_sum: int):
"""Sanity check that the distribution sums to a given value."""
sum_state_probabilities = sum(distribution.values())
assert abs(sum_state_probabilities - expected_sum) < 1e-4, (
"Sum of probabilities of all possible states should be the number of "
f"population, it is {sum_state_probabilities}.")


class DistributionPolicy(distribution_module.Distribution):
class DistributionPolicy(tabular_distribution.TabularDistribution):
"""Computes the distribution of a specified strategy."""

def __init__(self, game: pyspiel.Game, policy: policy_module.Policy,
def __init__(self,
game: pyspiel.Game,
policy: policy_module.Policy,
root_state: pyspiel.State = None):
"""Initializes the distribution calculation.
Expand All @@ -153,7 +57,6 @@ def __init__(self, game: pyspiel.Game, policy: policy_module.Policy,
self._root_states = game.new_initial_states()
else:
self._root_states = [root_state]
self.distribution = None
self.evaluate()

def evaluate(self):
Expand All @@ -164,54 +67,112 @@ def evaluate(self):
# Distribution at the current timestep. Maps state strings to
# floats. For each group of states for a given population, these
# floats represent a probability distribution.
current_distribution = {state_to_str(state): 1
for state in current_states}
current_distribution = {
self.state_to_str(state): 1 for state in current_states
}
# List of all distributions computed so far.
all_distributions = [current_distribution]

while type_from_states(current_states) != pyspiel.StateType.TERMINAL:
new_states, new_distribution = one_forward_step(current_states,
current_distribution,
self._policy)
check_distribution_sum(new_distribution, self.game.num_players())
new_states, new_distribution = self._one_forward_step(
current_states, current_distribution, self._policy)
_check_distribution_sum(new_distribution, self.game.num_players())
current_distribution = new_distribution
current_states = new_states
all_distributions.append(new_distribution)

# Merge all per-timestep distributions into `self.distribution`.
self.distribution = {}
for dist in all_distributions:
for state_str, prob in dist.items():
if state_str in self.distribution:
raise ValueError(
f"{state_str} has already been seen in distribution.")
self.distribution[state_str] = prob

def value(self, state):
return self.value_str(state_to_str(state))

def value_str(self, state_str, default_value=None):
"""Return probability of the state encoded by state_str.
def _forward_actions(
self, current_states: List[pyspiel.State], distribution: DistributionDict,
actions_and_probs_fn) -> Tuple[List[pyspiel.State], DistributionDict]:
"""Applies one action to each current state.
Args:
state_str: string description of the state. This should be created
using observation_string.
default_value: in case the state has not been seen by the distribution, to
avoid raising a value error the default value is returned if it is not
None.
current_states: The states to apply actions on.
distribution: Current distribution.
actions_and_probs_fn: Function that maps one state to the corresponding
list of (action, proba). For decision nodes, this should be the policy,
and for chance nodes, this should be chance outcomes.
Returns:
state_probability: probability to be in the state descripbed by
state_str.
A pair:
- new_states: List of new states after applying one action on
each input state.
- new_distribution: Probabilities for each of these states.
"""
new_states = []
new_distribution = collections.defaultdict(float)
for state in current_states:
state_str = self.state_to_str(state)
for action, prob in actions_and_probs_fn(state):
new_state = state.child(action)
new_state_str = self.state_to_str(new_state)
if new_state_str not in new_distribution:
new_states.append(new_state)
new_distribution[new_state_str] += prob * distribution[state_str]
return new_states, new_distribution

def _one_forward_step(self, current_states: List[pyspiel.State],
distribution: DistributionDict,
policy: policy_module.Policy):
"""Performs one step of the forward equation.
Raises:
ValueError: if the state has not been seen by the distribution and no
default value has been passed to the method.
Namely, this takes as input a list of current state, the current
distribution, and performs one step of the forward equation, using
actions coming from the policy or from the chance node
probabilities, or propagating the distribution to the MFG nodes.
Args:
current_states: The states to perform the forward step on. All states are
assumed to be of the same type.
distribution: Current distribution.
policy: Policy that will be used if states
Returns:
A pair:
- new_states: List of new states after applying one step of the
forward equation (either performing one action or doing one
distribution update).
- new_distribution: Probabilities for each of these states.
"""
if default_value is None:
try:
return self.distribution[state_str]
except KeyError as e:
raise ValueError(
f"Distribution not computed for state {state_str}") from e
return self.distribution.get(state_str, default_value)
state_types = type_from_states(current_states)
if state_types == pyspiel.StateType.CHANCE:
return self._forward_actions(current_states, distribution,
lambda state: state.chance_outcomes())

if state_types == pyspiel.StateType.MEAN_FIELD:
new_states = []
new_distribution = {}
for state in current_states:
dist = [
# We need to default to 0, since the support requested by
# the state in `state.distribution_support()` might have
# states that we might not have reached yet. A probability
# of 0. should be given for them.
distribution.get(str_state, 0.)
for str_state in state.distribution_support()
]
new_state = state.clone()
new_state.update_distribution(dist)
new_state_str = self.state_to_str(new_state)
if new_state_str not in new_distribution:
new_states.append(new_state)
new_distribution[new_state_str] = 0.0
new_distribution[new_state_str] += distribution.get(
self.state_to_str(state), 0)
return new_states, new_distribution

if state_types == pyspiel.StateType.DECISION:
return self._forward_actions(
current_states, distribution,
lambda state: policy.action_probabilities(state).items())

raise ValueError(
f"Unpexpected state_stypes: {state_types}, states: {current_states}")
31 changes: 26 additions & 5 deletions open_spiel/python/mfg/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,13 @@
The main way of using a distribution is to call `value(state)`.
"""

import abc
from typing import Any, Optional

class Distribution(object):
import pyspiel


class Distribution(abc.ABC):
"""Base class for distributions.
This represents a probability distribution over the states of a game.
Expand All @@ -31,15 +36,16 @@ class Distribution(object):
game: the game for which this distribution is derives
"""

def __init__(self, game):
def __init__(self, game: pyspiel.Game):
"""Initializes a distribution.
Args:
game: the game for which this distribution is derives
"""
self.game = game

def value(self, state):
@abc.abstractmethod
def value(self, state: pyspiel.State) -> float:
"""Returns the probability of the distribution on the state.
Args:
Expand All @@ -50,7 +56,10 @@ def value(self, state):
"""
raise NotImplementedError()

def value_str(self, state_str, default_value=None):
@abc.abstractmethod
def value_str(self,
state_str: str,
default_value: Optional[float] = None) -> float:
"""Returns the probability of the distribution on the state string given.
Args:
Expand All @@ -63,7 +72,7 @@ def value_str(self, state_str, default_value=None):
"""
raise NotImplementedError()

def __call__(self, state):
def __call__(self, state: pyspiel.State) -> float:
"""Turns the distribution into a callable.
Args:
Expand All @@ -73,3 +82,15 @@ def __call__(self, state):
Float: probability.
"""
return self.value(state)


class ParametricDistribution(Distribution):
"""A parametric distribution."""

@abc.abstractmethod
def get_params(self) -> Any:
"""Returns the distribution parameters."""

@abc.abstractmethod
def set_params(self, params: Any):
"""Sets the distribution parameters."""
Loading

0 comments on commit 6e36388

Please sign in to comment.