diff --git a/auto_causality/optimiser.py b/auto_causality/optimiser.py index e56f4985..5f0fbb07 100644 --- a/auto_causality/optimiser.py +++ b/auto_causality/optimiser.py @@ -82,7 +82,6 @@ class AutoCausality: def __init__( self, - data_df=None, metric="energy_distance", metrics_to_report=None, time_budget=None, @@ -107,7 +106,6 @@ def __init__( """constructor. Args: - data_df (pandas.DataFrame): dataset to perform causal inference on metric (str): metric to optimise. Defaults to "erupt" for CATE, "energy_distance" for IV metrics_to_report (list). additional metrics to compute and report. @@ -185,7 +183,6 @@ def __init__( self._best_estimators = defaultdict(lambda: (float("-inf"), None)) self.original_estimator_list = estimator_list - self.data_df = data_df or pd.DataFrame() self.causal_model = None self.identified_estimand = None self.problem = None @@ -257,13 +254,14 @@ def fit( estimator_list: Optional[Union[str, List[str]]] = None, resume: Optional[bool] = False, time_budget: Optional[int] = None, + store_data: Optional[bool] = True, ): """Performs AutoML on list of causal inference estimators - If estimator has a search space specified in its parameters, HPO is performed on the whole model. - Otherwise, only its component models are optimised Args: - data_df (pandas.DataFrame): dataset for causal inference + data (pandas.DataFrame): dataset for causal inference treatment (str): name of treatment variable outcome (str): name of outcome variable common_causes (List[str]): list of names of common causes @@ -273,6 +271,7 @@ def fit( estimator_list (Optional[Union[str, List[str]]]): subset of estimators to consider resume (Optional[bool]): set to True to continue previous fit time_budget (Optional[int]): change new time budget allocated to fit, useful for warm starts. + store_data (Optional[bool]): Set true if keep train_df, test_df after fit """ if not isinstance(data, CausalityDataset): @@ -456,6 +455,11 @@ def fit( ) self.update_summary_scores() + + if not store_data: + delattr(self, 'train_df') + delattr(self, 'test_df') + delattr(self, 'data') def update_summary_scores(self): self.scores = Scorer.best_score_by_estimator(self.results.results, self.metric) diff --git a/tests/autocausality/test_drop_data_after_fit.py b/tests/autocausality/test_drop_data_after_fit.py new file mode 100644 index 00000000..4ca349a7 --- /dev/null +++ b/tests/autocausality/test_drop_data_after_fit.py @@ -0,0 +1,80 @@ +import pytest +import warnings + +from auto_causality import AutoCausality +from auto_causality.datasets import synth_ihdp, linear_multi_dataset +from auto_causality.params import SimpleParamService + +warnings.filterwarnings("ignore") # suppress sklearn deprecation warnings for now.. + + +class TestDropDataAfterFit(object): + def test_fit_and_drop_data(self): + """tests if CATE model can be instantiated and fit to data""" + + from auto_causality.shap import shap_values # noqa F401 + + data = synth_ihdp() + data.preprocess_dataset() + + cfg = SimpleParamService( + propensity_model=None, + outcome_model=None, + n_jobs=-1, + include_experimental=False, + multivalue=False, + ) + estimator_list = cfg.estimator_names_from_patterns("backdoor", "all", 1) + # outcome = targets[0] + auto_causality = AutoCausality( + num_samples=len(estimator_list), + components_time_budget=5, + estimator_list=estimator_list, # "all", # + use_ray=False, + verbose=3, + components_verbose=2, + resources_per_trial={"cpu": 0.5}, + ) + + auto_causality.fit(data, store_data=False) + auto_causality.effect(data.data) + auto_causality.score_dataset(data.data, "test") + + # now let's test Shapley values calculation + for est_name, scores in auto_causality.scores.items(): + # Dummy model doesn't support Shapley values + # Orthoforest shapley calc is VERY slow + if "Dummy" not in est_name and "Ortho" not in est_name: + + print("Calculating Shapley values for", est_name) + shap_values(scores["estimator"], data.data[:10]) + + print(f"Best estimator: {auto_causality.best_estimator}") + + def test_fit_and_keep_data(self): + data = linear_multi_dataset(10000) + cfg = SimpleParamService( + propensity_model=None, + outcome_model=None, + n_jobs=-1, + include_experimental=False, + multivalue=True, + ) + estimator_list = cfg.estimator_names_from_patterns( + "backdoor", "all", data_rows=len(data) + ) + + data.preprocess_dataset() + + ac = AutoCausality( + estimator_list="all", + num_samples=len(estimator_list), + components_time_budget=5, + ) + ac.fit(data) + # TODO add an effect() call and an effect_tt call + + +if __name__ == "__main__": + pytest.main([__file__]) + # TestEndToEnd().test_endtoend_iv()