Skip to content

Commit

Permalink
fix type checks
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamesfox1 committed Sep 18, 2023
1 parent 362b949 commit 76e9f3f
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pycid/analyze/reasoning_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _path_is_effective(mb: MACIDBase, path: List[str], effective_set: Set[str])


def _directed_effective_path_not_through_set_y(
mb: MACIDBase, start: str, finish: str, effective_set: Set[str], y: Set[str] = None
mb: MACIDBase, start: str, finish: str, effective_set: Set[str], y: Optional[Set[str]] = None
) -> bool:
"""Check whether a directed effective path exists that doesn't pass through any of the nodes in the set y."""
if y is None:
Expand Down
12 changes: 6 additions & 6 deletions pycid/core/causal_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_tabular_cpd(self, variable: str, relationship: Relationship) -> TabularCP
elif isinstance(relationship, Mapping):
return ConstantCPD(variable, relationship, self.cbn)

def __init__(self, edges: Iterable[Tuple[str, str]] = None, **kwargs: Any):
def __init__(self, edges: Optional[Iterable[Tuple[str, str]]] = None, **kwargs: Any):
"""Initialize a Causal Bayesian Network
Parameters
Expand Down Expand Up @@ -129,7 +129,7 @@ def is_structural_causal_model(self) -> bool:
return True

def query(
self, query: Iterable[str], context: Dict[str, Outcome], intervention: Dict[str, Outcome] = None
self, query: Iterable[str], context: Dict[str, Outcome], intervention: Optional[Dict[str, Outcome]] = None
) -> BeliefPropagation:
"""Return P(query|context, do(intervention))*P(context | do(intervention)).
Expand Down Expand Up @@ -194,7 +194,7 @@ def expected_value(
self,
variables: Iterable[str],
context: Dict[str, Outcome],
intervention: Dict[str, Outcome] = None,
intervention: Optional[Dict[str, Outcome]] = None,
) -> List[float]:
"""Compute the expected value of a real-valued variable for a given context,
under an optional intervention
Expand Down Expand Up @@ -266,9 +266,9 @@ def _get_label(self, node: str) -> str:

def draw(
self,
node_color: Callable[[str], Union[str, np.ndarray]] = None,
node_shape: Callable[[str], str] = None,
node_label: Callable[[str], str] = None,
node_color: Optional[Callable[[str], Union[str, np.ndarray]]] = None,
node_shape: Optional[Callable[[str], str]] = None,
node_label: Optional[Callable[[str], str]] = None,
layout: Optional[Callable[[Any], Dict[Any, Any]]] = None,
) -> None:
"""
Expand Down
8 changes: 4 additions & 4 deletions pycid/core/cpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
cbn: CausalBayesianNetwork,
domain: Optional[Sequence[Outcome]] = None,
state_names: Optional[Mapping[str, List]] = None,
label: str = None,
label: Optional[str] = None,
) -> None:
"""Initialize StochasticFunctionCPD with a variable name and a stochastic function.
Expand Down Expand Up @@ -130,7 +130,7 @@ def stochastic_function(self, **pv: Outcome) -> Mapping[Outcome, float]:
else:
return {ret: 1}

def compute_label(self, function: Callable = None) -> str:
def compute_label(self, function: Optional[Callable] = None) -> str:
"""Try to generate a string that succinctly describes the relationship"""
function = function if function is not None else self.func
if hasattr(function, "__name__") and function.__name__ != "<lambda>":
Expand Down Expand Up @@ -224,7 +224,7 @@ def __init__(
variable: str,
dictionary: Mapping,
cbn: CausalBayesianNetwork,
domain: Sequence[Outcome] = None,
domain: Optional[Sequence[Outcome]] = None,
label: Optional[str] = None,
):
super().__init__(variable, lambda **pv: dictionary, cbn, domain=domain, label=label or str(dictionary))
Expand Down Expand Up @@ -267,7 +267,7 @@ def discrete_uniform(domain: List[Outcome]) -> Dict[Outcome, float]:


def noisy_copy(
value: Outcome, probability: float = 0.9, domain: List[Outcome] = None
value: Outcome, probability: float = 0.9, domain: Optional[List[Outcome]] = None
) -> Dict[Outcome, Optional[float]]:
"""specify a variable's CPD as copying the value of some other variable with a certain probability."""
dist = dict.fromkeys(domain) if domain else {}
Expand Down
8 changes: 4 additions & 4 deletions pycid/core/get_paths.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, Iterator, List, Sequence, Set, Tuple
from typing import Callable, Iterable, Iterator, List, Optional, Sequence, Set, Tuple

