Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to pytorch #7

Merged
merged 18 commits into from
Oct 22, 2023
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions causy/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from causy.exit_conditions import ExitOnNoActions
from causy.generators import PairsWithNeighboursGenerator
from causy.graph import graph_model_factory, Loop
from causy.independence_tests import (
CalculateCorrelations,
CorrelationCoefficientTest,
PartialCorrelationTest,
ExtendedPartialCorrelationTestMatrix,
)
from causy.interfaces import AS_MANY_AS_FIELDS, ComparisonSettings
from causy.orientation_tests import (
ColliderTest,
NonColliderTest,
@@ -32,3 +34,32 @@
),
]
)


ParallelPC = graph_model_factory(
pipeline_steps=[
CalculateCorrelations(),
CorrelationCoefficientTest(threshold=0.01),
PartialCorrelationTest(threshold=0.01),
ExtendedPartialCorrelationTestMatrix(
threshold=0.01,
chunk_size_parallel_processing=1000,
parallel=True,
generator=PairsWithNeighboursGenerator(
chunked=False,
shuffle_combinations=True,
comparison_settings=ComparisonSettings(min=4, max=AS_MANY_AS_FIELDS),
),
),
ColliderTest(),
Loop(
pipeline_steps=[
NonColliderTest(),
FurtherOrientTripleTest(),
OrientQuadrupleTest(),
FurtherOrientQuadrupleTest(),
],
exit_condition=ExitOnNoActions(),
),
]
)
8 changes: 4 additions & 4 deletions causy/cli.py
Original file line number Diff line number Diff line change
@@ -95,12 +95,12 @@ def execute(
edges = []
for edge in retrieve_edges(model.graph):
print(
f"{edge[0].name} -> {edge[1].name}: {model.graph.edges[edge[0]][edge[1]]}"
f"{model.graph.nodes[edge[0]].name} -> {model.graph.nodes[edge[1]].name}: {model.graph.edges[edge[0]][edge[1]]}"
)
edges.append(
{
"from": edge[0].to_dict(),
"to": edge[1].to_dict(),
"from": model.graph.nodes[edge[0]].to_dict(),
"to": model.graph.nodes[edge[1]].to_dict(),
"value": model.graph.edges[edge[0]][edge[1]],
}
)
@@ -127,7 +127,7 @@ def execute(
n_graph = nx.DiGraph()
for u in model.graph.edges:
for v in model.graph.edges[u]:
n_graph.add_edge(u.name, v.name)
n_graph.add_edge(model.graph.nodes[u].name, model.graph.nodes[v].name)
fig = plt.figure(figsize=(10, 10))
nx.draw(n_graph, with_labels=True, ax=fig.add_subplot(111))
fig.savefig(render_save_file)
51 changes: 43 additions & 8 deletions causy/generators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import itertools
import logging

@@ -8,7 +9,7 @@
GraphModelInterface,
AS_MANY_AS_FIELDS,
)

from causy.utils import serialize_module_name

logger = logging.getLogger(__name__)

@@ -43,6 +44,24 @@ def generate(


class PairsWithNeighboursGenerator(GeneratorInterface):
shuffle_combinations = True
chunked = True

def __init__(
self,
comparison_settings: ComparisonSettings,
chunked: bool = None,
shuffle_combinations: bool = None,
):
super().__init__(comparison_settings, chunked)
if shuffle_combinations is not None:
self.shuffle_combinations = shuffle_combinations

def to_dict(self):
result = super().to_dict()
result["params"]["shuffle_combinations"] = self.shuffle_combinations
return result

def generate(
self, graph: BaseGraphInterface, graph_model_instance_: GraphModelInterface
):
@@ -71,20 +90,36 @@ def generate(
for i in range(start, stop):
logger.debug(f"PairsWithNeighboursGenerator: i={i}")
checked_combinations = set()
for node in graph.edges:
for neighbour in graph.edges[node]:
local_edges = copy.deepcopy(graph.edges)
for node in local_edges:
for neighbour in local_edges[node]:
if (node, neighbour) in checked_combinations:
continue

checked_combinations.add((node, neighbour))
if i == 2:
yield (node.id, neighbour.id)
yield (node, neighbour)
continue

other_neighbours = set(graph.edges[node])
other_neighbours.remove(neighbour)
if neighbour in other_neighbours:
other_neighbours.remove(neighbour)
else:
continue
if len(other_neighbours) + 2 < i:
continue

for k in itertools.combinations(other_neighbours, i):
yield [node.id, neighbour.id] + [ks.id for ks in k]
combinations = itertools.combinations(other_neighbours, i)
if self.shuffle_combinations:
combinations = list(combinations)
import random

random.shuffle(combinations)

if self.chunked:
chunk = []
for k in combinations:
chunk.append([node, neighbour] + [ks for ks in k])
yield chunk
else:
for k in combinations:
yield [node, neighbour] + [ks for ks in k]
102 changes: 55 additions & 47 deletions causy/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional, Dict, Set
from typing import List, Optional, Dict, Set, Tuple

import torch
import torch.multiprocessing as mp
@@ -47,8 +47,8 @@ class UndirectedGraphError(Exception):

class UndirectedGraph(BaseGraphInterface):
nodes: Dict[str, Node]
edges: Dict[Node, Dict[Node, Dict]]
edge_history: Dict[Set[Node], List[TestResult]]
edges: Dict[str, Dict[str, Dict]]
edge_history: Dict[Tuple[str, str], List[TestResult]]
action_history: List[Dict[str, List[TestResult]]]

def __init__(self):
@@ -68,16 +68,16 @@ def add_edge(self, u: Node, v: Node, value: Dict):
raise UndirectedGraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
if u not in self.edges:
self.edges[u] = {}
if v not in self.edges:
self.edges[v] = {}
if u.id not in self.edges:
self.edges[u.id] = {}
if v.id not in self.edges:
self.edges[v.id] = {}

self.edges[u][v] = value
self.edges[v][u] = value
self.edges[u.id][v.id] = value
self.edges[v.id][u.id] = value

self.edge_history[(u, v)] = []
self.edge_history[(v, u)] = []
self.edge_history[(u.id, v.id)] = []
self.edge_history[(v.id, u.id)] = []

def retrieve_edge_history(
self, u, v, action: TestResultAction = None
@@ -90,14 +90,14 @@ def retrieve_edge_history(
:return:
"""
if action is None:
return self.edge_history[(u, v)]
return self.edge_history[(u.id, v.id)]

return [i for i in self.edge_history[(u, v)] if i.action == action]
return [i for i in self.edge_history[(u.id, v.id)] if i.action == action]

def add_edge_history(self, u, v, action: TestResultAction):
if (u, v) not in self.edge_history:
self.edge_history[(u, v)] = []
self.edge_history[(u, v)].append(action)
if (u.id, v.id) not in self.edge_history:
self.edge_history[(u.id, v.id)] = []
self.edge_history[(u.id, v.id)].append(action)

def remove_edge(self, u: Node, v: Node):
"""
@@ -110,18 +110,18 @@ def remove_edge(self, u: Node, v: Node):
raise UndirectedGraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
if u not in self.edges:
if u.id not in self.edges:
raise UndirectedGraphError(f"Node {u} does not have any nodes")
if v not in self.edges:
if v.id not in self.edges:
raise UndirectedGraphError(f"Node {v} does not have any nodes")

if v not in self.edges[u]:
if v.id not in self.edges[u.id]:
return
del self.edges[u][v]
del self.edges[u.id][v.id]

if u not in self.edges[v]:
if u.id not in self.edges[v.id]:
return
del self.edges[v][u]
del self.edges[v.id][u.id]

def remove_directed_edge(self, u: Node, v: Node):
"""
@@ -134,14 +134,14 @@ def remove_directed_edge(self, u: Node, v: Node):
raise UndirectedGraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
if u not in self.edges:
if u.id not in self.edges:
raise UndirectedGraphError(f"Node {u} does not have any nodes")
if v not in self.edges:
if v.id not in self.edges:
raise UndirectedGraphError(f"Node {v} does not have any nodes")

if v not in self.edges[u]:
if v.id not in self.edges[u.id]:
return
del self.edges[u][v]
del self.edges[u.id][v.id]

def update_edge(self, u: Node, v: Node, value: Dict):
"""
@@ -154,13 +154,13 @@ def update_edge(self, u: Node, v: Node, value: Dict):
raise UndirectedGraphError(f"Node {u} does not exist")
if v.id not in self.nodes:
raise UndirectedGraphError(f"Node {v} does not exist")
if u not in self.edges:
if u.id not in self.edges:
raise UndirectedGraphError(f"Node {u} does not have any edges")
if v not in self.edges:
if v.id not in self.edges:
raise UndirectedGraphError(f"Node {v} does not have any edges")

self.edges[u][v] = value
self.edges[v][u] = value
self.edges[u.id][v.id] = value
self.edges[v.id][u.id] = value

def edge_exists(self, u: Node, v: Node):
"""
@@ -173,9 +173,9 @@ def edge_exists(self, u: Node, v: Node):
return False
if v.id not in self.nodes:
return False
if u in self.edges and v in self.edges[u]:
if u.id in self.edges and v.id in self.edges[u.id]:
return True
if v in self.edges and u in self.edges[v]:
if v.id in self.edges and u.id in self.edges[v.id]:
return True
return False

@@ -190,9 +190,9 @@ def directed_edge_exists(self, u: Node, v: Node):
return False
if v.id not in self.nodes:
return False
if u not in self.edges:
if u.id not in self.edges:
return False
if v not in self.edges[u]:
if v.id not in self.edges[u.id]:
return False
return True

@@ -238,7 +238,7 @@ def bidirected_edge_exists(self, u: Node, v: Node):
return False

def edge_value(self, u: Node, v: Node):
return self.edges[u][v]
return self.edges[u.id][v.id]

def add_node(self, name: str, values: List[float], id: str = None):
"""
@@ -265,8 +265,8 @@ def directed_path_exists(self, u: Node, v: Node):
"""
if self.directed_edge_exists(u, v):
return True
for w in self.edges[u]:
if self.directed_path_exists(w, v):
for w in self.edges[u.id]:
if self.directed_path_exists(self.nodes[w], v):
return True
return False

@@ -280,8 +280,8 @@ def directed_paths(self, u: Node, v: Node):
if self.directed_edge_exists(u, v):
return [[(u, v)]]
paths = []
for w in self.edges[u]:
for path in self.directed_paths(w, v):
for w in self.edges[u.id]:
for path in self.directed_paths(self.nodes[w], v):
paths.append([(u, w)] + path)
return paths

@@ -310,7 +310,6 @@ def unpack_run(args):


class AbstractGraphModel(GraphModelInterface, ABC):

pipeline_steps: List[IndependenceTestInterface]
graph: BaseGraphInterface
pool: mp.Pool
@@ -456,14 +455,23 @@ def execute_pipeline_step(self, test_fn: IndependenceTestInterface):
result = [result]
actions_taken.extend(self._take_action(result))
else:
iterator = [
unpack_run(i)
for i in [
[test_fn, [*i], self.graph]
for i in test_fn.GENERATOR.generate(self.graph, self)
if test_fn.GENERATOR.chunked:
for chunk in test_fn.GENERATOR.generate(self.graph, self):
iterator = [
unpack_run(i)
for i in [[test_fn, [*c], self.graph] for c in chunk]
]
actions_taken.extend(self._take_action(iterator))
else:
iterator = [
unpack_run(i)
for i in [
[test_fn, [*i], self.graph]
for i in test_fn.GENERATOR.generate(self.graph, self)
]
]
]
actions_taken.extend(self._take_action(iterator))
actions_taken.extend(self._take_action(iterator))

self.graph.action_history.append(
{"step": type(test_fn).__name__, "actions": actions_taken}
)
Loading