Skip to content

Commit

Permalink
fix: dps batch loading and stick to master 3
Browse files Browse the repository at this point in the history
  • Loading branch information
levje committed Nov 8, 2024
1 parent 3669453 commit 6bbc096
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
1 change: 0 additions & 1 deletion dwi_ml/models/projects/ae_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ModelAE(MainModelAbstract):
deterministic (3D vectors) or probabilistic (based on probability
distribution parameters).
"""

def __init__(self,
experiment_name: str,
step_size: float = None,
Expand Down
3 changes: 1 addition & 2 deletions dwi_ml/models/projects/learn2track_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,7 @@ def forward(self, x: List[torch.tensor],
unsorted_indices = invert_permutation(sorted_indices)
x = [x[i] for i in sorted_indices]
if input_streamlines is not None:
input_streamlines = [input_streamlines[i]
for i in sorted_indices]
input_streamlines = [input_streamlines[i] for i in sorted_indices]

# ==== 0. Previous dirs.
n_prev_dirs = None
Expand Down
3 changes: 0 additions & 3 deletions dwi_ml/models/projects/transformer_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class AbstractTransformerModel(ModelWithNeighborhood, ModelWithDirectionGetter,
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
the embedding probably adapts to leave place for the positional encoding.
"""

def __init__(self,
experiment_name: str,
# Target preprocessing params for the batch loader + tracker
Expand Down Expand Up @@ -828,7 +827,6 @@ class OriginalTransformerModel(AbstractTransformerModelWithTarget):
emb_choice_x
"""

def __init__(self, input_embedded_size, n_layers_d: int, **kw):
"""
d_model = input_embedded_size = target_embedded_size.
Expand Down Expand Up @@ -970,7 +968,6 @@ class TransformerSrcAndTgtModel(AbstractTransformerModelWithTarget):
[ emb_choice_x ; emb_choice_y ]
"""

def __init__(self, **kw):
"""
No additional params. d_model = input size + target size.
Expand Down
5 changes: 5 additions & 0 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def load_batch_streamlines(
# the loaded, processed streamlines, not to the ids in the hdf5 file.
final_s_ids_per_subj = defaultdict(slice)
batch_streamlines = []
batch_dps = defaultdict(list)
for subj, s_ids in streamline_ids_per_subj:
logger.debug(
" Data loader: Processing data preparation for "
Expand Down Expand Up @@ -332,6 +333,10 @@ def load_batch_streamlines(
sft.to_corner()
batch_streamlines.extend(sft.streamlines)

# Add data per streamline for the batch elements
for key, value in sft.data_per_streamline.items():
batch_dps[key].extend(value)

batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines]
data_per_streamline = _dps_to_tensors(sft.data_per_streamline)

Expand Down

0 comments on commit 6bbc096

Please sign in to comment.