diff --git a/dwi_ml/data/dataset/streamline_containers.py b/dwi_ml/data/dataset/streamline_containers.py index 034eccec..d78f01f4 100644 --- a/dwi_ml/data/dataset/streamline_containers.py +++ b/dwi_ml/data/dataset/streamline_containers.py @@ -106,17 +106,25 @@ def _get_one_streamline(self, idx: int): def get_array_sequence(self, item=None): if item is None: - streamlines, dps = _load_all_streamlines_from_hdf(self.hdf_group) + streamlines, data_per_streamline = _load_all_streamlines_from_hdf( + self.hdf_group) else: streamlines = ArraySequence() - dps_dict = defaultdict(list) + data_per_streamline = defaultdict(list) + + # If data_per_streamline is not in the hdf5, use an empty dict + # so that we don't add anything to the data_per_streamline in the + # following steps. + hdf_dps_group = self.hdf_group['data_per_streamline'] if \ + 'data_per_streamline' in self.hdf_group.keys() else {} if isinstance(item, int): data = self._get_one_streamline(item) streamlines.append(data) - for dps_key in self.hdf_group['dps_keys']: - dps_dict[dps_key].append(self.hdf_group[dps_key][item]) + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][item]) elif isinstance(item, list) or isinstance(item, np.ndarray): # Getting a list of value from a hdf5: slow. Uses fancy indexing. @@ -129,8 +137,9 @@ def get_array_sequence(self, item=None): data = self._get_one_streamline(i) streamlines.append(data, cache_build=True) - for dps_key in self.hdf_group['dps_keys']: - dps_dict[dps_key].append(self.hdf_group[dps_key][item]) + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][item]) streamlines.finalize_append() @@ -141,16 +150,16 @@ def get_array_sequence(self, item=None): streamline = self.hdf_group['data'][offset:offset + length] streamlines.append(streamline, cache_build=True) - for dps_key in self.hdf_group['dps_keys']: - dps_dict[dps_key].append( - self.hdf_group[dps_key][offset:offset + length]) + for dps_key in hdf_dps_group.keys(): + data_per_streamline[dps_key].append( + hdf_dps_group[dps_key][offset:offset + length]) streamlines.finalize_append() else: raise ValueError('Item should be either a int, list, ' 'np.ndarray or slice but we received {}' .format(type(item))) - return streamlines, dps + return streamlines, data_per_streamline @property def lengths(self): diff --git a/dwi_ml/models/projects/learn2track_model.py b/dwi_ml/models/projects/learn2track_model.py index 9ba8074c..d3b11237 100644 --- a/dwi_ml/models/projects/learn2track_model.py +++ b/dwi_ml/models/projects/learn2track_model.py @@ -227,6 +227,7 @@ def computed_params_for_display(self): def forward(self, x: List[torch.tensor], input_streamlines: List[torch.tensor] = None, + data_per_streamline: List[torch.tensor] = {}, hidden_recurrent_states: List = None, return_hidden=False, point_idx: int = None): """Run the model on a batch of sequences. @@ -284,7 +285,8 @@ def forward(self, x: List[torch.tensor], unsorted_indices = invert_permutation(sorted_indices) x = [x[i] for i in sorted_indices] if input_streamlines is not None: - input_streamlines = [input_streamlines[i] for i in sorted_indices] + input_streamlines = [input_streamlines[i] + for i in sorted_indices] # ==== 0. Previous dirs. n_prev_dirs = None diff --git a/dwi_ml/training/trainers_withGV.py b/dwi_ml/training/trainers_withGV.py index a0aebfcb..c42ee214 100644 --- a/dwi_ml/training/trainers_withGV.py +++ b/dwi_ml/training/trainers_withGV.py @@ -242,7 +242,7 @@ def gv_phase_one_batch(self, data, compute_all_scores=False): seeds and first few segments. Expected results are the batch's validation streamlines. """ - real_lines, ids_per_subj = data + real_lines, ids_per_subj, data_per_streamline = data # Possibly sending again to GPU even if done in the local loss # computation, but easier with current implementation. diff --git a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py index f1bcf6c0..aaa830db 100644 --- a/dwi_ml/unit_tests/utils/data_and_models_for_tests.py +++ b/dwi_ml/unit_tests/utils/data_and_models_for_tests.py @@ -84,7 +84,7 @@ def compute_loss(self, model_outputs, target_streamlines=None, else: return torch.zeros(n, device=self.device), 1 - def forward(self, inputs: list, streamlines): + def forward(self, inputs: list, streamlines, data_per_streamline): # Not using streamlines. Pretending to use inputs. _ = self.fake_parameter regressed_dir = torch.as_tensor([1., 1., 1.]) @@ -143,7 +143,8 @@ def get_tracking_directions(self, regressed_dirs, algo, raise ValueError("'algo' should be 'det' or 'prob'.") def forward(self, inputs: List[torch.tensor], - target_streamlines: List[torch.tensor]): + target_streamlines: List[torch.tensor], + data_per_streamline: List[torch.tensor]): # Previous dirs if self.nb_previous_dirs > 0: target_dirs = compute_directions(target_streamlines)