Skip to content

Commit

Permalink
Merge pull request #241 from EmmaRenauld/small_fixes
Browse files Browse the repository at this point in the history
Small fixes
  • Loading branch information
EmmaRenauld authored May 17, 2024
2 parents 38ae414 + e0a074f commit 6266c2a
Show file tree
Hide file tree
Showing 22 changed files with 149 additions and 149 deletions.
3 changes: 1 addition & 2 deletions dwi_ml/data/dataset/multi_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
hdf_handle, subj_id, ref_group_info)

# Add subject to the list
logger.debug(" Adding it to the list of subjects.")
subj_idx = self.subjs_data_list.add_subject(subj_data)

# Arrange streamlines
Expand All @@ -290,7 +289,6 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
if subj_data.is_lazy:
subj_data.add_handle(hdf_handle)

logger.debug(" Counting streamlines")
for group in range(len(self.streamline_groups)):
subj_sft_data = subj_data.sft_data_list[group]
n_streamlines = len(subj_sft_data)
Expand All @@ -302,6 +300,7 @@ def load(self, hdf_handle: h5py.File, subj_id=None):
subj_data.hdf_handle = None

# Arrange final data properties: Concatenate all subjects
logging.debug("All subjects added. Final verifications.")
self.streamline_lengths_mm = \
[np.concatenate(lengths_mm[group], axis=0)
for group in range(len(self.streamline_groups))]
Expand Down
3 changes: 2 additions & 1 deletion dwi_ml/data/dataset/single_subject_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def init_single_subject_from_hdf(
subject_mri_data_list.append(subject_mri_group_data)

for group in streamline_groups:
logger.debug(" Loading subject's streamlines")
logger.debug(" Loading streamlines group '{}'"
.format(group))
sft_data = SFTData.init_sft_data_from_hdf_info(
hdf_file[subject_id][group])
subject_sft_data_list.append(sft_data)
Expand Down
17 changes: 7 additions & 10 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def format_filelist(filenames, enforce_presence, folder=None) -> List[str]:
return new_files


def _load_and_verify_file(filename: str, subj_input_path, group_name: str,
group_affine, group_res):
def _load_and_verify_file(filename: str, group_name: str, group_affine,
group_res):
"""
Loads a 3D or 4D nifti file. If it is a 3D dataset, adds a dimension to
make it 4D. Then checks that it is compatible with a given group based on
Expand All @@ -70,26 +70,22 @@ def _load_and_verify_file(filename: str, subj_input_path, group_name: str,
------
filename: str
File's name. Must be .nii or .nii.gz.
subj_input_path: Path
Path where to load the nifti file from.
group_name: str
Name of the group with which 'filename' file must be compatible.
group_affine: np.array
The loaded file's affine must be equal (or very close) to this affine.
group_res: np.array
The loaded file's resolution must be equal (or very close) to this res.
"""
data_file = subj_input_path.joinpath(filename)

if not data_file.is_file():
if not os.path.isfile(filename):
logging.debug(" Skipping file {} because it was not "
"found in this subject's folder".format(filename))
# Note: if args.enforce_files_presence was set to true, this
# case is not possible, already checked in
# create_hdf5_dataset.py.
return None

data, affine, res, _ = load_file_to4d(data_file)
data, affine, res, _ = load_file_to4d(filename)

if not np.allclose(affine, group_affine, atol=1e-5):
# Note. When keeping default options on tolerance, we have run
Expand Down Expand Up @@ -483,6 +479,7 @@ def _process_one_volume_group(self, group: str, subj_id: str,
else:
std_mask = np.logical_or(sub_mask_data, std_mask)

# Get the files and add the subject_dir as prefix.
file_list = self.groups_config[group]['files']
file_list = format_filelist(file_list, self.enforce_files_presence,
folder=subj_input_dir)
Expand All @@ -505,8 +502,8 @@ def _process_one_volume_group(self, group: str, subj_id: str,
for file_name in file_list[1:]:
logging.info(" - Processing file {}"
.format(os.path.basename(file_name)))
data = _load_and_verify_file(file_name, subj_input_dir, group,
group_affine, group_res)
data = _load_and_verify_file(file_name, group, group_affine,
group_res)

if std_option == 'per_file':
logging.info(' - Standardizing')
Expand Down
4 changes: 2 additions & 2 deletions dwi_ml/models/main_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def add_args_main_model(p):
add_resample_or_compress_arg(p)

def set_context(self, context):
assert context in ['training', 'tracking']
assert context in ['training', 'validation']
self._context = context

@property
Expand Down Expand Up @@ -730,7 +730,7 @@ def __init__(self, dg_key: str = 'cosine-regression',
.format(self.positional_encoding_key))

def set_context(self, context):
assert context in ['training', 'tracking', 'visu']
assert context in ['training', 'validation', 'tracking', 'visu']
self._context = context

def instantiate_direction_getter(self, dg_input_size):
Expand Down
35 changes: 35 additions & 0 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
# Find idx of streamline group
self.streamline_group_idx = self.dataset.streamline_groups.index(
self.streamline_group_name)
self.data_contains_connectivity = \
self.dataset.streamlines_contain_connectivity[
self.streamline_group_idx]

# Set random numbers
self.rng = rng
Expand Down Expand Up @@ -314,6 +317,38 @@ def load_batch_streamlines(

return batch_streamlines, final_s_ids_per_subj

def load_batch_connectivity_matrices(
self, streamline_ids_per_subj: Dict[int, slice]):
if not self.data_contains_connectivity:
raise ValueError("No connectivity matrix in this dataset.")

# The batch's streamline ids will change throughout processing because
# of data augmentation, so we need to do it subject by subject to
# keep track of the streamline ids. These final ids will correspond to
# the loaded, processed streamlines, not to the ids in the hdf5 file.
subjs = list(streamline_ids_per_subj.keys())
nb_subjs = len(subjs)
matrices = [None] * nb_subjs
volume_sizes = [None] * nb_subjs
connectivity_nb_blocs = [None] * nb_subjs
connectivity_labels = [None] * nb_subjs
for i, subj in enumerate(subjs):
# No cache for the sft data. Accessing it directly.
# Note: If this is used through the dataloader, multiprocessing
# is used. Each process will open a handle.
subj_data = \
self.context_subset.subjs_data_list.get_subj_with_handle(subj)
subj_sft_data = subj_data.sft_data_list[self.streamline_group_idx]

# We could access it only at required index, maybe. Loading the
# whole matrix here.
(matrices[i], volume_sizes[i],
connectivity_nb_blocs[i], connectivity_labels[i]) = \
subj_sft_data.get_connectivity_matrix_and_info()

return (matrices, volume_sizes,
connectivity_nb_blocs, connectivity_labels)


class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader):
"""
Expand Down
43 changes: 24 additions & 19 deletions dwi_ml/training/batch_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(self, dataset: MultiSubjectDataset,
batch_size_validation: Union[int, None],
batch_size_units: str, nb_streamlines_per_chunk: int = None,
rng: int = None, nb_subjects_per_batch: int = None,
cycles: int = None, log_level=logging.root.level):
cycles: int = None, log_level=logger.root.level):
"""
Parameters
----------
Expand All @@ -56,7 +56,8 @@ def __init__(self, dataset: MultiSubjectDataset,
Batch size. Can be defined in number of streamlines or in total
length_mm (specified through batch_size_units).
batch_size_validation: Union[int, None]
Idem
Idem. If None, it is expected that there will not be a validation
set.
batch_size_units: str
'nb_streamlines' or 'length_mm' (which should hopefully be
correlated to the number of input data points).
Expand All @@ -80,30 +81,36 @@ def __init__(self, dataset: MultiSubjectDataset,
"""
super().__init__(None) # This does nothing but python likes it.

# Batch sampler's logging level can be changed separately from main
# scripts.
logger.setLevel(log_level)
self.logger = logger

# Checking that batch_size is correct
for batch_size in [batch_size_training, batch_size_validation]:
if batch_size and batch_size <= 0:
if batch_size is not None and batch_size <= 0:
raise ValueError("batch_size (i.e. number of total timesteps "
"in the batch) should be a positive int "
"value, but got batch_size={}"
.format(batch_size))

if batch_size_units == 'nb_streamlines':
if nb_streamlines_per_chunk is not None:
logging.warning("With a max_batch_size computed in terms of "
"number of streamlines, the chunk size is not "
"used. Ignored")
logger.warning("With a max_batch_size computed in terms of "
"number of streamlines, the chunk size is not "
"used. Ignored.")
nb_streamlines_per_chunk = None
elif batch_size_units == 'length_mm':
if nb_streamlines_per_chunk is None:
logging.debug("Chunk size was not set. Using default {}"
.format(DEFAULT_CHUNK_SIZE))
logger.debug("Chunk size was not set. Using default {}"
.format(DEFAULT_CHUNK_SIZE))
else:
raise ValueError("batch_size_unit should either be "
"'nb_streamlines' or 'length_mm', got {}"
.format(batch_size_units))

# Checking that n_volumes was given if cycles was given
if cycles and not nb_subjects_per_batch:
if cycles and nb_subjects_per_batch is None:
raise ValueError("If `cycles` is defined, "
"`nb_subjects_per_batch` should be defined. Got: "
"nb_subjects_per_batch={}, cycles={}"
Expand All @@ -127,11 +134,6 @@ def __init__(self, dataset: MultiSubjectDataset,
self.rng = rng
self.np_rng = np.random.RandomState(self.rng)

# Batch sampler's logging level can be changed separately from main
# scripts.
self.logger = logger
self.logger.setLevel(log_level)

# For later use, context
self.context = None
self.context_subset = None
Expand Down Expand Up @@ -164,8 +166,8 @@ def init_from_checkpoint(cls, dataset, checkpoint_state: dict,
else:
batch_sampler = cls(dataset=dataset, **checkpoint_state)

logging.info("Batch sampler's user-defined parameters: " +
format_dict_to_str(batch_sampler.params_for_checkpoint))
logger.info("Batch sampler's user-defined parameters: " +
format_dict_to_str(batch_sampler.params_for_checkpoint))

return batch_sampler

Expand Down Expand Up @@ -221,9 +223,11 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]:

# This is the list of all possible streamline ids
global_streamlines_ids = np.arange(
self.context_subset.total_nb_streamlines[self.streamline_group_idx])
self.context_subset.total_nb_streamlines[
self.streamline_group_idx])
ids_per_subjs = \
self.context_subset.streamline_ids_per_subj[self.streamline_group_idx]
self.context_subset.streamline_ids_per_subj[
self.streamline_group_idx]

# This contains one bool per streamline:
# 1 = this streamline has not been used yet.
Expand Down Expand Up @@ -279,7 +283,8 @@ def __iter__(self) -> Iterator[List[Tuple[int, list]]]:

# Final subject's batch size could be smaller if no streamlines are
# left for this subject.
max_batch_size_per_subj = int(self.context_batch_size / nb_subjects)
max_batch_size_per_subj = int(
self.context_batch_size / nb_subjects)
if self.batch_size_units == 'nb_streamlines':
chunk_size = max_batch_size_per_subj
else:
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/projects/learn2track_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dwi_ml.models.projects.learn2track_model import Learn2TrackModel
from dwi_ml.tracking.io_utils import prepare_tracking_mask
from dwi_ml.tracking.propagation import propagate_multiple_lines
from dwi_ml.training.with_generation.trainer import \
from dwi_ml.training.trainers_withGV import \
DWIMLTrainerForTrackingOneInput

logger = logging.getLogger('trainer_logger')
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/projects/transformer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dwi_ml.tracking.io_utils import prepare_tracking_mask
from dwi_ml.tracking.propagation import propagate_multiple_lines

from dwi_ml.training.with_generation.trainer import \
from dwi_ml.training.trainers_withGV import \
DWIMLTrainerForTrackingOneInput


Expand Down
50 changes: 28 additions & 22 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,27 +178,6 @@ def __init__(self,
self.batch_loader = batch_loader
self.model = model

# Create DataLoaders from the BatchSamplers
# * Before usage, context must be set for the batch sampler and the
# batch loader, to use appropriate parameters.
# * Pin memory if interpolation is done by workers; this means that
# dataloader output is on GPU, ready to be fed to the model.
# Otherwise, dataloader output is kept on CPU, and the main thread
# sends volumes and coords on GPU for interpolation.
logger.debug("- Instantiating dataloaders...")
self.train_dataloader = DataLoader(
dataset=self.batch_sampler.dataset.training_set,
batch_sampler=self.batch_sampler,
num_workers=self.nb_cpu_processes,
collate_fn=self.batch_loader.load_batch_streamlines,
pin_memory=self.use_gpu)
self.valid_dataloader = DataLoader(
dataset=self.batch_sampler.dataset.validation_set,
batch_sampler=self.batch_sampler,
num_workers=self.nb_cpu_processes,
collate_fn=self.batch_loader.load_batch_streamlines,
pin_memory=self.use_gpu)

# ----------------------
# Checks
# ----------------------
Expand Down Expand Up @@ -233,11 +212,38 @@ def __init__(self,
"Best practice is to have a validation set.")
else:
self.use_validation = True
if max_batches_per_epoch_validation is None:
self.max_batches_per_epoch_validation = 1000

if optimizer not in ['SGD', 'Adam', 'RAdam']:
raise ValueError("Optimizer choice {} not recognized."
.format(optimizer))

# ----------------
# Create DataLoaders from the BatchSamplers
# ----------------
# * Before usage, context must be set for the batch sampler and the
# batch loader, to use appropriate parameters.
# * Pin memory if interpolation is done by workers; this means that
# dataloader output is on GPU, ready to be fed to the model.
# Otherwise, dataloader output is kept on CPU, and the main thread
# sends volumes and coords on GPU for interpolation.
logger.debug("- Instantiating dataloaders...")
self.train_dataloader = DataLoader(
dataset=self.batch_sampler.dataset.training_set,
batch_sampler=self.batch_sampler,
num_workers=self.nb_cpu_processes,
collate_fn=self.batch_loader.load_batch_streamlines,
pin_memory=self.use_gpu)
self.valid_dataloader = None
if self.use_validation:
self.valid_dataloader = DataLoader(
dataset=self.batch_sampler.dataset.validation_set,
batch_sampler=self.batch_sampler,
num_workers=self.nb_cpu_processes,
collate_fn=self.batch_loader.load_batch_streamlines,
pin_memory=self.use_gpu)

# ----------------------
# Evolving values. They will need to be updated if initialized from
# checkpoint.
Expand Down Expand Up @@ -918,7 +924,7 @@ def validate_one_batch(self, data, epoch):
"""
Computes the loss(es) for the current batch and updates monitors.
"""
mean_local_loss, n, _ = self.run_one_batch(data)
mean_local_loss, n = self.run_one_batch(data)
self.valid_local_loss_monitor.update(mean_local_loss.cpu().item(),
weight=n)

Expand Down
Loading

0 comments on commit 6266c2a

Please sign in to comment.