From 36ead79d88e152ab057e3ac838b176caa080361a Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Thu, 31 Aug 2023 18:31:15 +0300 Subject: [PATCH 1/4] rename cli args for eval and experiment handover --- src/nhssynth/cli/common_arguments.py | 16 ++++++++++------ src/nhssynth/cli/module_arguments.py | 7 +------ src/nhssynth/cli/module_setup.py | 8 ++++---- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/nhssynth/cli/common_arguments.py b/src/nhssynth/cli/common_arguments.py index 46b62331..4e099be3 100644 --- a/src/nhssynth/cli/common_arguments.py +++ b/src/nhssynth/cli/common_arguments.py @@ -99,16 +99,20 @@ def get_parser(overrides: bool = False) -> argparse.ArgumentParser: "sdv_metadata", "filename of the metadata formatted for use with SDV", ), - "synthetic": suffix_parser_generator( - "synthetic", - "filename of the synthetic data", - ), "experiments": suffix_parser_generator( "experiments", "filename of the experiment bundle, i.e. the collection of all seeds, models, and synthetic datasets", ), - "evaluation_bundle": suffix_parser_generator( - "evaluation_bundle", + "synthetic_datasets": suffix_parser_generator( + "synthetic_datasets", + "filename of the collection of synthetic datasets generated by a given set of `experiments`", + ), + "model": suffix_parser_generator( + "model", + "name for each of the collection of models trained during a given set of `experiments`", + ), + "evaluations": suffix_parser_generator( + "evaluations", "filename of the (collection of) evaluation(s) for a given set of `experiments`", ), } diff --git a/src/nhssynth/cli/module_arguments.py b/src/nhssynth/cli/module_arguments.py index 4967eec4..d9b42de6 100644 --- a/src/nhssynth/cli/module_arguments.py +++ b/src/nhssynth/cli/module_arguments.py @@ -95,12 +95,6 @@ def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides: default=1, help="how many times to repeat the training process per model architecture ( is incremented each time)", ) - group.add_argument( - "--model-file", - type=str, - default="_model", - help="specify the filename of the model to be saved in `experiments//`, defaults to `_model.pt`", - ) group.add_argument( "--batch-size", type=int, @@ -118,6 +112,7 @@ def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides: group.add_argument( "--patience", type=int, + nargs="+", default=5, help="how many epochs the model is allowed to train for without improvement", ) diff --git a/src/nhssynth/cli/module_setup.py b/src/nhssynth/cli/module_setup.py index b766b0dd..374eee22 100644 --- a/src/nhssynth/cli/module_setup.py +++ b/src/nhssynth/cli/module_setup.py @@ -114,28 +114,28 @@ def add_config_args(parser: argparse.ArgumentParser) -> None: add_args=add_model_args, description="run the model architecture module, to train a synthetic data generator", help="train a model", - common_parsers=["transformed", "metatransformer", "synthetic", "experiments"], + common_parsers=["transformed", "metatransformer", "experiments", "synthetic_datasets", "model"], ), "evaluation": ModuleConfig( func=evaluation.run, add_args=add_evaluation_args, description="run the evaluation module, to evaluate an experiment", help="evaluate an experiment", - common_parsers=["sdv_metadata", "typed", "experiments", "evaluation_bundle"], + common_parsers=["sdv_metadata", "typed", "experiments", "synthetic_datasets", "evaluations"], ), "plotting": ModuleConfig( func=plotting.run, add_args=add_plotting_args, description="run the plotting module, to generate plots for a given model and / or evaluation", help="generate plots", - common_parsers=["typed", "evaluation_bundle"], + common_parsers=["typed", "evaluations"], ), "dashboard": ModuleConfig( func=dashboard.run, add_args=add_dashboard_args, description="run the dashboard module, to produce a streamlit dashboard", help="start up a streamlit dashboard to view the results of an evaluation", - common_parsers=["typed", "experiments", "evaluation_bundle"], + common_parsers=["typed", "experiments", "synthetic_datasets", "evaluations"], no_seed=True, ), "pipeline": ModuleConfig( From ab1d06d8d3970facd0f08700f3d2a2ee0c78359e Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Thu, 31 Aug 2023 18:31:30 +0300 Subject: [PATCH 2/4] fix bug in handover naming --- src/nhssynth/modules/plotting/io.py | 28 +++++++++++++--------------- src/nhssynth/modules/plotting/run.py | 4 ++-- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/nhssynth/modules/plotting/io.py b/src/nhssynth/modules/plotting/io.py index 86f40303..7125fb4e 100644 --- a/src/nhssynth/modules/plotting/io.py +++ b/src/nhssynth/modules/plotting/io.py @@ -7,26 +7,24 @@ from nhssynth.common.io import * -def check_input_paths( - fn_dataset: str, fn_typed: str, fn_evaluation_bundle: str, dir_experiment: Path -) -> tuple[str, str]: +def check_input_paths(fn_dataset: str, fn_typed: str, fn_evaluations: str, dir_experiment: Path) -> tuple[str, str]: """ Sets up the input and output paths for the model files. Args: fn_dataset: The base name of the dataset. fn_typed: The name of the typed data file. - fn_evaluation_bundle: The name of the file containing the evaluation bundle. + fn_evaluations: The name of the file containing the evaluation bundle. dir_experiment: The path to the experiment directory. Returns: The paths to the data, metadata and metatransformer files. """ - fn_dataset, fn_typed, fn_evaluation_bundle = consistent_endings([fn_dataset, fn_typed, fn_evaluation_bundle]) - fn_typed, fn_evaluation_bundle = potential_suffixes([fn_typed, fn_evaluation_bundle], fn_dataset) - warn_if_path_supplied([fn_dataset, fn_typed, fn_evaluation_bundle], dir_experiment) + fn_dataset, fn_typed, fn_evaluations = consistent_endings([fn_dataset, fn_typed, fn_evaluations]) + fn_typed, fn_evaluations = potential_suffixes([fn_typed, fn_evaluations], fn_dataset) + warn_if_path_supplied([fn_dataset, fn_typed, fn_evaluations], dir_experiment) check_exists([fn_typed], dir_experiment) - return fn_dataset, fn_typed, fn_evaluation_bundle + return fn_dataset, fn_typed, fn_evaluations def load_required_data( @@ -42,20 +40,20 @@ def load_required_data( Returns: The data, metadata and metatransformer. """ - if all(x in args.module_handover for x in ["dataset", "typed", "evaluation_bundle"]): + if all(x in args.module_handover for x in ["dataset", "typed", "evaluations"]): return ( args.module_handover["dataset"], args.module_handover["typed"], - args.module_handover["evaluation_bundle"], + args.module_handover["evaluations"], ) else: - fn_dataset, fn_typed, fn_evaluation_bundle = check_input_paths( - args.dataset, args.typed, args.evaluation_bundle, dir_experiment + fn_dataset, fn_typed, fn_evaluations = check_input_paths( + args.dataset, args.typed, args.evaluations, dir_experiment ) with open(dir_experiment / fn_typed, "rb") as f: real_data = pickle.load(f) - with open(dir_experiment / fn_evaluation_bundle, "rb") as f: - evaluation_bundle = pickle.load(f) + with open(dir_experiment / fn_evaluations, "rb") as f: + evaluations = pickle.load(f) - return fn_dataset, real_data, evaluation_bundle + return fn_dataset, real_data, evaluations diff --git a/src/nhssynth/modules/plotting/run.py b/src/nhssynth/modules/plotting/run.py index f40def3b..209b617c 100644 --- a/src/nhssynth/modules/plotting/run.py +++ b/src/nhssynth/modules/plotting/run.py @@ -35,9 +35,9 @@ def run(args: argparse.Namespace) -> argparse.Namespace: set_seed(args.seed) dir_experiment = experiment_io(args.experiment_name) - fn_dataset, real_data, evaluation_bundle = load_required_data(args, dir_experiment) + fn_dataset, real_data, evaluations = load_required_data(args, dir_experiment) - for architecture, architecture_bundle in evaluation_bundle.items(): + for architecture, architecture_bundle in evaluations.items(): if isinstance(architecture_bundle, dict): for seed, seed_bundle in architecture_bundle.items(): print(f"\nModel architecture: {architecture} Seed: {seed}") From 7710935aefe7a3e1a243320af03216fefe05dceb Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Fri, 1 Sep 2023 15:56:11 +0300 Subject: [PATCH 3/4] Update pipeline test --- config/test_pipeline.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/config/test_pipeline.yaml b/config/test_pipeline.yaml index 6ad50fc9..b4bd507f 100644 --- a/config/test_pipeline.yaml +++ b/config/test_pipeline.yaml @@ -8,14 +8,17 @@ model: architecture: - VAE - DPVAE - num_epochs: 30 + num_epochs: + - 30 + - 50 + patience: 10 target_epsilon: - 1.0 - 3.0 - 6.0 max_grad_norm: 5.0 secure_mode: false - repeats: 3 + repeats: 4 evaluation: downstream_tasks: true column_shape_metrics: From c066bbd9cf661ccd4a185b5433d282df1e7b643b Mon Sep 17 00:00:00 2001 From: HarrisonWilde Date: Fri, 1 Sep 2023 15:56:36 +0300 Subject: [PATCH 4/4] Improving experiment and evaluation handover --- src/nhssynth/modules/dataloader/io.py | 8 +- .../modules/dataloader/metatransformer.py | 7 +- src/nhssynth/modules/dataloader/run.py | 2 +- src/nhssynth/modules/evaluation/io.py | 59 ++++++++------ src/nhssynth/modules/evaluation/run.py | 10 +-- src/nhssynth/modules/evaluation/utils.py | 81 +++++-------------- src/nhssynth/modules/model/io.py | 78 +++++++----------- src/nhssynth/modules/model/run.py | 38 ++++----- src/nhssynth/modules/model/utils.py | 38 +++++---- 9 files changed, 128 insertions(+), 193 deletions(-) diff --git a/src/nhssynth/modules/dataloader/io.py b/src/nhssynth/modules/dataloader/io.py index dc4d52da..3f4e2001 100644 --- a/src/nhssynth/modules/dataloader/io.py +++ b/src/nhssynth/modules/dataloader/io.py @@ -3,11 +3,17 @@ from pathlib import Path import numpy as np +import pandas as pd from nhssynth.common.io import * from nhssynth.modules.dataloader.metatransformer import MetaTransformer from tqdm import tqdm +class TypedDataset: + def __init__(self, typed_dataset: pd.DataFrame): + self.contents = typed_dataset + + def check_input_paths( fn_input: str, fn_metadata: str, @@ -108,7 +114,7 @@ def write_data_outputs( metatransformer.save_metadata(dir_experiment / fn_metadata, args.collapse_yaml) metatransformer.save_constraint_graphs(dir_experiment / fn_constraint_graph) with open(dir_experiment / fn_typed, "wb") as f: - pickle.dump(metatransformer.get_typed_dataset(), f) + pickle.dump(TypedDataset(metatransformer.get_typed_dataset()), f) transformed_dataset = metatransformer.get_transformed_dataset() transformed_dataset.to_pickle(dir_experiment / fn_transformed) if args.write_csv: diff --git a/src/nhssynth/modules/dataloader/metatransformer.py b/src/nhssynth/modules/dataloader/metatransformer.py index 34195c7e..b0d8d34d 100644 --- a/src/nhssynth/modules/dataloader/metatransformer.py +++ b/src/nhssynth/modules/dataloader/metatransformer.py @@ -261,7 +261,7 @@ def get_typed_dataset(self) -> pd.DataFrame: raise ValueError( "The typed dataset has not yet been created. Call `mt.apply()` (or `mt.apply_dtypes()`) first." ) - return TypedDataset(self.typed_dataset) + return self.typed_dataset def get_prepared_dataset(self) -> pd.DataFrame: if not hasattr(self, "prepared_dataset"): @@ -301,8 +301,3 @@ def save_metadata(self, path: pathlib.Path, collapse_yaml: bool = False) -> None def save_constraint_graphs(self, path: pathlib.Path) -> None: return self.metadata.constraints._output_graphs_html(path) - - -class TypedDataset: - def __init__(self, typed_dataset: pd.DataFrame): - self.internal = typed_dataset diff --git a/src/nhssynth/modules/dataloader/run.py b/src/nhssynth/modules/dataloader/run.py index 83978f2d..7ad8d08b 100644 --- a/src/nhssynth/modules/dataloader/run.py +++ b/src/nhssynth/modules/dataloader/run.py @@ -32,7 +32,7 @@ def run(args: argparse.Namespace) -> argparse.Namespace: } ) if "evaluation" in args.modules_to_run: - args.module_handover.update({"typed": mt.get_typed_dataset().internal}) + args.module_handover.update({"typed": mt.get_typed_dataset()}) print("\033[0m") diff --git a/src/nhssynth/modules/evaluation/io.py b/src/nhssynth/modules/evaluation/io.py index d0bbe838..b4ee1b58 100644 --- a/src/nhssynth/modules/evaluation/io.py +++ b/src/nhssynth/modules/evaluation/io.py @@ -5,19 +5,23 @@ import pandas as pd from nhssynth.common.io import * -from nhssynth.modules.evaluation.utils import EvalFrame + + +class Evaluations: + def __init__(self, evaluations: dict[str, dict[str, Any]]): + self.contents = evaluations def check_input_paths( - fn_dataset: str, fn_typed: str, fn_experiments: str, fn_sdv_metadata: str, dir_experiment: Path + fn_dataset: str, fn_typed: str, fn_synthetic_datasets: str, fn_sdv_metadata: str, dir_experiment: Path ) -> tuple[str, str]: """ Sets up the input and output paths for the model files. Args: fn_dataset: The base name of the dataset. - fn_typed: The name of the typed data file. - fn_experiments: The name of the metatransformer file. + fn_typed: The name of the typed real dataset file. + fn_synthetic_datasets: The filename of the collection of synethtic datasets. fn_sdv_metadata: The name of the SDV metadata file. dir_experiment: The path to the experiment directory. @@ -25,37 +29,40 @@ def check_input_paths( The paths to the data, metadata and metatransformer files. """ fn_dataset = Path(fn_dataset).stem - fn_typed, fn_experiments, fn_sdv_metadata = consistent_endings([fn_typed, fn_experiments, fn_sdv_metadata]) - fn_typed, fn_experiments, fn_sdv_metadata = potential_suffixes( - [fn_typed, fn_experiments, fn_sdv_metadata], fn_dataset + fn_typed, fn_synthetic_datasets, fn_sdv_metadata = consistent_endings( + [fn_typed, fn_synthetic_datasets, fn_sdv_metadata] + ) + fn_typed, fn_synthetic_datasets, fn_sdv_metadata = potential_suffixes( + [fn_typed, fn_synthetic_datasets, fn_sdv_metadata], fn_dataset ) - warn_if_path_supplied([fn_typed, fn_experiments, fn_sdv_metadata], dir_experiment) - check_exists([fn_typed, fn_experiments, fn_sdv_metadata], dir_experiment) - return fn_dataset, fn_typed, fn_experiments, fn_sdv_metadata + warn_if_path_supplied([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment) + check_exists([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment) + return fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata def output_eval( - eval_frame: EvalFrame, + evaluations: pd.DataFrame, fn_dataset: Path, - fn_evaluation_bundle: str, + fn_evaluations: str, dir_experiment: Path, ) -> None: """ Sets up the input and output paths for the model files. Args: + evaluations: The evaluations to output. fn_dataset: The base name of the dataset. - fn_evaluation_bundle: The name of the evaluation bundle file. + fn_evaluations: The filename of the collection of evaluations. dir_experiment: The path to the experiment output directory. Returns: The path to output the model. """ - fn_evaluation_bundle = consistent_ending(fn_evaluation_bundle) - fn_evaluation_bundle = potential_suffix(fn_evaluation_bundle, fn_dataset) - warn_if_path_supplied([fn_evaluation_bundle], dir_experiment) - with open(dir_experiment / fn_evaluation_bundle, "wb") as f: - pickle.dump(eval_frame, f) + fn_evaluations = consistent_ending(fn_evaluations) + fn_evaluations = potential_suffix(fn_evaluations, fn_dataset) + warn_if_path_supplied([fn_evaluations], dir_experiment) + with open(dir_experiment / fn_evaluations, "wb") as f: + pickle.dump(Evaluations(evaluations), f) def load_required_data( @@ -71,22 +78,22 @@ def load_required_data( Returns: The dataset name, the real data, the bundle of synthetic data from the modelling stage, and the SDV metadata. """ - if all(x in args.module_handover for x in ["dataset", "typed", "experiments", "sdv_metadata"]): + if all(x in args.module_handover for x in ["dataset", "typed", "synthetic_datasets", "sdv_metadata"]): return ( args.module_handover["dataset"], args.module_handover["typed"], - args.module_handover["experiments"], + args.module_handover["synthetic_datasets"], args.module_handover["sdv_metadata"], ) else: - fn_dataset, fn_typed, fn_experiments, fn_sdv_metadata = check_input_paths( - args.dataset, args.typed, args.experiments, args.sdv_metadata, dir_experiment + fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata = check_input_paths( + args.dataset, args.typed, args.synthetic_datasets, args.sdv_metadata, dir_experiment ) with open(dir_experiment / fn_typed, "rb") as f: - real_data = pickle.load(f).internal + real_data = pickle.load(f).contents with open(dir_experiment / fn_sdv_metadata, "rb") as f: sdv_metadata = pickle.load(f) - with open(dir_experiment / fn_experiments, "rb") as f: - experiments = pickle.load(f) + with open(dir_experiment / fn_synthetic_datasets, "rb") as f: + synthetic_datasets = pickle.load(f).contents - return fn_dataset, real_data, experiments, sdv_metadata + return fn_dataset, real_data, synthetic_datasets, sdv_metadata diff --git a/src/nhssynth/modules/evaluation/run.py b/src/nhssynth/modules/evaluation/run.py index bc4e1084..de146eb3 100644 --- a/src/nhssynth/modules/evaluation/run.py +++ b/src/nhssynth/modules/evaluation/run.py @@ -6,12 +6,12 @@ def run(args: argparse.Namespace) -> argparse.Namespace: - print("mRunning evaluation module...\n\033[32") + print("Running evaluation module...\n\033[32m") set_seed(args.seed) dir_experiment = experiment_io(args.experiment_name) - fn_dataset, real_dataset, experiments, sdv_metadata = load_required_data(args, dir_experiment) + fn_dataset, real_dataset, synthetic_datasets, sdv_metadata = load_required_data(args, dir_experiment) args, tasks, metrics = validate_metric_args(args, fn_dataset, real_dataset.columns) @@ -27,14 +27,14 @@ def run(args: argparse.Namespace) -> argparse.Namespace: args.sensitive_categorical_fields, ) - eval_frame.evaluate(real_dataset, experiments) + eval_frame.evaluate(real_dataset, synthetic_datasets) - output_eval(eval_frame.get_evaluation_bundle(), fn_dataset, args.evaluation_bundle, dir_experiment) + output_eval(eval_frame.get_evaluations(), fn_dataset, args.evaluations, dir_experiment) if "dashboard" in args.modules_to_run or "plotting" in args.modules_to_run: args.module_handover.update({"fn_dataset": fn_dataset}) if "plotting" in args.modules_to_run: - args.module_handover.update({"evaluation_bundle": eval_frame, "experiments": experiments}) + args.module_handover.update({"evaluations": eval_frame, "synthetic_datasets": synthetic_datasets}) print("\033[0m") diff --git a/src/nhssynth/modules/evaluation/utils.py b/src/nhssynth/modules/evaluation/utils.py index d03e5cb3..71f636c1 100644 --- a/src/nhssynth/modules/evaluation/utils.py +++ b/src/nhssynth/modules/evaluation/utils.py @@ -9,7 +9,6 @@ NUMERICAL_PRIVACY_METRICS, TABLE_METRICS, ) -from nhssynth.common.dicts import filter_dict from nhssynth.modules.evaluation.aequitas import run_aequitas from nhssynth.modules.evaluation.tasks import Task, get_tasks from sdmetrics.single_table import MultiColumnPairsMetric, MultiSingleColumnMetric @@ -49,7 +48,7 @@ def __init__( self.metric_groups = self._build_metric_groups() - def _build_metric_groups(self) -> set[str]: + def _build_metric_groups(self) -> list[str]: metric_groups = set() if self.tasks: metric_groups.add("task") @@ -64,66 +63,30 @@ def _build_metric_groups(self) -> set[str]: metric_groups.add("columnwise") if metric in TABLE_METRICS and issubclass(TABLE_METRICS[metric], MultiColumnPairsMetric): metric_groups.add("pairwise") - return metric_groups - - def evaluate(self, real_dataset: pd.DataFrame, experiments: list[dict[str, Any]]) -> dict[str, pd.DataFrame]: - assert "Real" not in [experiment["id"] for experiment in experiments], "Real is a reserved experiment ID." - self._experiments = pd.DataFrame( - [{"architecture": "Real", "id": "Real"}] - + [ - { - **filter_dict(row, {"model_config", "dataset", "num_configs"}), - **row["model_config"], - } - for row in experiments - ] - ) - assert len(self._experiments["id"]) == len(self._experiments["id"].unique()), "Experiment IDs must be unique." - self._dict = {id: {} for id in self._experiments["id"]} - self._step(real_dataset, "Real") - pbar = tqdm(experiments, desc="Evaluating") - for experiment in pbar: - pbar.set_description( - f"Evaluating {experiment['architecture']}, config {experiment['config_idx']}, repeat {experiment['repeat']}" - ) - self._step(real_dataset, experiment["id"], experiment["dataset"]) + return list(metric_groups) + + def evaluate(self, real_dataset: pd.DataFrame, synthetic_datasets: list[dict[str, Any]]) -> dict[str, pd.DataFrame]: + assert not any("Real" in i for i in synthetic_datasets.index), "Real is a reserved dataset ID." + assert synthetic_datasets.index.is_unique, "Dataset IDs must be unique." + self._evaluations = pd.DataFrame(index=synthetic_datasets.index, columns=self.metric_groups) + self._evaluations.loc[("Real", None, None)] = self._step(real_dataset) + pbar = tqdm(synthetic_datasets.iterrows(), desc="Evaluating", total=len(synthetic_datasets)) + for i, dataset in pbar: + pbar.set_description(f"Evaluating {i[0]}, repeat {i[1]}, config {i[2]}") + self._evaluations.loc[i] = self._step(real_dataset, dataset.values[0]) def get_evaluations(self) -> dict[str, pd.DataFrame]: """ Return a dict of dataframes each with one column for seed, one for architecture, and one per metric / report """ - assert hasattr(self, "_dict"), "You must first run `evaluate` on a `real_dataset` and set of `experiments`." - out = {} - for metric_group in self.metric_groups: - out[metric_group] = [] - for id in self._experiments["id"]: - if id != "Real" or metric_group in {"task", "aequitas"}: - out[metric_group].append({"id": id, **self._dict[id][metric_group]}) + assert hasattr( + self, "_evaluations" + ), "You must first run `evaluate` on a `real_dataset` and set of `synthetic_datasets`." return { - metric_group: pd.DataFrame(metric_group_dict) - if metric_group not in {"columnwise", "pairwise"} - else metric_group_dict - for metric_group, metric_group_dict in out.items() + metric_group: self._evaluations[metric_group].apply(pd.Series).dropna(how="all") + for metric_group in self.metric_groups } - def get_experiments(self) -> pd.DataFrame: - """ - Return a dataframe of experiment configurations - """ - assert hasattr( - self, "_experiments" - ), "You must first run `evaluate` on a `real_dataset` and set of `experiments`." - return self._experiments - - def get_evaluation_bundle(self) -> tuple[dict[str, pd.DataFrame], pd.DataFrame]: - """ - Return a tuple of (evaluations, experiments) - """ - return EvalBundle(self.get_evaluations(), self.get_experiments()) - - def _update(self, eval_dict, id: str) -> None: - self._dict[id].update(eval_dict) - def _task_step(self, data: pd.DataFrame) -> dict[str, dict]: metric_dict = {metric_group: {} for metric_group in self.metric_groups} for task in tqdm(self.tasks, desc="Running downstream tasks", leave=False): @@ -168,20 +131,14 @@ def _compute_metric( ) return metric_dict - def _step(self, real_data: pd.DataFrame, id: str, synthetic_data: pd.DataFrame = None) -> None: + def _step(self, real_data: pd.DataFrame, synthetic_data: pd.DataFrame = None) -> None: if synthetic_data is None: metric_dict = self._task_step(real_data) else: metric_dict = self._task_step(synthetic_data) for metric in tqdm(self.metrics, desc="Running metrics", leave=False): metric_dict = self._compute_metric(metric_dict, metric, real_data, synthetic_data) - self._update(metric_dict, id) - - -class EvalBundle: - def __init__(self, evaluations: pd.DataFrame, experiments: pd.DataFrame): - self.evaluations = evaluations - self.experiments = experiments + return metric_dict def validate_metric_args( diff --git a/src/nhssynth/modules/model/io.py b/src/nhssynth/modules/model/io.py index 92cf5353..7d007fc9 100644 --- a/src/nhssynth/modules/model/io.py +++ b/src/nhssynth/modules/model/io.py @@ -1,12 +1,20 @@ import argparse import pickle from pathlib import Path -from typing import Optional import pandas as pd from nhssynth.common.io import * from nhssynth.modules.dataloader.metatransformer import MetaTransformer -from nhssynth.modules.model.common.model import Model + + +class Experiments: + def __init__(self, experiments: pd.DataFrame): + self.contents = experiments + + +class SyntheticDatasets: + def __init__(self, synthetic_datasets: pd.DataFrame): + self.contents = synthetic_datasets def check_input_paths( @@ -32,59 +40,31 @@ def check_input_paths( return fn_dataset, fn_transformed, fn_metatransformer -def check_output_paths( - fn_dataset: Path, - fn_synthetic: str, - fn_model: str, - dir_experiment: Path, - suffix: str, -) -> tuple[str, str]: - """ - Sets up the input and output paths for the model files. - - Args: - fn_dataset: The base name of the dataset. - fn_synthetic: The name of the synthetic data file. - fn_model: The name of the model file. - dir_experiment: The path to the experiment output directory. - suffix: The suffix to append to the output files, usually the model architecture (and seed if applicable). - - Returns: - The path to output the model. - """ - fn_synthetic, fn_model = consistent_endings([(fn_synthetic, ".csv", suffix), (fn_model, ".pt", suffix)]) - fn_synthetic, fn_model = potential_suffixes([fn_synthetic, fn_model], fn_dataset) - warn_if_path_supplied([fn_synthetic, fn_model], dir_experiment) - return fn_synthetic, fn_model - - -def output_iter( - model: Model, - synthetic: pd.DataFrame, +def write_data_outputs( + experiments: pd.DataFrame, + synthetic_datasets: pd.DataFrame, + models: pd.DataFrame, fn_dataset: str, - synthetic_name: str, - model_name: str, dir_experiment: Path, - suffix: str, + args: argparse.Namespace, ) -> None: - dir_iter = dir_experiment / suffix - dir_iter.mkdir(parents=True, exist_ok=True) - fn_output, fn_model = check_output_paths(fn_dataset, synthetic_name, model_name, dir_experiment, suffix) - synthetic.to_csv(dir_iter / fn_output, index=False) - model.save(dir_iter / fn_model) + train_configs = experiments["train_config"].apply(pd.Series) + model_configs = experiments["model_config"].apply(pd.Series) + experiments = experiments.drop(columns=["train_config", "model_config"]).join(train_configs).join(model_configs) + fn_experiments, fn_synthetic_datasets = consistent_endings([args.experiments, args.synthetic_datasets]) + fn_experiments, fn_synthetic_datasets = potential_suffixes([fn_experiments, fn_synthetic_datasets], fn_dataset) + warn_if_path_supplied([fn_experiments, fn_synthetic_datasets], dir_experiment) -def output_full( - experiments: list[tuple[int, str, pd.DataFrame]], - fn_dataset: str, - experiments_name: str, - dir_experiment: Path, -) -> None: - fn_experiments = consistent_ending(experiments_name) - fn_experiments = potential_suffix(fn_experiments, fn_dataset) - warn_if_path_supplied(fn_experiments, dir_experiment) with open(dir_experiment / fn_experiments, "wb") as f: - pickle.dump(experiments, f) + pickle.dump(Experiments(experiments), f) + with open(dir_experiment / fn_synthetic_datasets, "wb") as f: + pickle.dump(SyntheticDatasets(synthetic_datasets), f) + (dir_experiment / "models").mkdir(parents=True, exist_ok=True) + for i, model in models.iterrows(): + fn_model = consistent_ending(args.model, ending=".pt", suffix=f"{i[0]}_repeat_{i[1]}_config_{i[2]}") + fn_model = potential_suffix(fn_model, fn_dataset) + model["model"].save(dir_experiment / "models" / fn_model) def load_required_data( diff --git a/src/nhssynth/modules/model/run.py b/src/nhssynth/modules/model/run.py index 68137ed6..d0559f74 100644 --- a/src/nhssynth/modules/model/run.py +++ b/src/nhssynth/modules/model/run.py @@ -4,28 +4,27 @@ import pandas as pd from nhssynth.common import * from nhssynth.modules.dataloader.metatransformer import MetaTransformer -from nhssynth.modules.model.io import load_required_data, output_full, output_iter +from nhssynth.modules.model.io import load_required_data, write_data_outputs from nhssynth.modules.model.models import MODELS from nhssynth.modules.model.utils import get_experiments def run_iter( experiment: dict[str, Any], + architecture: str, real_dataset: pd.DataFrame, metatransformer: MetaTransformer, - patience: int, displayed_metrics: list[str], num_samples: int, ) -> pd.DataFrame: set_seed(experiment["seed"]) - model = MODELS[experiment["architecture"]](real_dataset, metatransformer, **experiment["model_config"]) + model = MODELS[architecture](real_dataset, metatransformer, **experiment["model_config"]) _, _ = model.train( - num_epochs=experiment["num_epochs"], - patience=patience, + **experiment["train_config"], displayed_metrics=displayed_metrics.copy(), ) - synthetic = model.generate(num_samples) - return model, synthetic + dataset = model.generate(num_samples) + return {"dataset": dataset}, {"model": model} def run(args: argparse.Namespace) -> argparse.Namespace: @@ -37,30 +36,23 @@ def run(args: argparse.Namespace) -> argparse.Namespace: fn_dataset, real_dataset, metatransformer = load_required_data(args, dir_experiment) experiments = get_experiments(args) - for experiment in experiments: + models = pd.DataFrame(index=experiments.index, columns=["model"]) + synthetic_datasets = pd.DataFrame(index=experiments.index, columns=["dataset"]) + + for i, experiment in experiments.iterrows(): print( - f"\nRunning the {experiment['architecture']} architecture with configuration {experiment['config_idx']} of {experiment['num_configs']}, repeat {experiment['repeat']} of {args.repeats} 🤖\033[31m" - ) - model, synthetic_dataset = run_iter( - experiment, real_dataset, metatransformer, args.patience, args.displayed_metrics.copy(), args.num_samples + f"\nRunning the {i[0]} architecture, repeat {i[1]} of {args.repeats}, with configuration {i[2]} of {experiment['num_configs']} 🤖\033[31m" ) - output_iter( - model, - synthetic_dataset, - fn_dataset, - args.synthetic, - args.model_file, - dir_experiment, - experiment["id"], + synthetic_datasets.loc[i], models.loc[i] = run_iter( + experiment, i[0], real_dataset, metatransformer, args.displayed_metrics.copy(), args.num_samples ) - experiment["dataset"] = synthetic_dataset - output_full(experiments, fn_dataset, args.experiments, dir_experiment) + write_data_outputs(experiments, synthetic_datasets, models, fn_dataset, dir_experiment, args) if "dashboard" in args.modules_to_run or "evaluation" in args.modules_to_run or "plotting" in args.modules_to_run: args.module_handover.update({"fn_dataset": fn_dataset}) if "evaluation" in args.modules_to_run or "plotting" in args.modules_to_run: - args.module_handover.update({"experiments": experiments}) + args.module_handover.update({"synthetic_datasets": synthetic_datasets}) print("") diff --git a/src/nhssynth/modules/model/utils.py b/src/nhssynth/modules/model/utils.py index 342a3c63..be13e92a 100644 --- a/src/nhssynth/modules/model/utils.py +++ b/src/nhssynth/modules/model/utils.py @@ -13,30 +13,28 @@ def wrap_arg(arg) -> Union[list, tuple]: return arg +def configs_from_arg_combinations(args: argparse.Namespace, arg_list: list[str]): + wrapped_args = {arg: wrap_arg(getattr(args, arg)) for arg in arg_list} + combinations = list(itertools.product(*wrapped_args.values())) + return [{k: v for k, v in zip(wrapped_args.keys(), values) if v is not None} for values in combinations] + + def get_experiments(args: argparse.Namespace) -> list[dict[str, Any]]: - experiments = [] + experiments = pd.DataFrame( + columns=["architecture", "repeat", "config", "model_config", "seed", "train_config", "num_configs"] + ) + train_configs = configs_from_arg_combinations(args, ["num_epochs", "patience"]) for arch_name, repeat in itertools.product(*[wrap_arg(args.architecture), list(range(args.repeats))]): arch = MODELS[arch_name] - model_args = { - arg: wrap_arg(getattr(args, arg)) for arg in arch.get_args() + ["batch_size", "use_gpu", "num_epochs"] - } - model_configs = list(itertools.product(*model_args.values())) - for i, values in enumerate(model_configs): - model_config = {k: v for k, v in zip(model_args.keys(), values) if v is not None} - num_epochs = model_config.pop("num_epochs") - experiment = { + model_configs = configs_from_arg_combinations(args, arch.get_args() + ["batch_size", "use_gpu"]) + for i, (train_config, model_config) in enumerate(itertools.product(train_configs, model_configs)): + experiments.loc[len(experiments.index)] = { "architecture": arch_name, + "repeat": repeat + 1, + "config": i + 1, "model_config": model_config, - "config_idx": str(i + 1), - "num_configs": len(model_configs), + "num_configs": len(model_configs) * len(train_configs), "seed": args.seed + repeat if args.seed else None, - "repeat": str(repeat + 1), - "num_epochs": num_epochs, + "train_config": train_config, } - experiment["id"] = ( - arch_name - + (f"_config_{experiment['config_idx']}" if len(model_configs) > 1 else "") - + (f"_repeat_{experiment['repeat']}" if args.repeats > 1 else "") - ) - experiments.append(experiment) - return experiments + return experiments.set_index(["architecture", "repeat", "config"], drop=True)