Skip to content

Commit

Permalink
Merge pull request #104 from nhsx/harry/improving-eval-handover
Browse files Browse the repository at this point in the history
Improving handover of evaluations and experiments
  • Loading branch information
HarrisonWilde authored Sep 1, 2023
2 parents bc63eb2 + c066bbd commit 5d12079
Show file tree
Hide file tree
Showing 15 changed files with 163 additions and 228 deletions.
7 changes: 5 additions & 2 deletions config/test_pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@ model:
architecture:
- VAE
- DPVAE
num_epochs: 30
num_epochs:
- 30
- 50
patience: 10
target_epsilon:
- 1.0
- 3.0
- 6.0
max_grad_norm: 5.0
secure_mode: false
repeats: 3
repeats: 4
evaluation:
downstream_tasks: true
column_shape_metrics:
Expand Down
16 changes: 10 additions & 6 deletions src/nhssynth/cli/common_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,20 @@ def get_parser(overrides: bool = False) -> argparse.ArgumentParser:
"sdv_metadata",
"filename of the metadata formatted for use with SDV",
),
"synthetic": suffix_parser_generator(
"synthetic",
"filename of the synthetic data",
),
"experiments": suffix_parser_generator(
"experiments",
"filename of the experiment bundle, i.e. the collection of all seeds, models, and synthetic datasets",
),
"evaluation_bundle": suffix_parser_generator(
"evaluation_bundle",
"synthetic_datasets": suffix_parser_generator(
"synthetic_datasets",
"filename of the collection of synthetic datasets generated by a given set of `experiments`",
),
"model": suffix_parser_generator(
"model",
"name for each of the collection of models trained during a given set of `experiments`",
),
"evaluations": suffix_parser_generator(
"evaluations",
"filename of the (collection of) evaluation(s) for a given set of `experiments`",
),
}
7 changes: 1 addition & 6 deletions src/nhssynth/cli/module_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,6 @@ def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides:
default=1,
help="how many times to repeat the training process per model architecture (<SEED> is incremented each time)",
)
group.add_argument(
"--model-file",
type=str,
default="_model",
help="specify the filename of the model to be saved in `experiments/<EXPERIMENT_NAME>/`, defaults to `<DATASET>_model.pt`",
)
group.add_argument(
"--batch-size",
type=int,
Expand All @@ -118,6 +112,7 @@ def add_model_args(parser: argparse.ArgumentParser, group_title: str, overrides:
group.add_argument(
"--patience",
type=int,
nargs="+",
default=5,
help="how many epochs the model is allowed to train for without improvement",
)
Expand Down
8 changes: 4 additions & 4 deletions src/nhssynth/cli/module_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,28 +114,28 @@ def add_config_args(parser: argparse.ArgumentParser) -> None:
add_args=add_model_args,
description="run the model architecture module, to train a synthetic data generator",
help="train a model",
common_parsers=["transformed", "metatransformer", "synthetic", "experiments"],
common_parsers=["transformed", "metatransformer", "experiments", "synthetic_datasets", "model"],
),
"evaluation": ModuleConfig(
func=evaluation.run,
add_args=add_evaluation_args,
description="run the evaluation module, to evaluate an experiment",
help="evaluate an experiment",
common_parsers=["sdv_metadata", "typed", "experiments", "evaluation_bundle"],
common_parsers=["sdv_metadata", "typed", "experiments", "synthetic_datasets", "evaluations"],
),
"plotting": ModuleConfig(
func=plotting.run,
add_args=add_plotting_args,
description="run the plotting module, to generate plots for a given model and / or evaluation",
help="generate plots",
common_parsers=["typed", "evaluation_bundle"],
common_parsers=["typed", "evaluations"],
),
"dashboard": ModuleConfig(
func=dashboard.run,
add_args=add_dashboard_args,
description="run the dashboard module, to produce a streamlit dashboard",
help="start up a streamlit dashboard to view the results of an evaluation",
common_parsers=["typed", "experiments", "evaluation_bundle"],
common_parsers=["typed", "experiments", "synthetic_datasets", "evaluations"],
no_seed=True,
),
"pipeline": ModuleConfig(
Expand Down
8 changes: 7 additions & 1 deletion src/nhssynth/modules/dataloader/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
from pathlib import Path

import numpy as np
import pandas as pd
from nhssynth.common.io import *
from nhssynth.modules.dataloader.metatransformer import MetaTransformer
from tqdm import tqdm


class TypedDataset:
def __init__(self, typed_dataset: pd.DataFrame):
self.contents = typed_dataset


def check_input_paths(
fn_input: str,
fn_metadata: str,
Expand Down Expand Up @@ -108,7 +114,7 @@ def write_data_outputs(
metatransformer.save_metadata(dir_experiment / fn_metadata, args.collapse_yaml)
metatransformer.save_constraint_graphs(dir_experiment / fn_constraint_graph)
with open(dir_experiment / fn_typed, "wb") as f:
pickle.dump(metatransformer.get_typed_dataset(), f)
pickle.dump(TypedDataset(metatransformer.get_typed_dataset()), f)
transformed_dataset = metatransformer.get_transformed_dataset()
transformed_dataset.to_pickle(dir_experiment / fn_transformed)
if args.write_csv:
Expand Down
7 changes: 1 addition & 6 deletions src/nhssynth/modules/dataloader/metatransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def get_typed_dataset(self) -> pd.DataFrame:
raise ValueError(
"The typed dataset has not yet been created. Call `mt.apply()` (or `mt.apply_dtypes()`) first."
)
return TypedDataset(self.typed_dataset)
return self.typed_dataset

def get_prepared_dataset(self) -> pd.DataFrame:
if not hasattr(self, "prepared_dataset"):
Expand Down Expand Up @@ -301,8 +301,3 @@ def save_metadata(self, path: pathlib.Path, collapse_yaml: bool = False) -> None

def save_constraint_graphs(self, path: pathlib.Path) -> None:
return self.metadata.constraints._output_graphs_html(path)


class TypedDataset:
def __init__(self, typed_dataset: pd.DataFrame):
self.internal = typed_dataset
2 changes: 1 addition & 1 deletion src/nhssynth/modules/dataloader/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def run(args: argparse.Namespace) -> argparse.Namespace:
}
)
if "evaluation" in args.modules_to_run:
args.module_handover.update({"typed": mt.get_typed_dataset().internal})
args.module_handover.update({"typed": mt.get_typed_dataset()})

print("\033[0m")

Expand Down
59 changes: 33 additions & 26 deletions src/nhssynth/modules/evaluation/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,57 +5,64 @@

import pandas as pd
from nhssynth.common.io import *
from nhssynth.modules.evaluation.utils import EvalFrame


class Evaluations:
def __init__(self, evaluations: dict[str, dict[str, Any]]):
self.contents = evaluations


def check_input_paths(
fn_dataset: str, fn_typed: str, fn_experiments: str, fn_sdv_metadata: str, dir_experiment: Path
fn_dataset: str, fn_typed: str, fn_synthetic_datasets: str, fn_sdv_metadata: str, dir_experiment: Path
) -> tuple[str, str]:
"""
Sets up the input and output paths for the model files.
Args:
fn_dataset: The base name of the dataset.
fn_typed: The name of the typed data file.
fn_experiments: The name of the metatransformer file.
fn_typed: The name of the typed real dataset file.
fn_synthetic_datasets: The filename of the collection of synethtic datasets.
fn_sdv_metadata: The name of the SDV metadata file.
dir_experiment: The path to the experiment directory.
Returns:
The paths to the data, metadata and metatransformer files.
"""
fn_dataset = Path(fn_dataset).stem
fn_typed, fn_experiments, fn_sdv_metadata = consistent_endings([fn_typed, fn_experiments, fn_sdv_metadata])
fn_typed, fn_experiments, fn_sdv_metadata = potential_suffixes(
[fn_typed, fn_experiments, fn_sdv_metadata], fn_dataset
fn_typed, fn_synthetic_datasets, fn_sdv_metadata = consistent_endings(
[fn_typed, fn_synthetic_datasets, fn_sdv_metadata]
)
fn_typed, fn_synthetic_datasets, fn_sdv_metadata = potential_suffixes(
[fn_typed, fn_synthetic_datasets, fn_sdv_metadata], fn_dataset
)
warn_if_path_supplied([fn_typed, fn_experiments, fn_sdv_metadata], dir_experiment)
check_exists([fn_typed, fn_experiments, fn_sdv_metadata], dir_experiment)
return fn_dataset, fn_typed, fn_experiments, fn_sdv_metadata
warn_if_path_supplied([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment)
check_exists([fn_typed, fn_synthetic_datasets, fn_sdv_metadata], dir_experiment)
return fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata


def output_eval(
eval_frame: EvalFrame,
evaluations: pd.DataFrame,
fn_dataset: Path,
fn_evaluation_bundle: str,
fn_evaluations: str,
dir_experiment: Path,
) -> None:
"""
Sets up the input and output paths for the model files.
Args:
evaluations: The evaluations to output.
fn_dataset: The base name of the dataset.
fn_evaluation_bundle: The name of the evaluation bundle file.
fn_evaluations: The filename of the collection of evaluations.
dir_experiment: The path to the experiment output directory.
Returns:
The path to output the model.
"""
fn_evaluation_bundle = consistent_ending(fn_evaluation_bundle)
fn_evaluation_bundle = potential_suffix(fn_evaluation_bundle, fn_dataset)
warn_if_path_supplied([fn_evaluation_bundle], dir_experiment)
with open(dir_experiment / fn_evaluation_bundle, "wb") as f:
pickle.dump(eval_frame, f)
fn_evaluations = consistent_ending(fn_evaluations)
fn_evaluations = potential_suffix(fn_evaluations, fn_dataset)
warn_if_path_supplied([fn_evaluations], dir_experiment)
with open(dir_experiment / fn_evaluations, "wb") as f:
pickle.dump(Evaluations(evaluations), f)


def load_required_data(
Expand All @@ -71,22 +78,22 @@ def load_required_data(
Returns:
The dataset name, the real data, the bundle of synthetic data from the modelling stage, and the SDV metadata.
"""
if all(x in args.module_handover for x in ["dataset", "typed", "experiments", "sdv_metadata"]):
if all(x in args.module_handover for x in ["dataset", "typed", "synthetic_datasets", "sdv_metadata"]):
return (
args.module_handover["dataset"],
args.module_handover["typed"],
args.module_handover["experiments"],
args.module_handover["synthetic_datasets"],
args.module_handover["sdv_metadata"],
)
else:
fn_dataset, fn_typed, fn_experiments, fn_sdv_metadata = check_input_paths(
args.dataset, args.typed, args.experiments, args.sdv_metadata, dir_experiment
fn_dataset, fn_typed, fn_synthetic_datasets, fn_sdv_metadata = check_input_paths(
args.dataset, args.typed, args.synthetic_datasets, args.sdv_metadata, dir_experiment
)
with open(dir_experiment / fn_typed, "rb") as f:
real_data = pickle.load(f).internal
real_data = pickle.load(f).contents
with open(dir_experiment / fn_sdv_metadata, "rb") as f:
sdv_metadata = pickle.load(f)
with open(dir_experiment / fn_experiments, "rb") as f:
experiments = pickle.load(f)
with open(dir_experiment / fn_synthetic_datasets, "rb") as f:
synthetic_datasets = pickle.load(f).contents

return fn_dataset, real_data, experiments, sdv_metadata
return fn_dataset, real_data, synthetic_datasets, sdv_metadata
10 changes: 5 additions & 5 deletions src/nhssynth/modules/evaluation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@


def run(args: argparse.Namespace) -> argparse.Namespace:
print("mRunning evaluation module...\n\033[32")
print("Running evaluation module...\n\033[32m")

set_seed(args.seed)
dir_experiment = experiment_io(args.experiment_name)

fn_dataset, real_dataset, experiments, sdv_metadata = load_required_data(args, dir_experiment)
fn_dataset, real_dataset, synthetic_datasets, sdv_metadata = load_required_data(args, dir_experiment)

args, tasks, metrics = validate_metric_args(args, fn_dataset, real_dataset.columns)

Expand All @@ -27,14 +27,14 @@ def run(args: argparse.Namespace) -> argparse.Namespace:
args.sensitive_categorical_fields,
)

eval_frame.evaluate(real_dataset, experiments)
eval_frame.evaluate(real_dataset, synthetic_datasets)

output_eval(eval_frame.get_evaluation_bundle(), fn_dataset, args.evaluation_bundle, dir_experiment)
output_eval(eval_frame.get_evaluations(), fn_dataset, args.evaluations, dir_experiment)

if "dashboard" in args.modules_to_run or "plotting" in args.modules_to_run:
args.module_handover.update({"fn_dataset": fn_dataset})
if "plotting" in args.modules_to_run:
args.module_handover.update({"evaluation_bundle": eval_frame, "experiments": experiments})
args.module_handover.update({"evaluations": eval_frame, "synthetic_datasets": synthetic_datasets})

print("\033[0m")

Expand Down
Loading

0 comments on commit 5d12079

Please sign in to comment.