From afa1dd991a3379911693d6e9c5ebf2323c0952d4 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 14 May 2024 11:58:34 -0400 Subject: [PATCH 1/5] Prepare new testing data --- dwi_ml/data/hdf5/hdf5_creation.py | 17 +++++++---------- dwi_ml/unit_tests/utils/expected_values.py | 2 +- scripts_python/tests/test_all_steps_l2t.py | 3 +-- scripts_python/tests/test_all_steps_tto.py | 2 +- .../tests/test_create_hdf5_dataset.py | 2 +- 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index b1f2d170..74715f1e 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -59,8 +59,8 @@ def format_filelist(filenames, enforce_presence, folder=None) -> List[str]: return new_files -def _load_and_verify_file(filename: str, subj_input_path, group_name: str, - group_affine, group_res): +def _load_and_verify_file(filename: str, group_name: str, group_affine, + group_res): """ Loads a 3D or 4D nifti file. If it is a 3D dataset, adds a dimension to make it 4D. Then checks that it is compatible with a given group based on @@ -70,8 +70,6 @@ def _load_and_verify_file(filename: str, subj_input_path, group_name: str, ------ filename: str File's name. Must be .nii or .nii.gz. - subj_input_path: Path - Path where to load the nifti file from. group_name: str Name of the group with which 'filename' file must be compatible. group_affine: np.array @@ -79,9 +77,7 @@ def _load_and_verify_file(filename: str, subj_input_path, group_name: str, group_res: np.array The loaded file's resolution must be equal (or very close) to this res. """ - data_file = subj_input_path.joinpath(filename) - - if not data_file.is_file(): + if not os.path.isfile(filename): logging.debug(" Skipping file {} because it was not " "found in this subject's folder".format(filename)) # Note: if args.enforce_files_presence was set to true, this @@ -89,7 +85,7 @@ def _load_and_verify_file(filename: str, subj_input_path, group_name: str, # create_hdf5_dataset.py. return None - data, affine, res, _ = load_file_to4d(data_file) + data, affine, res, _ = load_file_to4d(filename) if not np.allclose(affine, group_affine, atol=1e-5): # Note. When keeping default options on tolerance, we have run @@ -483,6 +479,7 @@ def _process_one_volume_group(self, group: str, subj_id: str, else: std_mask = np.logical_or(sub_mask_data, std_mask) + # Get the files and add the subject_dir as prefix. file_list = self.groups_config[group]['files'] file_list = format_filelist(file_list, self.enforce_files_presence, folder=subj_input_dir) @@ -505,8 +502,8 @@ def _process_one_volume_group(self, group: str, subj_id: str, for file_name in file_list[1:]: logging.info(" - Processing file {}" .format(os.path.basename(file_name))) - data = _load_and_verify_file(file_name, subj_input_dir, group, - group_affine, group_res) + data = _load_and_verify_file(file_name, group, group_affine, + group_res) if std_option == 'per_file': logging.info(' - Standardizing') diff --git a/dwi_ml/unit_tests/utils/expected_values.py b/dwi_ml/unit_tests/utils/expected_values.py index 98448847..512ca977 100644 --- a/dwi_ml/unit_tests/utils/expected_values.py +++ b/dwi_ml/unit_tests/utils/expected_values.py @@ -7,7 +7,7 @@ │ ├── code.sh │ ├── config_file.json │ ├── empty_subjs_list.txt -│ └── training_subjs.txt +│ └── subjs_list.txt ├── dwi_ml_ready │ └── subjX │ ├── anat diff --git a/scripts_python/tests/test_all_steps_l2t.py b/scripts_python/tests/test_all_steps_l2t.py index 41eac458..67f38816 100644 --- a/scripts_python/tests/test_all_steps_l2t.py +++ b/scripts_python/tests/test_all_steps_l2t.py @@ -141,8 +141,7 @@ def test_visu(script_runner, experiments_path): assert ret.success -def future_test_training_with_generation_validation(script_runner, experiments_path): - # toDo NOT DOING ANYTHING NOW BECAUSE HDF5 DOES NOT CONTAIN A VALIDATION SUBJ! +def test_training_with_generation_validation(script_runner, experiments_path): if torch.cuda.is_available(): option = '--use_gpu' diff --git a/scripts_python/tests/test_all_steps_tto.py b/scripts_python/tests/test_all_steps_tto.py index 49fcd6d4..1b8ddb5b 100644 --- a/scripts_python/tests/test_all_steps_tto.py +++ b/scripts_python/tests/test_all_steps_tto.py @@ -53,7 +53,7 @@ def test_execution(script_runner, experiments_path): ret = script_runner.run('tt_train_model.py', experiments_path, experiment_name, hdf5_file, input_group_name, streamline_group_name, - '--model', 'TTO', + '--model', 'TTO', '--dg_key', 'gaussian', '--max_epochs', '1', '--batch_size_training', '5', '--batch_size_units', 'nb_streamlines', '--max_batches_per_epoch_training', '2', diff --git a/scripts_python/tests/test_create_hdf5_dataset.py b/scripts_python/tests/test_create_hdf5_dataset.py index 9ef4c5bb..fd2ade91 100644 --- a/scripts_python/tests/test_create_hdf5_dataset.py +++ b/scripts_python/tests/test_create_hdf5_dataset.py @@ -20,7 +20,7 @@ def test_execution_bst(script_runner): dwi_ml_folder = os.path.join(data_dir, 'dwi_ml_ready') config_file = os.path.join(data_dir, 'code_creation/config_file.json') - training_subjs = os.path.join(data_dir, 'code_creation/training_subjs.txt') + training_subjs = os.path.join(data_dir, 'code_creation/subjs_list.txt') validation_subjs = os.path.join(data_dir, 'code_creation/empty_subjs_list.txt') testing_subjs = validation_subjs From d415f8d6f28de9c01dd906c21af758c86fb943ea Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 14 May 2024 12:00:37 -0400 Subject: [PATCH 2/5] Fix wrong return in validation from recent PR --- dwi_ml/training/trainers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index ad0d4947..c8732275 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -918,7 +918,7 @@ def validate_one_batch(self, data, epoch): """ Computes the loss(es) for the current batch and updates monitors. """ - mean_local_loss, n, _ = self.run_one_batch(data) + mean_local_loss, n = self.run_one_batch(data) self.valid_local_loss_monitor.update(mean_local_loss.cpu().item(), weight=n) From 2ec62bb9fcd066880873c26086f2f027974cc828 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 14 May 2024 13:46:06 -0400 Subject: [PATCH 3/5] Fix unit test now that we have a validation set --- .../data/dataset/multi_subject_containers.py | 3 +- .../data/dataset/single_subject_containers.py | 3 +- dwi_ml/models/main_models.py | 4 +- dwi_ml/training/batch_samplers.py | 43 +++++++++-------- dwi_ml/training/trainers.py | 48 +++++++++++-------- .../unit_tests/test_train_trainerOneInput.py | 11 +++-- .../utils/data_and_models_for_tests.py | 6 +-- 7 files changed, 66 insertions(+), 52 deletions(-) diff --git a/dwi_ml/data/dataset/multi_subject_containers.py b/dwi_ml/data/dataset/multi_subject_containers.py index 10b0438e..0c6bded5 100644 --- a/dwi_ml/data/dataset/multi_subject_containers.py +++ b/dwi_ml/data/dataset/multi_subject_containers.py @@ -281,7 +281,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None): hdf_handle, subj_id, ref_group_info) # Add subject to the list - logger.debug(" Adding it to the list of subjects.") subj_idx = self.subjs_data_list.add_subject(subj_data) # Arrange streamlines @@ -290,7 +289,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None): if subj_data.is_lazy: subj_data.add_handle(hdf_handle) - logger.debug(" Counting streamlines") for group in range(len(self.streamline_groups)): subj_sft_data = subj_data.sft_data_list[group] n_streamlines = len(subj_sft_data) @@ -302,6 +300,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None): subj_data.hdf_handle = None # Arrange final data properties: Concatenate all subjects + logging.debug("All subjects added. Final verifications.") self.streamline_lengths_mm = \ [np.concatenate(lengths_mm[group], axis=0) for group in range(len(self.streamline_groups))] diff --git a/dwi_ml/data/dataset/single_subject_containers.py b/dwi_ml/data/dataset/single_subject_containers.py index 73afbca6..fbd9b6cc 100644 --- a/dwi_ml/data/dataset/single_subject_containers.py +++ b/dwi_ml/data/dataset/single_subject_containers.py @@ -110,7 +110,8 @@ def init_single_subject_from_hdf( subject_mri_data_list.append(subject_mri_group_data) for group in streamline_groups: - logger.debug(" Loading subject's streamlines") + logger.debug(" Loading streamlines group '{}'" + .format(group)) sft_data = SFTData.init_sft_data_from_hdf_info( hdf_file[subject_id][group]) subject_sft_data_list.append(sft_data) diff --git a/dwi_ml/models/main_models.py b/dwi_ml/models/main_models.py index 23c1f717..596f9e0f 100644 --- a/dwi_ml/models/main_models.py +++ b/dwi_ml/models/main_models.py @@ -99,7 +99,7 @@ def add_args_main_model(p): add_resample_or_compress_arg(p) def set_context(self, context): - assert context in ['training', 'tracking'] + assert context in ['training', 'validation'] self._context = context @property @@ -730,7 +730,7 @@ def __init__(self, dg_key: str = 'cosine-regression', .format(self.positional_encoding_key)) def set_context(self, context): - assert context in ['training', 'tracking', 'visu'] + assert context in ['training', 'validation', 'tracking', 'visu'] self._context = context def instantiate_direction_getter(self, dg_input_size): diff --git a/dwi_ml/training/batch_samplers.py b/dwi_ml/training/batch_samplers.py index 81622af8..56f1bf51 100644 --- a/dwi_ml/training/batch_samplers.py +++ b/dwi_ml/training/batch_samplers.py @@ -43,7 +43,7 @@ def __init__(self, dataset: MultiSubjectDataset, batch_size_validation: Union[int, None], batch_size_units: str, nb_streamlines_per_chunk: int = None, rng: int = None, nb_subjects_per_batch: int = None, - cycles: int = None, log_level=logging.root.level): + cycles: int = None, log_level=logger.root.level): """ Parameters ---------- @@ -56,7 +56,8 @@ def __init__(self, dataset: MultiSubjectDataset, Batch size. Can be defined in number of streamlines or in total length_mm (specified through batch_size_units). batch_size_validation: Union[int, None] - Idem + Idem. If None, it is expected that there will not be a validation + set. batch_size_units: str 'nb_streamlines' or 'length_mm' (which should hopefully be correlated to the number of input data points). @@ -80,9 +81,14 @@ def __init__(self, dataset: MultiSubjectDataset, """ super().__init__(None) # This does nothing but python likes it. + # Batch sampler's logging level can be changed separately from main + # scripts. + logger.setLevel(log_level) + self.logger = logger + # Checking that batch_size is correct for batch_size in [batch_size_training, batch_size_validation]: - if batch_size and batch_size <= 0: + if batch_size is not None and batch_size <= 0: raise ValueError("batch_size (i.e. number of total timesteps " "in the batch) should be a positive int " "value, but got batch_size={}" @@ -90,20 +96,21 @@ def __init__(self, dataset: MultiSubjectDataset, if batch_size_units == 'nb_streamlines': if nb_streamlines_per_chunk is not None: - logging.warning("With a max_batch_size computed in terms of " - "number of streamlines, the chunk size is not " - "used. Ignored") + logger.warning("With a max_batch_size computed in terms of " + "number of streamlines, the chunk size is not " + "used. Ignored.") + nb_streamlines_per_chunk = None elif batch_size_units == 'length_mm': if nb_streamlines_per_chunk is None: - logging.debug("Chunk size was not set. Using default {}" - .format(DEFAULT_CHUNK_SIZE)) + logger.debug("Chunk size was not set. Using default {}" + .format(DEFAULT_CHUNK_SIZE)) else: raise ValueError("batch_size_unit should either be " "'nb_streamlines' or 'length_mm', got {}" .format(batch_size_units)) # Checking that n_volumes was given if cycles was given - if cycles and not nb_subjects_per_batch: + if cycles and nb_subjects_per_batch is None: raise ValueError("If `cycles` is defined, " "`nb_subjects_per_batch` should be defined. Got: " "nb_subjects_per_batch={}, cycles={}" @@ -127,11 +134,6 @@ def __init__(self, dataset: MultiSubjectDataset, self.rng = rng self.np_rng = np.random.RandomState(self.rng) - # Batch sampler's logging level can be changed separately from main - # scripts. - self.logger = logger - self.logger.setLevel(log_level) - # For later use, context self.context = None self.context_subset = None @@ -164,8 +166,8 @@ def init_from_checkpoint(cls, dataset, checkpoint_state: dict, else: batch_sampler = cls(dataset=dataset, **checkpoint_state) - logging.info("Batch sampler's user-defined parameters: " + - format_dict_to_str(batch_sampler.params_for_checkpoint)) + logger.info("Batch sampler's user-defined parameters: " + + format_dict_to_str(batch_sampler.params_for_checkpoint)) return batch_sampler @@ -221,9 +223,11 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]: # This is the list of all possible streamline ids global_streamlines_ids = np.arange( - self.context_subset.total_nb_streamlines[self.streamline_group_idx]) + self.context_subset.total_nb_streamlines[ + self.streamline_group_idx]) ids_per_subjs = \ - self.context_subset.streamline_ids_per_subj[self.streamline_group_idx] + self.context_subset.streamline_ids_per_subj[ + self.streamline_group_idx] # This contains one bool per streamline: # 1 = this streamline has not been used yet. @@ -279,7 +283,8 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]: # Final subject's batch size could be smaller if no streamlines are # left for this subject. - max_batch_size_per_subj = int(self.context_batch_size / nb_subjects) + max_batch_size_per_subj = int( + self.context_batch_size / nb_subjects) if self.batch_size_units == 'nb_streamlines': chunk_size = max_batch_size_per_subj else: diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index c8732275..4355ad0f 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -178,27 +178,6 @@ def __init__(self, self.batch_loader = batch_loader self.model = model - # Create DataLoaders from the BatchSamplers - # * Before usage, context must be set for the batch sampler and the - # batch loader, to use appropriate parameters. - # * Pin memory if interpolation is done by workers; this means that - # dataloader output is on GPU, ready to be fed to the model. - # Otherwise, dataloader output is kept on CPU, and the main thread - # sends volumes and coords on GPU for interpolation. - logger.debug("- Instantiating dataloaders...") - self.train_dataloader = DataLoader( - dataset=self.batch_sampler.dataset.training_set, - batch_sampler=self.batch_sampler, - num_workers=self.nb_cpu_processes, - collate_fn=self.batch_loader.load_batch_streamlines, - pin_memory=self.use_gpu) - self.valid_dataloader = DataLoader( - dataset=self.batch_sampler.dataset.validation_set, - batch_sampler=self.batch_sampler, - num_workers=self.nb_cpu_processes, - collate_fn=self.batch_loader.load_batch_streamlines, - pin_memory=self.use_gpu) - # ---------------------- # Checks # ---------------------- @@ -233,11 +212,38 @@ def __init__(self, "Best practice is to have a validation set.") else: self.use_validation = True + if max_batches_per_epoch_validation is None: + self.max_batches_per_epoch_validation = 1000 if optimizer not in ['SGD', 'Adam', 'RAdam']: raise ValueError("Optimizer choice {} not recognized." .format(optimizer)) + # ---------------- + # Create DataLoaders from the BatchSamplers + # ---------------- + # * Before usage, context must be set for the batch sampler and the + # batch loader, to use appropriate parameters. + # * Pin memory if interpolation is done by workers; this means that + # dataloader output is on GPU, ready to be fed to the model. + # Otherwise, dataloader output is kept on CPU, and the main thread + # sends volumes and coords on GPU for interpolation. + logger.debug("- Instantiating dataloaders...") + self.train_dataloader = DataLoader( + dataset=self.batch_sampler.dataset.training_set, + batch_sampler=self.batch_sampler, + num_workers=self.nb_cpu_processes, + collate_fn=self.batch_loader.load_batch_streamlines, + pin_memory=self.use_gpu) + self.valid_dataloader = None + if self.use_validation: + self.valid_dataloader = DataLoader( + dataset=self.batch_sampler.dataset.validation_set, + batch_sampler=self.batch_sampler, + num_workers=self.nb_cpu_processes, + collate_fn=self.batch_loader.load_batch_streamlines, + pin_memory=self.use_gpu) + # ---------------------- # Evolving values. They will need to be updated if initialized from # checkpoint. diff --git a/dwi_ml/unit_tests/test_train_trainerOneInput.py b/dwi_ml/unit_tests/test_train_trainerOneInput.py index 0a11a105..c8554a5c 100644 --- a/dwi_ml/unit_tests/test_train_trainerOneInput.py +++ b/dwi_ml/unit_tests/test_train_trainerOneInput.py @@ -24,12 +24,16 @@ def experiments_path(tmp_path_factory): def test_trainer_and_models(experiments_path): - data_dir = fetch_testing_data() + # This unit test uses the test data. + data_dir = fetch_testing_data() hdf5_filename = os.path.join(data_dir, 'hdf5_file.hdf5') # Initializing dataset + logging.info("Initializing dataset") dataset = MultiSubjectDataset(hdf5_filename, lazy=False) + + logging.info("Loading data (non-lazy)") dataset.load_data() # Initializing model 1 + associated batch sampler. @@ -57,7 +61,6 @@ def test_trainer_and_models(experiments_path): def _create_sampler_and_loader(dataset, model): # Initialize batch sampler - logging.debug('\nInitializing sampler...') batch_sampler = create_test_batch_sampler( dataset, batch_size=batch_size, batch_size_units='nb_streamlines', log_level=logging.WARNING) @@ -77,7 +80,7 @@ def _create_trainer(batch_sampler, batch_loader, model, experiments_path, model=model, experiments_path=str(experiments_path), experiment_name=experiment_name, log_level='DEBUG', max_batches_per_epoch_training=2, - max_batches_per_epoch_validation=None, max_epochs=2, patience=None, + max_batches_per_epoch_validation=2, max_epochs=2, patience=None, use_gpu=False) # Note. toDo Test fails with nb_cpu_processes=1. Why?? @@ -86,5 +89,5 @@ def _create_trainer(batch_sampler, batch_loader, model, experiments_path, if __name__ == '__main__': tmp_dir = tempfile.TemporaryDirectory() - logging.getLogger().setLevel('INFO') + logging.getLogger().setLevel('DEBUG') test_trainer_and_models(tmp_dir.name) diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index 50e0b6ea..bf80b9ce 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -179,10 +179,10 @@ def create_test_batch_sampler( test_default_cycles = 1 test_default_rng = 1234 - logging.debug(' Initializing batch sampler...') + logging.debug('Initializing batch sampler...') batch_sampler = DWIMLBatchIDSampler( subset, TEST_EXPECTED_STREAMLINE_GROUPS[0], - batch_size_training=batch_size, batch_size_validation=0, + batch_size_training=batch_size, batch_size_validation=batch_size, batch_size_units=batch_size_units, nb_streamlines_per_chunk=chunk_size, rng=test_default_rng, @@ -195,7 +195,7 @@ def create_test_batch_sampler( def create_batch_loader( subset, model, noise_size=0., split_ratio=0., reverse_ratio=0., log_level=logging.DEBUG): - logging.debug(' Initializing batch loader...') + logging.debug('Initializing batch loader...') batch_loader = DWIMLBatchLoaderOneInput( dataset=subset, input_group_name=TEST_EXPECTED_VOLUME_GROUPS[0], streamline_group_name=TEST_EXPECTED_STREAMLINE_GROUPS[0], rng=1234, From ffc4f3a0855ff776cfc22f98ae88a44a194446da Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 14 May 2024 14:19:11 -0400 Subject: [PATCH 4/5] Fix tests for L2T, GV phase, from adding validation set --- dwi_ml/training/batch_loaders.py | 35 +++++++++++++++ .../training/projects/learn2track_trainer.py | 2 +- .../training/projects/transformer_trainer.py | 2 +- .../trainer.py => trainers_withGV.py} | 40 +++++++++-------- dwi_ml/training/utils/batch_loaders.py | 5 +-- dwi_ml/training/with_generation/__init__.py | 0 .../training/with_generation/batch_loader.py | 43 ------------------- .../l2t_resume_training_from_checkpoint.py | 5 +-- scripts_python/l2t_update_deprecated_exp.py | 5 +-- scripts_python/tests/test_all_steps_l2t.py | 1 + .../tt_resume_training_from_checkpoint.py | 5 +-- 11 files changed, 69 insertions(+), 74 deletions(-) rename dwi_ml/training/{with_generation/trainer.py => trainers_withGV.py} (94%) delete mode 100644 dwi_ml/training/with_generation/__init__.py delete mode 100644 dwi_ml/training/with_generation/batch_loader.py diff --git a/dwi_ml/training/batch_loaders.py b/dwi_ml/training/batch_loaders.py index cd9c1f46..2152468b 100644 --- a/dwi_ml/training/batch_loaders.py +++ b/dwi_ml/training/batch_loaders.py @@ -119,6 +119,9 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract, # Find idx of streamline group self.streamline_group_idx = self.dataset.streamline_groups.index( self.streamline_group_name) + self.data_contains_connectivity = \ + self.dataset.streamlines_contain_connectivity[ + self.streamline_group_idx] # Set random numbers self.rng = rng @@ -314,6 +317,38 @@ def load_batch_streamlines( return batch_streamlines, final_s_ids_per_subj + def load_batch_connectivity_matrices( + self, streamline_ids_per_subj: Dict[int, slice]): + if not self.data_contains_connectivity: + raise ValueError("No connectivity matrix in this dataset.") + + # The batch's streamline ids will change throughout processing because + # of data augmentation, so we need to do it subject by subject to + # keep track of the streamline ids. These final ids will correspond to + # the loaded, processed streamlines, not to the ids in the hdf5 file. + subjs = list(streamline_ids_per_subj.keys()) + nb_subjs = len(subjs) + matrices = [None] * nb_subjs + volume_sizes = [None] * nb_subjs + connectivity_nb_blocs = [None] * nb_subjs + connectivity_labels = [None] * nb_subjs + for i, subj in enumerate(subjs): + # No cache for the sft data. Accessing it directly. + # Note: If this is used through the dataloader, multiprocessing + # is used. Each process will open a handle. + subj_data = \ + self.context_subset.subjs_data_list.get_subj_with_handle(subj) + subj_sft_data = subj_data.sft_data_list[self.streamline_group_idx] + + # We could access it only at required index, maybe. Loading the + # whole matrix here. + (matrices[i], volume_sizes[i], + connectivity_nb_blocs[i], connectivity_labels[i]) = \ + subj_sft_data.get_connectivity_matrix_and_info() + + return (matrices, volume_sizes, + connectivity_nb_blocs, connectivity_labels) + class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader): """ diff --git a/dwi_ml/training/projects/learn2track_trainer.py b/dwi_ml/training/projects/learn2track_trainer.py index 4f10474e..ef31ccee 100644 --- a/dwi_ml/training/projects/learn2track_trainer.py +++ b/dwi_ml/training/projects/learn2track_trainer.py @@ -9,7 +9,7 @@ from dwi_ml.models.projects.learn2track_model import Learn2TrackModel from dwi_ml.tracking.io_utils import prepare_tracking_mask from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.training.with_generation.trainer import \ +from dwi_ml.training.trainers_withGV import \ DWIMLTrainerForTrackingOneInput logger = logging.getLogger('trainer_logger') diff --git a/dwi_ml/training/projects/transformer_trainer.py b/dwi_ml/training/projects/transformer_trainer.py index 62f40ed4..295b5feb 100644 --- a/dwi_ml/training/projects/transformer_trainer.py +++ b/dwi_ml/training/projects/transformer_trainer.py @@ -9,7 +9,7 @@ from dwi_ml.tracking.io_utils import prepare_tracking_mask from dwi_ml.tracking.propagation import propagate_multiple_lines -from dwi_ml.training.with_generation.trainer import \ +from dwi_ml.training.trainers_withGV import \ DWIMLTrainerForTrackingOneInput diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/trainers_withGV.py similarity index 94% rename from dwi_ml/training/with_generation/trainer.py rename to dwi_ml/training/trainers_withGV.py index bac638db..a0aebfcb 100644 --- a/dwi_ml/training/with_generation/trainer.py +++ b/dwi_ml/training/trainers_withGV.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Adds a tracking step to verify the generation process. Metrics on the -streamlines are: +Adds a generation-validation phase: a tracking step. Metrics on the streamlines +are: - Very good / acceptable / very far IS threshold: Percentage of streamlines ending inside a radius of 15 / 25 / 40 voxels of @@ -47,10 +47,9 @@ from dwi_ml.models.main_models import ModelWithDirectionGetter from dwi_ml.tracking.propagation import propagate_multiple_lines from dwi_ml.tracking.io_utils import prepare_tracking_mask +from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput from dwi_ml.training.trainers import DWIMLTrainerOneInput from dwi_ml.training.utils.monitoring import BatchHistoryMonitor -from dwi_ml.training.with_generation.batch_loader import \ - DWIMLBatchLoaderWithConnectivity logger = logging.getLogger('train_logger') @@ -65,7 +64,7 @@ class DWIMLTrainerForTrackingOneInput(DWIMLTrainerOneInput): model: ModelWithDirectionGetter - batch_loader: DWIMLBatchLoaderWithConnectivity + batch_loader: DWIMLBatchLoaderOneInput def __init__(self, add_a_tracking_validation_phase: bool = False, tracking_phase_frequency: int = 1, @@ -105,6 +104,12 @@ def __init__(self, add_a_tracking_validation_phase: bool = False, self.compute_connectivity = self.batch_loader.data_contains_connectivity + # -------- Checks + if add_a_tracking_validation_phase and \ + tracking_phase_mask_group is None: + raise NotImplementedError("Not ready to run without a tracking " + "mask.") + # -------- Monitors # At training time: only the one metric used for training. # At validation time: A lot of exploratory metrics monitors. @@ -177,8 +182,7 @@ def validate_one_batch(self, data, epoch): (gen_n, mean_final_dist, mean_clipped_final_dist, percent_IS_very_good, percent_IS_acceptable, percent_IS_very_far, diverging_pnt, connectivity) = \ - self.validation_generation_one_batch( - data, compute_all_scores=True) + self.gv_phase_one_batch(data, compute_all_scores=True) self.tracking_very_good_IS_monitor.update( percent_IS_very_good, weight=gen_n) @@ -194,8 +198,9 @@ def validate_one_batch(self, data, epoch): self.tracking_valid_diverg_monitor.update( diverging_pnt, weight=gen_n) - self.tracking_connectivity_score_monitor.update( - connectivity, weight=gen_n) + if self.compute_connectivity: + self.tracking_connectivity_score_monitor.update( + connectivity, weight=gen_n) elif len(self.tracking_mean_final_distance_monitor.average_per_epoch) == 0: logger.info("Skipping tracking-like generation validation " "from batch. No values yet: adding fake initial " @@ -216,7 +221,8 @@ def validate_one_batch(self, data, epoch): self.tracking_clipped_final_distance_monitor.update( ACCEPTABLE_THRESHOLD) - self.tracking_connectivity_score_monitor.update(1) + if self.compute_connectivity: + self.tracking_connectivity_score_monitor.update(1) else: logger.info("Skipping tracking-like generation validation " "from batch. Copying previous epoch's values.") @@ -230,7 +236,7 @@ def validate_one_batch(self, data, epoch): self.tracking_connectivity_score_monitor]: monitor.update(monitor.average_per_epoch[-1]) - def validation_generation_one_batch(self, data, compute_all_scores=False): + def gv_phase_one_batch(self, data, compute_all_scores=False): """ Use tractography to generate streamlines starting from the "true" seeds and first few segments. Expected results are the batch's @@ -304,12 +310,12 @@ def validation_generation_one_batch(self, data, compute_all_scores=False): total_point += abs(100 - div_point) diverging_point = total_point / len(lines) - invalid_ratio_severe = invalid_ratio_severe.cpu().numpy().astype(np.float32) - invalid_ratio_acceptable = invalid_ratio_acceptable.cpu().numpy().astype(np.float32) - invalid_ratio_loose = invalid_ratio_loose.cpu().numpy().astype(np.float32) - final_dist = final_dist.cpu().numpy().astype(np.float32) - final_dist_clipped = final_dist_clipped.cpu().numpy().astype(np.float32) - diverging_point = np.asarray(diverging_point, dtype=np.float32) + invalid_ratio_severe = invalid_ratio_severe.item() + invalid_ratio_acceptable = invalid_ratio_acceptable.item() + invalid_ratio_loose = invalid_ratio_loose.item() + final_dist = final_dist.item() + final_dist_clipped = final_dist_clipped.item() + return (len(lines), final_dist, final_dist_clipped, invalid_ratio_severe, invalid_ratio_acceptable, invalid_ratio_loose, diverging_point, diff --git a/dwi_ml/training/utils/batch_loaders.py b/dwi_ml/training/utils/batch_loaders.py index ff8b4166..24df17a1 100644 --- a/dwi_ml/training/utils/batch_loaders.py +++ b/dwi_ml/training/utils/batch_loaders.py @@ -4,8 +4,7 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.training.with_generation.batch_loader import \ - DWIMLBatchLoaderWithConnectivity +from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput def add_args_batch_loader(p: argparse.ArgumentParser): @@ -44,7 +43,7 @@ def add_args_batch_loader(p: argparse.ArgumentParser): def prepare_batch_loader(dataset, model, args, sub_loggers_level): # Preparing the batch loader. with Timer("\nPreparing batch loader...", newline=True, color='pink'): - batch_loader = DWIMLBatchLoaderWithConnectivity( + batch_loader = DWIMLBatchLoaderOneInput( dataset=dataset, model=model, input_group_name=args.input_group_name, streamline_group_name=args.streamline_group_name, diff --git a/dwi_ml/training/with_generation/__init__.py b/dwi_ml/training/with_generation/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/dwi_ml/training/with_generation/batch_loader.py b/dwi_ml/training/with_generation/batch_loader.py deleted file mode 100644 index 77db3f8e..00000000 --- a/dwi_ml/training/with_generation/batch_loader.py +++ /dev/null @@ -1,43 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Dict - -from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput - - -class DWIMLBatchLoaderWithConnectivity(DWIMLBatchLoaderOneInput): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.data_contains_connectivity = \ - self.dataset.streamlines_contain_connectivity[self.streamline_group_idx] - - def load_batch_connectivity_matrices( - self, streamline_ids_per_subj: Dict[int, slice]): - if not self.data_contains_connectivity: - raise ValueError("No connectivity matrix in this dataset.") - - # The batch's streamline ids will change throughout processing because - # of data augmentation, so we need to do it subject by subject to - # keep track of the streamline ids. These final ids will correspond to - # the loaded, processed streamlines, not to the ids in the hdf5 file. - subjs = list(streamline_ids_per_subj.keys()) - nb_subjs = len(subjs) - matrices = [None] * nb_subjs - volume_sizes = [None] * nb_subjs - connectivity_nb_blocs = [None] * nb_subjs - connectivity_labels = [None] * nb_subjs - for i, subj in enumerate(subjs): - # No cache for the sft data. Accessing it directly. - # Note: If this is used through the dataloader, multiprocessing - # is used. Each process will open a handle. - subj_data = \ - self.context_subset.subjs_data_list.get_subj_with_handle(subj) - subj_sft_data = subj_data.sft_data_list[self.streamline_group_idx] - - # We could access it only at required index, maybe. Loading the - # whole matrix here. - (matrices[i], volume_sizes[i], - connectivity_nb_blocs[i], connectivity_labels[i]) = \ - subj_sft_data.get_connectivity_matrix_and_info() - - return (matrices, volume_sizes, - connectivity_nb_blocs, connectivity_labels) diff --git a/scripts_python/l2t_resume_training_from_checkpoint.py b/scripts_python/l2t_resume_training_from_checkpoint.py index 754d6a26..21573424 100644 --- a/scripts_python/l2t_resume_training_from_checkpoint.py +++ b/scripts_python/l2t_resume_training_from_checkpoint.py @@ -13,12 +13,11 @@ from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_verbose_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer from dwi_ml.training.utils.experiment import add_args_resuming_experiment from dwi_ml.training.utils.trainer import run_experiment -from dwi_ml.training.with_generation.batch_loader import \ - DWIMLBatchLoaderWithConnectivity def prepare_arg_parser(): @@ -60,7 +59,7 @@ def init_from_checkpoint(args, checkpoint_path): dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level) # Prepare batch loader - batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint( + batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint( dataset, model, checkpoint_state['batch_loader_params'], sub_loggers_level) diff --git a/scripts_python/l2t_update_deprecated_exp.py b/scripts_python/l2t_update_deprecated_exp.py index ee6288eb..1f3304dd 100644 --- a/scripts_python/l2t_update_deprecated_exp.py +++ b/scripts_python/l2t_update_deprecated_exp.py @@ -18,10 +18,9 @@ from dwi_ml.experiment_utils.prints import format_dict_to_str from dwi_ml.io_utils import add_verbose_arg from dwi_ml.models.projects.learn2track_model import Learn2TrackModel +from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer -from dwi_ml.training.with_generation.batch_loader import \ - DWIMLBatchLoaderWithConnectivity def prepare_arg_parser(): @@ -217,7 +216,7 @@ def fix_checkpoint(args, model): # Init stuff will succeed if ok. batch_sampler = DWIMLBatchIDSampler.init_from_checkpoint( dataset, checkpoint_state['batch_sampler_params']) - batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint( + batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint( dataset, model, checkpoint_state['batch_loader_params']) experiments_path, experiment_name = os.path.split(args.out_experiment) trainer = Learn2TrackTrainer.init_from_checkpoint( diff --git a/scripts_python/tests/test_all_steps_l2t.py b/scripts_python/tests/test_all_steps_l2t.py index 67f38816..0b3210dd 100644 --- a/scripts_python/tests/test_all_steps_l2t.py +++ b/scripts_python/tests/test_all_steps_l2t.py @@ -160,5 +160,6 @@ def test_training_with_generation_validation(script_runner, experiments_path): '--max_batches_per_epoch_validation', '1', '-v', 'INFO', '--step_size', '0.5', '--add_a_tracking_validation_phase', + '--tracking_mask', 'wm_mask', '--tracking_phase_frequency', '1', option) assert ret.success diff --git a/scripts_python/tt_resume_training_from_checkpoint.py b/scripts_python/tt_resume_training_from_checkpoint.py index b78e5926..191d0877 100644 --- a/scripts_python/tt_resume_training_from_checkpoint.py +++ b/scripts_python/tt_resume_training_from_checkpoint.py @@ -13,12 +13,11 @@ from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_verbose_arg, verify_which_model_in_path from dwi_ml.models.projects.transformer_models import find_transformer_class +from dwi_ml.training.batch_loaders import DWIMLBatchLoaderOneInput from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler from dwi_ml.training.projects.transformer_trainer import TransformerTrainer from dwi_ml.training.utils.experiment import add_args_resuming_experiment from dwi_ml.training.utils.trainer import run_experiment -from dwi_ml.training.with_generation.batch_loader import \ - DWIMLBatchLoaderWithConnectivity def prepare_arg_parser(): @@ -62,7 +61,7 @@ def init_from_checkpoint(args, checkpoint_path): dataset, checkpoint_state['batch_sampler_params'], sub_loggers_level) # Prepare batch loader - batch_loader = DWIMLBatchLoaderWithConnectivity.init_from_checkpoint( + batch_loader = DWIMLBatchLoaderOneInput.init_from_checkpoint( dataset, model, checkpoint_state['batch_loader_params'], sub_loggers_level) From e0a074fd1ae51469ea2d225067d3c9e679578dc4 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 15 May 2024 14:19:54 -0400 Subject: [PATCH 5/5] Remove GPU option in test --- dwi_ml/unit_tests/utils/data_and_models_for_tests.py | 2 +- scripts_python/tests/test_all_steps_l2t.py | 7 +------ 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index bf80b9ce..f1bcf6c0 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -31,7 +31,7 @@ def fetch_testing_data(): # Access to the file dwi_ml.zip: # https://drive.google.com/uc?id=1beRWAorhaINCncttgwqVAP2rNOfx842Q name_as_dict = { - 'data_for_tests_dwi_ml.zip': "da6c94fbef7ac13029acdb8b94325096"} + 'data_for_tests_dwi_ml.zip': "59c9275d2fe83b7e2d6154877ab32b8b"} fetch_data(name_as_dict) return testing_data_dir diff --git a/scripts_python/tests/test_all_steps_l2t.py b/scripts_python/tests/test_all_steps_l2t.py index 0b3210dd..fc2d9311 100644 --- a/scripts_python/tests/test_all_steps_l2t.py +++ b/scripts_python/tests/test_all_steps_l2t.py @@ -143,11 +143,6 @@ def test_visu(script_runner, experiments_path): def test_training_with_generation_validation(script_runner, experiments_path): - if torch.cuda.is_available(): - option = '--use_gpu' - else: - option = '' - logging.info("************ TESTING TRAINING WITH GENERATION ************") experiment_name = 'test2' ret = script_runner.run('l2t_train_model.py', @@ -161,5 +156,5 @@ def test_training_with_generation_validation(script_runner, experiments_path): '-v', 'INFO', '--step_size', '0.5', '--add_a_tracking_validation_phase', '--tracking_mask', 'wm_mask', - '--tracking_phase_frequency', '1', option) + '--tracking_phase_frequency', '1') assert ret.success