diff --git a/README.md b/README.md index dc2afee..eaba9ed 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ This runs a web server on port 8000 which allows you to visualize the causal dis Use a default algorithm ```python -from causy.algorithms import PC +from causy.causal_discovery.constraint.algorithms import PC from causy.graph_utils import retrieve_edges model = PC() @@ -90,61 +90,61 @@ from causy.common_pipeline_steps.exit_conditions import ExitOnNoActions from causy.graph_model import graph_model_factory from causy.common_pipeline_steps.logic import Loop from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations -from causy.independence_tests.common import ( - CorrelationCoefficientTest, - PartialCorrelationTest, - ExtendedPartialCorrelationTestMatrix, +from causy.causal_discovery.constraint.independence_tests.common import ( + CorrelationCoefficientTest, + PartialCorrelationTest, + ExtendedPartialCorrelationTestMatrix, ) -from causy.orientation_rules.pc import ( - ColliderTest, - NonColliderTest, - FurtherOrientTripleTest, - OrientQuadrupleTest, - FurtherOrientQuadrupleTest, +from causy.causal_discovery.constraint.orientation_rules.pc import ( + ColliderTest, + NonColliderTest, + FurtherOrientTripleTest, + OrientQuadrupleTest, + FurtherOrientQuadrupleTest, ) -from causy.interfaces import CausyAlgorithm -from causy.algorithms.pc import PC_EDGE_TYPES +from causy.models import Algorithm +from causy.causal_discovery.constraint.algorithms.pc import PC_EDGE_TYPES from causy.graph_utils import retrieve_edges CustomPC = graph_model_factory( - CausyAlgorithm( - pipeline_steps=[ - CalculatePearsonCorrelations(), - CorrelationCoefficientTest(threshold=0.05), - PartialCorrelationTest(threshold=0.05), - ExtendedPartialCorrelationTestMatrix(threshold=0.05), - ColliderTest(), - Loop( + Algorithm( pipeline_steps=[ - NonColliderTest(), - FurtherOrientTripleTest(), - OrientQuadrupleTest(), - FurtherOrientQuadrupleTest(), + CalculatePearsonCorrelations(), + CorrelationCoefficientTest(threshold=0.05), + PartialCorrelationTest(threshold=0.05), + ExtendedPartialCorrelationTestMatrix(threshold=0.05), + ColliderTest(), + Loop( + pipeline_steps=[ + NonColliderTest(), + FurtherOrientTripleTest(), + OrientQuadrupleTest(), + FurtherOrientQuadrupleTest(), + ], + exit_condition=ExitOnNoActions(), + ), ], - exit_condition=ExitOnNoActions(), - ), - ], - name="CustomPC", - edge_types=PC_EDGE_TYPES, - ) + name="CustomPC", + edge_types=PC_EDGE_TYPES, + ) ) model = CustomPC() model.create_graph_from_data( - [ - {"a": 1, "b": 0.3}, - {"a": 0.5, "b": 0.2} - ] + [ + {"a": 1, "b": 0.3}, + {"a": 0.5, "b": 0.2} + ] ) model.create_all_possible_edges() model.execute_pipeline_steps() edges = retrieve_edges(model.graph) for edge in edges: - print( - f"{edge[0].name} -> {edge[1].name}: {model.graph.edges[edge[0]][edge[1]]}" - ) + print( + f"{edge[0].name} -> {edge[1].name}: {model.graph.edges[edge[0]][edge[1]]}" + ) ``` ### Supported algorithms diff --git a/causy/causal_discovery/__init__.py b/causy/causal_discovery/__init__.py new file mode 100644 index 0000000..14aff63 --- /dev/null +++ b/causy/causal_discovery/__init__.py @@ -0,0 +1 @@ +from .constraint.algorithms import AVAILABLE_ALGORITHMS diff --git a/causy/independence_tests/__init__.py b/causy/causal_discovery/constraint/__init__.py similarity index 100% rename from causy/independence_tests/__init__.py rename to causy/causal_discovery/constraint/__init__.py diff --git a/causy/algorithms/__init__.py b/causy/causal_discovery/constraint/algorithms/__init__.py similarity index 100% rename from causy/algorithms/__init__.py rename to causy/causal_discovery/constraint/algorithms/__init__.py diff --git a/causy/algorithms/fci.py b/causy/causal_discovery/constraint/algorithms/fci.py similarity index 100% rename from causy/algorithms/fci.py rename to causy/causal_discovery/constraint/algorithms/fci.py diff --git a/causy/algorithms/pc.py b/causy/causal_discovery/constraint/algorithms/pc.py similarity index 74% rename from causy/algorithms/pc.py rename to causy/causal_discovery/constraint/algorithms/pc.py index 95352f9..f454484 100644 --- a/causy/algorithms/pc.py +++ b/causy/causal_discovery/constraint/algorithms/pc.py @@ -12,7 +12,7 @@ from causy.generators import PairsWithNeighboursGenerator, RandomSampleGenerator from causy.graph_model import graph_model_factory from causy.common_pipeline_steps.logic import Loop, ApplyActionsTogether -from causy.independence_tests.common import ( +from causy.causal_discovery.constraint.independence_tests.common import ( CorrelationCoefficientTest, PartialCorrelationTest, ExtendedPartialCorrelationTestMatrix, @@ -20,14 +20,20 @@ from causy.common_pipeline_steps.calculation import ( CalculatePearsonCorrelations, ) -from causy.interfaces import AS_MANY_AS_FIELDS, ComparisonSettings, CausyAlgorithm -from causy.orientation_rules.pc import ( +from causy.interfaces import AS_MANY_AS_FIELDS +from causy.models import ComparisonSettings, Algorithm +from causy.causal_discovery.constraint.orientation_rules.pc import ( ColliderTest, NonColliderTest, FurtherOrientTripleTest, OrientQuadrupleTest, FurtherOrientQuadrupleTest, ) +from causy.variables import ( + FloatVariable, + VariableReference, + IntegerVariable, +) PC_DEFAULT_THRESHOLD = 0.005 @@ -55,18 +61,19 @@ PC_EDGE_TYPES = [DirectedEdge(), UndirectedEdge()] PC = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(display_name="Calculate Pearson Correlations"), CorrelationCoefficientTest( - threshold=PC_DEFAULT_THRESHOLD, + threshold=VariableReference(name="threshold"), display_name="Correlation Coefficient Test", ), PartialCorrelationTest( - threshold=PC_DEFAULT_THRESHOLD, display_name="Partial Correlation Test" + threshold=VariableReference(name="threshold"), + display_name="Partial Correlation Test", ), ExtendedPartialCorrelationTestMatrix( - threshold=PC_DEFAULT_THRESHOLD, + threshold=VariableReference(name="threshold"), display_name="Extended Partial Correlation Test Matrix", ), *PC_ORIENTATION_RULES, @@ -77,19 +84,24 @@ edge_types=PC_EDGE_TYPES, extensions=[PC_GRAPH_UI_EXTENSION], name="PC", + variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], ) ) PCStable = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(), ApplyActionsTogether( pipeline_steps=[ - CorrelationCoefficientTest(threshold=PC_DEFAULT_THRESHOLD), - PartialCorrelationTest(threshold=PC_DEFAULT_THRESHOLD), + CorrelationCoefficientTest( + threshold=VariableReference(name="threshold") + ), + PartialCorrelationTest( + threshold=VariableReference(name="threshold"), + ), ExtendedPartialCorrelationTestMatrix( - threshold=PC_DEFAULT_THRESHOLD + threshold=VariableReference(name="threshold"), ), ] ), @@ -99,27 +111,28 @@ edge_types=PC_EDGE_TYPES, extensions=[PC_GRAPH_UI_EXTENSION], name="PCStable", + variables=[FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD)], ) ) ParallelPC = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(display_name="Calculate Pearson Correlations"), CorrelationCoefficientTest( - threshold=PC_DEFAULT_THRESHOLD, + threshold=VariableReference(name="threshold"), display_name="Correlation Coefficient Test", ), PartialCorrelationTest( - threshold=PC_DEFAULT_THRESHOLD, + threshold=VariableReference(name="threshold"), parallel=True, chunk_size_parallel_processing=50000, display_name="Partial Correlation Test", ), ExtendedPartialCorrelationTestMatrix( # run first a sampled version of the test so we can minimize the number of tests in the full version - threshold=PC_DEFAULT_THRESHOLD, + threshold=VariableReference(name="threshold"), display_name="Sampled Extended Partial Correlation Test Matrix", chunk_size_parallel_processing=5000, parallel=True, @@ -132,11 +145,11 @@ ), ), chunked=False, - every_nth=200, + every_nth=VariableReference(name="every_nth_sample"), ), ), ExtendedPartialCorrelationTestMatrix( - threshold=PC_DEFAULT_THRESHOLD, + threshold=VariableReference(name="threshold"), display_name="Extended Partial Correlation Test Matrix", chunk_size_parallel_processing=20000, parallel=True, @@ -156,5 +169,9 @@ edge_types=PC_EDGE_TYPES, extensions=[PC_GRAPH_UI_EXTENSION], name="ParallelPC", + variables=[ + FloatVariable(name="threshold", value=PC_DEFAULT_THRESHOLD), + IntegerVariable(name="every_nth_sample", value=200), + ], ) ) diff --git a/causy/orientation_rules/__init__.py b/causy/causal_discovery/constraint/independence_tests/__init__.py similarity index 100% rename from causy/orientation_rules/__init__.py rename to causy/causal_discovery/constraint/independence_tests/__init__.py diff --git a/causy/independence_tests/common.py b/causy/causal_discovery/constraint/independence_tests/common.py similarity index 95% rename from causy/independence_tests/common.py rename to causy/causal_discovery/constraint/independence_tests/common.py index daf9a1f..fe4d4b7 100644 --- a/causy/independence_tests/common.py +++ b/causy/causal_discovery/constraint/independence_tests/common.py @@ -10,13 +10,12 @@ PipelineStepInterface, BaseGraphInterface, NodeInterface, - TestResult, - TestResultAction, AS_MANY_AS_FIELDS, - ComparisonSettings, GeneratorInterface, PipelineStepInterfaceType, ) +from causy.models import ComparisonSettings, TestResultAction, TestResult +from causy.variables import IntegerParameter, BoolParameter logger = logging.getLogger(__name__) @@ -27,8 +26,8 @@ class CorrelationCoefficientTest( generator: Optional[GeneratorInterface] = AllCombinationsGenerator( comparison_settings=ComparisonSettings(min=2, max=2) ) - chunk_size_parallel_processing: int = 1 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1 + parallel: BoolParameter = False def process( self, nodes: List[str], graph: BaseGraphInterface @@ -64,8 +63,8 @@ class PartialCorrelationTest( generator: Optional[GeneratorInterface] = AllCombinationsGenerator( comparison_settings=ComparisonSettings(min=3, max=3) ) - chunk_size_parallel_processing: int = 1 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1 + parallel: BoolParameter = False def process( self, nodes: Tuple[str], graph: BaseGraphInterface @@ -141,8 +140,8 @@ class ExtendedPartialCorrelationTestMatrix( comparison_settings=ComparisonSettings(min=4, max=AS_MANY_AS_FIELDS), shuffle_combinations=False, ) - chunk_size_parallel_processing: int = 1000 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1000 + parallel: BoolParameter = False def process( self, nodes: List[str], graph: BaseGraphInterface @@ -245,8 +244,8 @@ class ExtendedPartialCorrelationTestLinearRegression( comparison_settings=ComparisonSettings(min=4, max=AS_MANY_AS_FIELDS), shuffle_combinations=False, ) - chunk_size_parallel_processing: int = 1000 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1000 + parallel: BoolParameter = False def process( self, nodes: List[str], graph: BaseGraphInterface diff --git a/causy/causal_discovery/constraint/orientation_rules/__init__.py b/causy/causal_discovery/constraint/orientation_rules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/causy/orientation_rules/fci.py b/causy/causal_discovery/constraint/orientation_rules/fci.py similarity index 98% rename from causy/orientation_rules/fci.py rename to causy/causal_discovery/constraint/orientation_rules/fci.py index 70c9549..724a72d 100644 --- a/causy/orientation_rules/fci.py +++ b/causy/causal_discovery/constraint/orientation_rules/fci.py @@ -2,14 +2,12 @@ from causy.generators import AllCombinationsGenerator from causy.interfaces import ( - TestResultAction, PipelineStepInterface, - ComparisonSettings, BaseGraphInterface, - TestResult, GeneratorInterface, PipelineStepInterfaceType, ) +from causy.models import ComparisonSettings, TestResultAction, TestResult class ColliderRuleFCI( diff --git a/causy/orientation_rules/pc.py b/causy/causal_discovery/constraint/orientation_rules/pc.py similarity index 96% rename from causy/orientation_rules/pc.py rename to causy/causal_discovery/constraint/orientation_rules/pc.py index 496255a..e6de91d 100644 --- a/causy/orientation_rules/pc.py +++ b/causy/causal_discovery/constraint/orientation_rules/pc.py @@ -4,13 +4,13 @@ from causy.generators import AllCombinationsGenerator from causy.interfaces import ( BaseGraphInterface, - TestResult, - TestResultAction, PipelineStepInterface, - ComparisonSettings, GeneratorInterface, PipelineStepInterfaceType, ) +from causy.models import ComparisonSettings, TestResultAction, TestResult +from causy.variables import IntegerParameter, BoolParameter + # theory for all orientation rules with pictures: # https://hpi.de/fileadmin/user_upload/fachgebiete/plattner/teaching/CausalInference/2019/Introduction_to_Constraint-Based_Causal_Structure_Learning.pdf @@ -24,8 +24,8 @@ class ColliderTest( generator: Optional[GeneratorInterface] = AllCombinationsGenerator( comparison_settings=ComparisonSettings(min=2, max=2) ) - chunk_size_parallel_processing: int = 1 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1 + parallel: BoolParameter = False def process( self, nodes: Tuple[str], graph: BaseGraphInterface @@ -92,8 +92,8 @@ class NonColliderTest( generator: Optional[GeneratorInterface] = AllCombinationsGenerator( comparison_settings=ComparisonSettings(min=2, max=2) ) - chunk_size_parallel_processing: int = 1 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1 + parallel: BoolParameter = False def process( self, nodes: Tuple[str], graph: BaseGraphInterface @@ -158,8 +158,8 @@ class FurtherOrientTripleTest( generator: Optional[GeneratorInterface] = AllCombinationsGenerator( comparison_settings=ComparisonSettings(min=2, max=2) ) - chunk_size_parallel_processing: int = 1 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1 + parallel: BoolParameter = False def process( self, nodes: Tuple[str], graph: BaseGraphInterface @@ -216,8 +216,8 @@ class OrientQuadrupleTest( generator: Optional[GeneratorInterface] = AllCombinationsGenerator( comparison_settings=ComparisonSettings(min=2, max=2) ) - chunk_size_parallel_processing: int = 1 - parallel: bool = False + chunk_size_parallel_processing: IntegerParameter = 1 + parallel: BoolParameter = False def process( self, nodes: Tuple[str], graph: BaseGraphInterface diff --git a/causy/causal_effect_estimation/multivariate_regression.py b/causy/causal_effect_estimation/multivariate_regression.py index 888f4ca..9c456ac 100644 --- a/causy/causal_effect_estimation/multivariate_regression.py +++ b/causy/causal_effect_estimation/multivariate_regression.py @@ -5,11 +5,10 @@ from causy.generators import PairsWithEdgesInBetweenGenerator from causy.interfaces import ( BaseGraphInterface, - TestResult, PipelineStepInterface, - TestResultAction, GeneratorInterface, ) +from causy.models import TestResultAction, TestResult class ComputeDirectEffectsMultivariateRegression(PipelineStepInterface): diff --git a/causy/cli.py b/causy/cli.py index 7df9e4f..f1ade70 100644 --- a/causy/cli.py +++ b/causy/cli.py @@ -3,50 +3,41 @@ import typer +from causy.data_loader import JSONDataLoader from causy.graph_model import graph_model_factory -from causy.interfaces import ( - CausyResult, - ActionHistoryStep, - CausyAlgorithmReferenceType, +from causy.models import ( + Result, + AlgorithmReferenceType, ) from causy.serialization import ( serialize_algorithm, load_algorithm_from_specification, CausyJSONEncoder, - load_algorithm_from_reference, load_json, ) from causy.graph_utils import ( retrieve_edges, ) -from causy.ui import server -from causy.workspaces.cli import app as workspaces_app +from causy.ui.cli import ui as ui_app +from causy.workspaces.cli import workspace_app as workspaces_app +from causy.causal_discovery import AVAILABLE_ALGORITHMS app = typer.Typer() app.add_typer(workspaces_app, name="workspace") +app.command(name="ui", help="run causy ui")(ui_app) @app.command() def eject(algorithm: str, output_file: str): typer.echo(f"💾 Loading algorithm {algorithm}") - model = load_algorithm_from_reference(algorithm)() + model = AVAILABLE_ALGORITHMS[algorithm]() result = serialize_algorithm(model, algorithm_name=algorithm) typer.echo(f"💾 Saving algorithm {algorithm} to {output_file}") with open(output_file, "w") as file: file.write(json.dumps(result, indent=4)) -@app.command() -def ui(result_file: str): - result = load_json(result_file) - - server_config, server_runner = server(result) - typer.launch(f"http://{server_config.host}:{server_config.port}") - typer.echo(f"🚀 Starting server at http://{server_config.host}:{server_config.port}") - server_runner.run() - - @app.command() def execute( data_file: str, @@ -62,22 +53,23 @@ def execute( algorithm = load_algorithm_from_specification(model_dict) model = graph_model_factory(algorithm=algorithm)() algorithm_reference = { - "type": CausyAlgorithmReferenceType.FILE, + "type": AlgorithmReferenceType.FILE, "reference": pipeline, # TODO: how to reference pipeline in a way that it can be loaded? } elif algorithm: typer.echo(f"💾 Creating pipeline from algorithm {algorithm}") - model = load_algorithm_from_reference(algorithm)() + model = AVAILABLE_ALGORITHMS[algorithm]() algorithm_reference = { - "type": CausyAlgorithmReferenceType.NAME, + "type": AlgorithmReferenceType.NAME, "reference": algorithm, } else: raise ValueError("Either pipeline_file or algorithm must be specified") + dl = JSONDataLoader(data_file) # initialize from json - model.create_graph_from_data(load_json(data_file)) + model.create_graph_from_data(dl) # TODO: I should become a configurable skeleton builder model.create_all_possible_edges() @@ -90,7 +82,7 @@ def execute( f"{model.graph.nodes[edge[0]].name} -> {model.graph.nodes[edge[1]].name}: {model.graph.edges[edge[0]][edge[1]]}" ) - result = CausyResult( + result = Result( algorithm=algorithm_reference, action_history=model.graph.graph.action_history, edges=model.graph.retrieve_edges(), diff --git a/causy/common_pipeline_steps/calculation.py b/causy/common_pipeline_steps/calculation.py index ff2f98c..cafb6f3 100644 --- a/causy/common_pipeline_steps/calculation.py +++ b/causy/common_pipeline_steps/calculation.py @@ -1,18 +1,15 @@ from typing import Tuple, Optional, Generic import torch -from pydantic import BaseModel from causy.generators import AllCombinationsGenerator from causy.interfaces import ( PipelineStepInterface, - ComparisonSettings, BaseGraphInterface, - TestResult, - TestResultAction, GeneratorInterface, PipelineStepInterfaceType, ) +from causy.models import ComparisonSettings, TestResultAction, TestResult class CalculatePearsonCorrelations( diff --git a/causy/common_pipeline_steps/logic.py b/causy/common_pipeline_steps/logic.py index bd19f3d..11c4ed8 100644 --- a/causy/common_pipeline_steps/logic.py +++ b/causy/common_pipeline_steps/logic.py @@ -1,23 +1,17 @@ import time -from typing import Optional, List, Union, Dict, Any, Generic +from typing import Optional, Generic -from pydantic import BaseModel from causy.interfaces import ( LogicStepInterface, BaseGraphInterface, GraphModelInterface, - PipelineStepInterface, ExitConditionInterface, - PipelineStepInterfaceType, LogicStepInterfaceType, - ActionHistoryStep, -) -from causy.graph_utils import ( - load_pipeline_artefact_by_definition, - load_pipeline_steps_by_definition, ) +from causy.models import ActionHistoryStep + class Loop(LogicStepInterface[LogicStepInterfaceType], Generic[LogicStepInterfaceType]): """ diff --git a/causy/common_pipeline_steps/placeholder.py b/causy/common_pipeline_steps/placeholder.py index a915a96..81f4ba9 100644 --- a/causy/common_pipeline_steps/placeholder.py +++ b/causy/common_pipeline_steps/placeholder.py @@ -1,13 +1,13 @@ import logging -from typing import Tuple, List, Generic +from typing import Tuple, List, Generic, Optional from causy.interfaces import ( PipelineStepInterface, - TestResult, BaseGraphInterface, - TestResultAction, PipelineStepInterfaceType, ) +from causy.models import TestResultAction, TestResult +from causy.variables import StringVariable, IntegerVariable, FloatVariable, BoolVariable logger = logging.getLogger(__name__) @@ -15,8 +15,24 @@ class PlaceholderTest( PipelineStepInterface[PipelineStepInterfaceType], Generic[PipelineStepInterfaceType] ): - chunk_size_parallel_processing: int = 10 - parallel: bool = False + def __init__(self, **kwargs): + super().__init__(**kwargs) + if "placeholder_str" in kwargs: + self.placeholder_str = kwargs["placeholder_str"] + + if "placeholder_int" in kwargs: + self.placeholder_int = kwargs["placeholder_int"] + + if "placeholder_float" in kwargs: + self.placeholder_float = kwargs["placeholder_float"] + + if "placeholder_bool" in kwargs: + self.placeholder_bool = kwargs["placeholder_bool"] + + placeholder_str: Optional[StringVariable] = "placeholder" + placeholder_int: Optional[IntegerVariable] = 1 + placeholder_float: Optional[FloatVariable] = 1.0 + placeholder_bool: Optional[BoolVariable] = True def process( self, nodes: Tuple[str], graph: BaseGraphInterface diff --git a/causy/contrib/graph_ui.py b/causy/contrib/graph_ui.py index 3ccd1f2..c28ea3c 100644 --- a/causy/contrib/graph_ui.py +++ b/causy/contrib/graph_ui.py @@ -3,7 +3,7 @@ from pydantic import BaseModel -from causy.interfaces import CausyExtension, CausyExtensionType +from causy.interfaces import ExtensionInterface, ExtensionType class EdgeUIConfig(BaseModel): @@ -41,5 +41,5 @@ class EdgeTypeConfig(BaseModel): conditional_ui_configs: Optional[List[ConditionalEdgeUIConfig]] = None -class GraphUIExtension(CausyExtension[CausyExtensionType], Generic[CausyExtensionType]): +class GraphUIExtension(ExtensionInterface[ExtensionType], Generic[ExtensionType]): edges: List[EdgeTypeConfig] diff --git a/causy/data_loader.py b/causy/data_loader.py new file mode 100644 index 0000000..769f4fe --- /dev/null +++ b/causy/data_loader.py @@ -0,0 +1,143 @@ +import enum +import hashlib +import importlib +import json +from abc import ABC, abstractmethod +from typing import Iterator, Dict, Union, Any, Optional + +from pydantic import BaseModel + +from causy.graph_utils import hash_dictionary + + +class DataLoaderType(enum.StrEnum): + DYNAMIC = "dynamic" # python function which yields data + JSON = "json" + JSONL = "jsonl" + + +class DataLoaderReference(BaseModel): + """represents a single data loader + :param type: the type of dataloader + :param reference: path to either the python class which can be executed to load the data or the data source file itself + """ + + type: DataLoaderType + reference: str + options: Optional[Dict[str, Any]] = None + + +class AbstractDataLoader(ABC): + @abstractmethod + def __init__(self, reference: str, options: Optional[Dict[str, Any]] = None): + pass + + reference: str + options: Optional[Dict[str, Any]] = None + + @abstractmethod + def load(self) -> Iterator[Dict[str, Union[float, int, str]]]: + """ + loads the data from the source and returns it as an iterator + :return: + """ + pass + + @abstractmethod + def hash(self) -> str: + """ + returns a hash of the data that is loaded + :return: + """ + pass + + def _hash_options(self): + return hash_dictionary(self.options) + + +class FileDataLoader(AbstractDataLoader, ABC): + """ + A data loader which loads data from a file reference (e.g. json, csv, etc.) + """ + + def __init__(self, reference: str, options: Optional[Dict[str, Any]] = None): + self.reference = reference + self.options = options + + reference: str + + def hash(self) -> str: + with open(self.reference, "rb") as f: + return ( + f'{hashlib.file_digest(f, "sha256").hexdigest()}_{self._hash_options()}' + ) + + +class JSONDataLoader(FileDataLoader): + """ + A data loader which loads data from a json file + """ + + def load(self) -> Iterator[Dict[str, Union[float, int, str]]]: + with open(self.reference, "r") as f: + data = json.loads(f.read()) + if isinstance(data, list): + for item in data: + yield item + elif isinstance(data, dict): + yield {"_dict": data} + return + else: + raise ValueError( + f"Invalid JSON format. Data in {self.reference} is of type {type(data)}." + ) + + +class JSONLDataLoader(FileDataLoader): + """ + A data loader which loads data from a jsonl file + """ + + def load(self) -> Iterator[Dict[str, Union[float, int, str]]]: + with open(self.reference, "r") as f: + for line in f: + yield json.loads(line) + + +class DynamicDataLoader(AbstractDataLoader): + """ + A data loader which loads another data loader dynamically based on the reference + """ + + def __init__(self, reference: str, options: Optional[Dict[str, Any]] = None): + self.reference = reference + self.data_loader = self._load_data_loader() + self.options = options + + reference: str + data_loader: AbstractDataLoader + + def _load_data_loader(self) -> AbstractDataLoader: + module = importlib.import_module(self.reference) + # todo: should the cls be referenced here? + return module.DataLoader(**self.options) + + def load(self) -> Iterator[Dict[str, Union[float, int, str]]]: + return self.data_loader.load() + + +DATA_LOADERS = { + DataLoaderType.JSON: JSONDataLoader, + DataLoaderType.JSONL: JSONLDataLoader, + DataLoaderType.DYNAMIC: DynamicDataLoader, +} + + +def load_data_loader(reference: DataLoaderReference) -> AbstractDataLoader: + """ + loads the data loader based on the reference + :param reference: a data loader reference + :return: + """ + + return DATA_LOADERS[reference.type](reference.reference, reference.options) diff --git a/causy/edge_types.py b/causy/edge_types.py index d0d45e4..fe69cf5 100644 --- a/causy/edge_types.py +++ b/causy/edge_types.py @@ -1,4 +1,5 @@ -from typing import Optional, List +import enum +from typing import Optional, List, Generic from causy.contrib.graph_ui import ( EdgeTypeConfig, @@ -8,11 +9,20 @@ ) from causy.interfaces import ( EdgeTypeInterface, + EdgeTypeInterfaceType, ) -class DirectedEdge(EdgeTypeInterface): - name: str = "DIRECTED" +class EdgeTypeEnum(enum.StrEnum): + DIRECTED = "directed" + UNDIRECTED = "undirected" + BIDIRECTED = "bidirected" + + +class DirectedEdge(EdgeTypeInterface, Generic[EdgeTypeInterfaceType]): + name: str = EdgeTypeEnum.DIRECTED.name + IS_DIRECTED: bool = True + STR_REPRESENTATION: str = "-->" # u --> v class DirectedEdgeUIConfig(EdgeTypeConfig): @@ -47,8 +57,10 @@ class DirectedEdgeUIConfig(EdgeTypeConfig): ] -class UndirectedEdge(EdgeTypeInterface): - name: str = "UNDIRECTED" +class UndirectedEdge(EdgeTypeInterface, Generic[EdgeTypeInterfaceType]): + name: str = EdgeTypeEnum.UNDIRECTED.name + IS_DIRECTED: bool = False + STR_REPRESENTATION: str = "---" # u --- v class UndirectedEdgeUIConfig(EdgeTypeConfig): @@ -66,8 +78,10 @@ class UndirectedEdgeUIConfig(EdgeTypeConfig): ) -class BiDirectedEdge(EdgeTypeInterface): - name: str = "BIDIRECTED" +class BiDirectedEdge(EdgeTypeInterface, Generic[EdgeTypeInterfaceType]): + name: str = EdgeTypeEnum.BIDIRECTED.name + IS_DIRECTED: bool = False # This is a bi-directed edge - so it is not directed in the traditional sense + STR_REPRESENTATION: str = "<->" # u <-> v class BiDirectedEdgeUIConfig(EdgeTypeConfig): @@ -82,3 +96,10 @@ class BiDirectedEdgeUIConfig(EdgeTypeConfig): marker_start="ArrowClosed", marker_end="ArrowClosed", ) + + +EDGE_TYPES = { + DirectedEdge().name: DirectedEdge, + UndirectedEdge().name: UndirectedEdge, + BiDirectedEdge().name: BiDirectedEdge, +} diff --git a/causy/generators.py b/causy/generators.py index 9808771..73a3219 100644 --- a/causy/generators.py +++ b/causy/generators.py @@ -6,13 +6,14 @@ from pydantic import BaseModel from causy.interfaces import ( - ComparisonSettings, GeneratorInterface, BaseGraphInterface, GraphModelInterface, AS_MANY_AS_FIELDS, ) +from causy.models import ComparisonSettings from causy.graph_utils import load_pipeline_artefact_by_definition +from causy.variables import IntegerParameter, BoolParameter logger = logging.getLogger(__name__) @@ -59,14 +60,14 @@ class PairsWithEdgesInBetweenGenerator(GeneratorInterface): However, if it is an edge which points in both/no directions, it will be iterated over them twice. """ - chunk_size: int = 100 - chunked: Optional[bool] = None + chunk_size: IntegerParameter = 100 + chunked: Optional[BoolParameter] = None def __init__( self, *args, - chunk_size: Optional[int] = None, - chunked: Optional[bool] = None, + chunk_size: Optional[IntegerParameter] = None, + chunked: Optional[BoolParameter] = None, **kwargs, ): super().__init__(*args, **kwargs) @@ -101,8 +102,8 @@ class PairsWithNeighboursGenerator(GeneratorInterface): (if, among others, X and Y are neighbours and X and Z are neighbours) """ - shuffle_combinations: bool = True - chunked: bool = True + shuffle_combinations: BoolParameter = True + chunked: BoolParameter = True def __init__( self, @@ -129,8 +130,6 @@ def generate( if stop > len(graph.nodes) + 1: stop = len(graph.nodes) + 1 - print(stop) - # if stop is smaller then start, we can't create any combinations if stop < start: return @@ -138,7 +137,7 @@ def generate( if start < 2: raise ValueError("PairsWithNeighboursGenerator: start must be at least 2") for range_size in range(start, stop): - print(f"range_size = {range_size}") + logger.info(f"range_size = {range_size}") logger.debug(f"PairsWithNeighboursGenerator: range_size={range_size}") checked_combinations = set() for node in graph.edges: @@ -161,7 +160,7 @@ def generate( if graph.directed_edge_exists(k, node) ] ) - print(f"other_neighbors before removal={other_neighbours}") + logger.info(f"other_neighbors before removal={other_neighbours}") if neighbour in other_neighbours: other_neighbours.remove(neighbour) @@ -169,8 +168,8 @@ def generate( logger.debug( "PairsWithNeighboursGenerator: neighbour not in other_neighbours. This should not happen." ) - print(f"node={node}, neighbour={neighbour}") - print(f"other_neighbours={other_neighbours}") + logger.info(f"node={node}, neighbour={neighbour}") + logger.info(f"other_neighbours={other_neighbours}") combinations = list( itertools.combinations(other_neighbours, range_size - 2) ) @@ -196,7 +195,7 @@ class RandomSampleGenerator(GeneratorInterface, BaseModel): Executes another generator and returns a random sample of the results """ - every_nth: int = 100 + every_nth: IntegerParameter = 100 generator: Optional[GeneratorInterface] = None def __init__( diff --git a/causy/graph.py b/causy/graph.py index 43c422b..7c31c58 100644 --- a/causy/graph.py +++ b/causy/graph.py @@ -11,12 +11,10 @@ from causy.interfaces import ( BaseGraphInterface, NodeInterface, - TestResultAction, - TestResult, EdgeInterface, EdgeTypeInterface, - ActionHistoryStep, ) +from causy.models import TestResultAction, TestResult, ActionHistoryStep logger = logging.getLogger(__name__) @@ -341,7 +339,7 @@ def directed_paths( def inducing_path_exists(self, u: Union[Node, str], v: Union[Node, str]) -> bool: """ Check if an inducing path from u to v exists. - An inducing path from u to v is a directed path from u to v on which all mediators are colliders. + An inducing path from u to v is a directed reference from u to v on which all mediators are colliders. :param u: node u :param v: node v :return: True if an inducing path exists, False otherwise diff --git a/causy/graph_model.py b/causy/graph_model.py index c7f5cd7..6d389fa 100644 --- a/causy/graph_model.py +++ b/causy/graph_model.py @@ -3,21 +3,24 @@ from abc import ABC from copy import deepcopy import time -from typing import Optional, List, Dict, Callable, Union +from typing import Optional, List, Dict, Callable, Union, Any import torch.multiprocessing as mp +from causy.data_loader import AbstractDataLoader from causy.edge_types import DirectedEdge from causy.graph import GraphManager from causy.graph_utils import unpack_run from causy.interfaces import ( PipelineStepInterface, - TestResultAction, LogicStepInterface, BaseGraphInterface, GraphModelInterface, - CausyAlgorithm, - ActionHistoryStep, +) +from causy.models import TestResultAction, Algorithm, ActionHistoryStep +from causy.variables import ( + resolve_variables_to_algorithm_for_pipeline_steps, + resolve_variables, ) logger = logging.getLogger(__name__) @@ -36,7 +39,7 @@ class AbstractGraphModel(GraphModelInterface, ABC): """ - algorithm: CausyAlgorithm + algorithm: Algorithm pipeline_steps: List[PipelineStepInterface] graph: BaseGraphInterface pool: mp.Pool @@ -44,7 +47,7 @@ class AbstractGraphModel(GraphModelInterface, ABC): def __init__( self, graph=None, - algorithm: CausyAlgorithm = None, + algorithm: Algorithm = None, ): self.graph = graph self.algorithm = algorithm @@ -124,8 +127,38 @@ def __create_graph_from_list(self, data: List[Dict[str, float]]): return graph + def _create_from_data_loader(self, data_loader: AbstractDataLoader): + """ + Create a graph from a data loader + :param data_loader: the data loader + :return: the graph + """ + nodes: Dict[str, List[float]] = {} + keys = None + + # load nodes into node dict + for row in data_loader.load(): + if isinstance(row, dict) and "_dict" in row: + # edge case for when data is in a dict of lists + return self.__create_graph_from_dict(row["_dict"]) + + if keys is None: + keys = row.keys() + for key in sorted(keys): + nodes[key] = [] + + for key in keys: + nodes[key].append(row[key]) + + graph = GraphManager() + for key in keys: + graph.add_node(key, nodes[key], id_=key) + + return graph + def create_graph_from_data( - self, data: Union[List[Dict[str, float]], Dict[str, List[float]]] + self, + data: Union[List[Dict[str, float]], Dict[str, List[float]], AbstractDataLoader], ): """ Create a graph from data @@ -133,7 +166,9 @@ def create_graph_from_data( :return: """ - if isinstance(data, dict): + if isinstance(data, AbstractDataLoader): + graph = self._create_from_data_loader(data) + elif isinstance(data, dict): graph = self.__create_graph_from_dict(data) else: graph = self.__create_graph_from_list(data) @@ -350,15 +385,29 @@ def execute_pipeline_step( def graph_model_factory( - algorithm: CausyAlgorithm = None, + algorithm: Algorithm = None, + variables: Dict[str, Any] = None, ) -> type[AbstractGraphModel]: """ Create a graph model based on a List of pipeline_steps :param algorithm: the algorithm which should be used to create the graph model :return: the graph model """ + original_algorithm = deepcopy(algorithm) + if variables is None and algorithm.variables is not None: + variables = resolve_variables(algorithm.variables, {}) + elif variables is None: + variables = {} + + if len(variables) > 0: + algorithm.pipeline_steps = resolve_variables_to_algorithm_for_pipeline_steps( + algorithm.pipeline_steps, variables + ) class GraphModel(AbstractGraphModel): + # store the original algorithm for later use like ejecting it without the resolved variables + _original_algorithm = original_algorithm + def __init__(self): super().__init__(algorithm=algorithm) diff --git a/causy/graph_utils.py b/causy/graph_utils.py index d48c7b3..b49c6ca 100644 --- a/causy/graph_utils.py +++ b/causy/graph_utils.py @@ -1,5 +1,9 @@ +import hashlib import importlib -from typing import List, Tuple +import json +from typing import List, Tuple, Dict + +from causy.variables import deserialize_variable_references def unpack_run(args): @@ -30,6 +34,7 @@ def load_pipeline_steps_by_definition(steps): pipeline = [] for step in steps: st_function = load_pipeline_artefact_by_definition(step) + st_function = deserialize_variable_references(st_function) pipeline.append(st_function) return pipeline @@ -45,3 +50,20 @@ def retrieve_edges(graph) -> List[Tuple[str, str]]: for v in graph.edges[u]: edges.append((u, v)) return edges + + +def hash_dictionary(dct: Dict): + """ + Hash a dictionary using SHA256 (e.g. for caching) + :param dct: + :return: + """ + return hashlib.sha256( + json.dumps( + dct, + ensure_ascii=False, + sort_keys=True, + indent=None, + separators=(",", ":"), + ).encode() + ).hexdigest() diff --git a/causy/interfaces.py b/causy/interfaces.py index d457420..543b14e 100644 --- a/causy/interfaces.py +++ b/causy/interfaces.py @@ -1,12 +1,9 @@ -import enum import multiprocessing from abc import ABC, abstractmethod -from datetime import datetime - -from pydantic.dataclasses import dataclass from typing import List, Dict, Optional, Union, TypeVar, Generic, Any import logging +from pydantic import BaseModel, computed_field, Field import torch from causy.graph_utils import ( @@ -14,8 +11,12 @@ serialize_module_name, load_pipeline_steps_by_definition, ) - -from pydantic import BaseModel, computed_field, AwareDatetime, Field +from causy.variables import ( + StringParameter, + IntegerParameter, + BoolParameter, + FloatParameter, +) logger = logging.getLogger(__name__) @@ -24,9 +25,9 @@ AS_MANY_AS_FIELDS = 0 -class ComparisonSettings(BaseModel): - min: int = 2 - max: int = AS_MANY_AS_FIELDS +class ComparisonSettingsInterface(BaseModel, ABC): + min: IntegerParameter + max: IntegerParameter @computed_field @property @@ -34,7 +35,7 @@ def name(self) -> str: return serialize_module_name(self) -class NodeInterface(BaseModel): +class NodeInterface(BaseModel, ABC): """ Node interface for the graph. A node is defined by a name and a value. """ @@ -47,7 +48,10 @@ class Config: arbitrary_types_allowed = True -class EdgeTypeInterface(BaseModel): +EdgeTypeInterfaceType = TypeVar("EdgeTypeInterfaceType") + + +class EdgeTypeInterface(ABC, BaseModel, Generic[EdgeTypeInterfaceType]): """ Edge type interface for the graph An edge type is defined by a name @@ -55,6 +59,10 @@ class EdgeTypeInterface(BaseModel): name: str + # define if it is a directed or undirected edge type (default is undirected). We use this e.g. when we compare the graph. + IS_DIRECTED: bool = True + STR_REPRESENTATION: str = "-" + def __hash__(self): return hash(self.name) @@ -68,7 +76,7 @@ def __repr__(self): return self.name -class EdgeInterface(BaseModel): +class EdgeInterface(BaseModel, ABC): """ Edge interface for the graph A graph edge is defined by two nodes and an edge type. It can also have metadata. @@ -82,38 +90,56 @@ class EdgeInterface(BaseModel): class Config: arbitrary_types_allowed = True - -class TestResultAction(enum.StrEnum): + def __eq__(self, other): + """ + Check if two edges are equal by comparing the nodes and the edge type + :param other: + :return: + """ + if not isinstance(other, EdgeInterface): + return False + + if self.edge_type != other.edge_type: + return False + if self.edge_type.IS_DIRECTED: + return self.u == other.u and self.v == other.v + else: + return self.is_connection_between_same_nodes(other) + + def is_connection_between_same_nodes(self, edge): + return ( + self.u == edge.u + and self.v == edge.v + or self.u == edge.v + and self.v == edge.u + ) + + +class TestResultInterface(BaseModel, ABC): """ - Actions that can be taken on the graph. These actions are used to keep track of the history of the graph. + Test result interface for the graph + A test result is defined by two nodes and an action. It can also have metadata. """ - REMOVE_EDGE_UNDIRECTED = "REMOVE_EDGE_UNDIRECTED" - UPDATE_EDGE = "UPDATE_EDGE" - UPDATE_EDGE_TYPE = "UPDATE_EDGE_TYPE" - UPDATE_EDGE_DIRECTED = "UPDATE_EDGE_DIRECTED" - UPDATE_EDGE_TYPE_DIRECTED = "UPDATE_EDGE_TYPE_DIRECTED" - DO_NOTHING = "DO_NOTHING" - REMOVE_EDGE_DIRECTED = "REMOVE_EDGE_DIRECTED" - - -class TestResult(BaseModel): u: NodeInterface v: NodeInterface - action: TestResultAction + action: str data: Optional[Dict] = None + class Config: + arbitrary_types_allowed = True + class BaseGraphInterface(ABC): nodes: Dict[str, NodeInterface] edges: Dict[str, Dict[str, Dict]] @abstractmethod - def retrieve_edge_history(self, u, v, action: TestResultAction) -> List[TestResult]: + def retrieve_edge_history(self, u, v, action: str) -> List[TestResultInterface]: pass @abstractmethod - def add_edge_history(self, u, v, action: TestResult): + def add_edge_history(self, u, v, action: TestResultInterface): pass @abstractmethod @@ -174,11 +200,11 @@ def execute_pipeline_step(self, step, apply_to_graph: bool = True): class GeneratorInterface(ABC, BaseModel): - comparison_settings: Optional[ComparisonSettings] = None - chunked: Optional[bool] = False - every_nth: Optional[int] = None + comparison_settings: Optional[ComparisonSettingsInterface] = None + chunked: Optional[BoolParameter] = False + every_nth: Optional[IntegerParameter] = None generator: Optional["GeneratorInterface"] = None - shuffle_combinations: Optional[bool] = None + shuffle_combinations: Optional[BoolParameter] = None @abstractmethod def generate(self, graph: BaseGraphInterface, graph_model_instance_: dict): @@ -186,12 +212,12 @@ def generate(self, graph: BaseGraphInterface, graph_model_instance_: dict): def __init__( self, - comparison_settings: Optional[ComparisonSettings] = None, - chunked: bool = None, - every_nth: int = None, - generator: "GeneratorInterface" = None, - shuffle_combinations: bool = None, - ): + comparison_settings: Optional[ComparisonSettingsInterface] = None, + chunked: Optional[BoolParameter] = None, + every_nth: Optional[IntegerParameter] = None, + generator: Optional["GeneratorInterface"] = None, + shuffle_combinations: Optional[BoolParameter] = None, + ) -> None: super().__init__(comparison_settings=comparison_settings) if isinstance(comparison_settings, dict): comparison_settings = load_pipeline_artefact_by_definition( @@ -223,20 +249,21 @@ def name(self) -> str: class PipelineStepInterface(ABC, BaseModel, Generic[PipelineStepInterfaceType]): generator: Optional[GeneratorInterface] = None - threshold: Optional[float] = DEFAULT_THRESHOLD - chunk_size_parallel_processing: int = 1 - parallel: bool = True + threshold: Optional[FloatParameter] = DEFAULT_THRESHOLD + chunk_size_parallel_processing: IntegerParameter = 1 + parallel: BoolParameter = True - display_name: Optional[str] = None + display_name: Optional[StringParameter] = None def __init__( self, - threshold: float = None, + threshold: Optional[FloatParameter] = None, generator: Optional[GeneratorInterface] = None, - chunk_size_parallel_processing: int = None, - parallel: bool = None, - display_name: Optional[str] = None, - ): + chunk_size_parallel_processing: Optional[IntegerParameter] = None, + parallel: Optional[BoolParameter] = None, + display_name: Optional[StringParameter] = None, + **kwargs, + ) -> None: super().__init__() if generator: if isinstance(generator, dict): @@ -264,7 +291,7 @@ def name(self) -> str: @abstractmethod def process( self, nodes: List[str], graph: BaseGraphInterface - ) -> Optional[TestResult]: + ) -> Optional[TestResultInterface]: """ Test if u and v are independent :param u: u values @@ -275,7 +302,7 @@ def process( def __call__( self, nodes: List[str], graph: BaseGraphInterface - ) -> Optional[TestResult]: + ) -> Optional[TestResultInterface]: return self.process(nodes, graph) @@ -285,8 +312,8 @@ def check( self, graph: BaseGraphInterface, graph_model_instance_: GraphModelInterface, - actions_taken: List[TestResult], - iteration: int, + actions_taken: List[TestResultInterface], + iteration: IntegerParameter, ) -> bool: """ :param graph: @@ -301,7 +328,7 @@ def __call__( self, graph: BaseGraphInterface, graph_model_instance_: GraphModelInterface, - actions_taken: List[TestResult], + actions_taken: List[TestResultInterface], iteration: int, ) -> bool: return self.check(graph, graph_model_instance_, actions_taken, iteration) @@ -319,7 +346,7 @@ class LogicStepInterface(ABC, BaseModel, Generic[LogicStepInterfaceType]): pipeline_steps: Optional[List[Union[PipelineStepInterfaceType]]] = None exit_condition: Optional[ExitConditionInterface] = None - display_name: Optional[str] = None + display_name: Optional[StringParameter] = None @abstractmethod def execute(self, graph: BaseGraphInterface, graph_model_instance_: dict): @@ -354,45 +381,11 @@ def __init__( self.display_name = display_name -CausyExtensionType = TypeVar("CausyExtensionType", bound="CausyExtension") +ExtensionType = TypeVar("ExtensionType", bound="ExtensionInterface") -class CausyExtension(BaseModel, Generic[CausyExtensionType]): +class ExtensionInterface(BaseModel, Generic[ExtensionType]): @computed_field @property def name(self) -> str: return serialize_module_name(self) - - -class CausyAlgorithm(BaseModel): - name: str - pipeline_steps: List[Union[PipelineStepInterfaceType, LogicStepInterface]] - pipeline_steps: List[Union[PipelineStepInterfaceType, LogicStepInterface]] - edge_types: List[EdgeTypeInterface] - extensions: Optional[List[CausyExtensionType]] = None - - -class CausyAlgorithmReferenceType(enum.StrEnum): - FILE = "file" - NAME = "name" - PYTHON_MODULE = "python_module" - - -class CausyAlgorithmReference(BaseModel): - reference: str - type: CausyAlgorithmReferenceType - - -class ActionHistoryStep(BaseModel): - name: str - duration: Optional[float] = None # seconds - actions: Optional[List[TestResult]] = [] - steps: Optional[List["ActionHistoryStep"]] = [] - - -class CausyResult(BaseModel): - algorithm: CausyAlgorithmReference - created_at: datetime = Field(default_factory=datetime.now) - nodes: Dict[str, NodeInterface] - edges: List[EdgeInterface] - action_history: List[ActionHistoryStep] diff --git a/causy/models.py b/causy/models.py new file mode 100644 index 0000000..a057c75 --- /dev/null +++ b/causy/models.py @@ -0,0 +1,96 @@ +import enum +import hashlib +import json +from datetime import datetime +from typing import Optional, Dict, List, Union, Any + +from pydantic import BaseModel, computed_field, Field + +from causy.graph_utils import serialize_module_name, hash_dictionary +from causy.interfaces import ( + AS_MANY_AS_FIELDS, + NodeInterface, + PipelineStepInterfaceType, + LogicStepInterface, + EdgeTypeInterface, + ExtensionType, + EdgeInterface, + TestResultInterface, + ComparisonSettingsInterface, + ExtensionInterface, + EdgeTypeInterfaceType, +) +from causy.variables import IntegerParameter, VariableInterfaceType + + +class ComparisonSettings(ComparisonSettingsInterface): + min: IntegerParameter = 2 + max: IntegerParameter = AS_MANY_AS_FIELDS + + @computed_field + @property + def name(self) -> str: + return serialize_module_name(self) + + +class TestResultAction(enum.StrEnum): + """ + Actions that can be taken on the graph. These actions are used to keep track of the history of the graph. + """ + + REMOVE_EDGE_UNDIRECTED = "REMOVE_EDGE_UNDIRECTED" + UPDATE_EDGE = "UPDATE_EDGE" + UPDATE_EDGE_TYPE = "UPDATE_EDGE_TYPE" + UPDATE_EDGE_DIRECTED = "UPDATE_EDGE_DIRECTED" + UPDATE_EDGE_TYPE_DIRECTED = "UPDATE_EDGE_TYPE_DIRECTED" + DO_NOTHING = "DO_NOTHING" + REMOVE_EDGE_DIRECTED = "REMOVE_EDGE_DIRECTED" + + +class TestResult(TestResultInterface): + u: NodeInterface + v: NodeInterface + action: TestResultAction + data: Optional[Dict] = None + + +class AlgorithmReferenceType(enum.StrEnum): + FILE = "file" + NAME = "name" + PYTHON_MODULE = "python_module" + + +class AlgorithmReference(BaseModel): + reference: str + type: AlgorithmReferenceType + + +class Algorithm(BaseModel): + name: str + pipeline_steps: List[Union[PipelineStepInterfaceType, LogicStepInterface]] + pipeline_steps: List[Union[PipelineStepInterfaceType, LogicStepInterface]] + edge_types: List[EdgeTypeInterfaceType] + extensions: Optional[List[ExtensionType]] = None + variables: Optional[List[Union[VariableInterfaceType]]] = None + + def hash(self) -> str: + return hash_dictionary(self.model_dump()) + + +class ActionHistoryStep(BaseModel): + name: str + duration: Optional[float] = None # seconds + actions: Optional[List[TestResult]] = [] + steps: Optional[List["ActionHistoryStep"]] = [] + + +class Result(BaseModel): + algorithm: AlgorithmReference + created_at: datetime = Field(default_factory=datetime.now) + nodes: Dict[str, NodeInterface] + edges: List[EdgeInterface] + action_history: List[ActionHistoryStep] + variables: Optional[Dict[str, Any]] = None + data_loader_hash: Optional[str] = None + algorithm_hash: Optional[str] = None + variables_hash: Optional[str] = None diff --git a/causy/serialization.py b/causy/serialization.py index f6a5110..f388bea 100644 --- a/causy/serialization.py +++ b/causy/serialization.py @@ -1,3 +1,4 @@ +import copy import datetime import importlib import json @@ -6,18 +7,14 @@ import os import torch +import yaml from pydantic import parse_obj_as +from causy.edge_types import EDGE_TYPES from causy.graph_utils import load_pipeline_steps_by_definition -from causy.interfaces import CausyAlgorithmReferenceType - - -def load_algorithm_from_reference(algorithm: str): - st_function = importlib.import_module("causy.algorithms") - st_function = getattr(st_function, algorithm) - if not st_function: - raise ValueError(f"Algorithm {algorithm} not found") - return st_function +from causy.models import AlgorithmReferenceType, Result, AlgorithmReference +from causy.variables import deserialize_variable +from causy.causal_discovery import AVAILABLE_ALGORITHMS def serialize_algorithm(model, algorithm_name: str = None): @@ -33,14 +30,20 @@ def load_algorithm_from_specification(algorithm_dict: Dict[str, Any]): algorithm_dict["pipeline_steps"] = load_pipeline_steps_by_definition( algorithm_dict["pipeline_steps"] ) - from causy.interfaces import CausyAlgorithm + if "variables" not in algorithm_dict or algorithm_dict["variables"] is None: + algorithm_dict["variables"] = [] + + algorithm_dict["variables"] = [ + deserialize_variable(variable) for variable in algorithm_dict["variables"] + ] + from causy.models import Algorithm - return parse_obj_as(CausyAlgorithm, algorithm_dict) + return parse_obj_as(Algorithm, algorithm_dict) def load_algorithm_by_reference(reference_type: str, algorithm: str): # TODO: test me - if reference_type == CausyAlgorithmReferenceType.FILE: + if reference_type == AlgorithmReferenceType.FILE: # validate if the reference points only in the same directory or subdirectory # to avoid security issues absolute_path = os.path.realpath(algorithm) @@ -50,11 +53,24 @@ def load_algorithm_by_reference(reference_type: str, algorithm: str): with open(absolute_path, "r") as file: # load the algorithm from the file - algorithm = json.loads(file.read()) - return load_algorithm_from_specification(algorithm) - elif reference_type == CausyAlgorithmReferenceType.NAME: - return load_algorithm_from_reference(algorithm)().algorithm - elif reference_type == CausyAlgorithmReferenceType.PYTHON_MODULE: + # try first json + try: + algorithm = json.loads(file.read()) + return load_algorithm_from_specification(algorithm) + except json.JSONDecodeError: + pass + file.seek(0) + # then try yaml + try: + data = yaml.load(file.read(), Loader=yaml.FullLoader) + return load_algorithm_from_specification(data) + except yaml.YAMLError: + pass + raise ValueError("Invalid file format") + + elif reference_type == AlgorithmReferenceType.NAME: + return copy.deepcopy(AVAILABLE_ALGORITHMS[algorithm]()._original_algorithm) + elif reference_type == AlgorithmReferenceType.PYTHON_MODULE: st_function = importlib.import_module(algorithm) st_function = getattr(st_function, algorithm) if not st_function: @@ -75,3 +91,14 @@ def load_json(pipeline_file: str): with open(pipeline_file, "r") as file: pipeline = json.loads(file.read()) return pipeline + + +def deserialize_result(result: Dict[str, Any], klass=Result): + """Deserialize the result.""" + + result["algorithm"] = AlgorithmReference(**result["algorithm"]) + for i, edge in enumerate(result["edges"]): + result["edges"][i]["edge_type"] = EDGE_TYPES[edge["edge_type"]["name"]]( + **edge["edge_type"] + ) + return parse_obj_as(klass, result) diff --git a/causy/ui.py b/causy/ui.py deleted file mode 100644 index 64a5e97..0000000 --- a/causy/ui.py +++ /dev/null @@ -1,112 +0,0 @@ -import os - -import fastapi -import typer -import uvicorn -from typing import Any, Dict, Optional, Union - -from fastapi import APIRouter, HTTPException -from pydantic import BaseModel, UUID4 -from starlette.staticfiles import StaticFiles - -import logging - -from causy.interfaces import ( - NodeInterface, - CausyAlgorithm, - CausyAlgorithmReference, - CausyResult, - CausyAlgorithmReferenceType, -) -from causy.serialization import load_algorithm_by_reference - -logger = logging.getLogger(__name__) - -API_ROUTES = APIRouter() - -MODEL = None - - -class NodePosition(BaseModel): - x: Optional[float] - y: Optional[float] - - -class PositionedNode(NodeInterface): - position: Optional[NodePosition] = None - - -class CausyExtendedResult(CausyResult): - nodes: Dict[Union[UUID4, str], PositionedNode] - - -@API_ROUTES.get("/status", response_model=Dict[str, Any]) -async def get_status(): - """Get the current status of the API.""" - return {"status": "ok"} - - -@API_ROUTES.get("/model", response_model=CausyExtendedResult) -async def get_model(): - """Get the current model.""" - return MODEL - - -@API_ROUTES.get( - "/algorithm/{reference_type}/{reference}", response_model=CausyAlgorithm -) -async def get_algorithm(reference_type: str, reference: str): - """Get the current algorithm.""" - if reference.startswith("/") or ".." in reference: - raise HTTPException(400, "Invalid reference") - - if reference_type not in CausyAlgorithmReferenceType.__members__.values(): - raise HTTPException(400, "Invalid reference type") - - try: - algorithm = load_algorithm_by_reference(reference_type, reference) - return algorithm - except Exception as e: - raise HTTPException(400, str(e)) - - -def server(result: Dict[str, Any]): - """Create the FastAPI server.""" - app = fastapi.FastAPI( - title="causy-api", - version="0.0.1", - description="causys internal api to serve data from the result files", - ) - global MODEL - result["algorithm"] = CausyAlgorithmReference(**result["algorithm"]) - MODEL = CausyExtendedResult(**result) - - app.include_router(API_ROUTES, prefix="/api/v1", tags=["api"]) - app.mount( - "", - StaticFiles( - directory=os.path.join(os.path.dirname(__file__), "static"), html=True - ), - name="static", - ) - - host = os.getenv("HOST", "localhost") - port = int(os.getenv("PORT", "8000")) - cors_enabled = os.getenv("CORS_ENABLED", "false").lower() == "true" - - # cors e.g. for development of separate frontend - if cors_enabled: - logger.warning(typer.style("🌐 CORS enabled", fg=typer.colors.YELLOW)) - from fastapi.middleware.cors import CORSMiddleware - - app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - server_config = uvicorn.Config(app, host=host, port=port, log_level="error") - server = uvicorn.Server(server_config) - return server_config, server diff --git a/causy/ui/__init__.py b/causy/ui/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/causy/ui/cli.py b/causy/ui/cli.py new file mode 100644 index 0000000..c50727b --- /dev/null +++ b/causy/ui/cli.py @@ -0,0 +1,18 @@ +import typer +from causy.serialization import load_json +from causy.ui.server import server +from causy.workspaces.cli import _current_workspace + + +def ui(result_file: str = None): + """Start the causy UI.""" + if not result_file: + workspace = _current_workspace() + server_config, server_runner = server(workspace=workspace) + else: + result = load_json(result_file) + server_config, server_runner = server(result=result) + + typer.launch(f"http://{server_config.host}:{server_config.port}") + typer.echo(f"🚀 Starting server at http://{server_config.host}:{server_config.port}") + server_runner.run() diff --git a/causy/ui/models.py b/causy/ui/models.py new file mode 100644 index 0000000..bd18a6f --- /dev/null +++ b/causy/ui/models.py @@ -0,0 +1,30 @@ +from typing import Optional, List, Dict, Union + +from causy.interfaces import NodeInterface +from causy.models import Result +from causy.workspaces.models import Experiment +from pydantic import BaseModel, UUID4 + + +class NodePosition(BaseModel): + x: Optional[float] + y: Optional[float] + + +class ExperimentVersion(BaseModel): + version: int + name: str + + +class ExtendedExperiment(Experiment): + versions: Optional[List[ExperimentVersion]] = None + name: str = None + + +class PositionedNode(NodeInterface): + position: Optional[NodePosition] = None + + +class ExtendedResult(Result): + nodes: Dict[Union[UUID4, str], PositionedNode] + version: Optional[int] = None diff --git a/causy/ui/server.py b/causy/ui/server.py new file mode 100644 index 0000000..f75ac7e --- /dev/null +++ b/causy/ui/server.py @@ -0,0 +1,217 @@ +import logging +import os +from datetime import datetime +from importlib.metadata import version +from typing import Dict, Any, List, Optional + +import fastapi +import typer +import uvicorn + +from causy.models import ( + AlgorithmReference, + Algorithm, + AlgorithmReferenceType, +) +from causy.serialization import load_algorithm_by_reference +from causy.ui.models import ExtendedResult, ExtendedExperiment, ExperimentVersion +from causy.workspaces.cli import ( + _load_latest_experiment_result, + _load_experiment_versions, + _load_experiment_result, +) +from causy.workspaces.models import Workspace +from fastapi import APIRouter, HTTPException +from starlette.staticfiles import StaticFiles + +logger = logging.getLogger(__name__) +API_ROUTES = APIRouter() +MODEL: Optional[ExtendedResult] = None +WORKSPACE: Optional[Workspace] = None + + +@API_ROUTES.get("/status", response_model=Dict[str, Any]) +async def get_status(): + """Get the current status of the API.""" + return { + "status": "ok", + "model_loaded": MODEL is not None, + "workspace_loaded": WORKSPACE is not None, + "mode": "workspace" if WORKSPACE else "model", + "causy_version": version("causy"), + } + + +@API_ROUTES.get("/model", response_model=ExtendedResult) +async def get_model(): + """Get the current model.""" + if not MODEL: + raise HTTPException(404, "No model loaded") + return MODEL + + +@API_ROUTES.get("/workspace", response_model=Workspace) +async def get_workspace(): + if not WORKSPACE: + raise HTTPException(404, "No workspace loaded") + return WORKSPACE + + +@API_ROUTES.get("/experiments/{experiment_name}/latest", response_model=ExtendedResult) +async def get_latest_experiment(experiment_name: str): + """Get the current experiment.""" + if not WORKSPACE: + raise HTTPException(404, "No workspace loaded") + + if experiment_name not in WORKSPACE.experiments: + raise HTTPException(404, "Experiment not found") + + try: + experiment = _load_latest_experiment_result(WORKSPACE, experiment_name) + except Exception as e: + raise HTTPException(400, str(e)) + + version = _load_experiment_versions(WORKSPACE, experiment_name)[0] + + experiment["algorithm"] = AlgorithmReference(**experiment["algorithm"]) + experiment["version"] = version + experiment = ExtendedResult(**experiment) + + return experiment + + +@API_ROUTES.get( + "/experiments/{experiment_name}/{version_number}", + response_model=ExtendedResult, +) +async def get_experiment(experiment_name: str, version_number: int): + """Get the current experiment.""" + if not WORKSPACE: + raise HTTPException(404, "No workspace loaded") + + if experiment_name not in WORKSPACE.experiments: + raise HTTPException(404, "Experiment not found") + + try: + experiment = _load_experiment_result(WORKSPACE, experiment_name, version_number) + except Exception as e: + raise HTTPException(400, str(e)) + + experiment["algorithm"] = AlgorithmReference(**experiment["algorithm"]) + experiment["version"] = version_number + experiment = ExtendedResult(**experiment) + + return experiment + + +@API_ROUTES.get("/experiments", response_model=List[ExtendedExperiment]) +async def get_experiments(): + """Get the current experiment.""" + if not WORKSPACE: + raise HTTPException(404, "No workspace loaded") + + experiments = [] + for experiment_name, experiment in WORKSPACE.experiments.items(): + extended_experiment = ExtendedExperiment(**experiment.model_dump()) + versions = [] + for experiment_version in _load_experiment_versions(WORKSPACE, experiment_name): + versions.append( + ExperimentVersion( + version=experiment_version, + name=datetime.fromtimestamp(experiment_version).isoformat(), + ) + ) + extended_experiment.versions = versions + extended_experiment.name = experiment_name + experiments.append(extended_experiment) + + return experiments + + +@API_ROUTES.get("/algorithm/{reference_type}/{reference}", response_model=Algorithm) +async def get_algorithm(reference_type: str, reference: str): + """Get the current algorithm.""" + if reference.startswith("/") or ".." in reference: + raise HTTPException(400, "Invalid reference") + + if reference_type not in AlgorithmReferenceType.__members__.values(): + raise HTTPException(400, "Invalid reference type") + + try: + algorithm = load_algorithm_by_reference(reference_type, reference) + return algorithm + except Exception as e: + raise HTTPException(400, str(e)) + + +def _create_ui_app(with_static=True): + """Get the server.""" + app = fastapi.FastAPI( + title="causy-api", + version=version("causy"), + description="causys internal api to serve data from the result files", + ) + app.include_router(API_ROUTES, prefix="/api/v1", tags=["api"]) + if with_static: + app.mount( + "", + StaticFiles( + directory=os.path.join(os.path.dirname(__file__), "..", "static"), + html=True, + ), + name="static", + ) + return app + + +def _set_model(result: Dict[str, Any]): + """Set the model.""" + global MODEL + # for testing + if result is None: + MODEL = None + return + + result["algorithm"] = AlgorithmReference(**result["algorithm"]) + MODEL = ExtendedResult(**result) + + +def _set_workspace(workspace: Workspace): + """Set the workspace.""" + global WORKSPACE + WORKSPACE = workspace + + +def server(result: Dict[str, Any] = None, workspace: Workspace = None): + """Create the FastAPI server.""" + app = _create_ui_app() + + if result: + _set_model(result) + + if workspace: + _set_workspace(workspace) + + if not MODEL and not WORKSPACE: + raise ValueError("No model or workspace provided") + + host = os.getenv("HOST", "localhost") + port = int(os.getenv("PORT", "8000")) + cors_enabled = os.getenv("CORS_ENABLED", "false").lower() == "true" + + # cors e.g. for development of separate frontend + if cors_enabled: + logger.warning(typer.style("🌐 CORS enabled", fg=typer.colors.YELLOW)) + from fastapi.middleware.cors import CORSMiddleware + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + server_config = uvicorn.Config(app, host=host, port=port, log_level="error") + server = uvicorn.Server(server_config) + return server_config, server diff --git a/causy/variables.py b/causy/variables.py new file mode 100644 index 0000000..2d3236f --- /dev/null +++ b/causy/variables.py @@ -0,0 +1,261 @@ +import copy +import enum +from types import NoneType +from typing import Any, Union, TypeVar, Generic, Optional, List, Dict + +from pydantic import BaseModel, computed_field + + +VariableInterfaceType = TypeVar("VariableInterfaceType") + + +class VariableTypes(enum.Enum): + String = "string" + Integer = "integer" + Float = "float" + Bool = "bool" + + +class BaseVariable(BaseModel, Generic[VariableInterfaceType]): + """ + Represents a single variable. It can be a string, int, float or bool. The type of the variable is determined by the + type attribute. + """ + + def __init__(self, **data): + super().__init__(**data) + self.validate_value(self.value) + + name: str + value: Union[str, int, float, bool] + choices: Optional[List[Union[str, int, float, bool]]] = None + + def is_valid(self): + return self.is_valid_value(self.value) + + def is_valid_value(self, value): + try: + self.validate_value(value) + return True + except ValueError: + return False + + def validate_value(self, value): + if not isinstance(value, self._PYTHON_TYPE): + raise ValueError( + f"Variable {self.name} is not valid." + f" (should be {self.type} but is {type(value)})" + ) + + if self.choices and value not in self.choices: + raise ValueError( + f"Value {value} is not in the list of choices: {self.choices}" + ) + + @computed_field + @property + def type(self) -> Optional[str]: + return self._TYPE + + +class StringVariable( + BaseVariable[VariableInterfaceType], Generic[VariableInterfaceType] +): + """ + Represents a single string variable. + """ + + value: str + name: str + + _TYPE: Optional[str] = VariableTypes.String.value + _PYTHON_TYPE: Optional[type] = str + + +class IntegerVariable( + BaseVariable[VariableInterfaceType], Generic[VariableInterfaceType] +): + """ + Represents a single int variable. + """ + + value: int + name: str + + _TYPE: str = VariableTypes.Integer.value + _PYTHON_TYPE: Optional[type] = int + + def validate_value(self, value): + # check if the value is a boolean and raise an error + # we do this because in python bool is a subclass of int + if isinstance(value, bool): + raise ValueError( + f"Variable {self.name} is not valid." + f" (should be {self.type} but is {type(value)})" + ) + super().validate_value(value) + + +class FloatVariable( + BaseVariable[VariableInterfaceType], Generic[VariableInterfaceType] +): + """ + Represents a single float variable. + """ + + value: float + name: str + + _TYPE: str = VariableTypes.Float.value + _PYTHON_TYPE: Optional[type] = float + + +class BoolVariable(BaseVariable[VariableInterfaceType], Generic[VariableInterfaceType]): + """ + Represents a single bool variable. + """ + + value: bool + name: str + + _TYPE: str = VariableTypes.Bool.value + _PYTHON_TYPE: Optional[type] = bool + + +class VariableReference(BaseModel, Generic[VariableInterfaceType]): + """ + Represents a reference to a variable. + """ + + name: str + + @computed_field + @property + def type(self) -> str: + return "reference" + + +VARIABLE_MAPPING = { + VariableTypes.String.value: StringVariable, + VariableTypes.Integer.value: IntegerVariable, + VariableTypes.Float.value: FloatVariable, + VariableTypes.Bool.value: BoolVariable, +} + +BoolParameter = Union[bool, VariableReference] +IntegerParameter = Union[int, VariableReference] +FloatParameter = Union[float, VariableReference] +StringParameter = Union[str, VariableReference] +CausyParameter = Union[BoolParameter, IntegerParameter, FloatParameter, StringParameter] + + +def validate_variable_values(algorithm, variable_values: Dict[str, Any]): + """ + Validate the variable values for the algorithm. + :param algorithm: + :param variable_values: + :return: + """ + algorithm_variables = {avar.name: avar for avar in algorithm.variables} + + for variable_name, variable_value in variable_values.items(): + if variable_name not in algorithm_variables.keys(): + raise ValueError( + f"Variable {variable_name} not found in the algorithm variables." + ) + algorithm_variables[variable_name].validate_value(variable_value) + + return True + + +def resolve_variables( + variables: List[BaseVariable], variable_values: Dict[str, Any] +) -> Dict[str, Any]: + """ + Resolve the variables from the list of variables and the variable values coming from the user. + :param variables: + :param variable_values: + :return: + """ + resolved_variables = {} + for variable in variables: + if variable.name in variable_values: + resolved_variables[variable.name] = variable_values[variable.name] + else: + resolved_variables[variable.name] = variable.value + + return resolved_variables + + +def resolve_variable_to_object(obj: Any, variables): + """ + Resolve the variables to the object. + :param obj: + :param variables: + :return: + """ + for attribute, value in obj.__dict__.items(): + if isinstance(value, VariableReference): + if value.name in variables: + obj.__dict__[attribute] = variables[value.name] + else: + raise ValueError(f'Variable "{value.name}" not found in the variables.') + elif hasattr(value, "__dict__"): + obj.__dict__[attribute] = resolve_variable_to_object(value, variables) + return obj + + +def resolve_variables_to_algorithm_for_pipeline_steps(pipeline_steps, variables): + """ + Resolve the variables to the algorithm. + :param pipeline_steps: + :param variables: + :return: + """ + for k, pipeline_step in enumerate(pipeline_steps): + pipeline_steps[k] = resolve_variable_to_object(pipeline_step, variables) + # handle cases when we have sub-pipelines like in Loops + if hasattr(pipeline_step, "pipeline_steps"): + pipeline_step.pipeline_steps = ( + resolve_variables_to_algorithm_for_pipeline_steps( + pipeline_step.pipeline_steps, variables + ) + ) + + return pipeline_steps + + +def deserialize_variable(variable_dict: Dict[str, Any]) -> BaseVariable: + """ + Deserialize the variable from the dictionary. + :param variable_dict: + :return: + """ + if "type" not in variable_dict: + raise ValueError("Variable type not found.") + if variable_dict["type"] not in VARIABLE_MAPPING: + raise ValueError(f"Variable type {variable_dict['type']} not found.") + + return VARIABLE_MAPPING[variable_dict["type"]](**variable_dict) + + +def deserialize_variable_references(element: object) -> object: + """ + Deserialize the variable references from the pipeline step. + :param pipeline_step: + :return: + """ + for attribute, value in element.__dict__.items(): + if isinstance(value, dict) and "type" in value and value["type"] == "reference": + setattr(element, attribute, VariableReference(name=value["name"])) + + if hasattr(value, "__dict__"): + setattr(element, attribute, deserialize_variable_references(value)) + + if hasattr(element, "pipeline_steps"): + element.pipeline_steps = [ + deserialize_variable_references(pipeline_step) + for pipeline_step in element.pipeline_steps + ] + + return element diff --git a/causy/workspaces/cli.py b/causy/workspaces/cli.py index 4fa084b..adf47d5 100644 --- a/causy/workspaces/cli.py +++ b/causy/workspaces/cli.py @@ -1,28 +1,85 @@ -import click +import json +import logging +from datetime import datetime +from typing import List, Dict + import pydantic_yaml +import questionary import typer import os -from pydantic_yaml import parse_yaml_raw_as, to_yaml_str +from markdown.extensions.toc import slugify +from pydantic_yaml import to_yaml_str +from jinja2 import ( + Environment, + select_autoescape, + ChoiceLoader, + FileSystemLoader, + PackageLoader, +) +from rich.console import Console +from rich.table import Table + +from causy.graph_model import graph_model_factory +from causy.graph_utils import hash_dictionary +from causy.models import ( + AlgorithmReference, + AlgorithmReferenceType, + Result, +) +from causy.serialization import ( + load_algorithm_by_reference, + CausyJSONEncoder, + deserialize_result, +) +from causy.variables import validate_variable_values, resolve_variables +from causy.workspaces.models import Workspace, Experiment +from causy.data_loader import DataLoaderReference, load_data_loader -from causy.interfaces import CausyAlgorithmReference -from causy.workspaces.serializer_models import Workspace +workspace_app = typer.Typer() -app = typer.Typer() +pipeline_app = typer.Typer() +experiment_app = typer.Typer() +dataloader_app = typer.Typer() +logger = logging.getLogger(__name__) + +NO_COLOR = os.environ.get("NO_COLOR", False) WORKSPACE_FILE_NAME = "workspace.yml" +JINJA_ENV = Environment( + loader=ChoiceLoader( + [ + PackageLoader("causy", "workspaces/templates"), + FileSystemLoader("./templates"), + ] + ), + autoescape=select_autoescape(), +) + + +class WorkspaceNotFoundError(Exception): + pass + -def current_workspace(fail_if_none=True): +def show_error(message: str): + if NO_COLOR: + typer.echo(f"❌ {message}", err=True) + else: + typer.echo(typer.style(f"❌ {message}", fg=typer.colors.RED), err=True) + + +def show_success(message: str): + typer.echo(f"✅ {message}") + + +def _current_workspace(fail_if_none: bool = True) -> Workspace: """ Return the current workspace. - :return: + :param fail_if_none: if True, raise an exception if no workspace is found + :return: the workspace """ - # get current path - # check if there is a workspace - # return the workspace - workspace_data = None workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) if os.path.exists(workspace_path): @@ -30,7 +87,9 @@ def current_workspace(fail_if_none=True): workspace_data = f.read() if fail_if_none and workspace_data is None: - raise Exception("No workspace found in the current directory") + raise WorkspaceNotFoundError("No workspace found in the current directory") + + workspace = None if workspace_data is not None: workspace = pydantic_yaml.parse_yaml_raw_as(Workspace, workspace_data) @@ -38,16 +97,497 @@ def current_workspace(fail_if_none=True): return workspace -@app.command() -def main(): - print("Hello World!") +def _create_pipeline(workspace: Workspace = None) -> Workspace: + pipeline_creation = questionary.select( + "Do you want to use an existing pipeline or create a new one?", + choices=[ + questionary.Choice( + "Use an existing causy pipeline (preconfigured).", "PRECONFIGURED" + ), + questionary.Choice( + "Eject existing pipeline (allows you to change pipeline configs).", + "EJECT", + ), + questionary.Choice( + "Create a pipeline skeleton (as a python module).", "SKELETON" + ), + ], + ).ask() + + if pipeline_creation == "PRECONFIGURED": + from causy.causal_discovery.constraint.algorithms import AVAILABLE_ALGORITHMS + + pipeline_reference = questionary.select( + "Which pipeline do you want to use?", choices=AVAILABLE_ALGORITHMS.keys() + ).ask() + pipeline_reference = AVAILABLE_ALGORITHMS[pipeline_reference] + # make pipeline reference as string + pipeline = AlgorithmReference( + reference=pipeline_reference().algorithm.name, + type=AlgorithmReferenceType.NAME, + ) + + pipeline_name = questionary.text("Enter the name of the pipeline").ask() + pipeline_name = slugify(pipeline_name, "_") + + workspace.pipelines[pipeline_name] = pipeline + elif pipeline_creation == "EJECT": + from causy.causal_discovery.constraint.algorithms import AVAILABLE_ALGORITHMS + + pipeline_skeleton = questionary.select( + "Which pipeline do you want to use?", choices=AVAILABLE_ALGORITHMS.keys() + ).ask() + pipeline_reference = AVAILABLE_ALGORITHMS[pipeline_skeleton] + pipeline_name = questionary.text("Enter the name of the pipeline").ask() + pipeline_slug = slugify(pipeline_name, "_") + with open(f"{pipeline_slug}.yml", "w") as f: + f.write(to_yaml_str(pipeline_reference()._original_algorithm)) + + pipeline = AlgorithmReference( + reference=f"{pipeline_slug}.yml", type=AlgorithmReferenceType.FILE + ) + + workspace.pipelines[pipeline_name] = pipeline + elif pipeline_creation == "SKELETON": + pipeline_name = questionary.text("Enter the name of the pipeline").ask() + pipeline_slug = slugify(pipeline_name, "_") + JINJA_ENV.get_template("pipeline.py.tpl").stream( + pipeline_name=pipeline_name + ).dump(f"{pipeline_slug}.py") + pipeline = AlgorithmReference( + reference=f"{pipeline_slug}.PIPELINE", + type=AlgorithmReferenceType.PYTHON_MODULE, + ) + workspace.pipelines[pipeline_slug] = pipeline + + typer.echo(f'Pipeline "{pipeline_name}" created.') + + return workspace + + +def _create_experiment(workspace: Workspace) -> Workspace: + experiment_name = questionary.text("Enter the name of the experiment").ask() + experiment_pipeline = questionary.select( + "Select the pipeline for the experiment", choices=workspace.pipelines.keys() + ).ask() + experiment_data_loader = questionary.select( + "Select the data loader for the experiment", + choices=workspace.data_loaders.keys(), + ).ask() + + experiment_slug = slugify(experiment_name, "_") + + # extract and prefill the variables + variables = {} + pipeline = load_algorithm_by_reference( + workspace.pipelines[experiment_pipeline].type, + workspace.pipelines[experiment_pipeline].reference, + ) + if len(pipeline.variables) > 0: + variables = resolve_variables(pipeline.variables, {}) + + workspace.experiments[experiment_slug] = Experiment( + **{ + "pipeline": experiment_pipeline, + "data_loader": experiment_data_loader, + "variables": variables, + } + ) + + typer.echo(f'Experiment "{experiment_name}" created.') + + return workspace + + +def _create_data_loader(workspace: Workspace) -> Workspace: + data_loader_type = questionary.select( + "Do you want to use an existing pipeline or create a new one?", + choices=[ + questionary.Choice("Load a JSON File.", "json"), + questionary.Choice("Load a JSONL File.", "jsonl"), + questionary.Choice("Load data dynamically (via Python Script).", "dynamic"), + ], + ).ask() + + data_loader_name = questionary.text("Enter the name of the data loader").ask() + + if data_loader_type in ["json", "jsonl"]: + data_loader_path = questionary.path( + "Choose the file or enter the file name:", + ).ask() + data_loader_slug = slugify(data_loader_name, "_") + workspace.data_loaders[data_loader_slug] = DataLoaderReference( + **{ + "type": data_loader_type, + "reference": data_loader_path, + } + ) + elif data_loader_type == "dynamic": + data_loader_slug = slugify(data_loader_name, "_") + JINJA_ENV.get_template("dataloader.py.tpl").stream( + data_loader_name=data_loader_name + ).dump(f"{data_loader_slug}.py") + workspace.data_loaders[data_loader_slug] = DataLoaderReference( + **{ + "type": data_loader_type, + "reference": f"{data_loader_slug}.DataLoader", + } + ) + + typer.echo(f'Data loader "{data_loader_name}" created.') + + return workspace + + +def _execute_experiment(workspace: Workspace, experiment: Experiment) -> Result: + """ + Execute an experiment. This function will load the pipeline and the data loader and execute the pipeline. + :param workspace: + :param experiment: + :return: + """ + typer.echo(f"Loading Pipeline: {experiment.pipeline}") + pipeline = load_algorithm_by_reference( + workspace.pipelines[experiment.pipeline].type, + workspace.pipelines[experiment.pipeline].reference, + ) + + validate_variable_values(pipeline, experiment.variables) + variables = resolve_variables(pipeline.variables, experiment.variables) + typer.echo(f"Using variables: {variables}") + + typer.echo(f"Loading Data: {experiment.data_loader}") + data_loader = load_data_loader(workspace.data_loaders[experiment.data_loader]) + model = graph_model_factory(pipeline, experiment.variables)() + model.create_graph_from_data(data_loader) + model.create_all_possible_edges() + model.execute_pipeline_steps() + + return Result( + algorithm=workspace.pipelines[experiment.pipeline], + action_history=model.graph.graph.action_history, + edges=model.graph.retrieve_edges(), + nodes=model.graph.nodes, + variables=variables, + data_loader_hash=data_loader.hash(), + algorithm_hash=pipeline.hash(), + variables_hash=hash_dictionary(variables), + ) + + +def _load_latest_experiment_result( + workspace: Workspace, experiment_name: str +) -> Experiment: + versions = _load_experiment_versions(workspace, experiment_name) + + if experiment_name not in workspace.experiments: + raise ValueError(f"Experiment {experiment_name} not found in the workspace") + + if len(versions) == 0: + raise ValueError(f"Experiment {experiment_name} not found in the file system") + + with open(f"{experiment_name}_{versions[0]}.json", "r") as f: + experiment = json.load(f) + + return experiment + + +def _load_experiment_result( + workspace: Workspace, experiment_name: str, version_number: int +) -> Dict[str, any]: + if experiment_name not in workspace.experiments: + raise ValueError(f"Experiment {experiment_name} not found in the workspace") + + if version_number not in _load_experiment_versions(workspace, experiment_name): + raise ValueError( + f"Version {version_number} not found for experiment {experiment_name}" + ) + + with open(f"{experiment_name}_{version_number}.json", "r") as f: + experiment = json.load(f) + + return experiment + + +def _load_experiment_versions(workspace: Workspace, experiment_name: str) -> List[int]: + versions = [] + for file in os.listdir(): + # check for files if they have the right prefix followed by a unix timestamp (int) and the file extension, e.g. experiment_123456789.json. + # Extract the unix timestamp + if file.startswith(f"{experiment_name}_") and file.endswith(".json"): + segments = file.split("_") + timestamp = int(segments[-1].split(".")[0]) + name = "_".join(segments[:-1]) + if name != experiment_name: + # an experiment with a different name + continue + versions.append(timestamp) + return sorted(versions, reverse=True) + + +def _save_experiment_result(workspace: Workspace, experiment_name: str, result: Result): + timestamp = int(datetime.timestamp(result.created_at)) + with open(f"{experiment_name}_{timestamp}.json", "w") as f: + f.write(json.dumps(result.model_dump(), cls=CausyJSONEncoder, indent=4)) + + +@pipeline_app.command(name="add") +def create_pipeline(): + """Create a new pipeline in the current workspace.""" + workspace = _current_workspace() + workspace = _create_pipeline(workspace) + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + +@pipeline_app.command(name="rm") +def remove_pipeline(pipeline_name: str): + """Remove a pipeline from the current workspace.""" + workspace = _current_workspace() + + if pipeline_name not in workspace.pipelines: + show_error(f"Pipeline {pipeline_name} not found in the workspace.") + return + + # check if the pipeline is still in use + for experiment_name, experiment in workspace.experiments.items(): + if experiment.pipeline == pipeline_name: + show_error( + f"Pipeline {pipeline_name} is still in use by experiment {experiment_name}. Cannot remove." + ) + return + + del workspace.pipelines[pipeline_name] + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + show_success(f"Pipeline {pipeline_name} removed from the workspace.") + + +def _experiment_needs_reexecution(workspace: Workspace, experiment_name: str) -> bool: + """ + Check if an experiment needs to be re-executed. + :param workspace: + :param experiment_name: + :return: + """ + if experiment_name not in workspace.experiments: + raise ValueError(f"Experiment {experiment_name} not found in the workspace") + + versions = _load_experiment_versions(workspace, experiment_name) + + if len(versions) == 0: + logger.info(f"Experiment {experiment_name} not found in the file system.") + return True + + latest_experiment = _load_latest_experiment_result(workspace, experiment_name) + experiment = workspace.experiments[experiment_name] + latest_experiment = deserialize_result(latest_experiment) + if ( + latest_experiment.algorithm_hash is None + or latest_experiment.data_loader_hash is None + ): + logger.info(f"Experiment {experiment_name} has no hashes.") + return True + + pipeline = load_algorithm_by_reference( + workspace.pipelines[experiment.pipeline].type, + workspace.pipelines[experiment.pipeline].reference, + ) + + validate_variable_values(pipeline, experiment.variables) + variables = resolve_variables(pipeline.variables, experiment.variables) + + if latest_experiment.variables_hash != hash_dictionary(variables): + logger.info(f"Experiment {experiment_name} has different variables.") + return True + + model = graph_model_factory(pipeline, variables)() + if latest_experiment.algorithm_hash != model.algorithm.hash(): + logger.info(f"Experiment {experiment_name} has a different pipeline.") + return True + + data_loder = load_data_loader(workspace.data_loaders[experiment.data_loader]) + if latest_experiment.data_loader_hash != data_loder.hash(): + logger.info( + f"Experiment {experiment_name} has a different data loader/dataset." + ) + return True + + return False + + +def _clear_experiment(experiment_name: str, workspace: Workspace): + versions = _load_experiment_versions(workspace, experiment_name) + versions_removed = 0 + for version in versions: + try: + os.remove(f"{experiment_name}_{version}.json") + versions_removed += 1 + except FileNotFoundError: + pass + return versions_removed + + +@experiment_app.command(name="add") +def create_experiment(): + """Create a new experiment in the current workspace.""" + workspace = _current_workspace() + workspace = _create_experiment(workspace) + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + +@experiment_app.command(name="rm") +def remove_experiment(experiment_name: str): + """Remove an experiment from the current workspace.""" + workspace = _current_workspace() + + if experiment_name not in workspace.experiments: + show_error(f"Experiment {experiment_name} not found in the workspace.") + return + + versions_removed = _clear_experiment(experiment_name, workspace) + + del workspace.experiments[experiment_name] + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + show_success( + f"Experiment {experiment_name} removed from the workspace. Removed {versions_removed} versions." + ) + + +@experiment_app.command(name="clear") +def clear_experiment(experiment_name: str): + """Clear all versions of an experiment.""" + workspace = _current_workspace() + + if experiment_name not in workspace.experiments: + show_error(f"Experiment {experiment_name} not found in the workspace.") + return + + versions_removed = _clear_experiment(experiment_name, workspace) + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + show_success( + f"Experiment {experiment_name} cleared. Removed {versions_removed} versions." + ) + + +@experiment_app.command(name="update-variable") +def update_experiment_variable( + experiment_name: str, variable_name: str, variable_value: str +): + """Update a variable in an experiment.""" + workspace = _current_workspace() + + if experiment_name not in workspace.experiments: + show_error(f"Experiment {experiment_name} not found in the workspace.") + return + + experiment = workspace.experiments[experiment_name] + + pipeline = load_algorithm_by_reference( + workspace.pipelines[experiment.pipeline].type, + workspace.pipelines[experiment.pipeline].reference, + ) + + current_variable = None + for existing_variable in pipeline.variables: + if variable_name == existing_variable.name: + current_variable = existing_variable + break + else: + show_error(f"Variable {variable_name} not found in the experiment.") + return + + # try to cast the variable value to the correct type + try: + variable_value = current_variable._PYTHON_TYPE(variable_value) + except ValueError: + show_error( + f'Variable {variable_name} should be {current_variable.type}. But got "{variable_value}" which is not a valid value.' + ) + return + + # check if the variable is a valid value + if not validate_variable_values(pipeline, {variable_name: variable_value}): + show_error(f"Variable {variable_name} is not a valid value.") + return + + experiment.variables[variable_name] = variable_value + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + show_success(f"Variable {variable_name} updated in experiment {experiment_name}.") + + +@dataloader_app.command(name="add") +def create_data_loader(): + """Create a new data loader in the current workspace.""" + workspace = _current_workspace() + workspace = _create_data_loader(workspace) + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + +@dataloader_app.command(name="rm") +def remove_data_loader(data_loader_name: str): + """Remove a data loader from the current workspace.""" + workspace = _current_workspace() + + if data_loader_name not in workspace.data_loaders: + show_error(f"Data loader {data_loader_name} not found in the workspace.") + return + + # check if the data loader is still in use + for experiment_name, experiment in workspace.experiments.items(): + if experiment.data_loader == data_loader_name: + show_error( + f"Data loader {data_loader_name} is still in use by experiment {experiment_name}. Cannot remove." + ) + return + + del workspace.data_loaders[data_loader_name] + + workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) + with open(workspace_path, "w") as f: + f.write(pydantic_yaml.to_yaml_str(workspace)) + + show_success(f"Data loader {data_loader_name} removed from the workspace.") -@app.command() +@workspace_app.command() +def info(): + """Show general information about the workspace.""" + workspace = _current_workspace() + typer.echo(f"Workspace: {workspace.name}") + typer.echo(f"Author: {workspace.author}") + typer.echo(f"Pipelines: {workspace.pipelines}") + typer.echo(f"Data loaders: {workspace.data_loaders}") + typer.echo(f"Experiments: {workspace.experiments}") + + +@workspace_app.command() def init(): """ Initialize a new workspace in the current directory. - :retur """ workspace_path = os.path.join(os.getcwd(), WORKSPACE_FILE_NAME) @@ -79,32 +619,212 @@ def init(): workspace.pipelines = {} if configure_pipeline: - use_existing_pipeline = typer.confirm( - "Do you want to use an existing pipeline?", default=False - ) - if use_existing_pipeline: - from causy.algorithms import AVAILABLE_ALGORITHMS + workspace = _create_pipeline(workspace) - pipeline_name = click.prompt( - "\nSelect an tag to deploy: ?", - type=click.Choice(AVAILABLE_ALGORITHMS.keys()), - ) - pipeline_reference = AVAILABLE_ALGORITHMS[pipeline_name] - # make pipeline reference as string - pipeline = CausyAlgorithmReference( - name=pipeline_name, reference=str(pipeline_reference) + configure_data_loader = typer.confirm( + "Do you want to configure a data loader?", default=False + ) + + workspace.data_loaders = {} + if configure_data_loader: + data_loader_type = questionary.select( + "Do you want to use an existing pipeline or create a new one?", + choices=[ + questionary.Choice("Load a JSON File.", "json"), + questionary.Choice("Load a JSONL File.", "jsonl"), + questionary.Choice( + "Load data dynamically (via Python Script).", "dynamic" + ), + ], + ).ask() + + if data_loader_type in ["json", "jsonl"]: + data_loader_path = questionary.path( + "Choose the file or enter the file name:", + ).ask() + data_loader_name = questionary.text( + "Enter the name of the data loader" + ).ask() + data_loader_slug = slugify(data_loader_name, "_") + workspace.data_loaders[data_loader_slug] = { + "type": data_loader_type, + "reference": data_loader_path, + } + elif data_loader_type == "dynamic": + data_loader_name = questionary.text( + "Enter the name of the data loader" + ).ask() + data_loader_slug = slugify(data_loader_name, "_") + JINJA_ENV.get_template("dataloader.py.tpl").stream( + data_loader_name=data_loader_name + ).dump(f"{data_loader_slug}.py") + workspace.data_loaders[data_loader_slug] = DataLoaderReference( + **{ + "type": data_loader_type, + "reference": f"{data_loader_slug}.DataLoader", + } ) - workspace.pipelines[pipeline_name] = pipeline + workspace.experiments = {} - workspace.data_loaders = None - workspace.experiments = None + if len(workspace.pipelines) > 0 and len(workspace.data_loaders) > 0: + configure_experiment = typer.confirm( + "Do you want to configure an experiment?", default=False + ) + + if configure_experiment: + workspace = _create_experiment(workspace) with open(workspace_path, "w") as f: f.write(pydantic_yaml.to_yaml_str(workspace)) - print(f"Workspace created in {workspace_path}") + typer.echo(f"Workspace created in {workspace_path}") -@app.command() -def execute(experiment_name=""): - pass +@workspace_app.command() +def execute(experiment_name: str = None, force_reexecution: bool = False): + """ + Execute an experiment or all experiments in the workspace. + """ + workspace = _current_workspace() + if experiment_name is None: + # execute all experiments + for experiment_name, experiment in workspace.experiments.items(): + try: + needs_reexecution = _experiment_needs_reexecution( + workspace, experiment_name + ) + except ValueError as e: + show_error(str(e)) + needs_reexecution = True + + if needs_reexecution is False and force_reexecution is False: + typer.echo(f"Skipping experiment: {experiment_name}. (no changes)") + continue + typer.echo(f"Executing experiment: {experiment_name}") + result = _execute_experiment(workspace, experiment) + _save_experiment_result(workspace, experiment_name, result) + else: + if experiment_name not in workspace.experiments: + typer.echo(f"Experiment {experiment_name} not found in the workspace.") + return + experiment = workspace.experiments[experiment_name] + typer.echo(f"Executing experiment: {experiment_name}") + result = _execute_experiment(workspace, experiment) + + _save_experiment_result(workspace, experiment_name, result) + + +@workspace_app.command() +def diff(experiment_names: List[str], only_differences: bool = False): + """ + Show the differences between multiple experiment results. + """ + workspace = _current_workspace() + if len(experiment_names) < 2: + show_error("Please provide at least two experiment names/versions.") + return + + experiments_to_compare = [] + resolved_experiments = [] + + # check if the experiment strings are experiments or experiment_versions and load the respective experiments/versions + for experiment_name in experiment_names: + if experiment_name not in workspace.experiments: + potential_version = experiment_name.split("_")[-1] + try: + version = int(potential_version) + except ValueError: + show_error(f"Experiment {experiment_name} not found in the workspace") + return + experiment_name = "_".join(experiment_name.split("_")[:-1]) + if version not in _load_experiment_versions(workspace, experiment_name): + show_error( + f"Version {version} not found for experiment {experiment_name}" + ) + return + experiment_result = deserialize_result( + _load_experiment_result(workspace, experiment_name, version) + ) + experiment = workspace.experiments[experiment_name] + experiment_version = f"{experiment_name}_{version}" + + else: + experiment = workspace.experiments[experiment_name] + experiment_result = deserialize_result( + _load_latest_experiment_result(workspace, experiment_name) + ) + experiment_version = f"{experiment_name}_latest" + + experiments_to_compare.append( + { + "result": experiment_result, + "experiment": experiment, + "version": experiment_version, + } + ) + resolved_experiments.append(experiment_version) + find_equivalents = {} + + # find the differences between all experiments and the differences for each of the other edges + + for experiment in experiments_to_compare: + for edge in experiment["result"].edges: + u, v = sorted([edge.u.name, edge.v.name]) + if u not in find_equivalents: + find_equivalents[u] = {} + if v not in find_equivalents[u]: + find_equivalents[u][v] = {} + + if experiment["version"] in find_equivalents[u][v]: + if find_equivalents[u][v][experiment["version"]] != edge: + typer.echo( + f"Experiment {experiment['experiment']} has an inconsistent edge {u} -> {v}" + ) + else: + find_equivalents[u][v][experiment["version"]] = edge + + experiment_table = [] + + for node_u, s in find_equivalents.items(): + for node_v, result in s.items(): + experiment_table_row = {exp: None for exp in resolved_experiments} + for experiment, edge in result.items(): + experiment_table_row[experiment] = edge + + experiment_table.append(experiment_table_row) + + table = Table() + table.add_column("Edge") + + for experiment in resolved_experiments: + table.add_column(experiment, justify="center") + + for row in experiment_table: + elements = [key for key in row.values()] + first_element = None + for e in elements: + if e is not None: + first_element = e + break + + all_elements_same = all([e == first_element for e in elements]) + if only_differences and all_elements_same: + continue + table.add_row( + *[f"{first_element.u.name} - {first_element.v.name}"] + + [ + f"{row[experiment].edge_type.STR_REPRESENTATION}" + if row[experiment] + else "" + for experiment in resolved_experiments + ], + style="green" if all_elements_same else "red", + ) + + console = Console() + console.print(table) + + +workspace_app.add_typer(pipeline_app, name="pipeline", help="Manage pipelines") +workspace_app.add_typer(experiment_app, name="experiment", help="Manage experiments") +workspace_app.add_typer(dataloader_app, name="dataloader", help="Manage data loaders") diff --git a/causy/workspaces/models.py b/causy/workspaces/models.py new file mode 100644 index 0000000..c817e55 --- /dev/null +++ b/causy/workspaces/models.py @@ -0,0 +1,27 @@ +from typing import Optional, Dict, Any + +from pydantic import BaseModel + +from causy.data_loader import DataLoaderReference +from causy.models import AlgorithmReference, Algorithm + + +class Experiment(BaseModel): + """ + represents a single experiment + :param name: name of the experiment + :param pipeline: the name of the pipeline used + """ + + pipeline: str + data_loader: str + variables: Optional[Dict[str, Any]] = None + + +class Workspace(BaseModel): + name: str + author: Optional[str] + + pipelines: Optional[Dict[str, Algorithm | AlgorithmReference]] + data_loaders: Optional[Dict[str, DataLoaderReference]] + experiments: Optional[Dict[str, Experiment]] diff --git a/causy/workspaces/serializer_models.py b/causy/workspaces/serializer_models.py deleted file mode 100644 index e7b8e36..0000000 --- a/causy/workspaces/serializer_models.py +++ /dev/null @@ -1,42 +0,0 @@ -import enum -from typing import Optional, Dict, List, Any - -from pydantic import BaseModel - -from causy.interfaces import CausyAlgorithm, CausyAlgorithmReference - - -class DataLoaderType(enum.StrEnum): - DYNAMIC = "dynamic" # python function which yields data - JSON = "json" - JSONL = "jsonl" - - -class DataLoader(BaseModel): - """represents a single data loader - :param type: the type of dataloader - :param path: path to either the python class which can be executed to load the data or the data source file itself - """ - - type: DataLoaderType - path: str - - -class Experiment(BaseModel): - """ - represents a single experiment - :param name: name of the experiment - :param pipeline: the name of the pipeline used - """ - - pipeline: str - data_loader: str - - -class Workspace(BaseModel): - name: str - author: Optional[str] - - pipelines: Optional[Dict[str, CausyAlgorithm | CausyAlgorithmReference]] - data_loaders: Optional[Dict[str, DataLoader]] - experiments: Optional[Dict[str, Experiment]] diff --git a/causy/workspaces/templates/dataloader.py.tpl b/causy/workspaces/templates/dataloader.py.tpl new file mode 100644 index 0000000..f7bcbc8 --- /dev/null +++ b/causy/workspaces/templates/dataloader.py.tpl @@ -0,0 +1,30 @@ +from causy.data_loader import AbstractDataLoader + +class DataLoader(AbstractDataLoader): + """ + A causy dataloader. + + """ + + def __init__(self): + """ + Initialize your dataloader here + """ + pass + + def load(self): + """ + Load the data. This function should yield the data row by row. You have to yield it as a dictionary of column names and values (floats). + :return: + """ + yield { + "column1": 1.0, + "column2": 2.0 + } + + def hash(self): + """ + Returns a hash of the data. This is useful to check if the data has changed. + :return: + """ + return None diff --git a/causy/workspaces/templates/pipeline.py.tpl b/causy/workspaces/templates/pipeline.py.tpl new file mode 100644 index 0000000..f89369a --- /dev/null +++ b/causy/workspaces/templates/pipeline.py.tpl @@ -0,0 +1,12 @@ +from causy.graph_model import graph_model_factory +from causy.interfaces import CausyAlgorithm + +PIPELINE = graph_model_factory( + CausyAlgorithm( + pipeline_steps=[ + ], + edge_types=[], + extensions=[], + name="{{pipeline_name}}", + ) +) diff --git a/poetry.lock b/poetry.lock index ffde604..615e91e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -86,6 +86,17 @@ files = [ {file = "catalogue-2.0.10.tar.gz", hash = "sha256:4f56daa940913d3f09d589c191c74e5a6d51762b3a9e37dd53b7437afd6cda15"}, ] +[[package]] +name = "certifi" +version = "2024.2.2" +description = "Python package for providing Mozilla's CA Bundle." +optional = false +python-versions = ">=3.6" +files = [ + {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, + {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, +] + [[package]] name = "cfgv" version = "3.4.0" @@ -278,6 +289,51 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + +[[package]] +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "identify" version = "2.5.36" @@ -338,13 +394,13 @@ files = [ [[package]] name = "jinja2" -version = "3.1.3" +version = "3.1.4" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" files = [ - {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"}, - {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"}, + {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, + {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, ] [package.dependencies] @@ -353,6 +409,21 @@ MarkupSafe = ">=2.0" [package.extras] i18n = ["Babel (>=2.7)"] +[[package]] +name = "markdown" +version = "3.6" +description = "Python implementation of John Gruber's Markdown." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Markdown-3.6-py3-none-any.whl", hash = "sha256:48f276f4d8cfb8ce6527c8f79e2ee29708508bf4d40aa410fbc3b4ee832c850f"}, + {file = "Markdown-3.6.tar.gz", hash = "sha256:ed4f41f6daecbeeb96e576ce414c41d2d876daa9a16cb35fa8ed8c2ddfad0224"}, +] + +[package.extras] +docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] +testing = ["coverage", "pyyaml"] + [[package]] name = "markdown-it-py" version = "3.0.0" @@ -843,6 +914,20 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" +[[package]] +name = "prompt-toolkit" +version = "3.0.36" +description = "Library for building powerful interactive command lines in Python" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "prompt_toolkit-3.0.36-py3-none-any.whl", hash = "sha256:aa64ad242a462c5ff0363a7b9cfe696c20d55d9fc60c11fd8e632d064804d305"}, + {file = "prompt_toolkit-3.0.36.tar.gz", hash = "sha256:3e163f254bef5a03b146397d7c1963bd3e2812f0964bb9a24e6ec761fd28db63"}, +] + +[package.dependencies] +wcwidth = "*" + [[package]] name = "pydantic" version = "2.7.1" @@ -1124,6 +1209,20 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "questionary" +version = "2.0.1" +description = "Python library to build pretty command line user prompts ⭐️" +optional = false +python-versions = ">=3.8" +files = [ + {file = "questionary-2.0.1-py3-none-any.whl", hash = "sha256:8ab9a01d0b91b68444dff7f6652c1e754105533f083cbe27597c8110ecc230a2"}, + {file = "questionary-2.0.1.tar.gz", hash = "sha256:bcce898bf3dbb446ff62830c86c5c6fb9a22a54146f0f5597d3da43b10d8fc8b"}, +] + +[package.dependencies] +prompt_toolkit = ">=2.0,<=3.0.36" + [[package]] name = "rich" version = "13.7.1" @@ -1532,6 +1631,17 @@ platformdirs = ">=3.9.1,<5" docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +[[package]] +name = "wcwidth" +version = "0.2.13" +description = "Measures the displayed width of unicode strings in a terminal" +optional = false +python-versions = "*" +files = [ + {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, + {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, +] + [[package]] name = "zipp" version = "3.18.1" @@ -1550,4 +1660,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.13" -content-hash = "1c95f85f333df977719c8650fbbbe41ab31bf47bf9c6e1b519be171766fa0b1e" +content-hash = "f5020ad339bd5bd37fb07e2265e5a798f3f258cbc8c74e03e69d2be9f5d7ff71" diff --git a/pyproject.toml b/pyproject.toml index 8814e6b..df40698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,10 @@ uvicorn = "^0.27.0" srsly = "^2.4.8" pydantic-yaml = "^1.2.1" click = "^8.1.7" +questionary = "^2.0.1" +jinja2 = "^3.1.4" +markdown = "^3.6" +httpx = "^0.27.0" [tool.poetry.group.dev.dependencies] diff --git a/tests/test_generators.py b/tests/test_generators.py index 9377675..3080e88 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -1,13 +1,12 @@ -from causy.algorithms.pc import PC_EDGE_TYPES +from causy.causal_discovery.constraint.algorithms.pc import PC_EDGE_TYPES from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations from causy.generators import PairsWithNeighboursGenerator from causy.graph_model import graph_model_factory -from causy.graph_utils import retrieve_edges -from causy.independence_tests.common import ( +from causy.causal_discovery.constraint.independence_tests.common import ( CorrelationCoefficientTest, PartialCorrelationTest, ) -from causy.interfaces import CausyAlgorithm, ComparisonSettings +from causy.models import ComparisonSettings, Algorithm from causy.sample_generator import IIDSampleGenerator, SampleEdge, NodeReference from tests.utils import CausyTestCase @@ -28,7 +27,7 @@ def test_pairs_with_neighbours_generator(self): ) algo = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(), CorrelationCoefficientTest(threshold=0.005), @@ -44,7 +43,6 @@ def test_pairs_with_neighbours_generator(self): tst.create_graph_from_data(test_data) tst.create_all_possible_edges() tst.execute_pipeline_steps() - print(retrieve_edges(tst.graph)) result = PairsWithNeighboursGenerator( comparison_settings=ComparisonSettings(min=3, max=4) ).generate(tst.graph.graph, tst) @@ -52,4 +50,3 @@ def test_pairs_with_neighbours_generator(self): for i in result: all_results.extend(i) - print(all_results) diff --git a/tests/test_graph_model.py b/tests/test_graph_model.py index 7d703f0..7ba9b8e 100644 --- a/tests/test_graph_model.py +++ b/tests/test_graph_model.py @@ -1,6 +1,6 @@ import torch -from causy.algorithms import PC +from causy.causal_discovery.constraint.algorithms import PC from causy.sample_generator import IIDSampleGenerator, SampleEdge, NodeReference from tests.utils import CausyTestCase diff --git a/tests/test_independence_tests.py b/tests/test_independence_tests.py index a532d99..c78d7f6 100644 --- a/tests/test_independence_tests.py +++ b/tests/test_independence_tests.py @@ -1,13 +1,7 @@ -import random - -import torch - from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations from causy.graph_model import graph_model_factory -from causy.graph_utils import retrieve_edges -from causy.interfaces import CausyAlgorithm -from causy.math_utils import sum_lists -from causy.independence_tests.common import ( +from causy.models import Algorithm +from causy.causal_discovery.constraint.independence_tests.common import ( CorrelationCoefficientTest, PartialCorrelationTest, ExtendedPartialCorrelationTestMatrix, @@ -37,7 +31,7 @@ def test_correlation_coefficient_test(self): CorrelationCoefficientTest(threshold=0.1), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -65,7 +59,7 @@ def test_correlation_coefficient_test_2(self): CorrelationCoefficientTest(threshold=0.1), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -95,7 +89,7 @@ def test_correlation_coefficient_test_collider(self): CorrelationCoefficientTest(threshold=0.1), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -124,7 +118,7 @@ def test_partial_correlation_test(self): PartialCorrelationTest(threshold=0.01), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -154,7 +148,7 @@ def test_partial_correlation_test_2(self): PartialCorrelationTest(threshold=0.01), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -186,7 +180,7 @@ def test_partial_correlation_test_collider(self): PartialCorrelationTest(threshold=0.01), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -219,7 +213,7 @@ def test_extended_partial_correlation_test_matrix(self): ExtendedPartialCorrelationTestMatrix(threshold=0.01), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -253,7 +247,7 @@ def test_extended_partial_correlation_test_matrix2(self): ExtendedPartialCorrelationTestMatrix(threshold=0.01), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -288,7 +282,7 @@ def test_extended_partial_correlation_test_matrix3(self): ExtendedPartialCorrelationTestMatrix(threshold=0.01), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -323,7 +317,7 @@ def test_extended_partial_correlation_test_linear_regression2(self): data, graph = model.generate(1000000) tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -360,7 +354,7 @@ def test_extended_partial_correlation_test_linear_regression3(self): data, graph = model.generate(1000000) tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", @@ -393,7 +387,7 @@ def test_extended_partial_correlation_test_linear_regression(self): ExtendedPartialCorrelationTestLinearRegression(threshold=0.01), ] tst = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="CorrelationCoefficientTest", diff --git a/tests/test_orientation_rules.py b/tests/test_orientation_rules.py index 3d211dd..2074b0a 100644 --- a/tests/test_orientation_rules.py +++ b/tests/test_orientation_rules.py @@ -1,7 +1,7 @@ +from causy.causal_discovery.constraint.orientation_rules.fci import ColliderRuleFCI from causy.graph import GraphManager from causy.graph_model import graph_model_factory -from causy.interfaces import TestResult, TestResultAction, CausyAlgorithm -from causy.orientation_rules.fci import ColliderRuleFCI +from causy.models import TestResultAction, TestResult, Algorithm from tests.utils import CausyTestCase @@ -10,7 +10,7 @@ class OrientationTestCase(CausyTestCase): def test_collider_rule_fci(self): pipeline = [ColliderRuleFCI()] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="FCIColliderRule", diff --git a/tests/test_orientation_tests.py b/tests/test_orientation_tests.py index 886de09..10683ac 100644 --- a/tests/test_orientation_tests.py +++ b/tests/test_orientation_tests.py @@ -1,6 +1,6 @@ from causy.graph import GraphManager -from causy.interfaces import TestResult, TestResultAction, CausyAlgorithm -from causy.orientation_rules.pc import ( +from causy.models import TestResultAction, TestResult, Algorithm +from causy.causal_discovery.constraint.orientation_rules.pc import ( ColliderTest, NonColliderTest, FurtherOrientTripleTest, @@ -16,7 +16,7 @@ class OrientationRuleTestCase(CausyTestCase): def test_collider_test(self): pipeline = [ColliderTest()] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="TestCollider", @@ -45,7 +45,7 @@ def test_collider_test(self): def test_collider_test_with_nonempty_separation_set(self): pipeline = [ColliderTest()] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="TestCollider", @@ -73,7 +73,7 @@ def test_collider_test_with_nonempty_separation_set(self): def test_non_collider_test(self): pipeline = [NonColliderTest()] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="TestCollider", @@ -93,7 +93,7 @@ def test_non_collider_test(self): def test_further_orient_triple_test(self): pipeline = [FurtherOrientTripleTest()] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="TestCollider", @@ -117,7 +117,7 @@ def test_further_orient_triple_test(self): def test_orient_quadruple_test(self): pipeline = [OrientQuadrupleTest()] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="TestCollider", @@ -144,7 +144,7 @@ def test_orient_quadruple_test(self): def test_further_orient_quadruple_test(self): pipeline = [FurtherOrientQuadrupleTest()] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="TestCollider", diff --git a/tests/test_pc_e2e.py b/tests/test_pc_e2e.py index 99c9dde..35fc21e 100644 --- a/tests/test_pc_e2e.py +++ b/tests/test_pc_e2e.py @@ -1,21 +1,16 @@ -import csv -import torch - -from causy.algorithms import PC, ParallelPC -from causy.algorithms.pc import PCStable, PC_EDGE_TYPES +from causy.causal_discovery.constraint.algorithms.pc import PC_EDGE_TYPES from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations from causy.graph_model import graph_model_factory -from causy.graph_utils import retrieve_edges -from causy.independence_tests.common import ( +from causy.causal_discovery.constraint.independence_tests.common import ( CorrelationCoefficientTest, PartialCorrelationTest, ExtendedPartialCorrelationTestMatrix, ) -from causy.interfaces import CausyAlgorithm -from causy.orientation_rules.pc import ColliderTest +from causy.models import Algorithm +from causy.causal_discovery.constraint.orientation_rules.pc import ColliderTest from causy.sample_generator import IIDSampleGenerator, SampleEdge, NodeReference -from tests.utils import CausyTestCase, dump_fixture_graph, load_fixture_graph +from tests.utils import CausyTestCase, load_fixture_graph class PCTestTestCase(CausyTestCase): @@ -39,7 +34,7 @@ def test_pc_calculate_pearson_correlations(self): Test conditional independence of ordered pairs given pairs of other variables works. """ algo = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(), ], @@ -63,7 +58,7 @@ def test_pc_calculate_pearson_correlations(self): def test_pc_correlation_coefficient_test(self): algo = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(), CorrelationCoefficientTest(threshold=0.05), @@ -88,7 +83,7 @@ def test_pc_correlation_coefficient_test(self): def test_pc_partial_correlation_test(self): algo = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(), CorrelationCoefficientTest(threshold=0.05), @@ -114,7 +109,7 @@ def test_pc_partial_correlation_test(self): def test_pc_extended_partial_correlation_test_matrix(self): algo = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(), CorrelationCoefficientTest(threshold=0.05), @@ -141,7 +136,7 @@ def test_pc_extended_partial_correlation_test_matrix(self): def test_pc_collider_test(self): algo = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=[ CalculatePearsonCorrelations(), CorrelationCoefficientTest(threshold=0.05), diff --git a/tests/test_pc_graph.py b/tests/test_pc_graph.py index 3e3801d..2e6a06d 100644 --- a/tests/test_pc_graph.py +++ b/tests/test_pc_graph.py @@ -1,6 +1,6 @@ import csv -from causy.algorithms import PC +from causy.causal_discovery.constraint.algorithms import PC from causy.graph_utils import retrieve_edges from causy.sample_generator import IIDSampleGenerator, SampleEdge, NodeReference diff --git a/tests/test_sample_generator.py b/tests/test_sample_generator.py index 8f57b50..796440f 100644 --- a/tests/test_sample_generator.py +++ b/tests/test_sample_generator.py @@ -1,9 +1,7 @@ import torch -import random from causy.sample_generator import ( TimeseriesSampleGenerator, - random_normal, SampleEdge, IIDSampleGenerator, TimeAwareNodeReference, diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 0eeb47c..9843df6 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,8 +1,10 @@ -from causy.algorithms.pc import PC_DEFAULT_THRESHOLD +from causy.causal_discovery.constraint.algorithms.pc import PC_DEFAULT_THRESHOLD from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations from causy.graph_model import graph_model_factory -from causy.independence_tests.common import CorrelationCoefficientTest -from causy.interfaces import CausyAlgorithm +from causy.causal_discovery.constraint.independence_tests.common import ( + CorrelationCoefficientTest, +) +from causy.models import Algorithm from causy.serialization import serialize_algorithm, load_algorithm_from_specification from tests.utils import CausyTestCase @@ -15,7 +17,7 @@ def test_serialize(self): CorrelationCoefficientTest(threshold=PC_DEFAULT_THRESHOLD), ] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="test_serialize", @@ -34,7 +36,7 @@ def test_serialize_and_load(self): CorrelationCoefficientTest(threshold=PC_DEFAULT_THRESHOLD), ] model = graph_model_factory( - CausyAlgorithm( + Algorithm( pipeline_steps=pipeline, edge_types=[], name="test_serialize", diff --git a/tests/test_ui_api.py b/tests/test_ui_api.py new file mode 100644 index 0000000..2a00d4f --- /dev/null +++ b/tests/test_ui_api.py @@ -0,0 +1,263 @@ +from fastapi.testclient import TestClient + +from causy.data_loader import DataLoaderReference, DataLoaderType +from causy.models import AlgorithmReference, AlgorithmReferenceType +from causy.ui.server import _create_ui_app, _set_workspace, _set_model +from causy.workspaces.models import Workspace, Experiment +from tests.utils import CausyTestCase + + +class UIApiTestCase(CausyTestCase): + def test_status_endpoint(self): + _set_workspace(None) + _set_model(None) + app = _create_ui_app(with_static=False) + client = TestClient(app) + response = client.get("/api/v1/status") + result = response.json() + self.assertEqual(result["status"], "ok") + self.assertEqual(response.status_code, 200) + + self.assertEqual(result["causy_version"], "0.1.0") + self.assertEqual(result["model_loaded"], False) + self.assertEqual(result["workspace_loaded"], False) + + _set_workspace( + Workspace( + name="test_workspace", + author="test_author", + pipelines=None, + experiments=None, + data_loaders=None, + ) + ) + + response = client.get("/api/v1/status") + result = response.json() + self.assertEqual(result["status"], "ok") + self.assertEqual(response.status_code, 200) + + self.assertEqual(result["model_loaded"], False) + self.assertEqual(result["workspace_loaded"], True) + + def test_workspace(self): + app = _create_ui_app(with_static=False) + client = TestClient(app) + _set_workspace(None) + response = client.get("/api/v1/workspace") + self.assertEqual(response.status_code, 404) + + _set_workspace( + Workspace( + name="test_workspace", + author="test_author", + pipelines={ + "PC": AlgorithmReference( + type=AlgorithmReferenceType.NAME, reference="PC" + ) + }, + experiments={ + "test_experiment": Experiment( + pipeline="PC", data_loader="data_loader", variables=None + ) + }, + data_loaders={ + "test_data_loader": DataLoaderReference( + type=DataLoaderType.JSON, reference="data_loader" + ) + }, + ) + ) + + response = client.get("/api/v1/workspace") + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertEqual(result["name"], "test_workspace") + self.assertEqual(result["author"], "test_author") + self.assertEqual( + result["pipelines"], {"PC": {"type": "name", "reference": "PC"}} + ) + + self.assertEqual( + result["data_loaders"], + { + "test_data_loader": { + "type": "json", + "reference": "data_loader", + "options": None, + } + }, + ) + + self.assertEqual( + result["experiments"], + { + "test_experiment": { + "pipeline": "PC", + "data_loader": "data_loader", + "variables": None, + } + }, + ) + + def test_get_model(self): + app = _create_ui_app(with_static=False) + client = TestClient(app) + result = { + "algorithm": {"type": "name", "reference": "PC"}, + "nodes": {}, + "edges": [], + "action_history": [], + "data_loader": { + "type": "json", + "reference": "data_loader", + "options": None, + }, + "variables": {"test": "test"}, + "result": {"test": "test"}, + } + _set_model(result) + response = client.get("/api/v1/model") + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertEqual(result["algorithm"], {"type": "name", "reference": "PC"}) + + def test_get_algorithm(self): + app = _create_ui_app(with_static=False) + client = TestClient(app) + + response = client.get("/api/v1/algorithm/name/PC") + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertEqual(result["name"], "PC") + + def test_get_algorithm_invalid(self): + app = _create_ui_app(with_static=False) + client = TestClient(app) + + response = client.get("/api/v1/algorithm/name/INVALID") + self.assertEqual(response.status_code, 400) + + response = client.get("/api/v1/algorithm/python_module/INVALID") + self.assertEqual(response.status_code, 400) + + response = client.get("/api/v1/algorithm/name/..PC") + self.assertEqual(response.status_code, 400) + + def test_get_experiments(self): + app = _create_ui_app(with_static=False) + client = TestClient(app) + _set_workspace(None) + + response = client.get("/api/v1/experiments") + self.assertEqual(response.status_code, 404) + + _set_workspace( + Workspace( + name="test_workspace", + author="test_author", + pipelines={ + "PC": AlgorithmReference( + type=AlgorithmReferenceType.NAME, reference="PC" + ) + }, + experiments={ + "test_experiment": Experiment( + pipeline="PC", data_loader="data_loader", variables=None + ) + }, + data_loaders={ + "test_data_loader": DataLoaderReference( + type=DataLoaderType.JSON, reference="data_loader" + ) + }, + ) + ) + + response = client.get("/api/v1/experiments") + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertEqual(len(result), 1) + self.assertEqual(result[0]["name"], "test_experiment") + self.assertEqual(result[0]["pipeline"], "PC") + self.assertEqual(result[0]["data_loader"], "data_loader") + self.assertEqual(result[0]["variables"], None) + self.assertEqual(len(result[0]["versions"]), 0) + + def test_get_latest_experiment(self): + app = _create_ui_app(with_static=False) + client = TestClient(app) + _set_workspace(None) + + response = client.get("/api/v1/experiments/test_experiment/latest") + self.assertEqual(response.status_code, 404) + + _set_workspace( + Workspace( + name="test_workspace", + author="test_author", + pipelines={ + "PC": AlgorithmReference( + type=AlgorithmReferenceType.NAME, reference="PC" + ) + }, + experiments={ + "test_experiment": Experiment( + pipeline="PC", data_loader="data_loader", variables=None + ) + }, + data_loaders={ + "test_data_loader": DataLoaderReference( + type=DataLoaderType.JSON, reference="data_loader" + ) + }, + ) + ) + + response = client.get("/api/v1/experiments/test_experiment/latest") + self.assertEqual(response.status_code, 400) + result = response.json() + self.assertEqual( + result["detail"], "Experiment test_experiment not found in the file system" + ) + + # TODO: add test for when the experiment is found (hmw to mock the file system?) + + def test_get_experiment(self): + app = _create_ui_app(with_static=False) + client = TestClient(app) + _set_workspace(None) + + response = client.get("/api/v1/experiments/test_experiment/1") + self.assertEqual(response.status_code, 404) + + _set_workspace( + Workspace( + name="test_workspace", + author="test_author", + pipelines={ + "PC": AlgorithmReference( + type=AlgorithmReferenceType.NAME, reference="PC" + ) + }, + experiments={ + "test_experiment": Experiment( + pipeline="PC", data_loader="data_loader", variables=None + ) + }, + data_loaders={ + "test_data_loader": DataLoaderReference( + type=DataLoaderType.JSON, reference="data_loader" + ) + }, + ) + ) + + response = client.get("/api/v1/experiments/test_experiment/1") + self.assertEqual(response.status_code, 400) + result = response.json() + self.assertEqual( + result["detail"], "Version 1 not found for experiment test_experiment" + ) + + # TODO: add test for when the experiment is found (hmw to mock the file system?) diff --git a/tests/test_variables.py b/tests/test_variables.py new file mode 100644 index 0000000..501f36b --- /dev/null +++ b/tests/test_variables.py @@ -0,0 +1,344 @@ +import copy +from unittest import skip + +from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations +from causy.common_pipeline_steps.placeholder import PlaceholderTest +from causy.graph_model import graph_model_factory +from causy.graph_utils import ( + serialize_module_name, + load_pipeline_artefact_by_definition, + load_pipeline_steps_by_definition, +) +from causy.models import Algorithm +from causy.sample_generator import IIDSampleGenerator, SampleEdge, NodeReference +from causy.variables import ( + StringVariable, + FloatVariable, + IntegerVariable, + BoolVariable, + validate_variable_values, + VariableReference, + resolve_variables, + resolve_variables_to_algorithm_for_pipeline_steps, +) + +from tests.utils import CausyTestCase + + +class VariablesTestCase(CausyTestCase): + def test_validate_str_variable(self): + variable = StringVariable(name="threshold", value="default") + self.assertEqual(variable.is_valid_value("test"), True) + self.assertEqual(variable.is_valid_value(1), False) + self.assertEqual(variable.is_valid_value(1.0), False) + self.assertEqual(variable.is_valid_value(True), False) + + with self.assertRaises(ValueError): + variable.validate_value(1) + + with self.assertRaises(ValueError): + variable.validate_value(1.0) + + with self.assertRaises(ValueError): + variable.validate_value(True) + + with self.assertRaises(ValueError): + variable.validate_value(None) + + with self.assertRaises(ValueError): + variable.validate_value([]) + + variable.validate_value("test") + variable.validate_value("default") + + self.assertEqual(variable.is_valid(), True) + + variable = StringVariable( + name="threshold", value="default", choices=["test", "default"] + ) + self.assertEqual(variable.is_valid_value("test"), True) + self.assertEqual(variable.is_valid_value("default"), True) + self.assertEqual(variable.is_valid_value("test1"), False) + with self.assertRaises(ValueError): + variable.validate_value("test1") + + def test_validate_float_variable(self): + variable = FloatVariable(name="threshold", value=0.5) + self.assertEqual(variable.is_valid_value(1.0), True) + self.assertEqual(variable.is_valid_value(1), False) + self.assertEqual(variable.is_valid_value(0.5), True) + self.assertEqual(variable.is_valid_value(True), False) + + with self.assertRaises(ValueError): + variable.validate_value("test") + + with self.assertRaises(ValueError): + variable.validate_value(True) + + with self.assertRaises(ValueError): + variable.validate_value(None) + + with self.assertRaises(ValueError): + variable.validate_value([]) + + with self.assertRaises(ValueError): + variable.validate_value(1) + + self.assertEqual(variable.is_valid(), True) + + variable.validate_value(0.5) + variable.validate_value(1.0) + + variable = FloatVariable(name="threshold", value=0.5, choices=[0.5, 1.0]) + self.assertEqual(variable.is_valid_value(0.5), True) + self.assertEqual(variable.is_valid_value(1.0), True) + self.assertEqual(variable.is_valid_value(0.6), False) + with self.assertRaises(ValueError): + variable.validate_value(0.6) + + def test_validate_int_variable(self): + variable = IntegerVariable(name="threshold", value=1) + self.assertEqual(variable.is_valid_value(1), True) + self.assertEqual(variable.is_valid_value(1.0), False) + self.assertEqual(variable.is_valid_value(True), False) + + with self.assertRaises(ValueError): + variable.validate_value("test") + + with self.assertRaises(ValueError): + variable.validate_value(1.0) + + with self.assertRaises(ValueError): + variable.validate_value(True) + + with self.assertRaises(ValueError): + variable.validate_value(None) + + with self.assertRaises(ValueError): + variable.validate_value([]) + + variable.validate_value(1) + variable.validate_value(0) + variable.validate_value(100) + variable.validate_value(-100) + self.assertEqual(variable.is_valid(), True) + + variable = IntegerVariable(name="threshold", value=1, choices=[0, 1, 100]) + self.assertEqual(variable.is_valid_value(0), True) + self.assertEqual(variable.is_valid_value(1), True) + self.assertEqual(variable.is_valid_value(100), True) + self.assertEqual(variable.is_valid_value(101), False) + with self.assertRaises(ValueError): + variable.validate_value(101) + + def test_validate_bool_variable(self): + variable = BoolVariable(name="threshold", value=True) + self.assertEqual(variable.is_valid_value(True), True) + self.assertEqual(variable.is_valid_value(False), True) + self.assertEqual(variable.is_valid_value(1), False) + self.assertEqual(variable.is_valid_value(1.0), False) + + with self.assertRaises(ValueError): + variable.validate_value("test") + + with self.assertRaises(ValueError): + variable.validate_value(1.0) + + with self.assertRaises(ValueError): + variable.validate_value(None) + + with self.assertRaises(ValueError): + variable.validate_value([]) + + variable.validate_value(True) + variable.validate_value(False) + self.assertEqual(variable.is_valid(), True) + + variable = BoolVariable(name="threshold", value=True, choices=[True]) + self.assertEqual(variable.is_valid_value(True), True) + self.assertEqual(variable.is_valid_value(1), False) + self.assertEqual(variable.is_valid_value(0), False) + self.assertEqual(variable.is_valid_value("True"), False) + self.assertEqual(variable.is_valid_value("False"), False) + + with self.assertRaises(ValueError): + variable.validate_value("True") + + with self.assertRaises(ValueError): + variable.validate_value(False) + + def test_validate_variable_values(self): + algorithm = graph_model_factory( + Algorithm( + pipeline_steps=[], + edge_types=[], + extensions=[], + name="Test variable validation", + variables=[ + StringVariable(name="a_string", value="default"), + IntegerVariable(name="an_int", value=1), + BoolVariable(name="a_bool", value=True), + FloatVariable(name="a_float", value=0.1), + ], + ) + )() + + with self.assertRaises(ValueError): + validate_variable_values(algorithm.algorithm, {"a_string": 1}) + + with self.assertRaises(ValueError): + validate_variable_values(algorithm.algorithm, {"an_int": "test"}) + + with self.assertRaises(ValueError): + validate_variable_values(algorithm.algorithm, {"a_bool": 1}) + + with self.assertRaises(ValueError): + validate_variable_values(algorithm.algorithm, {"a_float": "test"}) + + with self.assertRaises(ValueError): + validate_variable_values(algorithm.algorithm, {"a_float": True}) + + validate_variable_values(algorithm.algorithm, {"a_string": "test"}) + validate_variable_values(algorithm.algorithm, {"an_int": 2}) + validate_variable_values(algorithm.algorithm, {"a_bool": False}) + validate_variable_values(algorithm.algorithm, {"a_float": 0.2}) + + with self.assertRaises(ValueError): + validate_variable_values(algorithm.algorithm, {"another_var": 0.21}) + + def test_resolve_variables(self): + algorithm = graph_model_factory( + Algorithm( + pipeline_steps=[ + PlaceholderTest( + placeholder_str=VariableReference(name="a_string"), + placeholder_int=VariableReference(name="an_int"), + placeholder_float=VariableReference(name="a_float"), + placeholder_bool=VariableReference(name="a_bool"), + ) + ], + edge_types=[], + extensions=[], + name="Test variable resolution", + variables=[ + StringVariable(name="a_string", value="default"), + IntegerVariable(name="an_int", value=1), + BoolVariable(name="a_bool", value=True), + FloatVariable(name="a_float", value=0.1), + ], + ) + )() + + resolved_variables = resolve_variables( + algorithm._original_algorithm.variables, + {"a_string": "test", "an_int": 2, "a_bool": False, "a_float": 0.2}, + ) + + self.assertEqual( + resolved_variables, + {"a_string": "test", "an_int": 2, "a_bool": False, "a_float": 0.2}, + ) + + resolved_variables = resolve_variables( + algorithm._original_algorithm.variables, + { + "a_string": "test", + }, + ) + + self.assertEqual( + resolved_variables, + {"a_string": "test", "an_int": 1, "a_bool": True, "a_float": 0.1}, + ) + + def test_resolve_variables_to_algorithm(self): + algorithm = graph_model_factory( + Algorithm( + pipeline_steps=[ + PlaceholderTest( + placeholder_str=VariableReference(name="a_string"), + placeholder_int=VariableReference(name="an_int"), + placeholder_float=VariableReference(name="a_float"), + placeholder_bool=VariableReference(name="a_bool"), + ) + ], + edge_types=[], + extensions=[], + name="Test variable resolution", + variables=[ + StringVariable(name="a_string", value="default"), + IntegerVariable(name="an_int", value=1), + BoolVariable(name="a_bool", value=True), + FloatVariable(name="a_float", value=0.1), + ], + ) + ) + + resolved_variables = resolve_variables_to_algorithm_for_pipeline_steps( + algorithm._original_algorithm.pipeline_steps, + {"a_string": "test", "an_int": 2, "a_bool": False, "a_float": 0.2}, + ) + + algorithm = graph_model_factory( + Algorithm( + pipeline_steps=[ + PlaceholderTest( + placeholder_str=VariableReference(name="a_string"), + placeholder_int=VariableReference(name="an_int"), + placeholder_float=VariableReference(name="a_float"), + placeholder_bool=VariableReference(name="a_bool"), + ) + ], + edge_types=[], + extensions=[], + name="Test variable resolution", + variables=[ + StringVariable(name="a_string", value="default"), + IntegerVariable(name="an_int", value=1), + BoolVariable(name="a_bool", value=True), + FloatVariable(name="a_float", value=0.1), + ], + ) + ) + with self.assertRaises(ValueError): + resolved_variables = resolve_variables_to_algorithm_for_pipeline_steps( + algorithm._original_algorithm.pipeline_steps, + { + "a_string": "test", + }, + ) + + algorithm = Algorithm( + pipeline_steps=[ + PlaceholderTest( + placeholder_str=VariableReference(name="not_defined"), + ) + ], + edge_types=[], + extensions=[], + name="Test variable resolution", + variables=[], + ) + + resolved_variables = resolve_variables_to_algorithm_for_pipeline_steps( + algorithm.pipeline_steps, + { + "not_defined": "test", + }, + ) + + algorithm = Algorithm( + pipeline_steps=[ + PlaceholderTest( + placeholder_str=VariableReference(name="not_defined"), + ) + ], + edge_types=[], + extensions=[], + name="Test variable resolution", + variables=[], + ) + with self.assertRaises(ValueError): + resolved_variables = resolve_variables_to_algorithm_for_pipeline_steps( + algorithm.pipeline_steps, {} + )