Skip to content

Commit

Permalink
add suuuper basic cli
Browse files Browse the repository at this point in the history
  • Loading branch information
LilithWittmann committed Sep 29, 2023
1 parent 5182c1f commit 67a4b6e
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 29 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# causalib

Causal discovery made easy.

## Dev usage

```bash
poetry run python main.py excute --help
poetry run python main.py pipelines/pc.json tests/fixtures/rki.json
```
24 changes: 22 additions & 2 deletions graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import importlib
import itertools
import json
from abc import ABC
from dataclasses import dataclass
from typing import List, Optional, Dict, Set
Expand Down Expand Up @@ -26,6 +28,10 @@
LogicStepInterface,
)

import logging

logger = logging.getLogger(__name__)

DEFAULT_INDEPENDENCE_TEST = CorrelationCoefficientTest


Expand All @@ -46,6 +52,7 @@ class UndirectedGraph(BaseGraphInterface):
nodes: Dict[str, Node]
edges: Dict[Node, Dict[Node, Dict]]
edge_history: Dict[Set[Node], List[CorrelationTestResult]]
action_history: List[Dict[str, List[CorrelationTestResult]]]

def __init__(self):
self.nodes = {}
Expand Down Expand Up @@ -240,12 +247,19 @@ def execute_pipeline_steps(self):
Execute all pipeline_steps
:return:
"""
action_history = []

for filter in self.pipeline_steps:
if isinstance(filter, LogicStepInterface):
filter.execute(self.graph, self)
continue

self.execute_pipeline_step(filter)
result = self.execute_pipeline_step(filter)
action_history.append(
{"step": filter.__class__.__name__, "actions": result}
)

self.graph.action_history = action_history

def execute_pipeline_step(self, test_fn: IndependenceTestInterface):
"""
Expand All @@ -255,6 +269,7 @@ def execute_pipeline_step(self, test_fn: IndependenceTestInterface):
:return:
"""
combinations = []
actions_taken = []

if type(test_fn.NUM_OF_COMPARISON_ELEMENTS) is int:
combinations = itertools.combinations(
Expand Down Expand Up @@ -307,7 +322,11 @@ def execute_pipeline_step(self, test_fn: IndependenceTestInterface):
if i is None:
continue
if i.x is not None and i.y is not None:
print(f"Action: {i.action} on {i.x.name} and {i.y.name}")
logger.info(f"Action: {i.action} on {i.x.name} and {i.y.name}")

# add the action to the actions history
actions_taken.append(i)

# execute the action returned by the test
if i.action == CorrelationTestResultAction.REMOVE_EDGE_UNDIRECTED:
self.graph.remove_edge(i.x, i.y)
Expand All @@ -322,6 +341,7 @@ def execute_pipeline_step(self, test_fn: IndependenceTestInterface):
elif i.action == CorrelationTestResultAction.REMOVE_EDGE_DIRECTED:
self.graph.remove_directed_edge(i.x, i.y)
self.graph.add_edge_history(i.x, i.y, i)
return actions_taken


def graph_model_factory(
Expand Down
31 changes: 17 additions & 14 deletions independence_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@

from utils import get_correlation

import logging

logger = logging.getLogger(__name__)

from interfaces import (
IndependenceTestInterface,
BaseGraphInterface,
Expand Down Expand Up @@ -72,8 +76,7 @@ def test(
t, critical_t = get_t_and_critial_t(
sample_size, nb_of_control_vars, corr, self.threshold
)
print("critical_t")
print(t, critical_t)
logger.debug(f"t, critical_t {t} {critical_t}")
if abs(t) < critical_t:
return CorrelationTestResult(
x=x,
Expand Down Expand Up @@ -117,7 +120,7 @@ def test(
cor_xz = graph.edge_value(x, z)["correlation"]
cor_yz = graph.edge_value(y, z)["correlation"]
except KeyError:
print("k_error")
logger.debug(f"KeyError {x} {y} {z}")
return CorrelationTestResult(
x=x, y=y, action=CorrelationTestResultAction.DO_NOTHING
)
Expand All @@ -139,8 +142,7 @@ def test(
t, critical_t = get_t_and_critial_t(
sample_size, nb_of_control_vars, par_corr, self.threshold
)
print("critical_t")
print(t, critical_t)
logger.debug(f"t, critical_t {t} {critical_t}")

if abs(t) < critical_t:
return CorrelationTestResult(
Expand All @@ -156,8 +158,8 @@ def test(

class ExtendedPartialCorrelationTest(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = ComparisonSettings(min=5, max=AS_MANY_AS_FIELDS)
CHUNK_SIZE_PARALLEL_PROCESSING = 1
PARALLEL = False
CHUNK_SIZE_PARALLEL_PROCESSING = 1000
PARALLEL = True

def test(
self, nodes: List[str], graph: BaseGraphInterface
Expand All @@ -184,7 +186,7 @@ def test(
if idx not in exclude_indices
]
par_corr = get_correlation(x, y, other_nodes)
print(par_corr)
logger.debug(f"par_corr {par_corr}")
# make t test for independence of a and y given other nodes
t, critical_t = get_t_and_critial_t(
sample_size, nb_of_control_vars, par_corr, self.threshold
Expand Down Expand Up @@ -252,8 +254,8 @@ def test(

class ExtendedPartialCorrelationTest2(IndependenceTestInterface):
NUM_OF_COMPARISON_ELEMENTS = ComparisonSettings(min=4, max=AS_MANY_AS_FIELDS)
CHUNK_SIZE_PARALLEL_PROCESSING = 50
PARALLEL = False
CHUNK_SIZE_PARALLEL_PROCESSING = 100
PARALLEL = True

def test(
self, nodes: List[str], graph: BaseGraphInterface
Expand Down Expand Up @@ -289,7 +291,7 @@ def test(
if i == k:
continue

print(partial_correlation_coefficients[i][k])
# print(partial_correlation_coefficients[i][k])
try:
t, critical_t = get_t_and_critial_t(
sample_size,
Expand All @@ -299,8 +301,9 @@ def test(
)
except ValueError:
# TODO: @sof fiugre out why this happens
print("ValueError")
print(partial_correlation_coefficients[i][k])
logger.debug(
f"ValueError {i} {k} ({partial_correlation_coefficients[i][k]})"
)
continue

if abs(t) < critical_t:
Expand Down Expand Up @@ -329,7 +332,7 @@ class PlaceholderTest(IndependenceTestInterface):
def test(
self, nodes: Tuple[str], graph: BaseGraphInterface
) -> List[CorrelationTestResult] | CorrelationTestResult:
print("PlaceholderTest")
logger.debug(f"PlaceholderTest {nodes}")
return CorrelationTestResult(
x=None, y=None, action=CorrelationTestResultAction.DO_NOTHING, data={}
)
23 changes: 18 additions & 5 deletions interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Dict
import logging

logger = logging.getLogger(__name__)

DEFAULT_THRESHOLD = 0.01

Expand All @@ -19,12 +22,15 @@ class NodeInterface:
name: str
values: List[float]

def to_dict(self):
return self.name


class CorrelationTestResultAction(enum.Enum):
REMOVE_EDGE_UNDIRECTED = 1
UPDATE_EDGE = 2
DO_NOTHING = 3
REMOVE_EDGE_DIRECTED = 4
class CorrelationTestResultAction(enum.StrEnum):
REMOVE_EDGE_UNDIRECTED = "REMOVE_EDGE_UNDIRECTED"
UPDATE_EDGE = "UPDATE_EDGE"
DO_NOTHING = "DO_NOTHING"
REMOVE_EDGE_DIRECTED = "REMOVE_EDGE_DIRECTED"


@dataclass
Expand All @@ -34,6 +40,13 @@ class CorrelationTestResult:
action: CorrelationTestResultAction
data: Dict = None

def to_dict(self):
return {
"x": self.x.to_dict(),
"y": self.y.to_dict(),
"action": self.action.name,
}


class BaseGraphInterface(ABC):
nodes: Dict[str, NodeInterface]
Expand Down
79 changes: 79 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import importlib
import json
from json import JSONEncoder

import typer

from graph import graph_model_factory

app = typer.Typer()
import logging


def load_json(pipeline_file: str):
with open(pipeline_file, "r") as file:
pipeline = json.loads(file.read())
return pipeline


def create_pipeline(pipeline_config: dict):
pipeline = []
for step in pipeline_config["steps"]:
path = ".".join(step["step"].split(".")[:-1])
cls = step["step"].split(".")[-1]
st_function = importlib.import_module(path)
st_function = getattr(st_function, cls)
if "params" not in step.keys():
pipeline.append(st_function())
else:
pipeline.append(st_function(**step["params"]))

return pipeline


def show_edges(graph):
for u in graph.edges:
for v in graph.edges[u]:
print(f"{u.name} -> {v.name}: {graph.edges[u][v]}")


class MyJSONEncoder(JSONEncoder):
def default(self, obj):
return obj.to_dict()


@app.command()
def execute(
pipeline_file: str,
data_file: str,
graph_actions_save_file: str = None,
log_level: str = "ERROR",
):
typer.echo(f"💾 Loading pipeline from {pipeline_file}")
pipeline_config = load_json(pipeline_file)
# set log level
logging.basicConfig(level=log_level)
pipeline = create_pipeline(pipeline_config)
model = graph_model_factory(pipeline_steps=pipeline)()

model.create_graph_from_data(load_json(data_file))
model.create_all_possible_edges()
typer.echo("🕵🏻‍♀ Executing pipeline steps...")
model.execute_pipeline_steps()
show_edges(model.graph)
if graph_actions_save_file:
typer.echo(f"💾 Saving graph actions to {graph_actions_save_file}")
with open(graph_actions_save_file, "w") as file:
file.write(
json.dumps(model.graph.action_history, cls=MyJSONEncoder, indent=4)
)
# model.save_graph_actions(graph_actions_save_file)


@app.command()
def visualize(output: str):
raise NotImplementedError()


if __name__ == "__main__":
app()
28 changes: 28 additions & 0 deletions pipelines/pc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"name": "PC",
"steps": [
{
"step": "independence_tests.CalculateCorrelations",
"params": {
}
},
{
"step": "independence_tests.CorrelationCoefficientTest",
"params": {
"threshold": 0.1
}
},
{
"step": "independence_tests.ExtendedPartialCorrelationTest2",
"params": {
"threshold": 0.1
}
},
{
"step": "independence_tests.UnshieldedTriplesTest",
"params": {
}
}
]

}
23 changes: 22 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ numpy = "^1.26.0"
scipy = "^1.11.2"
black = "^23.9.1"
pre-commit = "^3.4.0"
typer = "^0.9.0"


[tool.poetry.group.dev.dependencies]
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/rki.json

Large diffs are not rendered by default.

Loading

0 comments on commit 67a4b6e

Please sign in to comment.