Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
levje committed Oct 7, 2024
1 parent 2e4a191 commit 2de0a43
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 14 deletions.
29 changes: 19 additions & 10 deletions dwi_ml/data/dataset/streamline_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,25 @@ def _get_one_streamline(self, idx: int):

def get_array_sequence(self, item=None):
if item is None:
streamlines, dps = _load_all_streamlines_from_hdf(self.hdf_group)
streamlines, data_per_streamline = _load_all_streamlines_from_hdf(
self.hdf_group)
else:
streamlines = ArraySequence()
dps_dict = defaultdict(list)
data_per_streamline = defaultdict(list)

# If data_per_streamline is not in the hdf5, use an empty dict
# so that we don't add anything to the data_per_streamline in the
# following steps.
hdf_dps_group = self.hdf_group['data_per_streamline'] if \
'data_per_streamline' in self.hdf_group.keys() else {}

if isinstance(item, int):
data = self._get_one_streamline(item)
streamlines.append(data)

for dps_key in self.hdf_group['dps_keys']:
dps_dict[dps_key].append(self.hdf_group[dps_key][item])
for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

elif isinstance(item, list) or isinstance(item, np.ndarray):
# Getting a list of value from a hdf5: slow. Uses fancy indexing.
Expand All @@ -129,8 +137,9 @@ def get_array_sequence(self, item=None):
data = self._get_one_streamline(i)
streamlines.append(data, cache_build=True)

for dps_key in self.hdf_group['dps_keys']:
dps_dict[dps_key].append(self.hdf_group[dps_key][item])
for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][item])

streamlines.finalize_append()

Expand All @@ -141,16 +150,16 @@ def get_array_sequence(self, item=None):
streamline = self.hdf_group['data'][offset:offset + length]
streamlines.append(streamline, cache_build=True)

for dps_key in self.hdf_group['dps_keys']:
dps_dict[dps_key].append(
self.hdf_group[dps_key][offset:offset + length])
for dps_key in hdf_dps_group.keys():
data_per_streamline[dps_key].append(
hdf_dps_group[dps_key][offset:offset + length])
streamlines.finalize_append()

else:
raise ValueError('Item should be either a int, list, '
'np.ndarray or slice but we received {}'
.format(type(item)))
return streamlines, dps
return streamlines, data_per_streamline

@property
def lengths(self):
Expand Down
4 changes: 3 additions & 1 deletion dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def computed_params_for_display(self):

def forward(self, x: List[torch.tensor],
input_streamlines: List[torch.tensor] = None,
data_per_streamline: List[torch.tensor] = {},
hidden_recurrent_states: List = None, return_hidden=False,
point_idx: int = None):
"""Run the model on a batch of sequences.
Expand Down Expand Up @@ -284,7 +285,8 @@ def forward(self, x: List[torch.tensor],
unsorted_indices = invert_permutation(sorted_indices)
x = [x[i] for i in sorted_indices]
if input_streamlines is not None:
input_streamlines = [input_streamlines[i] for i in sorted_indices]
input_streamlines = [input_streamlines[i]
for i in sorted_indices]

# ==== 0. Previous dirs.
n_prev_dirs = None
Expand Down
2 changes: 1 addition & 1 deletion dwi_ml/training/trainers_withGV.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def gv_phase_one_batch(self, data, compute_all_scores=False):
seeds and first few segments. Expected results are the batch's
validation streamlines.
"""
real_lines, ids_per_subj = data
real_lines, ids_per_subj, data_per_streamline = data

# Possibly sending again to GPU even if done in the local loss
# computation, but easier with current implementation.
Expand Down
5 changes: 3 additions & 2 deletions dwi_ml/unit_tests/utils/data_and_models_for_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def compute_loss(self, model_outputs, target_streamlines=None,
else:
return torch.zeros(n, device=self.device), 1

def forward(self, inputs: list, streamlines):
def forward(self, inputs: list, streamlines, data_per_streamline):
# Not using streamlines. Pretending to use inputs.
_ = self.fake_parameter
regressed_dir = torch.as_tensor([1., 1., 1.])
Expand Down Expand Up @@ -143,7 +143,8 @@ def get_tracking_directions(self, regressed_dirs, algo,
raise ValueError("'algo' should be 'det' or 'prob'.")

def forward(self, inputs: List[torch.tensor],
target_streamlines: List[torch.tensor]):
target_streamlines: List[torch.tensor],
data_per_streamline: List[torch.tensor]):
# Previous dirs
if self.nb_previous_dirs > 0:
target_dirs = compute_directions(target_streamlines)
Expand Down

0 comments on commit 2de0a43

Please sign in to comment.