From 6b61ff438fed9dc0e83d8fbeddabed50e135e979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Tue, 14 Feb 2023 09:29:44 -0500 Subject: [PATCH] ENH: Add property to discard logging training to `Comet ML` Add training configuration flag and the corresponding experiment class property to discard logging the training process to `Comet ML`. It may be useful for users that do not have a `Comet ML` account or are not willing to log the experiment. --- configs/train_config.yaml | 2 ++ scripts/ae_train.py | 30 ++++++++++++++++--------- tractolearn/config/experiment.py | 5 +++++ tractolearn/learning/trainer_manager.py | 3 ++- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/configs/train_config.yaml b/configs/train_config.yaml index 0370904..1744ab1 100644 --- a/configs/train_config.yaml +++ b/configs/train_config.yaml @@ -43,5 +43,7 @@ weights: viz: False viz_num_batches: 10 +# Whether the experiment is to be logged to Comet ML +log_to_comet: False num_workers: 24 diff --git a/scripts/ae_train.py b/scripts/ae_train.py index ed1829b..9854c39 100644 --- a/scripts/ae_train.py +++ b/scripts/ae_train.py @@ -104,11 +104,14 @@ def main(): random.seed(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # TODO: Find a better way to import API key (eventually remove comet.ml) - logger.info(comet_ml.get_comet_version()) - experiment_recorder = experiment.record_experiment( - api_key=os.environ["COMETML"] - ) + experiment_recorder = None + log_to_comet = experiment.log_to_comet + if log_to_comet: + # TODO: Find a better way to import API key (eventually remove comet.ml) + logger.info(comet_ml.get_comet_version()) + experiment_recorder = experiment.record_experiment( + api_key=os.environ["COMETML"] + ) ref_anat_img = nib.load(experiment_dict["ref_anat_fname"]) isocenter = compute_isocenter(ref_anat_img) @@ -140,7 +143,7 @@ def main(): (data_manager.point_dims, data_manager.num_points), isocenter, volume, - experiment_recorder, + experiment_recorder=experiment_recorder, ) logger.info("Finished building model and trainer.") @@ -148,9 +151,13 @@ def main(): # Start training run logger.info("Starting training...") for epoch in range(1, experiment_dict["epochs"] + 1): - with experiment_recorder.train(): + if log_to_comet: + with experiment_recorder.train(): + trainer.train(epoch) + with experiment_recorder.validate(): + trainer.valid(epoch) + else: trainer.train(epoch) - with experiment_recorder.validate(): trainer.valid(epoch) # Project the valid set @@ -174,9 +181,10 @@ def main(): experiment_dict["rbx_classes"], ) # Log the latent space plot to Comet - experiment_recorder.log_image( - latent_plot_filename, name="latent_umap", step=epoch - ) + if log_to_comet: + experiment_recorder.log_image( + latent_plot_filename, name="latent_umap", step=epoch + ) logger.info("Finished training.") torch.cuda.empty_cache() diff --git a/tractolearn/config/experiment.py b/tractolearn/config/experiment.py index dac25a8..dd62413 100644 --- a/tractolearn/config/experiment.py +++ b/tractolearn/config/experiment.py @@ -78,6 +78,7 @@ class ExperimentKeys: NUM_WORKERS = "num_workers" DISTANCE_FUNCTION = "distance_function" TO_SWAP = "to_swap" + LOG_TO_COMET = "log_to_comet" class ThresholdTestKeys: @@ -168,6 +169,10 @@ def __init__( # Copy the YAML configuration file to the experiment directory shutil.copy(config, self.experiment_dir) + @property + def log_to_comet(self): + return self.config[ExperimentKeys.LOG_TO_COMET] + def setup_experiment(self): return self.config diff --git a/tractolearn/learning/trainer_manager.py b/tractolearn/learning/trainer_manager.py index 881eb76..3141a7f 100644 --- a/tractolearn/learning/trainer_manager.py +++ b/tractolearn/learning/trainer_manager.py @@ -2,6 +2,7 @@ import itertools import logging import sys +import typing from os.path import join as pjoin from typing import Tuple @@ -42,7 +43,7 @@ def __init__( input_size: Tuple[int, int], isocenter: np.array, volume: np.array, - experiment_recorder: Experiment, + experiment_recorder: typing.Union[Experiment, None], ): self._device = device