-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8dbec1a
commit 59f05bc
Showing
5 changed files
with
59 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from pycid.random.random_cid import random_cid, random_cids # noqa | ||
from pycid.random.random_cpd import random_cpd # noqa | ||
from pycid.random.random_cpd import RandomCPD # noqa | ||
from pycid.random.random_dag import random_dag # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,52 @@ | ||
from typing import List | ||
import contextlib | ||
from typing import Dict, Iterator, List | ||
|
||
import numpy as np | ||
|
||
from pycid import StochasticFunctionCPD | ||
from pycid.core.cpd import Outcome | ||
|
||
|
||
def random_cpd(variable: str, domain: List = [0, 1], smoothness: float = 1.0) -> StochasticFunctionCPD: | ||
""" | ||
Sample a random CPD, with outcomes in the given domain | ||
Parameters | ||
---------- | ||
variable: Name of variable | ||
@contextlib.contextmanager | ||
def temp_seed(seed: int) -> Iterator[None]: | ||
state = np.random.get_state() | ||
np.random.seed(seed) | ||
try: | ||
yield | ||
finally: | ||
np.random.set_state(state) | ||
|
||
domain: List of possible outcomes, defaults to [0, 1] | ||
|
||
smoothness: How different the probabilities for different probabilities are. | ||
When small (e.g. 0.001), most probability mass falls on a single outcome, and | ||
when large (e.g. 1000), the distribution approaches a uniform distribution. | ||
class RandomCPD(StochasticFunctionCPD): | ||
""" | ||
Sample a random CPD, with outcomes in the given domain | ||
""" | ||
|
||
def prob_vector() -> List: | ||
return np.random.dirichlet(np.ones(len(domain)) * smoothness, size=1).flat # type: ignore | ||
|
||
return StochasticFunctionCPD( | ||
variable, | ||
lambda **pv: {domain[i]: prob for i, prob in enumerate(prob_vector())}, | ||
domain=domain if domain else [0, 1], | ||
label="random cpd", | ||
) | ||
def __init__(self, variable: str, domain: List = [0, 1], smoothness: float = 1.0, seed: int = None) -> None: | ||
""" | ||
Parameters | ||
---------- | ||
variable: Name of variable | ||
domain: List of possible outcomes, defaults to [0, 1] | ||
smoothness: How different the probabilities for different probabilities are. | ||
When small (e.g. 0.001), most probability mass falls on a single outcome, and | ||
when large (e.g. 1000), the distribution approaches a uniform distribution. | ||
seed: Set the random seed | ||
""" | ||
self.seed = seed or np.random.randint(0, 10000) | ||
self.smoothness = smoothness | ||
|
||
def random_stochastic_function(**pv: Outcome) -> Dict[Outcome, float]: | ||
with temp_seed(self.seed + hash(frozenset(pv.items())) % 2 ** 31 - 1): | ||
prob_vec = np.random.dirichlet(np.ones(len(self.domain)) * self.smoothness, size=1).flat # type: ignore | ||
return {self.domain[i]: prob for i, prob in enumerate(prob_vec)} # type: ignore | ||
|
||
super().__init__( | ||
variable, | ||
random_stochastic_function, | ||
domain=domain if domain else [0, 1], | ||
label=f"RandomCPD({self.smoothness}, {self.seed})", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters