Skip to content

Commit

Permalink
test(graph): add basic tests for the graph itself (+ docs, restructur…
Browse files Browse the repository at this point in the history
…ing, …)
  • Loading branch information
LilithWittmann committed Oct 24, 2023
1 parent b5e9e5c commit d26bd20
Show file tree
Hide file tree
Showing 10 changed files with 421 additions and 93 deletions.
1 change: 1 addition & 0 deletions causy/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pc import PC, ParallelPC
File renamed without changes.
2 changes: 1 addition & 1 deletion causy/exit_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class ExitOnNoActions(ExitConditionInterface):
def check(self, graph, graph_model_instance_, actions_taken, iteration):
def check(self, graph, graph_model_instance_, actions_taken, iteration) -> bool:
"""
Check if there are no actions taken in the last iteration and if so, break the loop
If it is the first iteration, do not break the loop (we need to execute the first step)
Expand Down
168 changes: 123 additions & 45 deletions causy/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,17 @@ def __hash__(self):
return hash(self.id)


class UndirectedGraphError(Exception):
class GraphError(Exception):
pass


class UndirectedGraph(BaseGraphInterface):
class Graph(BaseGraphInterface):
"""
The graph represents the internal data structure of causy. It is a simple graph with nodes and edges.
But it supports to be handled as a directed graph, undirected graph and bidirected graph, which is important to implement different algorithms in different stages.
It also stores the history of the actions taken on the graph.
"""

nodes: Dict[str, Node]
edges: Dict[str, Dict[str, Dict]]
edge_history: Dict[Tuple[str, str], List[TestResult]]
Expand All @@ -60,10 +66,15 @@ def add_edge(self, u: Node, v: Node, value: Dict):
:param v: v node
:return:
"""

if u.id not in self.nodes:
raise UndirectedGraphError(f"Node {u} does not exist")
raise GraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
raise GraphError(f"Node {v} does not exist")

if u.id == v.id:
raise GraphError("Self loops are currently not allowed")

