From 914ebaa98b46971951f957fa2a39fc3fae242481 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 15 Jan 2024 14:33:45 -0500 Subject: [PATCH 01/23] Update hdf5 doc --- dwi_ml/data/hdf5/hdf5_creation.py | 24 +++--- dwi_ml/data/hdf5/utils.py | 27 ++++--- dwi_ml/io_utils.py | 3 +- scripts_python/dwiml_create_hdf5_dataset.py | 21 ++--- source/2_A_creating_the_hdf5.rst | 80 ++++++++----------- ...rst => 2_B_advanced_hdf5_organization.rst} | 17 ++-- source/2_B_preprocessing.rst | 69 ---------------- 7 files changed, 83 insertions(+), 158 deletions(-) rename source/{2_C_advanced_hdf5_organization.rst => 2_B_advanced_hdf5_organization.rst} (63%) delete mode 100644 source/2_B_preprocessing.rst diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index e46e534b..c2e56f3d 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -569,7 +569,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, streamlines_group.create_dataset( 'connectivity_matrix', data=connectivity_matrix) if conn_info[0] == 'from_label': - streamlines_group.attrs['connectivity_labels_volume'] = \ + streamlines_group.attrs['connectivity_label_volume'] = \ conn_info[1] else: streamlines_group.attrs['connectivity_nb_blocs'] = \ @@ -692,30 +692,30 @@ def _process_one_streamline_group( conn_info = None if 'connectivity_matrix' in self.groups_config[group]: logging.info(" Now preparing connectivity matrix") - if not ("connectivty_nb_blocs" in self.groups_config[group] or - "connectivty_labels" in self.groups_config[group]): + if not ("connectivity_nb_blocs" in self.groups_config[group] or + "connectivity_labels" in self.groups_config[group]): raise ValueError( "The config file must provide either the " - "connectivty_nb_blocs or the connectivty_labels information " + "connectivity_nb_blocs or the connectivity_labels option " "associated with the streamline group '{}'" .format(group)) - elif ("connectivty_nb_blocs" in self.groups_config[group] and - "connectivty_labels" in self.groups_config[group]): + elif ("connectivity_nb_blocs" in self.groups_config[group] and + "connectivity_labels" in self.groups_config[group]): raise ValueError( "The config file must only provide ONE of the " - "connectivty_nb_blocs or the connectivty_labels information " + "connectivity_nb_blocs or the connectivity_labels option " "associated with the streamline group '{}'" .format(group)) - elif "connectivty_nb_blocs" in self.groups_config[group]: + elif "connectivity_nb_blocs" in self.groups_config[group]: nb_blocs = format_nb_blocs_connectivity( - self.groups_config[group]['connectivty_nb_blocs']) + self.groups_config[group]['connectivity_nb_blocs']) conn_info = ['from_blocs', nb_blocs] else: - labels = self.groups_config[group]['connectivty_labels'] - if labels not in self.volume_groups: + labels_group = self.groups_config[group]['connectivity_labels'] + if labels_group not in self.volume_groups: raise ValueError("connectivity_labels_volume must be " "an existing volume group.") - conn_info = ['from_labels', labels] + conn_info = ['from_labels', labels_group] conn_file = subj_dir.joinpath( self.groups_config[group]['connectivity_matrix']) diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index f0b84fe5..a8818deb 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -6,6 +6,12 @@ def format_nb_blocs_connectivity(connectivity_nb_blocs) -> List: + """ + Convert the raw option for connectivity into a list of 3 values. + Ex: [10, 20, 10] is returned without modification. + Ex: 20 becomes [20, 20, 20] + With other values (ex, a list of <>3 values), an error is raised. + """ if connectivity_nb_blocs is None: # Default/const value with argparser '+' not possible. # Setting it manually. @@ -50,14 +56,14 @@ def add_hdf5_creation_args(p: ArgumentParser): "-> https://dwi-ml.readthedocs.io/en/latest/" "creating_hdf5.html") p.add_argument('training_subjs', - help="txt file containing the list of subjects ids to use " - "for training.") + help="A txt file containing the list of subjects ids to " + "use for training. \n(Can be an empty file.)") p.add_argument('validation_subjs', - help="txt file containing the list of subjects ids to use " - "for validation.") + help="A txt file containing the list of subjects ids to use " + "for validation. \n(Can be an empty file.)") p.add_argument('testing_subjs', - help="txt file containing the list of subjects ids to use " - "for testing.") + help="A txt file containing the list of subjects ids to use " + "for testing. \n(Can be an empty file.)") # Optional arguments p.add_argument('--enforce_files_presence', type=bool, default=True, @@ -68,12 +74,9 @@ def add_hdf5_creation_args(p: ArgumentParser): p.add_argument('--save_intermediate', action="store_true", help="If set, save intermediate processing files for " "each subject inside the \nhdf5 folder, in sub-" - "folders named subjid_intermediate.") - - p.add_argument('--logging', - choices=['error', 'warning', 'info', 'debug'], - default='warning', - help="Logging level. [warning]") + "folders named subjid_intermediate.\n" + "(Final concatenated standardized volumes and \n" + "final concatenated resampled/compressed streamlines.)") def add_mri_processing_args(p: ArgumentParser): diff --git a/dwi_ml/io_utils.py b/dwi_ml/io_utils.py index 752f8d4c..72543170 100644 --- a/dwi_ml/io_utils.py +++ b/dwi_ml/io_utils.py @@ -10,7 +10,8 @@ def add_logging_arg(p): '--logging', default='WARNING', metavar='level', choices=['ERROR', 'WARNING', 'INFO', 'DEBUG'], help="Logging level. Note that, for readability, not all debug logs \n" - "are printed in DEBUG mode, only the main ones.") + "are printed in DEBUG mode, only the main ones. \n" + "Default: WARNING.") def add_resample_or_compress_arg(p: ArgumentParser): diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 3b5f16d7..611b2bb2 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -2,16 +2,15 @@ # -*- coding: utf-8 -*- """ This script combines multiple diffusion MRI volumes and their streamlines into -a single .hdf5 file. A hdf5 folder will be created alongside dwi_ml_ready. It -will contain the .hdf5 file and possibly intermediate files. - -** You should have a file dwi_ml_ready organized as described in our doc: -https://dwi-ml.readthedocs.io/en/latest/data_organization.html - -** You should have a config file as described in our doc: -https://dwi-ml.readthedocs.io/en/latest/config_file.html - +a single .hdf5 file. +-------------------------------------- +See here for the complete explanation! + - How to organize your data + - How to prepare the config file + - How to run this script. + https://dwi-ml.readthedocs.io/en/latest/2_B_preprocessing.html +-------------------------------------- ** Note: The memory is a delicate question here, but checks have been made, and it appears that the SFT's garbage collector may not be working entirely well. @@ -36,6 +35,7 @@ add_hdf5_creation_args, add_mri_processing_args, add_streamline_processing_args) from dwi_ml.experiment_utils.timer import Timer +from dwi_ml.io_utils import add_logging_arg def _initialize_intermediate_subdir(hdf5_file, save_intermediate): @@ -104,6 +104,7 @@ def _parse_args(): add_mri_processing_args(p) add_streamline_processing_args(p) add_overwrite_arg(p) + add_logging_arg(p) return p @@ -114,7 +115,7 @@ def main(): args = p.parse_args() # Initialize logger - logging.getLogger().setLevel(level=str(args.logging).upper()) + logging.getLogger().setLevel(level=args.logging) # Silencing SFT's logger if our logging is in DEBUG mode, because it # typically produces a lot of outputs! diff --git a/source/2_A_creating_the_hdf5.rst b/source/2_A_creating_the_hdf5.rst index 18755dd7..86abf217 100644 --- a/source/2_A_creating_the_hdf5.rst +++ b/source/2_A_creating_the_hdf5.rst @@ -75,7 +75,8 @@ To create the hdf5 file, you will need a config file such as below. HDF groups w "input": { "type": "volume", "files": ["dwi/dwi.nii.gz", "anat/t1.nii.gz", "dwi/*__dwi.nii.gz], --> Will get, for instance, subX__dwi.nii.gz - "standardization": "all" + "standardization": "all", + "std_mask": [masks/some_mask.nii.gz] }, "target": { "type": "streamlines", @@ -94,59 +95,59 @@ To create the hdf5 file, you will need a config file such as below. HDF groups w } } -General group attributes in the config file: - -- The group's **name** could be 'input_volume', 'target_volume', 'target_directions', or anything. +| - - We will see further how to tell your model and your batch loader the group names of interest. +General group attributes in the config file: +"""""""""""""""""""""""""""""""""""""""""""" -- The group's **"files"** must exist in every subject folder inside a repository. That is: the files must be organized correctly on your computer. See (except if option 'enforce_files_presence is set to False). +Each group key will become the group's **name** in the hdf5. It can be anything you want. We suggest you keep it significative, ex 'input_volume', 'target_volume', 'target_directions'. In other scripts (ex, l2t_train_model.py, tt_train_model.py, etc), you will often be asked for the labels given to your groups. - - There is the possibility to add a wildcard (*) that will be replaced by the subject's id while loading. Ex: anat/\*__t1.nii.gz would become anat/subjX__t1.nii.gz. - - For streamlines, there is the possibility to use 'ALL' to load all tractograms present in a folder. - - The files from each group will be concatenated in the hdf5 (either as a final volume or as a final tractogram). +Each group may have a number of parameters: -- The groups **"type"** must be recognized in dwi_ml. Currently, accepted datatype are: + - **"type"**: It must be recognized in dwi_ml. Currently, accepted datatype are: - - 'volume': for instance, a dwi, an anat, mask, t1, fa, etc. - - 'streamlines': for instance, a .trk, .tck file (anything accepted by Dipy's Stateful Tractogram). + - 'volume': for instance, a dwi, an anat, mask, t1, fa, etc. + - 'streamlines': for instance, a .trk, .tck file (any format accepted by Dipy's *Stateful Tractogram*). -Additional attribute for volume groups: + - **"files"**: The listed file(s) must exist in every subject folder inside the root repository. That is: the files must be organized correctly on your computer (except if option 'enforce_files_presence is set to False). If there are more than one files, they will be concatenated (on the 4th dimension for volumes, using the union of tractograms for streamlines). -- The groups **"standardization"** must be one of: + - There is the possibility to add a wildcard (\*) that will be replaced by the subject's id while loading. Ex: anat/\*__t1.nii.gz would become anat/subjX__t1.nii.gz. + - For streamlines, there is the possibility to use 'ALL' to load all tractograms present in a folder. - - "all", to apply standardization (normalization) to the final (concatenated) file - - "independent", to apply it independently on the last dimension of the data (ex, for a fODF, it would apply it independently on each SH). - - "per_file", to apply it independently on each file included in the group. - - "none", to skip this step (ex: for binary masks, which must stay binary). +Additional attributes for volume groups: +"""""""""""""""""""""""""""""""""""""""" -****A note about data standardization** + - **std_mask**: The name of the standardization mask. Data is standardized (normalized) during data creation: data = (data - mean_in_mask) / std_in_mask. If more than one files are given, the union (logical_or) of all masks is used (ex of usage: ["masks/wm_mask.nii.gz", "masks/gm_mask.nii.gz"] would use a mask of all the brain). - Data is standardized (normalized) during data creation: data = (data - mean) / std. + - **"standardization"**: It defined the standardization option applied to the volume group. It must be one of: - If all voxel were to be used, most of them would probably contain the background of the data, bringing the mean and std probably very close to 0. Thus, non-zero voxels only are used to compute the mean and std, or voxels inside the provided mask if any. If a mask is provided, voxels outside the mask could have been set to NaN, but the simpler choice made here was to simply modify all voxels [ data = (data - mean) / std ], even voxels outside the mask, with the mean and std of voxels in the mask. Mask name for each subject is provided using --std_mask in the script create_hdf5_dataset.py. + - "all", to apply standardization (normalization) to the final (concatenated) file. + - "independent", to apply it independently on the last dimension of the data (ex, for a fODF, it would apply it independently on each SH). + - "per_file", to apply it independently on each file included in the group. + - "none", to skip this step (default) +****A note about data standardization** -Additional attribute for streamlines groups: +If all voxel were to be used, most of them would probably contain the background of the data, bringing the mean and std probably very close to 0. Thus, non-zero voxels only are used to compute the mean and std, or voxels inside the provided mask if any. If a mask is provided, voxels outside the mask could have been set to NaN, but the simpler choice made here was to simply modify all voxels [ data = (data - mean) / std ], even voxels outside the mask, with the mean and std of voxels in the mask. Mask name is provided through the config file. It is formatted as a list: if many files are listed, the union of the binary masks will be used. - - connectivity_nb_blocs: See dwiml_compute_connectivity_matrix_from_blocs for a description. - OR +Additional attributes for streamlines groups: +""""""""""""""""""""""""""""""""""""""""""""" - - connectivity_labels: The name of one volume group. + - **connectivity_matrix**: The name of the connectivity matrix to associate to the streamline group. This matrix will probably be used as a mean of validation during training. Then, you also need to explain how the matrix was created, so that you can create the connectivity matrix of the streamlines being validated, in order to compare it with the expected result. ONE of the two next options must be given: + - **connectivity_nb_blocs**: This explains that the connectivity matrix was created by dividing the volume space into regular blocs. See dwiml_compute_connectivity_matrix_from_blocs for a description. The value should be either an integers or a list of three integers. + - **connectivity_labels_volume**: This explains that the connectivity matrix was created by dividing the cortex into a list of regions associated with labels. The value must be the name of another volume group in the same config file, which refers to a map with one label per region. NOTE: This option is offered in preparation of future use only. Currently, you can create the hdf5 with this option, but connectivity computation using labels is not yet implemented in dwi_ml. 2.4. Creating the hdf5 ********************** -You will use the **create_hdf5_dataset.py** script to create a hdf5 file. You need to prepare config files to use this script (see :ref:`ref_config_file`). - -Exemple of use: (See also please_copy_and_adapt/ALL_STEPS.sh) +You will use the **dwiml_create_hdf5_dataset.py** script to create a hdf5 file. .. code-block:: bash dwi_ml_folder=YOUR_PATH - hdf5_folder=YOUR_PATH + hdf5_file=YOUR_OUT_FILE.hdf5 config_file=YOUR_FILE.json training_subjs=YOUR_FILE.txt validation_subjs=YOUR_FILE.txt @@ -154,28 +155,11 @@ Exemple of use: (See also please_copy_and_adapt/ALL_STEPS.sh) dwiml_create_hdf5_dataset.py --name $name --std_mask $mask --space $space \ --enforce_files_presence True \ - $dwi_ml_folder $hdf5_folder $config_file \ + $dwi_ml_folder $hdf5_file $config_file \ $training_subjs $validation_subjs $testing_subjs .. toctree:: :maxdepth: 1 :caption: Detailed explanations for developers: - 2_C_advanced_hdf5_organization - - -P.S How to get data? -******************** - -Here in the SCIL lab, we often suggest to use the Tractoflow pipeline to process your data. If you need help for the pre-processing and reorgnization of your database, consult the following pages: - -.. toctree:: - :maxdepth: 2 - - 2_B_preprocessing - - -Organizing data from tractoflow to dwi_ml_ready ------------------------------------------------ - -If you used tractoflow to preprocess your data, you may organize automatically the dwi_ml_ready folder. We have started to prepare a script for you, which you can find in bash_utilities/**organizse_from_tractoflow.sh**, which creates symlinks between your tractoflow results and a dwi_ml_ready folder. However, Tractoflow may have changed since we create this help, filenames could not correspond to your files. We encourage you to modify this script in your own project depending on your needs. + 2_B_advanced_hdf5_organization diff --git a/source/2_C_advanced_hdf5_organization.rst b/source/2_B_advanced_hdf5_organization.rst similarity index 63% rename from source/2_C_advanced_hdf5_organization.rst rename to source/2_B_advanced_hdf5_organization.rst index 92eda1e0..56a8dbbc 100644 --- a/source/2_C_advanced_hdf5_organization.rst +++ b/source/2_B_advanced_hdf5_organization.rst @@ -3,7 +3,7 @@ The hdf5 organization ===================== -Here is the output format created by create_hdf5_dataset.py and recognized by our scripts: +Here is the output format created by dwiml_create_hdf5_dataset.py and recognized by our scripts: .. code-block:: bash @@ -12,24 +12,29 @@ Here is the output format created by create_hdf5_dataset.py and recognized by ou hdf5.attrs['validation_subjs'] = the list of str representing the validation subjects. hdf5.attrs['testing_subjs'] = the list of str representing the testing subjects. - hdf5.keys() are the subjects. + # hdf5.keys() are the subjects. hdf5['subj1'].keys() are the groups from the config_file. hdf5['subj1']['group1'].attrs['type'] = 'volume' or 'streamlines'. hdf5['subj1']['group1']['data'] is the data. - For streamlines, other available data: - (from the data:) + # For streamlines, other available data: + # (from the data:) hdf5['subj1']['group1']['offsets'] hdf5['subj1']['group1']['lengths'] hdf5['subj1']['group1']['euclidean_lengths'] - (from the space attributes:) + # (from the space attributes:) hdf5['subj1']['group1']['space'] hdf5['subj1']['group1']['affine'] hdf5['subj1']['group1']['dimensions'] hdf5['subj1']['group1']['voxel_sizes'] hdf5['subj1']['group1']['voxel_order'] + # (others:) + hdf5['subj1']['group1']['connectivity_matrix'] + hdf5['subj1']['group1']['connectivity_matrix_type'] = 'from_blocs' or 'from_labels' + hdf5['subj1']['group1']['connectivity_label_volume'] (the labels' volume group) OR + hdf5['subj1']['group1']['connectivity_nb_blocs'] (a list of three integers) - For volumes, other available data: + # For volumes, other available data: hdf5['sub1']['group1']['affine'] hdf5['sub1']['group1']['voxres'] hdf5['sub1']['group1']['nb_features'] diff --git a/source/2_B_preprocessing.rst b/source/2_B_preprocessing.rst deleted file mode 100644 index 2f68a6a1..00000000 --- a/source/2_B_preprocessing.rst +++ /dev/null @@ -1,69 +0,0 @@ -.. _ref_preprocessing: - -Preprocessing your data (using scilpy) -====================================== - -Running the whole tractography process -************************************** - -We suggest using scilpy's `tractoflow `_ to preprocess dwi data and create tractograms. - -Obtaining clean bundles -*********************** - -Here is an example of steps that could be useful to preprocess bundles. Here, we consider that tractoflow has already been ran. - -Separating your tractogram into bundles -''''''''''''''''''''''''''''''''''''''' - -You might want to separate some good bundles from the whole-brain tractogram (ex, for bundle-specific tractography algorithms, BST, or simply to ensure that you train your ML algorithm on true-positive streamlines only). - -One possible technique to create bundles is to simply regroup the streamlines that are close based on some metric (ex, the MDF). See `scil_compute_qbx.py `_. - -However, separating data into known bundles (from an atlas) is probably a better way to clean your tractogram and to remove all possibly false positive streamlines. For a clustering based on atlases, Dipy offers Recobundles, or you can use scilpy's version RecobundlesX, which is a little different. You will need bundle models and their associated json file. You may check `scilpy's doc `_ RecobundlesX tab for a basic bash script example. - - -Tractseg -'''''''' - -`Tractseg `_ is one of the most used published techniques using machine learning for diffusion. If you want to compare your work with theirs, you might want to use their bundles. Here is how to use it: - - .. code-block:: bash - - TractSeg -i YOUR_DATA -o OUT_NAME --bvals original/SUBJ/bval \ - --bvecs original/SUBJ/bvec --raw_diffusion_input \ - --brain_mask preprocessed/SUBJ/some_mask \ - --output_type endings_segmentation --csd_type csd_msmt_5tt - -Other tools and tricks -*********************** - - -Tractogram conversions -'''''''''''''''''''''' - -Here is how to convert from trk to tck: - - .. code-block:: bash - - scil_convert_tractogram.py TRK_FILE TCK_OUT_NAME \ - --reference processed/SUBJ/some_ref.nii.gz - -Bundle masks -'''''''''''' - -Here is how to create a mask of voxels touched by a bundle: - - .. code-block:: bash - - scil_compute_density_map_from_streamlines.py BUNDLE.tck \ - preprocessed/subj/DTI_Metrics/SUBJ__fa.nii.gz OUT_NAME --binary - -Merging bundles -''''''''''''''' - -Here is how you can merge bundles together: - - .. code-block:: bash - - scil_mask_math.py union ALL_BUNDLES preprocessed/SUBJ/some_mask.nii.gz From d415173b86454916a4bf625a92a58a31b30753bb Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 15 Jan 2024 09:46:26 -0500 Subject: [PATCH 02/23] Move std mask option to config file --- dwi_ml/data/hdf5/hdf5_creation.py | 98 ++++++++++----------- dwi_ml/data/hdf5/utils.py | 14 +-- scripts_python/dwiml_create_hdf5_dataset.py | 8 +- 3 files changed, 49 insertions(+), 71 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index c2e56f3d..f5d95eed 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -107,8 +107,7 @@ class HDF5Creator: def __init__(self, root_folder: Path, out_hdf_filename: Path, training_subjs: List[str], validation_subjs: List[str], testing_subjs: List[str], groups_config: dict, - std_mask: str, step_size: float = None, - compress: float = None, + step_size: float = None, compress: float = None, enforce_files_presence: bool = True, save_intermediate: bool = False, intermediate_folder: Path = None): @@ -126,8 +125,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, List of subject names for each data set. groups_config: dict Information from json file loaded as a dict. - std_mask: str - Name of the standardization mask inside each subject's folder. step_size: float Step size to resample streamlines. Default: None. compress: float @@ -152,7 +149,6 @@ def __init__(self, root_folder: Path, out_hdf_filename: Path, self.compress = compress # Optional - self.std_mask = std_mask # (could be None) self.save_intermediate = save_intermediate self.enforce_files_presence = enforce_files_presence self.intermediate_folder = intermediate_folder @@ -295,20 +291,11 @@ def _check_files_presence(self): config_file_list = sum(nested_lookup('files', self.groups_config), []) config_file_list += nested_lookup( 'connectivity_matrix', self.groups_config) + config_file_list += nested_lookup('std_mask', self.groups_config) for subj_id in self.all_subjs: subj_input_dir = Path(self.root_folder).joinpath(subj_id) - # Find subject's standardization mask - if self.std_mask is not None: - for sub_mask in self.std_mask: - sub_std_mask_file = subj_input_dir.joinpath( - sub_mask.replace('*', subj_id)) - if not sub_std_mask_file.is_file(): - raise FileNotFoundError( - "Standardization mask {} not found for subject {}!" - .format(sub_std_mask_file, subj_id)) - # Find subject's files from group_config for this_file in config_file_list: this_file = this_file.replace('*', subj_id) @@ -368,31 +355,14 @@ def _create_one_subj(self, subj_id, hdf_handle): subj_hdf_group = hdf_handle.create_group(subj_id) - # Find subject's standardization mask - subj_std_mask_data = None - if self.std_mask is not None: - for sub_mask in self.std_mask: - sub_mask = sub_mask.replace('*', subj_id) - logging.info(" - Loading standardization mask {}" - .format(sub_mask)) - sub_mask_file = subj_input_dir.joinpath(sub_mask) - sub_mask_img = nib.load(sub_mask_file) - sub_mask_data = np.asanyarray(sub_mask_img.dataobj) > 0 - if subj_std_mask_data is None: - subj_std_mask_data = sub_mask_data - else: - subj_std_mask_data = np.logical_or(sub_mask_data, - subj_std_mask_data) - # Add the subj data based on groups in the json config file - ref = self._create_volume_groups( - subj_id, subj_input_dir, subj_std_mask_data, subj_hdf_group) + ref = self._create_volume_groups(subj_id, subj_input_dir, + subj_hdf_group) self._create_streamline_groups(ref, subj_input_dir, subj_id, subj_hdf_group) - def _create_volume_groups(self, subj_id, subj_input_dir, - subj_std_mask_data, subj_hdf_group): + def _create_volume_groups(self, subj_id, subj_input_dir, subj_hdf_group): """ Create the hdf5 groups for all volume groups in the config_file for a given subject. @@ -407,7 +377,7 @@ def _create_volume_groups(self, subj_id, subj_input_dir, (group_data, group_affine, group_header, group_res) = self._process_one_volume_group( - group, subj_id, subj_input_dir, subj_std_mask_data) + group, subj_id, subj_input_dir) if ref_header is None: ref_header = group_header else: @@ -431,8 +401,7 @@ def _create_volume_groups(self, subj_id, subj_input_dir, return ref_header def _process_one_volume_group(self, group: str, subj_id: str, - subj_input_path: Path, - subj_std_mask_data: np.ndarray = None): + subj_input_dir: Path): """ Processes each volume group from the json config file for a given subject: @@ -448,10 +417,8 @@ def _process_one_volume_group(self, group: str, subj_id: str, Group name. subj_id: str The subject's id. - subj_input_path: Path + subj_input_dir: Path Path where the files from file_list should be found. - subj_std_mask_data: np.ndarray of bools, optional - Binary mask that will be used for data standardization. Returns ------- @@ -460,19 +427,44 @@ def _process_one_volume_group(self, group: str, subj_id: str, group_affine: np.ndarray Affine for the group. """ - standardization = self.groups_config[group]['standardization'] + std_mask = None + std_option = 'none' + if 'standardization' in self.groups_config[group]: + std_option = self.groups_config[group]['standardization'] + if 'std_mask' in self.groups_config[group]: + if std_option == 'none': + logging.warning("You provided a std_mask for volume group {}, " + "but std_option is 'none'. Skipping.") + else: + # Load subject's standardization mask. Can be a list of files. + std_masks = self.groups_config[group]['std_mask'] + if isinstance(std_masks, str): + std_masks = [std_masks] + + for sub_mask in std_masks: + sub_mask = sub_mask.replace('*', subj_id) + logging.info(" - Loading standardization mask {}" + .format(sub_mask)) + sub_mask_file = subj_input_dir.joinpath(sub_mask) + sub_mask_img = nib.load(sub_mask_file) + sub_mask_data = np.asanyarray(sub_mask_img.dataobj) > 0 + if std_mask is None: + std_mask = sub_mask_data + else: + std_mask = np.logical_or(sub_mask_data, std_mask) + file_list = self.groups_config[group]['files'] # First file will define data dimension and affine file_name = file_list[0].replace('*', subj_id) - first_file = subj_input_path.joinpath(file_name) + first_file = subj_input_dir.joinpath(file_name) logging.info(" - Processing file {}".format(file_name)) group_data, group_affine, group_res, group_header = load_file_to4d( first_file) - if standardization == 'per_file': + if std_option == 'per_file': logging.debug(' *Standardizing sub-data') - group_data = standardize_data(group_data, subj_std_mask_data, + group_data = standardize_data(group_data, std_mask, independent=False) # Other files must fit (data shape, affine, voxel size) @@ -480,12 +472,12 @@ def _process_one_volume_group(self, group: str, subj_id: str, # is a minimal check. for file_name in file_list[1:]: file_name = file_name.replace('*', subj_id) - data = _load_and_verify_file(file_name, subj_input_path, group, + data = _load_and_verify_file(file_name, subj_input_dir, group, group_affine, group_res) - if standardization == 'per_file': + if std_option == 'per_file': logging.debug(' *Standardizing sub-data') - data = standardize_data(data, subj_std_mask_data, + data = standardize_data(data, std_mask, independent=False) # Append file data to hdf group. @@ -497,15 +489,15 @@ def _process_one_volume_group(self, group: str, subj_id: str, 'Wrong dimensions?'.format(file_name, group)) # Standardize data (per channel) (if not done 'per_file' yet). - if standardization == 'independent': + if std_option == 'independent': logging.debug(' *Standardizing data on each feature.') - group_data = standardize_data(group_data, subj_std_mask_data, + group_data = standardize_data(group_data, std_mask, independent=True) - elif standardization == 'all': + elif std_option == 'all': logging.debug(' *Standardizing data as a whole.') - group_data = standardize_data(group_data, subj_std_mask_data, + group_data = standardize_data(group_data, std_mask, independent=False) - elif standardization not in ['none', 'per_file']: + elif std_option not in ['none', 'per_file']: raise ValueError("standardization must be one of " "['all', 'independent', 'per_file', 'none']") diff --git a/dwi_ml/data/hdf5/utils.py b/dwi_ml/data/hdf5/utils.py index a8818deb..cbd7704b 100644 --- a/dwi_ml/data/hdf5/utils.py +++ b/dwi_ml/data/hdf5/utils.py @@ -44,7 +44,7 @@ def add_hdf5_creation_args(p: ArgumentParser): "-> https://dwi-ml.readthedocs.io/en/latest/" "creating_hdf5.html") p.add_argument('out_hdf5_file', - help="Path and name of the output hdf5 file.\n If " + help="Path and name of the output hdf5 file. \nIf " "--save_intermediate is set, the intermediate files " "will be saved in \nthe same location, in a folder " "name based on date and hour of creation.\n" @@ -79,18 +79,6 @@ def add_hdf5_creation_args(p: ArgumentParser): "final concatenated resampled/compressed streamlines.)") -def add_mri_processing_args(p: ArgumentParser): - g = p.add_argument_group('Volumes processing options:') - g.add_argument( - '--std_mask', nargs='+', metavar='m', - help="Mask defining the voxels used for data standardization. \n" - "-> Should be the name of a file inside dwi_ml_ready/{subj_id}.\n" - "-> You may add wildcards (*) that will be replaced by the " - "subject's id. \n" - "-> If none is given, all non-zero voxels will be used.\n" - "-> If more than one are given, masks will be combined.") - - def add_streamline_processing_args(p: ArgumentParser): g = p.add_argument_group('Streamlines processing options:') add_resample_or_compress_arg(g) diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 611b2bb2..61fb176c 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -32,8 +32,7 @@ from dwi_ml.data.hdf5.hdf5_creation import HDF5Creator from dwi_ml.data.hdf5.utils import ( - add_hdf5_creation_args, add_mri_processing_args, - add_streamline_processing_args) + add_hdf5_creation_args, add_streamline_processing_args) from dwi_ml.experiment_utils.timer import Timer from dwi_ml.io_utils import add_logging_arg @@ -89,8 +88,8 @@ def prepare_hdf5_creator(args): # Instantiate a creator and perform checks creator = HDF5Creator(Path(args.dwi_ml_ready_folder), args.out_hdf5_file, training_subjs, validation_subjs, testing_subjs, - groups_config, args.std_mask, args.step_size, - args.compress, args.enforce_files_presence, + groups_config, args.step_size, args.compress, + args.enforce_files_presence, args.save_intermediate, intermediate_subdir) return creator @@ -101,7 +100,6 @@ def _parse_args(): formatter_class=argparse.RawTextHelpFormatter) add_hdf5_creation_args(p) - add_mri_processing_args(p) add_streamline_processing_args(p) add_overwrite_arg(p) add_logging_arg(p) From c24e9844bac17a9a9fb8bdacb114573b82dc331e Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 17 Jan 2024 10:28:07 -0500 Subject: [PATCH 03/23] Add util script to print info on hdf5 --- .../dwiml_print_hdf5_architecture.py | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 scripts_python/dwiml_print_hdf5_architecture.py diff --git a/scripts_python/dwiml_print_hdf5_architecture.py b/scripts_python/dwiml_print_hdf5_architecture.py new file mode 100644 index 00000000..7b574f96 --- /dev/null +++ b/scripts_python/dwiml_print_hdf5_architecture.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +import argparse + +import h5py +from scilpy.io.utils import assert_inputs_exist + + +def _prepare_argparser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('hdf5_file', + help="Path to the hdf5 file.") + return p + + +def main(): + p = _prepare_argparser() + args = p.parse_args() + + assert_inputs_exist(p, args.hdf5_file) + + with h5py.File(args.hdf5_file, 'r') as hdf_handle: + + print("\n\nHere is the architecture of your hdf5:\n" + "--------------------------------------\n") + print("- Main hdf5 attributes: {}\n" + .format(list(hdf_handle.attrs.keys()))) + + if 'training_subjs' in hdf_handle.attrs: + print("- List of training subjects: {}\n" + .format(hdf_handle.attrs['training_subjs'])) + + if 'validation_subjs' in hdf_handle.attrs: + print("- List of validation subjects: {}\n" + .format(hdf_handle.attrs['validation_subjs'])) + + if 'testing_subjs' in hdf_handle.attrs: + print("- List of testing subjects: {}\n" + .format(hdf_handle.attrs['testing_subjs'])) + + print("- For each subject, caracteristics are:") + first_subj = list(hdf_handle.keys())[0] + for key, val in hdf_handle[first_subj].items(): + print(" - {}, with attributes {}" + .format(key, list(hdf_handle[first_subj][key].attrs.keys()))) + + +if __name__ == '__main__': + main() From 3d1014496e30580cb3587599c136f49516bc3849 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 1 Feb 2024 08:12:19 -0500 Subject: [PATCH 04/23] Fixes in connectivity_from_labels --- .../processing/streamlines/post_processing.py | 62 ++++++++++++------- ...compute_connectivity_matrix_from_labels.py | 47 +++++++++++--- 2 files changed, 79 insertions(+), 30 deletions(-) diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index f09e3186..a2764820 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import logging from typing import List import numpy as np @@ -320,13 +321,28 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, compressed streamlines.' Else, uses simple computation from endpoints. Faster. Also, works with incomplete parcellation. + + Returns + ------- + matrix: np.ndarray + With use_scilpy: shape (nb_labels + 1, nb_labels + 1) + (last label is "Not Found") + Else, shape (nb_labels, nb_labels) + labels: List + The list of labels """ - real_labels = np.unique(data_labels)[1:] + real_labels = list(np.sort(np.unique(data_labels))) nb_labels = len(real_labels) - matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int) + logging.debug("Computing connectivity matrix for {} labels." + .format(nb_labels)) - start_blocs = [] - end_blocs = [] + if use_scilpy: + matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int) + else: + matrix = np.zeros((nb_labels, nb_labels), dtype=int) + + start_labels = [] + end_labels = [] if use_scilpy: indices, points_to_idx = uncompress(streamlines, return_mapping=True) @@ -334,29 +350,33 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, for strl_vox_indices in indices: segments_info = segmenting_func(strl_vox_indices, data_labels) if len(segments_info) > 0: - start = segments_info[0]['start_label'] - end = segments_info[0]['end_label'] - start_blocs.append(start) - end_blocs.append(end) + start = real_labels.index(segments_info[0]['start_label']) + end = real_labels.index(segments_info[0]['end_label']) + else: + start = nb_labels + end = nb_labels - matrix[start, end] += 1 - if start != end: - matrix[end, start] += 1 + start_labels.append(start) + end_labels.append(end) + + matrix[start, end] += 1 + if start != end: + matrix[end, start] += 1 + + real_labels = real_labels + [np.NaN] - else: - # Putting it in 0,0, we will remember that this means 'other' - matrix[0, 0] += 1 - start_blocs.append(0) - end_blocs.append(0) else: for s in streamlines: # Vox space, corner origin # = we can get the nearest neighbor easily. # Coord 0 = voxel 0. Coord 0.9 = voxel 0. Coord 1 = voxel 1. - start = data_labels[tuple(np.floor(s[0, :]).astype(int))] - end = data_labels[tuple(np.floor(s[-1, :]).astype(int))] - start_blocs.append(start) - end_blocs.append(end) + start = real_labels.index( + data_labels[tuple(np.floor(s[0, :]).astype(int))]) + end = real_labels.index( + data_labels[tuple(np.floor(s[-1, :]).astype(int))]) + + start_labels.append(start) + end_labels.append(end) matrix[start, end] += 1 if start != end: matrix[end, start] += 1 @@ -367,7 +387,7 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, if binary: matrix = matrix.astype(bool) - return matrix, start_blocs, end_blocs + return matrix, real_labels, start_labels, end_labels def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs, diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py index f662f684..fb956c49 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py @@ -1,5 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- + +""" +Computes the connectivity matrix. +Labels associated with each line / row will be printed. +""" + import argparse import logging import os.path @@ -65,10 +71,16 @@ def main(): args = p.parse_args() if args.verbose: + # Currenlty, with debug, matplotlib prints a lot of stuff. Why?? logging.getLogger().setLevel(logging.INFO) tmp, ext = os.path.splitext(args.out_file) + + if ext != '.npy': + p.error("--out_file should have a .npy extension.") + out_fig = tmp + '.png' + out_ordered_labels = tmp + '_labels.txt' assert_inputs_exist(p, [args.in_labels, args.streamlines]) assert_outputs_exist(p, args, [args.out_file, out_fig], [args.save_biggest, args.save_smallest]) @@ -80,26 +92,36 @@ def main(): p.error("Streamlines not compatible with chosen volume.") else: args.reference = args.in_labels + + logging.info("Loading tractogram.") in_sft = load_tractogram_with_reference(p, args, args.streamlines) in_img = nib.load(args.in_labels) data_labels = get_data_as_labels(in_img) in_sft.to_vox() in_sft.to_corner() - matrix, start_blocs, end_blocs = compute_triu_connectivity_from_labels( - in_sft.streamlines, data_labels, - use_scilpy=args.use_longest_segment) + matrix, ordered_labels, start_blocs, end_blocs = \ + compute_triu_connectivity_from_labels( + in_sft.streamlines, data_labels, + use_scilpy=args.use_longest_segment) if args.hide_background is not None: - matrix[args.hide_background, :] = 0 - matrix[:, args.hide_background] = 0 + idx = ordered_labels.idx(args.hide_background) + matrix[idx, :] = 0 + matrix[:, idx] = 0 + ordered_labels[idx] = ("Hidden background ({})" + .format(args.hide_background)) + + logging.info("Labels are, in order: {}".format(ordered_labels)) # Options to try to investigate the connectivity matrix: # masking point (0,0) = streamline ending in wm. if args.save_biggest is not None: i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape) print("Saving biggest bundle: {} streamlines. From label {} to label " - "{}".format(matrix[i, j], i, j)) + "{} (line {}, column {} in the matrix)" + .format(matrix[i, j], ordered_labels[i], ordered_labels[j], + i, j)) biggest = find_streamlines_with_chosen_connectivity( in_sft.streamlines, i, j, start_blocs, end_blocs) sft = in_sft.from_sft(biggest, in_sft) @@ -109,15 +131,22 @@ def main(): tmp_matrix = np.ma.masked_equal(matrix, 0) i, j = np.unravel_index(tmp_matrix.argmin(axis=None), matrix.shape) print("Saving smallest bundle: {} streamlines. From label {} to label " - "{}".format(matrix[i, j], i, j)) - biggest = find_streamlines_with_chosen_connectivity( + "{} (line {}, column {} in the matrix)" + .format(matrix[i, j], ordered_labels[i], ordered_labels[j], + i, j)) + smallest = find_streamlines_with_chosen_connectivity( in_sft.streamlines, i, j, start_blocs, end_blocs) - sft = in_sft.from_sft(biggest, in_sft) + sft = in_sft.from_sft(smallest, in_sft) save_tractogram(sft, args.save_smallest) + ordered_labels = str(ordered_labels) + with open(out_ordered_labels, "w") as text_file: + text_file.write(ordered_labels) + if args.show_now: plt.imshow(matrix) plt.colorbar() + plt.title("Raw streamline count") plt.figure() plt.imshow(matrix > 0) From 8393e4f31225c0d37052a5118ce5981c95e2728f Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 1 Feb 2024 08:41:38 -0500 Subject: [PATCH 05/23] Add missing tests --- scripts_python/dwiml_divide_volume_into_blocs.py | 2 +- scripts_python/tests/test_compute_connectivity_score.py | 8 ++++++++ scripts_python/tests/test_divide_volume_into_blocs.py | 8 ++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 scripts_python/tests/test_compute_connectivity_score.py create mode 100644 scripts_python/tests/test_divide_volume_into_blocs.py diff --git a/scripts_python/dwiml_divide_volume_into_blocs.py b/scripts_python/dwiml_divide_volume_into_blocs.py index 8e702dec..439912f4 100644 --- a/scripts_python/dwiml_divide_volume_into_blocs.py +++ b/scripts_python/dwiml_divide_volume_into_blocs.py @@ -16,7 +16,7 @@ def _build_arg_parser(): help='Input file name, in nifti format.') p.add_argument( - 'out', metavar='OUT_FILE', dest='out_filename', + 'out_filename', help='name of the output file, which will be saved as a text file.') add_overwrite_arg(p) diff --git a/scripts_python/tests/test_compute_connectivity_score.py b/scripts_python/tests/test_compute_connectivity_score.py new file mode 100644 index 00000000..f9e6bf08 --- /dev/null +++ b/scripts_python/tests/test_compute_connectivity_score.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +def test_help_option(script_runner): + ret = script_runner.run('dwiml_compute_connectivity_score.py', + '--help') + assert ret.success diff --git a/scripts_python/tests/test_divide_volume_into_blocs.py b/scripts_python/tests/test_divide_volume_into_blocs.py new file mode 100644 index 00000000..95d0e9cc --- /dev/null +++ b/scripts_python/tests/test_divide_volume_into_blocs.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + + +def test_help_option(script_runner): + ret = script_runner.run('dwiml_divide_volume_into_blocs.py', + '--help') + assert ret.success From 126b4966723f9a22df1fa9cc14017d9e5e1663a0 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 1 Feb 2024 09:35:59 -0500 Subject: [PATCH 06/23] Improve figure in both connectivity scripts --- .../processing/streamlines/post_processing.py | 31 +++++++++++++++++++ ..._compute_connectivity_matrix_from_blocs.py | 10 ++---- ...compute_connectivity_matrix_from_labels.py | 15 +++------ 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index a2764820..bc4ca750 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -4,6 +4,9 @@ import numpy as np import torch +from matplotlib import pyplot as plt +from matplotlib.colors import LogNorm +from mpl_toolkits.axes_grid1 import make_axes_locatable from scilpy.tractanalysis.tools import \ extract_longest_segments_from_profile as segmenting_func @@ -433,6 +436,34 @@ def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs, return matrix, start_block, end_block +def prepare_figure_connectivity(matrix): + matrix = np.copy(matrix) + + fig, axs = plt.subplots(2, 2) + im = axs[0, 0].imshow(matrix) + divider = make_axes_locatable(axs[0, 0]) + cax = divider.append_axes('right', size='5%', pad=0.05) + fig.colorbar(im, cax=cax, orientation='vertical') + axs[0, 0].set_title("Raw streamline count") + + im = axs[0, 1].imshow(matrix + np.min(matrix[matrix > 0]), norm=LogNorm()) + divider = make_axes_locatable(axs[0, 1]) + cax = divider.append_axes('right', size='5%', pad=0.05) + fig.colorbar(im, cax=cax, orientation='vertical') + axs[0, 1].set_title("Raw streamline count (log view)") + + matrix = matrix / matrix.sum() * 100 + im = axs[1, 0].imshow(matrix) + divider = make_axes_locatable(axs[1, 0]) + cax = divider.append_axes('right', size='5%', pad=0.05) + fig.colorbar(im, cax=cax, orientation='vertical') + axs[1, 0].set_title("Percentage") + + matrix = matrix > 0 + axs[1, 1].imshow(matrix) + axs[1, 1].set_title("Binary") + + def find_streamlines_with_chosen_connectivity( streamlines, label1, label2, start_labels, end_labels): """ diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py index 7d30cba5..ae559f9a 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py @@ -15,7 +15,8 @@ from dwi_ml.data.hdf5.utils import format_nb_blocs_connectivity from dwi_ml.data.processing.streamlines.post_processing import \ - compute_triu_connectivity_from_blocs, find_streamlines_with_chosen_connectivity + compute_triu_connectivity_from_blocs, \ + find_streamlines_with_chosen_connectivity, prepare_figure_connectivity def _build_arg_parser(): @@ -102,12 +103,7 @@ def main(): sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_smallest) - plt.imshow(matrix) - plt.colorbar() - - plt.figure() - plt.imshow(matrix > 0) - plt.title('Binary') + prepare_figure_connectivity(matrix) if args.binary: matrix = matrix > 0 diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py index fb956c49..24db8333 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py @@ -22,7 +22,7 @@ from dwi_ml.data.processing.streamlines.post_processing import \ find_streamlines_with_chosen_connectivity, \ - compute_triu_connectivity_from_labels + compute_triu_connectivity_from_labels, prepare_figure_connectivity def _build_arg_parser(): @@ -143,14 +143,7 @@ def main(): with open(out_ordered_labels, "w") as text_file: text_file.write(ordered_labels) - if args.show_now: - plt.imshow(matrix) - plt.colorbar() - plt.title("Raw streamline count") - - plt.figure() - plt.imshow(matrix > 0) - plt.title('Binary') + prepare_figure_connectivity(matrix) if args.binary: matrix = matrix > 0 @@ -158,7 +151,9 @@ def main(): # Save results. np.save(args.out_file, matrix) plt.savefig(out_fig) - plt.show() + + if args.show_now: + plt.show() if __name__ == '__main__': From d3c27d224f93b2d66ee4d332afc61d9722647cc5 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 1 Feb 2024 15:10:17 -0500 Subject: [PATCH 07/23] Final improvements connectivity_from_labels --- .../processing/streamlines/post_processing.py | 4 +- ...compute_connectivity_matrix_from_labels.py | 45 +++++++++++-------- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index bc4ca750..eabc079d 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -336,8 +336,8 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, """ real_labels = list(np.sort(np.unique(data_labels))) nb_labels = len(real_labels) - logging.debug("Computing connectivity matrix for {} labels." - .format(nb_labels)) + logging.info("Computing connectivity matrix for {} labels." + .format(nb_labels)) if use_scilpy: matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int) diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py index 24db8333..9c682512 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py @@ -39,7 +39,7 @@ def _build_arg_parser(): "streamline count is saved.") p.add_argument('--show_now', action='store_true', help="If set, shows the matrix with matplotlib.") - p.add_argument('--hide_background', nargs='?', const=0, + p.add_argument('--hide_background', nargs='?', const=0, type=float, help="If true, set the connectivity matrix for chosen " "label (default: 0), to 0.") p.add_argument( @@ -80,9 +80,10 @@ def main(): p.error("--out_file should have a .npy extension.") out_fig = tmp + '.png' + out_fig_noback = tmp + '_hidden_background.png' out_ordered_labels = tmp + '_labels.txt' assert_inputs_exist(p, [args.in_labels, args.streamlines]) - assert_outputs_exist(p, args, [args.out_file, out_fig], + assert_outputs_exist(p, args, [args.out_file, out_fig, out_fig_noback], [args.save_biggest, args.save_smallest]) ext = os.path.splitext(args.streamlines)[1] @@ -105,14 +106,33 @@ def main(): in_sft.streamlines, data_labels, use_scilpy=args.use_longest_segment) + prepare_figure_connectivity(matrix) + plt.savefig(out_fig) + if args.hide_background is not None: - idx = ordered_labels.idx(args.hide_background) + idx = ordered_labels.index(args.hide_background) + nb_hidden = np.sum(matrix[idx, :]) + np.sum(matrix[:, idx]) - \ + matrix[idx, idx] + if nb_hidden > 0: + logging.info("CAREFUL! Deleting from the matrix {} streamlines " + "with one or both endpoints in a non-labelled area " + "(background = {}; line/column {})" + .format(nb_hidden, args.hide_background, idx)) + else: + logging.info("No streamlines with endpoints in the background :)") matrix[idx, :] = 0 matrix[:, idx] = 0 ordered_labels[idx] = ("Hidden background ({})" .format(args.hide_background)) - logging.info("Labels are, in order: {}".format(ordered_labels)) + prepare_figure_connectivity(matrix) + plt.savefig(out_fig_noback) + + if args.binary: + matrix = matrix > 0 + + # Save results. + np.save(args.out_file, matrix) # Options to try to investigate the connectivity matrix: # masking point (0,0) = streamline ending in wm. @@ -139,18 +159,10 @@ def main(): sft = in_sft.from_sft(smallest, in_sft) save_tractogram(sft, args.save_smallest) - ordered_labels = str(ordered_labels) with open(out_ordered_labels, "w") as text_file: - text_file.write(ordered_labels) - - prepare_figure_connectivity(matrix) - - if args.binary: - matrix = matrix > 0 - - # Save results. - np.save(args.out_file, matrix) - plt.savefig(out_fig) + logging.info("Labels are saved in: {}".format(out_ordered_labels)) + for i, label in enumerate(ordered_labels): + text_file.write("{} = {}\n".format(i, label)) if args.show_now: plt.show() @@ -158,6 +170,3 @@ def main(): if __name__ == '__main__': main() - - - From af2e3de3855df6379e78716167a3b3bb37d488f0 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 1 Feb 2024 15:17:04 -0500 Subject: [PATCH 08/23] Fix pep8 --- ..._compute_connectivity_matrix_from_blocs.py | 26 ++++++++----------- .../dwiml_divide_volume_into_blocs.py | 2 +- ...compute_connectivity_matrix_from_labels.py | 3 ++- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py index ae559f9a..b7f42f79 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py @@ -23,8 +23,8 @@ def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter) p.add_argument('in_volume', - help='Input nifti volume. Only used to get the shape of the ' - 'volume.') + help='Input nifti volume. Only used to get the shape of ' + 'the volume.') p.add_argument('streamlines', help='Tractogram (trk or tck).') p.add_argument('out_file', @@ -85,6 +85,15 @@ def main(): matrix, start_blocs, end_blocs = compute_triu_connectivity_from_blocs( in_sft.streamlines, in_img.shape, args.connectivity_nb_blocs) + prepare_figure_connectivity(matrix) + + if args.binary: + matrix = matrix > 0 + + # Save results. + np.save(args.out_file, matrix) + plt.savefig(out_fig) + # Options to try to investigate the connectivity matrix: if args.save_biggest is not None: i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape) @@ -103,22 +112,9 @@ def main(): sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_smallest) - prepare_figure_connectivity(matrix) - - if args.binary: - matrix = matrix > 0 - - # Save results. - np.save(args.out_file, matrix) - - plt.savefig(out_fig) - if args.show_now: plt.show() if __name__ == '__main__': main() - - - diff --git a/scripts_python/dwiml_divide_volume_into_blocs.py b/scripts_python/dwiml_divide_volume_into_blocs.py index 439912f4..bbdb1bdc 100644 --- a/scripts_python/dwiml_divide_volume_into_blocs.py +++ b/scripts_python/dwiml_divide_volume_into_blocs.py @@ -50,7 +50,7 @@ def main(): assert_inputs_exist(parser, args.in_image) assert_outputs_exist(parser, args, required=args.out_filename) - + volume = nib.load(args.in_image) final_volume = color_mri_connectivity_blocs([6, 6, 6], volume.shape) img = nib.Nifti1Image(final_volume, volume.affine) diff --git a/scripts_python/tests/test_compute_connectivity_matrix_from_labels.py b/scripts_python/tests/test_compute_connectivity_matrix_from_labels.py index 990b207e..faed1242 100644 --- a/scripts_python/tests/test_compute_connectivity_matrix_from_labels.py +++ b/scripts_python/tests/test_compute_connectivity_matrix_from_labels.py @@ -2,5 +2,6 @@ # -*- coding: utf-8 -*- def test_help_option(script_runner): - ret = script_runner.run('dwiml_compute_connectivity_matrix_from_labels.py', '--help') + ret = script_runner.run('dwiml_compute_connectivity_matrix_from_labels.py', + '--help') assert ret.success From 37cdbb8f48a6766e190c218fb5e4d4edf7f052d5 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 14 Feb 2024 12:44:05 -0500 Subject: [PATCH 09/23] Manage wildcards like real wildcards --- dwi_ml/data/hdf5/hdf5_creation.py | 169 ++++++++++---------- dwi_ml/data/io.py | 4 +- scripts_python/dwiml_create_hdf5_dataset.py | 2 +- source/2_A_creating_the_hdf5.rst | 5 +- 4 files changed, 89 insertions(+), 91 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index f5d95eed..650c1cce 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import datetime +import glob import logging import os from pathlib import Path @@ -24,6 +25,39 @@ from dwi_ml.data.processing.dwi.dwi import standardize_data +def format_filelist(filenames, enforce_presence, folder=None) -> List[str]: + """ + If folder is not None, it will be added as prefix to all files. + """ + if isinstance(filenames, str): + filenames = [filenames] + + new_files = [] + for i, f in enumerate(filenames): + if folder is not None: + f = str(folder.joinpath(f)) + if '*' in f: + tmp = glob.glob(f) + if len(tmp) == 0: + msg = "File not found, even with the wildcard: {}".format(f) + if enforce_presence: + raise FileNotFoundError(msg) + else: + logging.warning(msg) + else: + new_files.extend(f) + else: + if not Path(f).is_file(): + msg = "File not found: {}".format(f) + if enforce_presence: + raise FileNotFoundError(msg) + else: + logging.warning(msg) + else: + new_files.append(f) + return new_files + + def _load_and_verify_file(filename: str, subj_input_path, group_name: str, group_affine, group_res): """ @@ -297,19 +331,9 @@ def _check_files_presence(self): subj_input_dir = Path(self.root_folder).joinpath(subj_id) # Find subject's files from group_config - for this_file in config_file_list: - this_file = this_file.replace('*', subj_id) - if this_file.endswith('/ALL'): - logging.debug( - " Keyword 'ALL' detected; we will load all " - "files in the folder '{}'" - .format(this_file.replace('/ALL', ''))) - else: - this_file = subj_input_dir.joinpath(this_file) - if not this_file.is_file(): - raise FileNotFoundError( - "File from groups_config ({}) not found for " - "subject {}!".format(this_file, subj_id)) + config_file_list = format_filelist(config_file_list, + self.enforce_files_presence, + folder=subj_input_dir) def create_database(self): """ @@ -441,26 +465,25 @@ def _process_one_volume_group(self, group: str, subj_id: str, if isinstance(std_masks, str): std_masks = [std_masks] - for sub_mask in std_masks: - sub_mask = sub_mask.replace('*', subj_id) + std_masks = format_filelist(std_masks, folder=subj_input_dir) + for mask in std_masks: logging.info(" - Loading standardization mask {}" - .format(sub_mask)) - sub_mask_file = subj_input_dir.joinpath(sub_mask) - sub_mask_img = nib.load(sub_mask_file) - sub_mask_data = np.asanyarray(sub_mask_img.dataobj) > 0 + .format(os.path.basename(mask))) + sub_mask_data = nib.load(mask).get_fdata() > 0 if std_mask is None: std_mask = sub_mask_data else: std_mask = np.logical_or(sub_mask_data, std_mask) file_list = self.groups_config[group]['files'] + file_list = format_filelist(file_list, self.enforce_files_presence, + folder=subj_input_dir) # First file will define data dimension and affine - file_name = file_list[0].replace('*', subj_id) - first_file = subj_input_dir.joinpath(file_name) - logging.info(" - Processing file {}".format(file_name)) + logging.info(" - Processing file {}" + .format(os.path.basename(file_list[0]))) group_data, group_affine, group_res, group_header = load_file_to4d( - first_file) + file_list[0]) if std_option == 'per_file': logging.debug(' *Standardizing sub-data') @@ -470,23 +493,24 @@ def _process_one_volume_group(self, group: str, subj_id: str, # Other files must fit (data shape, affine, voxel size) # It is not a promise that data has been correctly registered, but it # is a minimal check. - for file_name in file_list[1:]: - file_name = file_name.replace('*', subj_id) - data = _load_and_verify_file(file_name, subj_input_dir, group, - group_affine, group_res) - - if std_option == 'per_file': - logging.debug(' *Standardizing sub-data') - data = standardize_data(data, std_mask, - independent=False) - - # Append file data to hdf group. - try: - group_data = np.append(group_data, data, axis=-1) - except ImportError: - raise ImportError( - 'Data file {} could not be added to data group {}. ' - 'Wrong dimensions?'.format(file_name, group)) + if len(file_list) > 1: + for file_name in file_list[1:]: + logging.info(" - Processing file {}" + .format(os.path.basename(file_name))) + data = _load_and_verify_file(file_name, subj_input_dir, group, + group_affine, group_res) + + if std_option == 'per_file': + logging.debug(' *Standardizing sub-data') + data = standardize_data(data, std_mask, independent=False) + + # Append file data to hdf group. + try: + group_data = np.append(group_data, data, axis=-1) + except ImportError: + raise ImportError( + 'Data file {} could not be added to data group {}. ' + 'Wrong dimensions?'.format(file_name, group)) # Standardize data (per channel) (if not done 'per_file' yet). if std_option == 'independent': @@ -590,9 +614,6 @@ def _process_one_streamline_group( Loads and processes a group of tractograms and merges all streamlines together. - Note. Wildcards will be replaced by the subject id. If the list is - folder/ALL, all tractograms in the folder will be used. - Parameters ---------- subj_dir : Path @@ -628,41 +649,26 @@ def _process_one_streamline_group( final_sft = None output_lengths = [] - for instructions in tractograms: - if instructions.endswith('/ALL'): - # instructions are to get all tractograms in given folder. - tractograms_dir = instructions.split('/ALL') - tractograms_dir = ''.join(tractograms_dir[:-1]) - tractograms_sublist = [ - instructions.replace('/ALL', '/' + os.path.basename(p)) - for p in subj_dir.glob(tractograms_dir + '/*')] - else: - # instruction is to get one specific tractogram - tractograms_sublist = [instructions] - - # Either a loop on "ALL" or a loop on only one file. - for tractogram_name in tractograms_sublist: - tractogram_name = tractogram_name.replace('*', subj_id) - tractogram_file = subj_dir.joinpath(tractogram_name) + tractograms = format_filelist(tractograms, self.enforce_files_presence, + folder=subj_dir) + for tractogram_file in tractograms: + sft = self._load_and_process_sft(tractogram_file, header) - sft = self._load_and_process_sft( - tractogram_file, tractogram_name, header) + if sft is not None: + # Compute euclidean lengths (rasmm space) + sft.to_space(Space.RASMM) + output_lengths.extend(length(sft.streamlines)) - if sft is not None: - # Compute euclidean lengths (rasmm space) - sft.to_space(Space.RASMM) - output_lengths.extend(length(sft.streamlines)) + # Sending to common space + sft.to_vox() + sft.to_corner() - # Sending to common space - sft.to_vox() - sft.to_corner() - - # Add processed tractogram to final big tractogram - if final_sft is None: - final_sft = sft - else: - final_sft = concatenate_sft([final_sft, sft], - erase_metadata=False) + # Add processed tractogram to final big tractogram + if final_sft is None: + final_sft = sft + else: + final_sft = concatenate_sft([final_sft, sft], + erase_metadata=False) if self.save_intermediate: output_fname = self.intermediate_folder.joinpath( @@ -716,16 +722,7 @@ def _process_one_streamline_group( return final_sft, output_lengths, conn_matrix, conn_info - def _load_and_process_sft(self, tractogram_file, tractogram_name, header): - if not tractogram_file.is_file(): - logging.debug( - " Skipping file {} because it was not found in this " - "subject's folder".format(tractogram_name)) - # Note: if args.enforce_files_presence was set to true, - # this case is not possible, already checked in - # create_hdf5_dataset - return None - + def _load_and_process_sft(self, tractogram_file, header): # Check file extension _, file_extension = os.path.splitext(str(tractogram_file)) if file_extension not in ['.trk', '.tck']: @@ -742,7 +739,7 @@ def _load_and_process_sft(self, tractogram_file, tractogram_name, header): # Loading tractogram and sending to wanted space logging.info(" - Processing tractogram {}" - .format(os.path.basename(tractogram_name))) + .format(os.path.basename(tractogram_file))) sft = load_tractogram(str(tractogram_file), header) # Resample or compress streamlines diff --git a/dwi_ml/data/io.py b/dwi_ml/data/io.py index df202c96..77bcc40a 100644 --- a/dwi_ml/data/io.py +++ b/dwi_ml/data/io.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +import os + import nibabel as nib import numpy as np @@ -19,7 +21,7 @@ def load_file_to4d(data_file): voxel_size: np.array with size 3, header: nibabel header. """ - ext = data_file.suffix + _, ext = os.path.splitext(data_file) if ext != '.gz' and ext != '.nii': raise ValueError('All data files should be nifti (.nii or .nii.gz) ' diff --git a/scripts_python/dwiml_create_hdf5_dataset.py b/scripts_python/dwiml_create_hdf5_dataset.py index 61fb176c..3377bc2f 100644 --- a/scripts_python/dwiml_create_hdf5_dataset.py +++ b/scripts_python/dwiml_create_hdf5_dataset.py @@ -9,7 +9,7 @@ - How to organize your data - How to prepare the config file - How to run this script. - https://dwi-ml.readthedocs.io/en/latest/2_B_preprocessing.html + https://dwi-ml.readthedocs.io/en/latest/2_A_creating_the_hdf5.html -------------------------------------- ** Note: The memory is a delicate question here, but checks have been made, and diff --git a/source/2_A_creating_the_hdf5.rst b/source/2_A_creating_the_hdf5.rst index 86abf217..ea7e1f55 100644 --- a/source/2_A_creating_the_hdf5.rst +++ b/source/2_A_creating_the_hdf5.rst @@ -87,7 +87,7 @@ To create the hdf5 file, you will need a config file such as below. HDF groups w } "bad_streamlines": { "type": "streamlines", - "files": ["bad_tractograms/ALL"] ---> Will get all trk and tck files. + "files": ["bad_tractograms/*"] ---> Will get all trk and tck files. } "wm_mask": { "type": "volume", @@ -111,8 +111,7 @@ Each group may have a number of parameters: - **"files"**: The listed file(s) must exist in every subject folder inside the root repository. That is: the files must be organized correctly on your computer (except if option 'enforce_files_presence is set to False). If there are more than one files, they will be concatenated (on the 4th dimension for volumes, using the union of tractograms for streamlines). - - There is the possibility to add a wildcard (\*) that will be replaced by the subject's id while loading. Ex: anat/\*__t1.nii.gz would become anat/subjX__t1.nii.gz. - - For streamlines, there is the possibility to use 'ALL' to load all tractograms present in a folder. + - There is the possibility to add a wildcard (\*). Additional attributes for volume groups: """""""""""""""""""""""""""""""""""""""" From e02b44f9a7de694b928c8913b8b48c95448f46c5 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 14 Feb 2024 14:16:39 -0500 Subject: [PATCH 10/23] Fix possible nested lists --- dwi_ml/data/hdf5/hdf5_creation.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 650c1cce..34e7b1be 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -309,6 +309,7 @@ def _verify_subjects_list(self): "testing set!".format(ignored_subj)) return unique_subjs + def _check_files_presence(self): """ Verifying now the list of files. Prevents stopping after a long @@ -320,20 +321,29 @@ def _check_files_presence(self): """ logging.debug("Verifying files presence") + def flatten_list(a_list): + new_list = [] + for element in a_list: + if isinstance(element, list): + new_list.extend(flatten_list(element)) + else: + new_list.append(element) + return new_list + # concatenating files from all groups files: - # sum: concatenates list of sub-lists - config_file_list = sum(nested_lookup('files', self.groups_config), []) - config_file_list += nested_lookup( - 'connectivity_matrix', self.groups_config) - config_file_list += nested_lookup('std_mask', self.groups_config) + config_file_list = [ + nested_lookup('files', self.groups_config), + nested_lookup('connectivity_matrix', self.groups_config), + nested_lookup('std_mask', self.groups_config)] + config_file_list = flatten_list(config_file_list) for subj_id in self.all_subjs: subj_input_dir = Path(self.root_folder).joinpath(subj_id) # Find subject's files from group_config - config_file_list = format_filelist(config_file_list, - self.enforce_files_presence, - folder=subj_input_dir) + _ = format_filelist(config_file_list, + self.enforce_files_presence, + folder=subj_input_dir) def create_database(self): """ @@ -465,7 +475,9 @@ def _process_one_volume_group(self, group: str, subj_id: str, if isinstance(std_masks, str): std_masks = [std_masks] - std_masks = format_filelist(std_masks, folder=subj_input_dir) + std_masks = format_filelist(std_masks, + self.enforce_files_presence, + folder=subj_input_dir) for mask in std_masks: logging.info(" - Loading standardization mask {}" .format(os.path.basename(mask))) From 38d61512d7d80e09b748c95ab45991d2d32da8a5 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Fri, 16 Feb 2024 08:57:16 -0500 Subject: [PATCH 11/23] Compute connectivity from labels: save rejected streamlines --- .../processing/streamlines/post_processing.py | 44 ++++++++++++------- ..._compute_connectivity_matrix_from_blocs.py | 4 +- ...compute_connectivity_matrix_from_labels.py | 35 ++++++++------- 3 files changed, 49 insertions(+), 34 deletions(-) diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index eabc079d..aee4852b 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -302,7 +302,6 @@ def _compute_origin_finish_blocs(streamlines, volume_size, nb_blocs): def compute_triu_connectivity_from_labels(streamlines, data_labels, - binary: bool = False, use_scilpy=False): """ Compute a connectivity matrix. @@ -313,8 +312,6 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, Streamlines, in vox space, corner origin. data_labels: np.ndarray The loaded nifti image. - binary: bool - If True, return a binary matrix. use_scilpy: bool If True, uses scilpy's method: 'Strategy is to keep the longest streamline segment @@ -380,6 +377,7 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, start_labels.append(start) end_labels.append(end) + matrix[start, end] += 1 if start != end: matrix[end, start] += 1 @@ -387,9 +385,6 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, matrix = np.triu(matrix) assert matrix.sum() == len(streamlines) - if binary: - matrix = matrix.astype(bool) - return matrix, real_labels, start_labels, end_labels @@ -463,9 +458,11 @@ def prepare_figure_connectivity(matrix): axs[1, 1].imshow(matrix) axs[1, 1].set_title("Binary") + plt.suptitle("All versions of the connectivity matrix.") + def find_streamlines_with_chosen_connectivity( - streamlines, label1, label2, start_labels, end_labels): + streamlines, start_labels, end_labels, label1, label2=None): """ Returns streamlines corresponding to a (label1, label2) or (label2, label1) connection. @@ -474,19 +471,32 @@ def find_streamlines_with_chosen_connectivity( ---------- streamlines: list of np arrays or list of tensors. Streamlines, in vox space, corner origin. - label1: int - The bloc of interest, either as starting or finishing point. - label2: int - The bloc of interest, either as starting or finishing point. start_labels: list[int] The starting bloc for each streamline. end_labels: list[int] The ending bloc for each streamline. + label1: int + The bloc of interest, either as starting or finishing point. + label2: int, optional + The bloc of interest, either as starting or finishing point. + If label2 is None, then all connections (label1, Y) and (X, label1) + are found. """ + start_labels = np.asarray(start_labels) + end_labels = np.asarray(end_labels) - str_ind1 = np.logical_and(start_labels == label1, - end_labels == label2) - str_ind2 = np.logical_and(start_labels == label2, - end_labels == label1) - str_ind = np.logical_or(str_ind1, str_ind2) - return [s for i, s in enumerate(streamlines) if str_ind[i]] + if label2 is None: + labels2 = np.unique(np.concatenate((start_labels[:], end_labels[:]))) + else: + labels2 = [label2] + + found = np.zeros(len(streamlines)) + for label2 in labels2: + str_ind1 = np.logical_and(start_labels == label1, + end_labels == label2) + str_ind2 = np.logical_and(start_labels == label2, + end_labels == label1) + str_ind = np.logical_or(str_ind1, str_ind2) + found = np.logical_or(found, str_ind) + + return [s for i, s in enumerate(streamlines) if found[i]] diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py index b7f42f79..2f76885a 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_blocs.py @@ -99,7 +99,7 @@ def main(): i, j = np.unravel_index(np.argmax(matrix, axis=None), matrix.shape) print("Saving biggest bundle: {} streamlines.".format(matrix[i, j])) biggest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, start_blocs, end_blocs, i, j) sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_biggest) @@ -108,7 +108,7 @@ def main(): i, j = np.unravel_index(tmp_matrix.argmin(axis=None), matrix.shape) print("Saving smallest bundle: {} streamlines.".format(matrix[i, j])) biggest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, start_blocs, end_blocs, i, j) sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_smallest) diff --git a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py index 9c682512..877653fd 100644 --- a/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py +++ b/scripts_python/dwiml_compute_connectivity_matrix_from_labels.py @@ -39,7 +39,7 @@ def _build_arg_parser(): "streamline count is saved.") p.add_argument('--show_now', action='store_true', help="If set, shows the matrix with matplotlib.") - p.add_argument('--hide_background', nargs='?', const=0, type=float, + p.add_argument('--hide_background', nargs='?', const=0, type=int, help="If true, set the connectivity matrix for chosen " "label (default: 0), to 0.") p.add_argument( @@ -80,10 +80,11 @@ def main(): p.error("--out_file should have a .npy extension.") out_fig = tmp + '.png' - out_fig_noback = tmp + '_hidden_background.png' out_ordered_labels = tmp + '_labels.txt' + out_rejected_streamlines = tmp + '_rejected_from_background.trk' assert_inputs_exist(p, [args.in_labels, args.streamlines]) - assert_outputs_exist(p, args, [args.out_file, out_fig, out_fig_noback], + assert_outputs_exist(p, args, + [args.out_file, out_fig, out_rejected_streamlines], [args.save_biggest, args.save_smallest]) ext = os.path.splitext(args.streamlines)[1] @@ -101,23 +102,26 @@ def main(): in_sft.to_vox() in_sft.to_corner() - matrix, ordered_labels, start_blocs, end_blocs = \ + matrix, ordered_labels, start_labels, end_labels = \ compute_triu_connectivity_from_labels( in_sft.streamlines, data_labels, use_scilpy=args.use_longest_segment) - prepare_figure_connectivity(matrix) - plt.savefig(out_fig) - if args.hide_background is not None: idx = ordered_labels.index(args.hide_background) nb_hidden = np.sum(matrix[idx, :]) + np.sum(matrix[:, idx]) - \ matrix[idx, idx] if nb_hidden > 0: - logging.info("CAREFUL! Deleting from the matrix {} streamlines " - "with one or both endpoints in a non-labelled area " - "(background = {}; line/column {})" - .format(nb_hidden, args.hide_background, idx)) + logging.warning("CAREFUL! Deleting from the matrix {} streamlines " + "with one or both endpoints in a non-labelled " + "area (background = {}; line/column {})" + .format(nb_hidden, args.hide_background, idx)) + rejected = find_streamlines_with_chosen_connectivity( + in_sft.streamlines, start_labels, end_labels, idx) + logging.info("Saving rejected streamlines in {}" + .format(out_rejected_streamlines)) + sft = in_sft.from_sft(rejected, in_sft) + save_tractogram(sft, out_rejected_streamlines) else: logging.info("No streamlines with endpoints in the background :)") matrix[idx, :] = 0 @@ -125,8 +129,9 @@ def main(): ordered_labels[idx] = ("Hidden background ({})" .format(args.hide_background)) - prepare_figure_connectivity(matrix) - plt.savefig(out_fig_noback) + # Save figure will all versions of the matrix. + prepare_figure_connectivity(matrix) + plt.savefig(out_fig) if args.binary: matrix = matrix > 0 @@ -143,7 +148,7 @@ def main(): .format(matrix[i, j], ordered_labels[i], ordered_labels[j], i, j)) biggest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, i, j, start_labels, end_labels) sft = in_sft.from_sft(biggest, in_sft) save_tractogram(sft, args.save_biggest) @@ -155,7 +160,7 @@ def main(): .format(matrix[i, j], ordered_labels[i], ordered_labels[j], i, j)) smallest = find_streamlines_with_chosen_connectivity( - in_sft.streamlines, i, j, start_blocs, end_blocs) + in_sft.streamlines, i, j, start_labels, end_labels) sft = in_sft.from_sft(smallest, in_sft) save_tractogram(sft, args.save_smallest) From 8a63853d518468cc06554ff4f01b36d0ed570d3d Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 15 Feb 2024 16:07:32 -0500 Subject: [PATCH 12/23] Create hdf5: add connectivity labels volume --- dwi_ml/data/hdf5/hdf5_creation.py | 40 ++++++++++++++----------------- source/2_A_creating_the_hdf5.rst | 4 ++-- 2 files changed, 20 insertions(+), 24 deletions(-) diff --git a/dwi_ml/data/hdf5/hdf5_creation.py b/dwi_ml/data/hdf5/hdf5_creation.py index 34e7b1be..b1f2d170 100644 --- a/dwi_ml/data/hdf5/hdf5_creation.py +++ b/dwi_ml/data/hdf5/hdf5_creation.py @@ -11,6 +11,7 @@ from dipy.io.utils import is_header_compatible from dipy.tracking.utils import length import h5py +from scilpy.image.labels import get_data_as_labels from dwi_ml.data.hdf5.utils import format_nb_blocs_connectivity from dwi_ml.data.processing.streamlines.data_augmentation import \ @@ -45,7 +46,7 @@ def format_filelist(filenames, enforce_presence, folder=None) -> List[str]: else: logging.warning(msg) else: - new_files.extend(f) + new_files.extend(tmp) else: if not Path(f).is_file(): msg = "File not found: {}".format(f) @@ -80,8 +81,6 @@ def _load_and_verify_file(filename: str, subj_input_path, group_name: str, """ data_file = subj_input_path.joinpath(filename) - logging.info(" - Processing file {}".format(filename)) - if not data_file.is_file(): logging.debug(" Skipping file {} because it was not " "found in this subject's folder".format(filename)) @@ -309,7 +308,6 @@ def _verify_subjects_list(self): "testing set!".format(ignored_subj)) return unique_subjs - def _check_files_presence(self): """ Verifying now the list of files. Prevents stopping after a long @@ -334,6 +332,7 @@ def flatten_list(a_list): config_file_list = [ nested_lookup('files', self.groups_config), nested_lookup('connectivity_matrix', self.groups_config), + nested_lookup('connectivity_labels', self.groups_config), nested_lookup('std_mask', self.groups_config)] config_file_list = flatten_list(config_file_list) @@ -472,14 +471,11 @@ def _process_one_volume_group(self, group: str, subj_id: str, else: # Load subject's standardization mask. Can be a list of files. std_masks = self.groups_config[group]['std_mask'] - if isinstance(std_masks, str): - std_masks = [std_masks] - std_masks = format_filelist(std_masks, self.enforce_files_presence, folder=subj_input_dir) for mask in std_masks: - logging.info(" - Loading standardization mask {}" + logging.info(" - Loading standardization mask {}" .format(os.path.basename(mask))) sub_mask_data = nib.load(mask).get_fdata() > 0 if std_mask is None: @@ -492,7 +488,7 @@ def _process_one_volume_group(self, group: str, subj_id: str, folder=subj_input_dir) # First file will define data dimension and affine - logging.info(" - Processing file {}" + logging.info(" - Processing file {} (first file=reference) " .format(os.path.basename(file_list[0]))) group_data, group_affine, group_res, group_header = load_file_to4d( file_list[0]) @@ -513,7 +509,7 @@ def _process_one_volume_group(self, group: str, subj_id: str, group_affine, group_res) if std_option == 'per_file': - logging.debug(' *Standardizing sub-data') + logging.info(' - Standardizing') data = standardize_data(data, std_mask, independent=False) # Append file data to hdf group. @@ -526,11 +522,11 @@ def _process_one_volume_group(self, group: str, subj_id: str, # Standardize data (per channel) (if not done 'per_file' yet). if std_option == 'independent': - logging.debug(' *Standardizing data on each feature.') + logging.info(' - Standardizing data on each feature.') group_data = standardize_data(group_data, std_mask, independent=True) elif std_option == 'all': - logging.debug(' *Standardizing data as a whole.') + logging.info(' - Standardizing data as a whole.') group_data = standardize_data(group_data, std_mask, independent=False) elif std_option not in ['none', 'per_file']: @@ -596,9 +592,9 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, 'connectivity_matrix_type'] = conn_info[0] streamlines_group.create_dataset( 'connectivity_matrix', data=connectivity_matrix) - if conn_info[0] == 'from_label': - streamlines_group.attrs['connectivity_label_volume'] = \ - conn_info[1] + if conn_info[0] == 'from_labels': + streamlines_group.create_dataset( + 'connectivity_label_volume', data=conn_info[1]) else: streamlines_group.attrs['connectivity_nb_blocs'] = \ conn_info[1] @@ -606,7 +602,8 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id, if len(sft.data_per_point) > 0: logging.debug('sft contained data_per_point. Data not kept.') if len(sft.data_per_streamline) > 0: - logging.debug('sft contained data_per_streamlines. Data not kept.') + logging.debug('sft contained data_per_streamlines. Data not ' + 'kept.') # Accessing private Dipy values, but necessary. # We need to deconstruct the streamlines into arrays with @@ -720,12 +717,11 @@ def _process_one_streamline_group( nb_blocs = format_nb_blocs_connectivity( self.groups_config[group]['connectivity_nb_blocs']) conn_info = ['from_blocs', nb_blocs] - else: - labels_group = self.groups_config[group]['connectivity_labels'] - if labels_group not in self.volume_groups: - raise ValueError("connectivity_labels_volume must be " - "an existing volume group.") - conn_info = ['from_labels', labels_group] + else: # labels + labels_file = self.groups_config[group]['connectivity_labels'] + labels_file = os.path.join(subj_dir, labels_file) + labels_data = get_data_as_labels(nib.load(labels_file)) + conn_info = ['from_labels', labels_data] conn_file = subj_dir.joinpath( self.groups_config[group]['connectivity_matrix']) diff --git a/source/2_A_creating_the_hdf5.rst b/source/2_A_creating_the_hdf5.rst index ea7e1f55..e7a96156 100644 --- a/source/2_A_creating_the_hdf5.rst +++ b/source/2_A_creating_the_hdf5.rst @@ -83,7 +83,7 @@ To create the hdf5 file, you will need a config file such as below. HDF groups w "files": ["tractograms/bundle1.trk", "tractograms/wholebrain.trk", "tractograms/*__wholebrain.trk"], ----> Will get, for instance, sub1000__bundle1.trk "connectivity_matrix": "my_file.npy", "connectivity_nb_blocs": 6 ---> OR - "connectivity_labels": labels_volume + "connectivity_labels": labels_volume_group } "bad_streamlines": { "type": "streamlines", @@ -136,7 +136,7 @@ Additional attributes for streamlines groups: - **connectivity_matrix**: The name of the connectivity matrix to associate to the streamline group. This matrix will probably be used as a mean of validation during training. Then, you also need to explain how the matrix was created, so that you can create the connectivity matrix of the streamlines being validated, in order to compare it with the expected result. ONE of the two next options must be given: - **connectivity_nb_blocs**: This explains that the connectivity matrix was created by dividing the volume space into regular blocs. See dwiml_compute_connectivity_matrix_from_blocs for a description. The value should be either an integers or a list of three integers. - - **connectivity_labels_volume**: This explains that the connectivity matrix was created by dividing the cortex into a list of regions associated with labels. The value must be the name of another volume group in the same config file, which refers to a map with one label per region. NOTE: This option is offered in preparation of future use only. Currently, you can create the hdf5 with this option, but connectivity computation using labels is not yet implemented in dwi_ml. + - **connectivity_labels**: This explains that the connectivity matrix was created by dividing the cortex into a list of regions associated with labels. The value must be the name of the associated labels file (typically a nifti file filled with integers). 2.4. Creating the hdf5 ********************** From 9f1b9922527c842a234faeb13e4133a024ab4942 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Mon, 26 Feb 2024 14:19:21 -0500 Subject: [PATCH 13/23] Add connectivity labels volume in SFTData --- dwi_ml/data/dataset/streamline_containers.py | 75 +++++++++++++------ .../processing/streamlines/post_processing.py | 12 ++- .../training/with_generation/batch_loader.py | 7 +- dwi_ml/training/with_generation/trainer.py | 26 +++++-- 4 files changed, 80 insertions(+), 40 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index 34241dfa..0c2991b7 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -46,6 +46,27 @@ def _load_all_streamlines_from_hdf(hdf_group: h5py.Group): return streamlines +def _load_connectivity_info(hdf_group: h5py.Group): + connectivity_nb_blocs = None + connectivity_labels = None + if 'connectivity_matrix' in hdf_group: + contains_connectivity = True + if 'connectivity_nb_blocs' in hdf_group.attrs: + connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs'] + elif 'connectivity_label_volume' in hdf_group: + connectivity_labels = np.asarray( + hdf_group['connectivity_label_volume'], dtype=int) + else: + raise ValueError( + "Information stored in the hdf5 is that it contains a " + "connectivity matrix, but we don't know how it was " + "created. Either 'connectivity_nb_blocs' or " + "'connectivity_labels' should be set.") + else: + contains_connectivity = False + return contains_connectivity, connectivity_nb_blocs, connectivity_labels + + class _LazyStreamlinesGetter(object): def __init__(self, hdf_group): self.hdf_group = hdf_group @@ -141,20 +162,30 @@ class SFTDataAbstract(object): """ def __init__(self, space_attributes: Tuple, space: Space, origin: Origin, contains_connectivity: bool, - connectivity_nb_blocs: List): + connectivity_nb_blocs: List = None, + connectivity_labels: np.ndarray = None): """ - Params - ------ - group: str - The current streamlines group id, as loaded in the hdf5 file (it - had type "streamlines"). Probabaly 'streamlines'. + The lazy/non-lazy versions will have more parameters, such as the + streamlines, the connectivity_matrix. In the case of the lazy version, + through the LazyStreamlinesGetter. + + Parameters + ---------- space_attributes: Tuple The space attributes consist of a tuple: (affine, dimensions, voxel_sizes, voxel_order) space: Space The space from dipy's Space format. - subject_id: str: - The subject's name + origin: Origin + The origin from dipy's Origin format. + contains_connectivity: bool + If true, will search for either the connectivity_nb_blocs or the + connectivity_from_labels information. + connectivity_nb_blocs: List + The information how to recreate the connectivity matrix. + connectivity_labels: np.ndarray + The 3D volume stating how to recreate the labels. + (toDo: Could be managed to be lazy) """ self.space_attributes = space_attributes self.space = space @@ -162,6 +193,7 @@ def __init__(self, space_attributes: Tuple, space: Space, origin: Origin, self.is_lazy = None self.contains_connectivity = contains_connectivity self.connectivity_nb_blocs = connectivity_nb_blocs + self.connectivity_labels = connectivity_labels def __len__(self): raise NotImplementedError @@ -195,7 +227,7 @@ def get_connectivity_matrix_and_info(self, ind=None): (_, ref_volume_shape, _, _) = self.space_attributes return (self._access_connectivity_matrix(ind), ref_volume_shape, - self.connectivity_nb_blocs) + self.connectivity_nb_blocs, self.connectivity_labels) def _access_connectivity_matrix(self, ind): raise NotImplementedError @@ -277,15 +309,14 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): streamlines = _load_all_streamlines_from_hdf(hdf_group) # Adding non-hidden parameters for nicer later access lengths_mm = hdf_group['euclidean_lengths'] - if 'connectivity_matrix' in hdf_group: - contains_connectivity = True - connectivity_matrix = np.asarray(hdf_group['connectivity_matrix'], - dtype=int) - connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs'] + + contains_connectivity, connectivity_nb_blocs, connectivity_labels = \ + _load_connectivity_info(hdf_group) + if contains_connectivity: + connectivity_matrix = np.asarray( + hdf_group['connectivity_matrix'], dtype=int) # int or bool? else: - contains_connectivity = False connectivity_matrix = None - connectivity_nb_blocs = None space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group) @@ -296,7 +327,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): space_attributes=space_attributes, space=space, origin=origin, contains_connectivity=contains_connectivity, - connectivity_nb_blocs=connectivity_nb_blocs) + connectivity_nb_blocs=connectivity_nb_blocs, + connectivity_labels=connectivity_labels) def _get_streamlines_as_list(self, streamline_ids): if streamline_ids is not None: @@ -337,12 +369,9 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): @classmethod def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group) - if 'connectivity_matrix' in hdf_group: - contains_connectivity = True - connectivity_nb_blocs = hdf_group.attrs['connectivity_nb_blocs'] - else: - contains_connectivity = False - connectivity_nb_blocs = None + + contains_connectivity, connectivity_nb_blocs, connectivity_labels = \ + _load_connectivity_info(hdf_group) streamlines = _LazyStreamlinesGetter(hdf_group) diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index aee4852b..1e546c3f 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -330,6 +330,10 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, Else, shape (nb_labels, nb_labels) labels: List The list of labels + start_labels: List + For each streamline, the label at starting point. + end_labels: List + For each streamline, the label at ending point. """ real_labels = list(np.sort(np.unique(data_labels))) nb_labels = len(real_labels) @@ -388,8 +392,7 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, return matrix, real_labels, start_labels, end_labels -def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs, - binary: bool = False): +def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs): """ Compute a connectivity matrix. @@ -405,8 +408,6 @@ def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs, In 3D, with 20x20x20, this is an 8000 x 8000 matrix (triangular). It probably contains a lot of zeros with the background being included. Can be saved as sparse. - binary: bool - If true, return a binary matrix. """ nb_blocs = np.asarray(nb_blocs) start_block, end_block = _compute_origin_finish_blocs( @@ -425,9 +426,6 @@ def compute_triu_connectivity_from_blocs(streamlines, volume_size, nb_blocs, matrix = np.triu(matrix) assert matrix.sum() == len(streamlines) - if binary: - matrix = matrix.astype(bool) - return matrix, start_block, end_block diff --git a/dwi_ml/training/with_generation/batch_loader.py b/dwi_ml/training/with_generation/batch_loader.py index 6fca9851..77db3f8e 100644 --- a/dwi_ml/training/with_generation/batch_loader.py +++ b/dwi_ml/training/with_generation/batch_loader.py @@ -24,6 +24,7 @@ def load_batch_connectivity_matrices( matrices = [None] * nb_subjs volume_sizes = [None] * nb_subjs connectivity_nb_blocs = [None] * nb_subjs + connectivity_labels = [None] * nb_subjs for i, subj in enumerate(subjs): # No cache for the sft data. Accessing it directly. # Note: If this is used through the dataloader, multiprocessing @@ -34,7 +35,9 @@ def load_batch_connectivity_matrices( # We could access it only at required index, maybe. Loading the # whole matrix here. - matrices[i], volume_sizes[i], connectivity_nb_blocs[i] = \ + (matrices[i], volume_sizes[i], + connectivity_nb_blocs[i], connectivity_labels[i]) = \ subj_sft_data.get_connectivity_matrix_and_info() - return matrices, volume_sizes, connectivity_nb_blocs + return (matrices, volume_sizes, + connectivity_nb_blocs, connectivity_labels) diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/with_generation/trainer.py index d516b16e..5288639d 100644 --- a/dwi_ml/training/with_generation/trainer.py +++ b/dwi_ml/training/with_generation/trainer.py @@ -43,7 +43,7 @@ from torch.nn import PairwiseDistance from dwi_ml.data.processing.streamlines.post_processing import \ - compute_triu_connectivity_from_blocs + compute_triu_connectivity_from_blocs, compute_triu_connectivity_from_labels from dwi_ml.models.main_models import ModelWithDirectionGetter from dwi_ml.tracking.propagation import propagate_multiple_lines from dwi_ml.tracking.io_utils import prepare_tracking_mask @@ -317,7 +317,8 @@ def _compare_connectivity(self, lines, ids_per_subj): compares with expected values for the subject. """ if self.compute_connectivity: - connectivity_matrices, volume_sizes, connectivity_nb_blocs = \ + (connectivity_matrices, volume_sizes, + connectivity_nb_blocs, connectivity_labels) = \ self.batch_loader.load_batch_connectivity_matrices(ids_per_subj) score = 0.0 @@ -325,20 +326,29 @@ def _compare_connectivity(self, lines, ids_per_subj): real_matrix = connectivity_matrices[i] volume_size = volume_sizes[i] nb_blocs = connectivity_nb_blocs[i] + labels = connectivity_labels[i] _lines = lines[ids_per_subj[subj]] - batch_matrix, _, _ = compute_triu_connectivity_from_blocs( - _lines, volume_size, nb_blocs, binary=False) + # Reference matrices are saved as binary in create_hdf5, + # but still. Ensuring. + real_matrix = real_matrix > 0 + + # But our matrix here won't be! + if nb_blocs is not None: + batch_matrix, _, _ = compute_triu_connectivity_from_blocs( + _lines, volume_size, nb_blocs) + else: + # ToDo. Allow use_scilpy? + batch_matrix, _, _ = compute_triu_connectivity_from_labels( + _lines, labels, use_scilpy=False) # Where our batch has a 0: not important, maybe it was simply # not in this batch. # Where our batch has a 1, if there was really a one: score - # should be 0. = 1 - 1. - # Else, score should be high (1). = 1 - 0. + # should be 0. = 1 - 1 = 1 - real + # Else, score should be high (1). = 1 - 0 = 1 - real # If two streamlines have the same connection, score is # either 0 or 2 for that voxel. ==> nb * (1 - real). - - # Reference matrices are saved as binary in create_hdf5. where_one = np.where(batch_matrix > 0) score += np.sum(batch_matrix[where_one] * (1.0 - real_matrix[where_one])) From 06fda9e788a76515f871ab94f05c7a4faf1b39f6 Mon Sep 17 00:00:00 2001 From: Emmanuelle Renauld Date: Tue, 27 Feb 2024 10:28:20 -0500 Subject: [PATCH 14/23] Fix unncessarty call of iterator --- dwi_ml/training/trainers.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index 7d8b45a5..f2936e91 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -798,12 +798,6 @@ def train_one_epoch(self, epoch): train_iterator = enumerate(pbar) for batch_id, data in train_iterator: - # Break if maximum number of batches has been reached - if batch_id == self.nb_batches_train: - # Explicitly close tqdm's progress bar to fix possible bugs - # when breaking the loop - pbar.close() - break # Enable gradients for backpropagation. Uses torch's module # train(), which "turns on" the training mode. @@ -814,6 +808,15 @@ def train_one_epoch(self, epoch): self.unclipped_grad_norm_monitor.update(unclipped_grad_norm) self.grad_norm_monitor.update(grad_norm) + # Break if maximum number of batches has been reached + # Break before calling the next train_iterator because it would load + # the batch. + if batch_id == self.nb_batches_train - 1: + # Explicitly close tqdm's progress bar to fix possible bugs + # when breaking the loop + pbar.close() + break + # Explicitly delete iterator to kill threads and free memory before # running validation del train_iterator @@ -854,17 +857,19 @@ def validate_one_epoch(self, epoch): tqdm_class=tqdm) as pbar: valid_iterator = enumerate(pbar) for batch_id, data in valid_iterator: + + # Validate this batch: forward propagation + loss + with torch.no_grad(): + self.validate_one_batch(data, epoch) + # Break if maximum number of epochs has been reached - if batch_id == self.nb_batches_valid: + # Break before calling the next valid_iterator to avoid loading batch + if batch_id == self.nb_batches_valid - 1: # Explicitly close tqdm's progress bar to fix possible bugs # when breaking the loop pbar.close() break - # Validate this batch: forward propagation + loss - with torch.no_grad(): - self.validate_one_batch(data, epoch) - # Explicitly delete iterator to kill threads and free memory before # running training again del valid_iterator From 3f540a4a0c211239422f4f19eeab8de301bff58d Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 27 Feb 2024 14:51:16 -0500 Subject: [PATCH 15/23] Fix l2t propagation during validation for many subjs --- dwi_ml/models/projects/learn2track_model.py | 21 +++--- .../tracking/projects/learn2track_tracker.py | 2 +- .../training/projects/learn2track_trainer.py | 74 ++++++++++++------- 3 files changed, 59 insertions(+), 38 deletions(-) diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index 00d57332..dd3e9e85 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -218,7 +218,7 @@ def computed_params_for_display(self): def forward(self, x: List[torch.tensor], input_streamlines: List[torch.tensor] = None, - hidden_recurrent_states: tuple = None, return_hidden=False, + hidden_recurrent_states: List = None, return_hidden=False, point_idx: int = None): """Run the model on a batch of sequences. @@ -234,7 +234,7 @@ def forward(self, x: List[torch.tensor], Batch of streamlines. Only used if previous directions are added to the model. Used to compute directions; its last point will not be used. - hidden_recurrent_states : tuple + hidden_recurrent_states : list[states] The current hidden states of the (stacked) RNN model. return_hidden: bool point_idx: int @@ -442,8 +442,7 @@ def copy_prev_dir(self, dirs): return copy_prev_dir - def remove_lines_in_hidden_state( - self, hidden_recurrent_states, lines_to_keep): + def take_lines_in_hidden_state(self, hidden_states, lines_to_keep): """ Utilitary method to remove a few streamlines from the hidden state. @@ -451,14 +450,12 @@ def remove_lines_in_hidden_state( if self.rnn_model.rnn_torch_key == 'lstm': # LSTM: For each layer, states are tuples; (h_t, C_t) # Size of tensors are each [1, nb_streamlines, nb_neurons] - hidden_recurrent_states = [ - (layer_states[0][:, lines_to_keep, :], - layer_states[1][:, lines_to_keep, :]) for - layer_states in hidden_recurrent_states] + hidden_states = [(layer_states[0][:, lines_to_keep, :], + layer_states[1][:, lines_to_keep, :]) for + layer_states in hidden_states] else: # GRU: For each layer, states are tensors; h_t. # Size of tensors are [1, nb_streamlines, nb_neurons]. - hidden_recurrent_states = [ - layer_states[:, lines_to_keep, :] for - layer_states in hidden_recurrent_states] - return hidden_recurrent_states + hidden_states = [layer_states[:, lines_to_keep, :] for + layer_states in hidden_states] + return hidden_states diff --git a/dwi_ml/tracking/projects/learn2track_tracker.py b/dwi_ml/tracking/projects/learn2track_tracker.py index 06ec3216..76336ac5 100644 --- a/dwi_ml/tracking/projects/learn2track_tracker.py +++ b/dwi_ml/tracking/projects/learn2track_tracker.py @@ -89,5 +89,5 @@ def update_memory_after_removing_lines(self, can_continue: np.ndarray, _): Indexes of lines that are kept. """ # Hidden states: list[states] (One value per layer). - self.hidden_recurrent_states = self.model.remove_lines_in_hidden_state( + self.hidden_recurrent_states = self.model.take_lines_in_hidden_state( self.hidden_recurrent_states, can_continue) diff --git a/dwi_ml/training/projects/learn2track_trainer.py b/dwi_ml/training/projects/learn2track_trainer.py index d6eb8a10..ee7275b9 100644 --- a/dwi_ml/training/projects/learn2track_trainer.py +++ b/dwi_ml/training/projects/learn2track_trainer.py @@ -28,44 +28,67 @@ def __init__(self, **kwargs): def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): assert self.model.step_size is not None, \ "We can't propagate compressed streamlines." + theta = 2 * np.pi # theta = 360 degrees + max_nbr_pts = int(200 / self.model.step_size) - # Running the beginning of the streamlines to get the hidden states - # (using one less point. The next will be done during propagation). - if self.tracking_phase_nb_segments_init > 0: - tmp_lines = [line[:-1, :] for line in lines] - inputs = self.batch_loader.load_batch_inputs(tmp_lines, ids_per_subj) - _, hidden_states = self.model(inputs, tmp_lines, return_hidden=True) - del tmp_lines, inputs - else: - hidden_states = None - + # These methods will be used during the loop on subjects def update_memory_after_removing_lines(can_continue: np.ndarray, _): - nonlocal hidden_states - hidden_states = self.model.remove_lines_in_hidden_state( - hidden_states, can_continue) + nonlocal subjs_hidden_states + nonlocal subj_batch_nb + subjs_hidden_states[subj_batch_nb] = ( + self.model.take_lines_in_hidden_state( + subjs_hidden_states[subj_batch_nb], can_continue)) def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): - nonlocal hidden_states + # Get dirs for current subject: run model + nonlocal subjs_hidden_states + nonlocal subj_idx + nonlocal subj_batch_nb n_last_pos = [pos[None, :] for pos in n_last_pos] - batch_inputs = self.batch_loader.load_batch_inputs(n_last_pos, - ids_per_subj) - - model_outputs, hidden_states = self.model( - batch_inputs, _lines, hidden_recurrent_states=hidden_states, + # Amongst the current batch of streamlines (n_pos), the ones + # belonging to current subject are: all of them! + subj_dict = {subj_idx: slice(0, len(n_last_pos))} + subj_inputs = self.batch_loader.load_batch_inputs( + n_last_pos, subj_dict) + + model_outputs, subjs_hidden_states[subj_batch_nb] = self.model( + subj_inputs, _lines, + hidden_recurrent_states=subjs_hidden_states[subj_batch_nb], return_hidden=True, point_idx=-1) next_dirs = self.model.get_tracking_directions( model_outputs, algo='det', eos_stopping_thresh=0.5) return next_dirs - theta = 2 * np.pi # theta = 360 degrees - max_nbr_pts = int(200 / self.model.step_size) + # Running the beginning of the streamlines to get the hidden states + # (using one less point. The next will be done during propagation). + # Here, subjs_hidden_states will be a list of hidden_states per subj. + if self.tracking_phase_nb_segments_init > 0: + tmp_lines = [line[:-1, :] for line in lines] + inputs = self.batch_loader.load_batch_inputs( + tmp_lines, ids_per_subj) + _, whole_hidden_states = self.model(inputs, tmp_lines, + return_hidden=True) + + subjs_hidden_states = [ + self.model.take_lines_in_hidden_state(whole_hidden_states, + subj_slice) + for subj, subj_slice in ids_per_subj.items()] + del tmp_lines, inputs, whole_hidden_states + else: + subjs_hidden_states = None + # Running the propagation separately for each subject + # (because they all need their own tracking mask) final_lines = [] - for subj_idx, line_idx in ids_per_subj.items(): + subj_batch_nb = -1 + for subj_idx, subj_line_idx_slice in ids_per_subj.items(): + subj_batch_nb += 1 # Will be used as non-local in methods above - with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + # Load the subject's tracking mask + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r' + ) as hdf_handle: subj_id = self.batch_loader.context_subset.subjects[subj_idx] logging.debug("Loading subj {} ({})'s tracking mask." .format(subj_idx, subj_id)) @@ -74,9 +97,10 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): mask_interp='nearest') tracking_mask.move_to(self.device) + # Propagates all lines for this subject final_lines.extend(propagate_multiple_lines( - lines[line_idx], update_memory_after_removing_lines, - get_dirs_at_last_pos, theta=theta, + lines[subj_line_idx_slice], update_memory_after_removing_lines, + get_next_dirs=get_dirs_at_last_pos, theta=theta, step_size=self.model.step_size, verify_opposite_direction=False, mask=tracking_mask, max_nbr_pts=max_nbr_pts, append_last_point=False, normalize_directions=True)) From 45a6245cb43bb634d3201d58425599f56ae0f878 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Tue, 27 Feb 2024 14:56:33 -0500 Subject: [PATCH 16/23] Fix return of labels. --- dwi_ml/data/dataset/streamline_containers.py | 3 ++- dwi_ml/training/with_generation/trainer.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index 0c2991b7..e7b2c894 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -379,7 +379,8 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): space_attributes=space_attributes, space=space, origin=origin, contains_connectivity=contains_connectivity, - connectivity_nb_blocs=connectivity_nb_blocs) + connectivity_nb_blocs=connectivity_nb_blocs, + connectivity_labels=connectivity_labels) def _get_streamlines_as_list(self, streamline_ids): streamlines = self.streamlines_getter.get_array_sequence(streamline_ids) diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/with_generation/trainer.py index 5288639d..263cbf7e 100644 --- a/dwi_ml/training/with_generation/trainer.py +++ b/dwi_ml/training/with_generation/trainer.py @@ -317,6 +317,8 @@ def _compare_connectivity(self, lines, ids_per_subj): compares with expected values for the subject. """ if self.compute_connectivity: + # toDo. See if it's too much to keep them all in memory. Could be + # done in the loop for each subject. (connectivity_matrices, volume_sizes, connectivity_nb_blocs, connectivity_labels) = \ self.batch_loader.load_batch_connectivity_matrices(ids_per_subj) @@ -329,6 +331,9 @@ def _compare_connectivity(self, lines, ids_per_subj): labels = connectivity_labels[i] _lines = lines[ids_per_subj[subj]] + # Move to cpu, numpy now. + _lines = [line.cpu().numpy() for line in _lines] + # Reference matrices are saved as binary in create_hdf5, # but still. Ensuring. real_matrix = real_matrix > 0 @@ -338,9 +343,10 @@ def _compare_connectivity(self, lines, ids_per_subj): batch_matrix, _, _ = compute_triu_connectivity_from_blocs( _lines, volume_size, nb_blocs) else: - # ToDo. Allow use_scilpy? - batch_matrix, _, _ = compute_triu_connectivity_from_labels( - _lines, labels, use_scilpy=False) + # Note: scilpy usage not ready! Simple endpoints position + batch_matrix, _, _, _ =\ + compute_triu_connectivity_from_labels( + _lines, labels, use_scilpy=False) # Where our batch has a 0: not important, maybe it was simply # not in this batch. From de0d20d9556f7ca3a91a12515ca999c56c1ea935 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 28 Feb 2024 08:39:25 -0500 Subject: [PATCH 17/23] Change logging level in compute connectivity --- dwi_ml/data/dataset/streamline_containers.py | 6 ++++-- dwi_ml/data/processing/streamlines/post_processing.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index e7b2c894..d8c370ac 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -368,7 +368,8 @@ def _access_connectivity_matrix(self, indxyz: Tuple = None): @classmethod def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): - space_attributes, space, origin = _load_space_attributes_from_hdf(hdf_group) + space_attributes, space, origin = _load_space_attributes_from_hdf( + hdf_group) contains_connectivity, connectivity_nb_blocs, connectivity_labels = \ _load_connectivity_info(hdf_group) @@ -383,5 +384,6 @@ def init_sft_data_from_hdf_info(cls, hdf_group: h5py.Group): connectivity_labels=connectivity_labels) def _get_streamlines_as_list(self, streamline_ids): - streamlines = self.streamlines_getter.get_array_sequence(streamline_ids) + streamlines = self.streamlines_getter.get_array_sequence( + streamline_ids) return streamlines diff --git a/dwi_ml/data/processing/streamlines/post_processing.py b/dwi_ml/data/processing/streamlines/post_processing.py index 1e546c3f..21356b83 100644 --- a/dwi_ml/data/processing/streamlines/post_processing.py +++ b/dwi_ml/data/processing/streamlines/post_processing.py @@ -337,8 +337,8 @@ def compute_triu_connectivity_from_labels(streamlines, data_labels, """ real_labels = list(np.sort(np.unique(data_labels))) nb_labels = len(real_labels) - logging.info("Computing connectivity matrix for {} labels." - .format(nb_labels)) + logging.debug("Computing connectivity matrix for {} labels." + .format(nb_labels)) if use_scilpy: matrix = np.zeros((nb_labels + 1, nb_labels + 1), dtype=int) From 2e350cd5e6eac5037a704703318bd54d0b90c5fc Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 28 Feb 2024 08:42:30 -0500 Subject: [PATCH 18/23] Fix a few lines too long --- dwi_ml/training/with_generation/trainer.py | 44 +++++++++++++--------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/dwi_ml/training/with_generation/trainer.py b/dwi_ml/training/with_generation/trainer.py index 263cbf7e..bac638db 100644 --- a/dwi_ml/training/with_generation/trainer.py +++ b/dwi_ml/training/with_generation/trainer.py @@ -25,8 +25,8 @@ - Connectivity fit: Percentage of streamlines ending in a block of the volume indeed connected in the validation subject. Real connectivity matrices must be saved in the - hdf5. Right now, volumes are simply split into blocs (the same way as in the - hdf5, ex, to 10x10x10 volumes for a total of 1000 blocks), not based on + hdf5. Right now, volumes are simply split into blocs (the same way as in + the hdf5, ex, to 10x10x10 volumes for a total of 1000 blocks), not based on anatomical ROIs. It has the advantage that it does not rely on the quality of segmentation. It had the drawback that a generated streamline ending very close to the "true" streamline, but in another block, if the @@ -175,9 +175,10 @@ def validate_one_batch(self, data, epoch): logger.debug("Additional tracking-like generation validation " "from batch.") (gen_n, mean_final_dist, mean_clipped_final_dist, - percent_IS_very_good, percent_IS_acceptable, percent_IS_very_far, - diverging_pnt, connectivity) = self.validation_generation_one_batch( - data, compute_all_scores=True) + percent_IS_very_good, percent_IS_acceptable, + percent_IS_very_far, diverging_pnt, connectivity) = \ + self.validation_generation_one_batch( + data, compute_all_scores=True) self.tracking_very_good_IS_monitor.update( percent_IS_very_good, weight=gen_n) @@ -196,8 +197,9 @@ def validate_one_batch(self, data, epoch): self.tracking_connectivity_score_monitor.update( connectivity, weight=gen_n) elif len(self.tracking_mean_final_distance_monitor.average_per_epoch) == 0: - logger.info("Skipping tracking-like generation validation from " - "batch. No values yet: adding fake initial values.") + logger.info("Skipping tracking-like generation validation " + "from batch. No values yet: adding fake initial " + "values.") # Fake values at the beginning # Bad IS = 100% self.tracking_very_good_IS_monitor.update(100.0) @@ -216,8 +218,8 @@ def validate_one_batch(self, data, epoch): self.tracking_connectivity_score_monitor.update(1) else: - logger.info("Skipping tracking-like generation validation from " - "batch. Copying previous epoch's values.") + logger.info("Skipping tracking-like generation validation " + "from batch. Copying previous epoch's values.") # Copy previous value for monitor in [self.tracking_very_good_IS_monitor, self.tracking_acceptable_IS_monitor, @@ -238,7 +240,8 @@ def validation_generation_one_batch(self, data, compute_all_scores=False): # Possibly sending again to GPU even if done in the local loss # computation, but easier with current implementation. - real_lines = [line.to(self.device, non_blocking=True, dtype=torch.float) + real_lines = [line.to(self.device, non_blocking=True, + dtype=torch.float) for line in real_lines] last_pos = torch.vstack([line[-1, :] for line in real_lines]) @@ -267,7 +270,8 @@ def validation_generation_one_batch(self, data, compute_all_scores=False): final_dist_clipped = torch.mean(final_dist_clipped) # 2. Connectivity scores, if available (else None) - connectivity_score = self._compare_connectivity(lines, ids_per_subj) + connectivity_score = self._compare_connectivity(lines, + ids_per_subj) # 3. "IS ratio", i.e. percentage of streamlines ending inside a # predefined radius. @@ -280,9 +284,9 @@ def validation_generation_one_batch(self, data, compute_all_scores=False): final_dist = torch.mean(final_dist) # 4. Verify point where streamline starts diverging. - # abs(100 - score): 0 = good. 100 = bad (either abs(100) -> diverged - # at first point or abs(-100) = diverged after twice the expected - # length. + # abs(100 - score): 0 = good. 100 = bad (either + # abs(100) -> diverged at first point or + # abs(-100) = diverged after twice the expected length. total_point = 0 for line, real_line in zip(lines, real_lines): expected_nb = len(real_line) @@ -321,7 +325,8 @@ def _compare_connectivity(self, lines, ids_per_subj): # done in the loop for each subject. (connectivity_matrices, volume_sizes, connectivity_nb_blocs, connectivity_labels) = \ - self.batch_loader.load_batch_connectivity_matrices(ids_per_subj) + self.batch_loader.load_batch_connectivity_matrices( + ids_per_subj) score = 0.0 for i, subj in enumerate(ids_per_subj.keys()): @@ -365,7 +370,8 @@ def _compare_connectivity(self, lines, ids_per_subj): score = None return score - def propagate_multiple_lines(self, lines: List[torch.Tensor], ids_per_subj): + def propagate_multiple_lines(self, lines: List[torch.Tensor], + ids_per_subj): """ Tractography propagation of 'lines'. """ @@ -396,7 +402,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): # accept multiple masks or manage it differently. final_lines = [] for subj_idx, line_idx in ids_per_subj.items(): - with h5py.File(self.batch_loader.dataset.hdf5_file, 'r') as hdf_handle: + with h5py.File(self.batch_loader.dataset.hdf5_file, 'r' + ) as hdf_handle: subj_id = self.batch_loader.context_subset.subjects[subj_idx] logging.debug("Loading subj {} ({})'s tracking mask." .format(subj_idx, subj_id)) @@ -408,7 +415,8 @@ def get_dirs_at_last_pos(_lines: List[torch.Tensor], n_last_pos): final_lines.extend(propagate_multiple_lines( lines[line_idx], update_memory_after_removing_lines, get_dirs_at_last_pos, theta=theta, - step_size=self.model.step_size, verify_opposite_direction=False, + step_size=self.model.step_size, + verify_opposite_direction=False, mask=tracking_mask, max_nbr_pts=max_nbr_pts, append_last_point=False, normalize_directions=True)) From ed053638c37d7091a43e0927dba745c815ac915d Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Wed, 28 Feb 2024 09:02:38 -0500 Subject: [PATCH 19/23] More line too longs. Fix tqdm appearance --- dwi_ml/training/trainers.py | 49 +++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/dwi_ml/training/trainers.py b/dwi_ml/training/trainers.py index f2936e91..ed28fb47 100644 --- a/dwi_ml/training/trainers.py +++ b/dwi_ml/training/trainers.py @@ -14,7 +14,8 @@ from dwi_ml.experiment_utils.memory import log_gpu_memory_usage from dwi_ml.experiment_utils.tqdm_logging import tqdm_logging_redirect -from dwi_ml.models.main_models import MainModelAbstract, ModelWithDirectionGetter +from dwi_ml.models.main_models import (MainModelAbstract, + ModelWithDirectionGetter) from dwi_ml.training.batch_loaders import ( DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput) from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler @@ -220,8 +221,10 @@ def __init__(self, os.mkdir(self.saving_path) os.mkdir(self.log_dir) - assert np.all([param == self.batch_loader.dataset.params_for_checkpoint[key] for - key, param in self.batch_sampler.dataset.params_for_checkpoint.items()]) + assert np.all( + [param == self.batch_loader.dataset.params_for_checkpoint[key] for + key, param in + self.batch_sampler.dataset.params_for_checkpoint.items()]) if self.batch_sampler.dataset.validation_set.nb_subjects == 0: self.use_validation = False logger.warning( @@ -324,7 +327,8 @@ def __init__(self, cls = torch.optim.SGD # Learning rate will be set at each epoch. - self.optimizer = cls(self.model.parameters(), weight_decay=weight_decay) + self.optimizer = cls(self.model.parameters(), + weight_decay=weight_decay) @property def params_for_checkpoint(self): @@ -373,7 +377,8 @@ def save_params_to_json(self): }, indent=4, separators=(',', ': '))) - json_filename2 = os.path.join(self.saving_path, "parameters_latest.json") + json_filename2 = os.path.join(self.saving_path, + "parameters_latest.json") shutil.copyfile(json_filename, json_filename2) def save_checkpoint(self): @@ -486,7 +491,8 @@ def init_from_checkpoint( return trainer @staticmethod - def load_params_from_checkpoint(experiments_path: str, experiment_name: str): + def load_params_from_checkpoint(experiments_path: str, + experiment_name: str): total_path = os.path.join( experiments_path, experiment_name, "checkpoint", "checkpoint_state.pkl") @@ -619,7 +625,8 @@ def train_and_validate(self): - For each epoch - uses _train_one_epoch and _validate_one_epoch, - saves a checkpoint, - - checks for earlyStopping if the loss is bad or patience is reached, + - checks for earlyStopping if the loss is bad or patience is + reached, - saves the model if the loss is good. """ logger.debug("Trainer {}: \nRunning the model {}.\n\n" @@ -690,7 +697,8 @@ def train_and_validate(self): if self.comet_exp: self.comet_exp.log_metric( - "best_loss", self.best_epoch_monitor.best_value, step=epoch) + "best_loss", self.best_epoch_monitor.best_value, + step=epoch) # End of epoch, save checkpoint for resuming later self.save_checkpoint() @@ -804,14 +812,21 @@ def train_one_epoch(self, epoch): with grad_context(): mean_loss = self.train_one_batch(data) - unclipped_grad_norm, grad_norm = self.back_propagation(mean_loss) + unclipped_grad_norm, grad_norm = self.back_propagation( + mean_loss) self.unclipped_grad_norm_monitor.update(unclipped_grad_norm) self.grad_norm_monitor.update(grad_norm) # Break if maximum number of batches has been reached - # Break before calling the next train_iterator because it would load - # the batch. if batch_id == self.nb_batches_train - 1: + # Explicitly breaking the loop here, else it calls the + # train_dataloader one more time, which samples a new + # batch that is not used (if we have not finished sampling + # all after nb_batches). + # Sending one more step to the tqdm bar, else it finishes + # at nb - 1. + pbar.update(1) + # Explicitly close tqdm's progress bar to fix possible bugs # when breaking the loop pbar.close() @@ -863,8 +878,15 @@ def validate_one_epoch(self, epoch): self.validate_one_batch(data, epoch) # Break if maximum number of epochs has been reached - # Break before calling the next valid_iterator to avoid loading batch if batch_id == self.nb_batches_valid - 1: + # Explicitly breaking the loop here, else it calls the + # train_dataloader one more time, which samples a new + # batch that is not used (if we have not finished sampling + # all after nb_batches). + # Sending one more step to the tqdm bar, else it finishes + # at nb - 1. + pbar.update(1) + # Explicitly close tqdm's progress bar to fix possible bugs # when breaking the loop pbar.close() @@ -992,7 +1014,8 @@ def fix_parameters(self): unclipped_grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.clip_grad) else: - unclipped_grad_norm = compute_gradient_norm(self.model.parameters()) + unclipped_grad_norm = compute_gradient_norm( + self.model.parameters()) if torch.isnan(unclipped_grad_norm): raise ValueError("Exploding gradients. Experiment failed.") From a609b9477a2d5380d3eee4af68e2adf0993c4586 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 29 Feb 2024 09:01:53 -0500 Subject: [PATCH 20/23] Update scilpy script --- bash_utilities/scil_score_ismrm_Renauld2023.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bash_utilities/scil_score_ismrm_Renauld2023.sh b/bash_utilities/scil_score_ismrm_Renauld2023.sh index 7de205b2..d788b3bf 100644 --- a/bash_utilities/scil_score_ismrm_Renauld2023.sh +++ b/bash_utilities/scil_score_ismrm_Renauld2023.sh @@ -54,7 +54,7 @@ then fi echo '------------- FINAL SCORING ------------' -scil_score_bundles.py -v $config_file_tractometry $out_dir \ +scil_bundle_score_many_bundles_one_tractogram.py -v $config_file_tractometry $out_dir \ --gt_dir $scoring_data --reference $ref --no_bbox_check cat $out_dir/results.json \ No newline at end of file From 8d343ba2bb843fbe7157cecd24c8cc8a33f6cb59 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 29 Feb 2024 15:00:53 -0500 Subject: [PATCH 21/23] Remove new logging from scilpy. Warning max for root logger --- .../l2t_resume_training_from_checkpoint.py | 5 +- scripts_python/l2t_train_from_pretrained.py | 141 ------------------ scripts_python/l2t_train_model.py | 3 +- .../tt_resume_training_from_checkpoint.py | 5 +- scripts_python/tt_train_model.py | 4 +- 5 files changed, 9 insertions(+), 149 deletions(-) delete mode 100644 scripts_python/l2t_train_from_pretrained.py diff --git a/scripts_python/l2t_resume_training_from_checkpoint.py b/scripts_python/l2t_resume_training_from_checkpoint.py index d4f91b54..26e6cb2d 100644 --- a/scripts_python/l2t_resume_training_from_checkpoint.py +++ b/scripts_python/l2t_resume_training_from_checkpoint.py @@ -80,9 +80,8 @@ def main(): p = prepare_arg_parser() args = p.parse_args() - # Setting root logger with high level, but we will set trainer to - # user-defined level. - logging.getLogger().setLevel(level=logging.INFO) + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Verify if a checkpoint has been saved. checkpoint_path = os.path.join( diff --git a/scripts_python/l2t_train_from_pretrained.py b/scripts_python/l2t_train_from_pretrained.py deleted file mode 100644 index 14b3ff56..00000000 --- a/scripts_python/l2t_train_from_pretrained.py +++ /dev/null @@ -1,141 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -""" -Train a model for Learn2Track -""" -import argparse -import logging -import os - -# comet_ml not used, but comet_ml requires to be imported before torch. -# See bug report here https://github.com/Lightning-AI/lightning/issues/5829 -# Importing now to solve issues later. -import comet_ml -import torch - -from scilpy.io.utils import assert_inputs_exist, assert_outputs_exist - -from dwi_ml.data.dataset.utils import prepare_multisubjectdataset -from dwi_ml.experiment_utils.prints import format_dict_to_str -from dwi_ml.experiment_utils.timer import Timer -from dwi_ml.io_utils import add_logging_arg, add_memory_args -from dwi_ml.models.projects.learn2track_model import Learn2TrackModel -from dwi_ml.training.projects.learn2track_trainer import Learn2TrackTrainer -from dwi_ml.training.utils.batch_samplers import (add_args_batch_sampler, - prepare_batch_sampler) -from dwi_ml.training.utils.batch_loaders import (add_args_batch_loader, - prepare_batch_loader) -from dwi_ml.training.utils.experiment import ( - add_mandatory_args_experiment_and_hdf5_path) -from dwi_ml.training.utils.trainer import run_experiment, add_training_args, \ - format_lr - - -def prepare_arg_parser(): - p = argparse.ArgumentParser(description=__doc__, - formatter_class=argparse.RawTextHelpFormatter) - add_mandatory_args_experiment_and_hdf5_path(p) - p.add_argument('pretrained_model', - help="Name of the pretrained experiment (from the same " - "experiments path) from which to load the model. " - "Should contain a 'best_model' folder with pickle " - "information to load the model") - add_args_batch_sampler(p) - add_args_batch_loader(p) - training_group = add_training_args(p, add_a_tracking_validation_phase=True) - add_memory_args(p, add_lazy_options=True, add_rng=True) - add_logging_arg(p) - - # Additional arg for projects - training_group.add_argument( - '--clip_grad', type=float, default=None, - help="Value to which the gradient norms to avoid exploding gradients." - "\nDefault = None (not clipping).") - - return p - - -def init_from_args(args, sub_loggers_level): - torch.manual_seed(args.rng) # Set torch seed - - # Prepare the dataset - dataset = prepare_multisubjectdataset(args, load_testing=False, - log_level=sub_loggers_level) - - # Loading an existing model - logging.info("Loading existing model") - best_model_path = os.path.join(args.experiments_path, - args.pretrained_model, 'best_model') - model = Learn2TrackModel.load_model_from_params_and_state( - best_model_path, sub_loggers_level) - - # Preparing the batch samplers - batch_sampler = prepare_batch_sampler(dataset, args, sub_loggers_level) - batch_loader = prepare_batch_loader(dataset, model, args, sub_loggers_level) - - # Instantiate trainer - with Timer("\n\nPreparing trainer", newline=True, color='red'): - lr = format_lr(args.learning_rate) - trainer = Learn2TrackTrainer( - model=model, experiments_path=args.experiments_path, - experiment_name=args.experiment_name, batch_sampler=batch_sampler, - batch_loader=batch_loader, - # COMET - comet_project=args.comet_project, - comet_workspace=args.comet_workspace, - # TRAINING - learning_rates=lr, weight_decay=args.weight_decay, - optimizer=args.optimizer, max_epochs=args.max_epochs, - max_batches_per_epoch_training=args.max_batches_per_epoch_training, - max_batches_per_epoch_validation=args.max_batches_per_epoch_validation, - patience=args.patience, patience_delta=args.patience_delta, - from_checkpoint=False, clip_grad=args.clip_grad, - # (generation validation:) - add_a_tracking_validation_phase=args.add_a_tracking_validation_phase, - tracking_phase_frequency=args.tracking_phase_frequency, - tracking_phase_nb_segments_init=args.tracking_phase_nb_segments_init, - tracking_phase_mask_group=args.tracking_mask, - # MEMORY - nb_cpu_processes=args.nbr_processes, use_gpu=args.use_gpu, - log_level=args.logging) - logging.info("Trainer params : " + - format_dict_to_str(trainer.params_for_checkpoint)) - - return trainer - - -def main(): - p = prepare_arg_parser() - args = p.parse_args() - - # Setting log level to INFO maximum for sub-loggers, else it becomes ugly, - # but we will set trainer to user-defined level. - sub_loggers_level = args.logging - if args.logging == 'DEBUG': - sub_loggers_level = 'INFO' - - logging.getLogger().setLevel(level=logging.INFO) - - # Check that all files exist - assert_inputs_exist(p, [args.hdf5_file]) - assert_outputs_exist(p, args, args.experiments_path) - - # Verify if a checkpoint has been saved. Else create an experiment. - if os.path.exists(os.path.join(args.experiments_path, args.experiment_name, - "checkpoint")): - raise FileExistsError("This experiment already exists. Delete or use " - "script l2t_resume_training_from_checkpoint.py.") - - trainer = init_from_args(args, sub_loggers_level) - - # Supervising that we loaded everything correctly. - print("Validation 0 = Initial verification: pre-trained results!") - trainer.validate_one_epoch(-1) - - print("Now starting training") - run_experiment(trainer) - - -if __name__ == '__main__': - main() diff --git a/scripts_python/l2t_train_model.py b/scripts_python/l2t_train_model.py index 50fa6b69..ff4d16d9 100755 --- a/scripts_python/l2t_train_model.py +++ b/scripts_python/l2t_train_model.py @@ -149,7 +149,8 @@ def main(): if args.logging == 'DEBUG': sub_loggers_level = 'INFO' - logging.getLogger().setLevel(level=logging.INFO) + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Check that all files exist assert_inputs_exist(p, [args.hdf5_file]) diff --git a/scripts_python/tt_resume_training_from_checkpoint.py b/scripts_python/tt_resume_training_from_checkpoint.py index a5cd91ea..6f9f614f 100644 --- a/scripts_python/tt_resume_training_from_checkpoint.py +++ b/scripts_python/tt_resume_training_from_checkpoint.py @@ -91,9 +91,8 @@ def main(): p = prepare_arg_parser() args = p.parse_args() - # Setting root logger with high level, but we will set trainer to - # user-defined level. - logging.getLogger().setLevel(level=logging.INFO) + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Verify if a checkpoint has been saved. checkpoint_path = os.path.join( diff --git a/scripts_python/tt_train_model.py b/scripts_python/tt_train_model.py index 6c639032..912645fa 100755 --- a/scripts_python/tt_train_model.py +++ b/scripts_python/tt_train_model.py @@ -166,7 +166,9 @@ def main(): sub_loggers_level = args.logging if args.logging == 'DEBUG': sub_loggers_level = 'INFO' - logging.getLogger().setLevel(level=args.logging) + + # General logging (ex, scilpy: Warning) + logging.getLogger().setLevel(level=logging.WARNING) # Check that all files exist assert_inputs_exist(p, [args.hdf5_file]) From 29fbae8a2c5918ccaef944ae155fa2666b6ec669 Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Thu, 29 Feb 2024 15:07:14 -0500 Subject: [PATCH 22/23] Fix sampler: batch size must be int --- dwi_ml/training/batch_samplers.py | 43 +++++++++++++++++++------------ 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/dwi_ml/training/batch_samplers.py b/dwi_ml/training/batch_samplers.py index ba612377..ea03812f 100644 --- a/dwi_ml/training/batch_samplers.py +++ b/dwi_ml/training/batch_samplers.py @@ -253,10 +253,15 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]: # Choose subjects from which to sample streamlines for the next # few cycles. if self.nb_subjects_per_batch: - # Sampling first from subjects that were not seed a lot yet + # Sampling first from subjects that were not seen a lot yet weights = streamlines_per_subj / np.sum(streamlines_per_subj) # Choosing only non-empty subjects + # NOTE. THIS IS QUESTIONNABLE! It means that the last batch of + # every epoch is ~1 subject: the one with the most streamlines. + # Other choice could be to break as soon as at least one + # subject is done. With batches not too big, we would still + # have seen most of the data of unfinished subjects. nb_subjects = min(self.nb_subjects_per_batch, np.count_nonzero(weights)) sampled_subjs = self.np_rng.choice( @@ -271,9 +276,9 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]: # Final subject's batch size could be smaller if no streamlines are # left for this subject. - max_batch_size_per_subj = self.context_batch_size / nb_subjects + max_batch_size_per_subj = int(self.context_batch_size / nb_subjects) if self.batch_size_units == 'nb_streamlines': - chunk_size = int(max_batch_size_per_subj) + chunk_size = max_batch_size_per_subj else: chunk_size = self.nb_streamlines_per_chunk or DEFAULT_CHUNK_SIZE @@ -296,9 +301,10 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]: batch_ids_per_subj = [] for subj in sampled_subjs: self.logger.debug(" Subj {}".format(subj)) - sampled_ids = self._sample_streamlines_for_subj( - subj, ids_per_subjs, global_unused_streamlines, - max_batch_size_per_subj, chunk_size) + sampled_ids, global_unused_streamlines = \ + self._sample_streamlines_for_subj( + subj, ids_per_subjs, global_unused_streamlines, + max_batch_size_per_subj, chunk_size) # Append tuple (subj, list_sampled_ids) to the batch if len(sampled_ids) > 0: @@ -331,8 +337,8 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs, ------ subj: int The subject's id. - ids_per_subjs: dict - The list of this subject's streamlines' global ids. + ids_per_subjs: dict[slice] + This subject's streamlines' global ids (slices). global_unused_streamlines: array One flag per global streamline id: 0 if already used, else 1. max_batch_size_per_subj: @@ -344,6 +350,8 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs, # subject subj_slice = ids_per_subjs[subj] + slice_to_list = list(range(subj_slice.start, subj_slice.stop)) + # We will continue iterating on this subject until we # break (i.e. when we reach the maximum batch size for this # subject) @@ -364,12 +372,12 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs, if len(chunk_rel_ids) == 0: raise ValueError( "Implementation error? Got no streamline for this subject " - "in this batch, but there are streamlines left. \n" - "Possibly means that the allowed batch size does not even " - "allow one streamline per batch.\n Check your batch size " - "choice!") + "in this batch, but there are streamlines left. To be " + "discussed with the implemetors.") - # Mask the sampled streamlines + # Mask the sampled streamlines. + # Technically done in-place, wouldn't need to return, but + # returning to be sure. global_unused_streamlines[chunk_global_ids] = 0 # Add sub-sampled ids to subject's batch @@ -383,7 +391,7 @@ def _sample_streamlines_for_subj(self, subj, ids_per_subjs, # Update size and get a new chunk total_subj_batch_size += subj_batch_size - return sampled_ids + return sampled_ids, global_unused_streamlines def _get_a_chunk_of_streamlines(self, subj_slice, global_unused_streamlines, @@ -443,15 +451,16 @@ def _get_a_chunk_of_streamlines(self, subj_slice, chosen_global_ids) tmp_computed_chunk_size = int(np.sum(size_per_streamline)) - # If batch_size has been exceeded, taking a little less streamlines - # for this chunk. if subj_batch_size + tmp_computed_chunk_size >= max_subj_batch_size: reached_max_heaviness = True + # If batch_size has been exceeded, taking a little less streamlines + # for this chunk. if subj_batch_size + tmp_computed_chunk_size > max_subj_batch_size: self.logger.debug( " Chunk_size was {}, but max batch size for this " - "subj is {} (we already had acculumated {})." + "subj is {} (we already had acculumated {}). Taking a bit " + "less streamlines." .format(tmp_computed_chunk_size, max_subj_batch_size, subj_batch_size)) From c1ec1890917b770f1a6c08f1bc57f9aae735ba9d Mon Sep 17 00:00:00 2001 From: EmmaRenauld Date: Fri, 1 Mar 2024 09:27:17 -0500 Subject: [PATCH 23/23] Fix -v usage from scilpy update --- bash_utilities/scil_score_ismrm_Renauld2023.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bash_utilities/scil_score_ismrm_Renauld2023.sh b/bash_utilities/scil_score_ismrm_Renauld2023.sh index d788b3bf..cd19cffc 100644 --- a/bash_utilities/scil_score_ismrm_Renauld2023.sh +++ b/bash_utilities/scil_score_ismrm_Renauld2023.sh @@ -54,7 +54,7 @@ then fi echo '------------- FINAL SCORING ------------' -scil_bundle_score_many_bundles_one_tractogram.py -v $config_file_tractometry $out_dir \ - --gt_dir $scoring_data --reference $ref --no_bbox_check +scil_bundle_score_many_bundles_one_tractogram.py $config_file_tractometry $out_dir \ + --gt_dir $scoring_data --reference $ref --no_bbox_check -v cat $out_dir/results.json \ No newline at end of file