import networkx as nx

Expand Down Expand Up @@ -184,7 +184,7 @@ def _get_path_edges(cbn: CausalBayesianNetwork, path: Sequence[str]) -> List[Tup
return structure


def is_active_path(cbn: CausalBayesianNetwork, path: Sequence[str], observed: Set[str] = None) -> bool:
def is_active_path(cbn: CausalBayesianNetwork, path: Sequence[str], observed: Optional[Set[str]] = None) -> bool:
"""
Check if a specifc path remains active given the 'observed' set of variables.
"""
Expand Down Expand Up @@ -213,7 +213,7 @@ def is_active_path(cbn: CausalBayesianNetwork, path: Sequence[str], observed: Se


def is_active_indirect_frontdoor_trail(
cbn: CausalBayesianNetwork, start_node: str, end_node: str, observed: Set[str] = None
cbn: CausalBayesianNetwork, start_node: str, end_node: str, observed: Optional[Set[str]] = None
) -> bool:
"""
checks whether an active indirect frontdoor path exists given the 'observed' set of variables.
Expand All @@ -240,7 +240,7 @@ def is_active_indirect_frontdoor_trail(


def is_active_backdoor_trail(
cbn: CausalBayesianNetwork, start_node: str, end_node: str, observed: Set[str] = None
cbn: CausalBayesianNetwork, start_node: str, end_node: str, observed: Optional[Set[str]] = None
) -> bool:
"""
Returns true if there is a backdoor path that's active given the 'observed' set of nodes.
Expand Down
16 changes: 8 additions & 8 deletions pycid/core/macid_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def to_tabular_cpd(self, variable: str, relationship: Union[Relationship, Sequen

def __init__(
self,
edges: Iterable[Tuple[str, str]] = None,
agent_decisions: Mapping[AgentLabel, List[str]] = None,
agent_utilities: Mapping[AgentLabel, List[str]] = None,
edges: Optional[Iterable[Tuple[str, str]]] = None,
agent_decisions: Optional[Mapping[AgentLabel, List[str]]] = None,
agent_utilities: Optional[Mapping[AgentLabel, List[str]]] = None,
**kwargs: Any,
):
"""Initialize a new MACIDBase instance.
Expand Down Expand Up @@ -142,7 +142,7 @@ def add_cpds(self, *cpds: TabularCPD, **relationships: Union[Relationship, List[
super().add_cpds(*cpds, **relationships)

def query(
self, query: Iterable[str], context: Dict[str, Outcome], intervention: Dict[str, Outcome] = None
self, query: Iterable[str], context: Dict[str, Outcome], intervention: Optional[Dict[str, Outcome]] = None
) -> BeliefPropagation:
"""Return P(query|context, do(intervention))*P(context | do(intervention)).
Expand Down Expand Up @@ -184,7 +184,7 @@ def query(
return super().query(query, context, intervention)

def expected_utility(
self, context: Dict[str, Outcome], intervention: Dict[str, Outcome] = None, agent: AgentLabel = 0
self, context: Dict[str, Outcome], intervention: Optional[Dict[str, Outcome]] = None, agent: AgentLabel = 0
) -> float:
"""Compute the expected utility of an agent for a given context and optional intervention
Expand Down Expand Up @@ -286,15 +286,15 @@ def pure_decision_rules(self, decision: str) -> Iterator[StochasticFunctionCPD]:

# We begin by representing each possible decision rule as a tuple of outcomes, with
# one element for each possible decision context
number_of_decision_contexts = int(np.product(parent_cardinalities))
number_of_decision_contexts = int(np.prod(parent_cardinalities))
functions_as_tuples = itertools.product(domain, repeat=number_of_decision_contexts)

def arg2idx(pv: Dict[str, Outcome]) -> int:
"""Convert a decision context into an index for the function list"""
idx = 0
for i, parent in enumerate(parents):
name_to_no: Dict[Outcome, int] = self.get_cpds(parent).name_to_no[parent]
idx += name_to_no[pv[parent]] * int(np.product(parent_cardinalities[:i]))
idx += name_to_no[pv[parent]] * int(np.prod(parent_cardinalities[:i]))
assert 0 <= idx <= number_of_decision_contexts
return idx

Expand Down Expand Up @@ -436,7 +436,7 @@ def _get_color(self, node: str) -> Union[np.ndarray, str]:
Assign a unique colour to each new agent's decision and utility nodes
"""
agents = list(self.agents)
colors = cm.rainbow(np.linspace(0, 1, len(agents)))
colors = cm.rainbow(np.linspace(0, 1, len(agents))) # type: ignore
try:
agent = self.decision_agent[node]
except KeyError:
Expand Down
6 changes: 3 additions & 3 deletions pycid/core/relevance_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Sequence
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence

import matplotlib.cm as cm
import matplotlib.pyplot as plt
Expand All @@ -21,7 +21,7 @@ class RelevanceGraph(nx.DiGraph):
- an edge D -> D' exists iff D' is r-reachable from D (ie D strategically or probabilistically relies on D')
"""

def __init__(self, cid: MACIDBase, decisions: Iterable[str] = None):
def __init__(self, cid: MACIDBase, decisions: Optional[Iterable[str]] = None):
super().__init__()
if decisions is None:
decisions = cid.decisions
Expand All @@ -47,7 +47,7 @@ def get_sccs(self) -> List[set]:

def _set_color_scc(self, node: str, sccs: Sequence[Any]) -> np.ndarray:
"Assign a unique color to the set of nodes in each SCC."
colors = cm.rainbow(np.linspace(0, 1, len(sccs)))
colors = cm.rainbow(np.linspace(0, 1, len(sccs))) # type: ignore
scc_index = 0
for idx, scc in enumerate(sccs):
if node in scc:
Expand Down
8 changes: 4 additions & 4 deletions pycid/examples/story_macids.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def robot_warehouse() -> MACID:
("B", "R"),
("B", "D2"),
("D2", "R"),
("D2", "O"),
("O", "U1"),
("D2", "Ob"),
("Ob", "U1"),
("R", "U2"),
],
agent_decisions={
Expand All @@ -258,8 +258,8 @@ def robot_warehouse() -> MACID:
Q=lambda D1: noisy_copy(D1, domain=[0, 1]),
B=lambda D1: noisy_copy(D1, probability=0.3, domain=[0, 1]),
R=lambda B, D2: int(not B or D2),
O=lambda D2: noisy_copy(D2, probability=0.6, domain=[0, 1]),
U1=lambda Q, B, O: int(Q and not O) - int(B),
Ob=lambda D2: noisy_copy(D2, probability=0.6, domain=[0, 1]),
U1=lambda Q, B, Ob: int(Q and not Ob) - int(B),
U2=lambda R: R,
)
return macid
Expand Down
2 changes: 1 addition & 1 deletion pycid/export/gambit.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def behavior_to_cpd(
macid: MACIDBase,
parents_to_infoset: Mapping[Tuple[Hashable, Tuple[Tuple[str, Any], ...]], pygambit.Infoset],
behavior: pygambit.lib.libgambit.MixedStrategyProfile,
decisions_in_sg: Union[KeysView[str], Set[str]] = None,
decisions_in_sg: Optional[Union[KeysView[str], Set[str]]] = None,
) -> List[StochasticFunctionCPD]:
"""Convert a pygambit behavior strategy to list of CPDs for each decision node.
Args:
Expand Down
6 changes: 4 additions & 2 deletions pycid/random/random_cpd.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import contextlib
from typing import Iterator, Mapping, Sequence
from typing import Iterator, Mapping, Optional, Sequence

import numpy as np

Expand All @@ -23,7 +23,9 @@ class RandomCPD:
Sample a random CPD, with outcomes in the given domain
"""

def __init__(self, domain: Sequence[Outcome] = None, smoothness: float = 1.0, seed: int = None) -> None:
def __init__(
self, domain: Optional[Sequence[Outcome]] = None, smoothness: float = 1.0, seed: Optional[int] = None
) -> None:
"""
Parameters
----------
Expand Down

0 comments on commit 76e9f3f

Please sign in to comment.