if u.id not in self.edges:
self.edges[u.id] = {}
if v.id not in self.edges:
Expand Down Expand Up @@ -94,33 +105,34 @@ def retrieve_edge_history(
return [i for i in self.edge_history[(u.id, v.id)] if i.action == action]

def add_edge_history(self, u, v, action: TestResult):
"""
Add an action to the edge history
:param u:
:param v:
:param action:
:return:
"""
if (u.id, v.id) not in self.edge_history:
self.edge_history[(u.id, v.id)] = []
self.edge_history[(u.id, v.id)].append(action)

def remove_edge(self, u: Node, v: Node):
"""
Remove an edge from the graph
Remove an edge from the graph (undirected)
:param u: u node
:param v: v node
:return:
"""
if u.id not in self.nodes:
raise UndirectedGraphError(f"Node {u} does not exist")
raise GraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
if u.id not in self.edges:
raise UndirectedGraphError(f"Node {u} does not have any nodes")
if v.id not in self.edges:
raise UndirectedGraphError(f"Node {v} does not have any nodes")
raise GraphError(f"Node {v} does not exist")

if v.id not in self.edges[u.id]:
return
del self.edges[u.id][v.id]
if u.id in self.edges and v.id in self.edges[u.id]:
del self.edges[u.id][v.id]

if u.id not in self.edges[v.id]:
return
del self.edges[v.id][u.id]
if v.id in self.edges and u.id in self.edges[v.id]:
del self.edges[v.id][u.id]

def remove_directed_edge(self, u: Node, v: Node):
"""
Expand All @@ -130,33 +142,39 @@ def remove_directed_edge(self, u: Node, v: Node):
:return:
"""
if u.id not in self.nodes:
raise UndirectedGraphError(f"Node {u} does not exist")
raise GraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
if u.id not in self.edges:
raise UndirectedGraphError(f"Node {u} does not have any nodes")
if v.id not in self.edges:
raise UndirectedGraphError(f"Node {v} does not have any nodes")
raise GraphError(f"Node {v} does not exist")

if u.id not in self.edges:
return # no edges from u
if v.id not in self.edges[u.id]:
return

del self.edges[u.id][v.id]

def update_edge(self, u: Node, v: Node, value: Dict):
"""
Update an edge in the graph
Update an undirected edge in the graph
:param u: u node
:param v: v node
:return:
"""

if u.id not in self.nodes:
raise UndirectedGraphError(f"Node {u} does not exist")
raise GraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
raise GraphError(f"Node {v} does not exist")
if u.id not in self.edges:
raise UndirectedGraphError(f"Node {u} does not have any edges")
raise GraphError(f"Node {u} does not have any edges")
if v.id not in self.edges:
raise UndirectedGraphError(f"Node {v} does not have any edges")
raise GraphError(f"Node {v} does not have any edges")

if u.id not in self.edges[v.id]:
raise GraphError(f"There is no edge from {u} to {v}")

if v.id not in self.edges[u.id]:
raise GraphError(f"There is no edge from {v} to {u}")

self.edges[u.id][v.id] = value
self.edges[v.id][u.id] = value
Expand All @@ -168,16 +186,16 @@ def update_directed_edge(self, u: Node, v: Node, value: Dict):
:param v: v node
:return:
"""
if u.name not in self.nodes:
raise UndirectedGraphError(f"Node {u} does not exist")
if v.name not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
if u not in self.edges:
raise UndirectedGraphError(f"Node {u} does not have any edges")
if v not in self.edges[u]:
raise UndirectedGraphError(f"There is no edge from {u} to {v}")
if u.id not in self.nodes:
raise GraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise GraphError(f"Node {v} does not exist")
if u.id not in self.edges:
raise GraphError(f"Node {u} does not have any edges")
if v.id not in self.edges[u.id]:
raise GraphError(f"There is no edge from {u} to {v}")

self.edges[u][v] = value
self.edges[u.id][v.id] = value

def edge_exists(self, u: Node, v: Node):
"""
Expand Down Expand Up @@ -254,7 +272,19 @@ def bidirected_edge_exists(self, u: Node, v: Node):
return True
return False

def edge_value(self, u: Node, v: Node):
def edge_value(self, u: Node, v: Node) -> Optional[Dict]:
"""
retrieve the value of an edge
:param u:
:param v:
:return:
"""

if u.id not in self.edges:
return None
if v.id not in self.edges[u.id]:
return None

return self.edges[u.id][v.id]

def add_node(self, name: str, values: List[float], id_: str = None) -> Node:
Expand All @@ -269,9 +299,17 @@ def add_node(self, name: str, values: List[float], id_: str = None) -> Node:
"""
if id_ is None:
id_ = str(uuid4())
node = Node(
name=name, id=id_, values=torch.tensor(values, dtype=torch.float64)
)

if id_ in self.nodes:
raise ValueError(f"Node with id {id_} already exists")

try:
tensor_values = torch.tensor(values, dtype=torch.float64)
except TypeError as e:
raise ValueError(f"Currently only numeric values are supported. {e}")

node = Node(name=name, id=id_, values=tensor_values)

self.nodes[id_] = node
return node

Expand Down Expand Up @@ -330,6 +368,18 @@ def unpack_run(args):


class AbstractGraphModel(GraphModelInterface, ABC):
"""
The graph model is the main class of causy. It is responsible for creating a graph from data and executing the pipeline_steps.
The graph model is responsible for the following tasks:
- Create a graph from data (create_graph_from_data)
- Execute the pipeline_steps (execute_pipeline_steps)
- Take actions on the graph (execute_pipeline_step & _take_action which is called by execute_pipeline_step)
It also initializes and takes care of the multiprocessing pool.
"""

pipeline_steps: List[IndependenceTestInterface]
graph: BaseGraphInterface
pool: mp.Pool
Expand Down Expand Up @@ -361,7 +411,7 @@ def create_graph_from_data(self, data: List[Dict[str, float]]):
for key in keys:
nodes[key].append(row[key])

graph = UndirectedGraph()
graph = Graph()
for key in keys:
graph.add_node(key, nodes[key])

Expand All @@ -370,7 +420,8 @@ def create_graph_from_data(self, data: List[Dict[str, float]]):

def create_all_possible_edges(self):
"""
Create all possible nodes
Create all possible edges on a graph
TODO: replace me with the skeleton builders
:return:
"""
for u in self.graph.nodes.values():
Expand Down Expand Up @@ -400,10 +451,27 @@ def execute_pipeline_steps(self):
return action_history

def _format_yield(self, test_fn, graph, generator):
"""
Format the yield for the parallel processing
:param test_fn: the pipeline_step test function
:param graph: the graph
:param generator: the generator object which generates the combinations
:return: yields the test function with its inputs
"""
for i in generator:
yield [test_fn, [*i], graph]

def _take_action(self, results):
"""
Take the actions returned by the test
In causy changes on the graph are not executed directly. Instead, the test returns an action which should be executed on the graph.
This is done to make it possible to execute the tests in parallel as well as to decide proactively at which point in the decisions taken by the pipeline step should be executed.
Actions are returned by the test and are executed on the graph. The actions are stored in the action history to make it possible to revert the actions or use them in a later step.
:param results:
:return:
"""
actions_taken = []
for result_items in results:
if result_items is None:
Expand Down Expand Up @@ -460,7 +528,7 @@ def _take_action(self, results):

def execute_pipeline_step(self, test_fn: IndependenceTestInterface):
"""
Filter the graph
Execute a single pipeline_step on the graph. either in parallel or in a single process depending on the test_fn.PARALLEL flag
:param test_fn: the test function
:param threshold: the threshold
:return:
Expand Down Expand Up @@ -511,7 +579,7 @@ def graph_model_factory(
pipeline_steps: Optional[List[IndependenceTestInterface]] = None,
) -> type[AbstractGraphModel]:
"""
Create a graph model
Create a graph model based on a List of pipeline_steps
:param pipeline_steps: a list of pipeline_steps which should be applied to the graph
:return: the graph model
"""
Expand All @@ -524,9 +592,19 @@ def __init__(self):


class Loop(LogicStepInterface):
"""
A loop which executes a list of pipeline_steps until the exit_condition is met.
"""

def execute(
self, graph: BaseGraphInterface, graph_model_instance_: GraphModelInterface
):
"""
Executes the loop til self.exit_condition is met
:param graph:
:param graph_model_instance_:
:return:
"""
n = 0
steps = None
while not self.exit_condition(
Expand Down
2 changes: 1 addition & 1 deletion causy/independence_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test(
cor_xy = graph.edge_value(x, y)["correlation"]
cor_xz = graph.edge_value(x, z)["correlation"]
cor_yz = graph.edge_value(y, z)["correlation"]
except KeyError:
except (KeyError, TypeError):
return

numerator = cor_xy - cor_xz * cor_yz
Expand Down
4 changes: 2 additions & 2 deletions causy/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class ExitConditionInterface(ABC):
def check(
self,
graph: BaseGraphInterface,
graph_model_instance_: dict,
graph_model_instance_: GraphModelInterface,
actions_taken: List[TestResult],
iteration: int,
) -> bool:
Expand All @@ -255,7 +255,7 @@ def check(
def __call__(
self,
graph: BaseGraphInterface,
graph_model_instance_: dict,
graph_model_instance_: GraphModelInterface,
actions_taken: List[TestResult],
iteration: int,
) -> bool:
Expand Down
Loading

0 comments on commit d26bd20

Please sign in to comment.