Skip to content

Commit

Permalink
various minor
Browse files Browse the repository at this point in the history
  • Loading branch information
tom4everitt committed Mar 31, 2021
1 parent 59f05bc commit 1ab2c5a
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
4 changes: 2 additions & 2 deletions pycid/core/cpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import types
from inspect import getsourcelines
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union, Iterator
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, List, Optional, Sequence, Union

import numpy as np
from pgmpy.factors.discrete import TabularCPD # type: ignore
Expand Down Expand Up @@ -217,7 +217,7 @@ def __repr__(self) -> str:
mapping = "\n".join([str(key) + " -> " + str(dictionary[key]) for key in sorted(list(dictionary.keys()))])
else:
mapping = ""
return f"<FunctionCPD {self.variable}:{self.stochastic_function}> \n{mapping}"
return f"{type(self).__name__}<{self.variable}:{self.stochastic_function}> \n{mapping}"

def __str__(self) -> str:
return self.__repr__()
Expand Down
11 changes: 9 additions & 2 deletions pycid/random/random_cpd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import contextlib
from typing import Dict, Iterator, List
from typing import Dict, Iterator, Sequence

import numpy as np

Expand All @@ -22,7 +24,9 @@ class RandomCPD(StochasticFunctionCPD):
Sample a random CPD, with outcomes in the given domain
"""

def __init__(self, variable: str, domain: List = [0, 1], smoothness: float = 1.0, seed: int = None) -> None:
def __init__(
self, variable: str, domain: Sequence[Outcome] = [0, 1], smoothness: float = 1.0, seed: int = None
) -> None:
"""
Parameters
----------
Expand Down Expand Up @@ -50,3 +54,6 @@ def random_stochastic_function(**pv: Outcome) -> Dict[Outcome, float]:
domain=domain if domain else [0, 1],
label=f"RandomCPD({self.smoothness}, {self.seed})",
)

def copy(self) -> RandomCPD:
return RandomCPD(self.variable, self.domain, self.smoothness, self.seed) # type: ignore
13 changes: 7 additions & 6 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ def test_random_dag_create_one() -> None:
assert nx.is_directed_acyclic_graph(dag)


def test_random_cpd() -> None:
cbn = CausalBayesianNetwork([("Y", "A"), ("Y", "D")])
def test_random_cpd_copy() -> None:
"""check that a copy of a random cpd yields the same distribution"""
cbn = CausalBayesianNetwork([("A", "B")])
cbn.add_cpds(
RandomCPD("Y"),
FunctionCPD("A", lambda y: y),
FunctionCPD("D", lambda y: y),
RandomCPD("A"),
FunctionCPD("B", lambda a: a),
)
assert cbn.expected_value(["D"], {}, intervention={"A": 0}) == cbn.expected_value(["D"], {}, intervention={"A": 1})
cbn2 = cbn.copy()
assert cbn.expected_value(["B"], {}) == cbn2.expected_value(["B"], {})


if __name__ == "__main__":
Expand Down

0 comments on commit 1ab2c5a

Please sign in to comment.