Skip to content

Commit

Permalink
better random cpd
Browse files Browse the repository at this point in the history
  • Loading branch information
tom4everitt committed Mar 31, 2021
1 parent 8dbec1a commit 59f05bc
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pycid/random/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from pycid.random.random_cid import random_cid, random_cids # noqa
from pycid.random.random_cpd import random_cpd # noqa
from pycid.random.random_cpd import RandomCPD # noqa
from pycid.random.random_dag import random_dag # noqa
4 changes: 2 additions & 2 deletions pycid/random/random_cid.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pycid.core.cid import CID
from pycid.core.cpd import DecisionDomain
from pycid.core.get_paths import find_active_path
from pycid.random.random_cpd import random_cpd
from pycid.random.random_cpd import RandomCPD

# TODO add a random_macid function

Expand Down Expand Up @@ -40,7 +40,7 @@ def random_cid(
if node in cid.decisions:
cid.add_cpds(DecisionDomain(node, [0, 1]))
else:
cid.add_cpds(random_cpd(node))
cid.add_cpds(RandomCPD(node))
return cid


Expand Down
63 changes: 42 additions & 21 deletions pycid/random/random_cpd.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,52 @@
from typing import List
import contextlib
from typing import Dict, Iterator, List

import numpy as np

from pycid import StochasticFunctionCPD
from pycid.core.cpd import Outcome


def random_cpd(variable: str, domain: List = [0, 1], smoothness: float = 1.0) -> StochasticFunctionCPD:
"""
Sample a random CPD, with outcomes in the given domain
Parameters
----------
variable: Name of variable
@contextlib.contextmanager
def temp_seed(seed: int) -> Iterator[None]:
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)

domain: List of possible outcomes, defaults to [0, 1]

smoothness: How different the probabilities for different probabilities are.
When small (e.g. 0.001), most probability mass falls on a single outcome, and
when large (e.g. 1000), the distribution approaches a uniform distribution.
class RandomCPD(StochasticFunctionCPD):
"""
Sample a random CPD, with outcomes in the given domain
"""

def prob_vector() -> List:
return np.random.dirichlet(np.ones(len(domain)) * smoothness, size=1).flat # type: ignore

return StochasticFunctionCPD(
variable,
lambda **pv: {domain[i]: prob for i, prob in enumerate(prob_vector())},
domain=domain if domain else [0, 1],
label="random cpd",
)
def __init__(self, variable: str, domain: List = [0, 1], smoothness: float = 1.0, seed: int = None) -> None:
"""
Parameters
----------
variable: Name of variable
domain: List of possible outcomes, defaults to [0, 1]
smoothness: How different the probabilities for different probabilities are.
When small (e.g. 0.001), most probability mass falls on a single outcome, and
when large (e.g. 1000), the distribution approaches a uniform distribution.
seed: Set the random seed
"""
self.seed = seed or np.random.randint(0, 10000)
self.smoothness = smoothness

def random_stochastic_function(**pv: Outcome) -> Dict[Outcome, float]:
with temp_seed(self.seed + hash(frozenset(pv.items())) % 2 ** 31 - 1):
prob_vec = np.random.dirichlet(np.ones(len(self.domain)) * self.smoothness, size=1).flat # type: ignore
return {self.domain[i]: prob for i, prob in enumerate(prob_vec)} # type: ignore

super().__init__(
variable,
random_stochastic_function,
domain=domain if domain else [0, 1],
label=f"RandomCPD({self.smoothness}, {self.seed})",
)
5 changes: 2 additions & 3 deletions tests/test_causal_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import pytest
from pgmpy.factors.discrete import TabularCPD # type: ignore

from pycid import CID, MACID, CausalBayesianNetwork, RandomCPD
from pycid.examples.simple_cbns import get_3node_cbn
from pycid.examples.simple_cids import get_3node_cid, get_minimal_cid
from pycid.examples.story_macids import taxi_competition

from pycid import CID, MACID, CausalBayesianNetwork, random_cpd


@pytest.fixture
def cid_3node() -> CID:
Expand Down Expand Up @@ -61,7 +60,7 @@ def test_query(cbn_3node: CausalBayesianNetwork) -> None:
@staticmethod
def test_query_disconnected_components() -> None:
cbn = CausalBayesianNetwork([("A", "B")])
cbn.add_cpds(random_cpd("A"), random_cpd("B"))
cbn.add_cpds(RandomCPD("A"), RandomCPD("B"))
cbn.query(["A"], {}, intervention={"B": 0}) # the intervention separates A and B into separare components

@staticmethod
Expand Down
13 changes: 12 additions & 1 deletion tests/test_random.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import sys

import pytest
import networkx as nx
import pytest

from pycid import CausalBayesianNetwork, FunctionCPD, RandomCPD
from pycid.random.random_cid import random_cid, random_cids
from pycid.random.random_dag import random_dag

Expand All @@ -22,5 +23,15 @@ def test_random_dag_create_one() -> None:
assert nx.is_directed_acyclic_graph(dag)


def test_random_cpd() -> None:
cbn = CausalBayesianNetwork([("Y", "A"), ("Y", "D")])
cbn.add_cpds(
RandomCPD("Y"),
FunctionCPD("A", lambda y: y),
FunctionCPD("D", lambda y: y),
)
assert cbn.expected_value(["D"], {}, intervention={"A": 0}) == cbn.expected_value(["D"], {}, intervention={"A": 1})


if __name__ == "__main__":
pytest.main(sys.argv)

0 comments on commit 59f05bc

Please sign in to comment.