Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NF] Auto-encoders - streamlines - FINTA #220

Merged
merged 28 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
cd9275a
add ae - finta
arnaudbore Nov 23, 2023
24bb79b
Merge branch 'master' into add_autoencoder_streamlines
arnaudbore Nov 23, 2023
31e2e85
modif with em
arnaudbore Nov 28, 2023
346c383
add visu
arnaudbore Nov 28, 2023
5411a21
merge master
arnaudbore Sep 13, 2024
a341891
fix pep8
arnaudbore Sep 13, 2024
72789f3
answer em comments from nov 2023
arnaudbore Sep 13, 2024
02cab7b
fix naming class
Sep 17, 2024
7791875
fix script
Sep 17, 2024
f4701be
fix viz
arnaudbore Sep 18, 2024
c4bd181
quick fix ae_vis_streamline
arnaudbore Sep 18, 2024
448043b
WIP: transformer ae
AntoineTheb Sep 19, 2024
7bf9fe3
Revert "WIP: transformer ae"
AntoineTheb Sep 19, 2024
72be6fe
set bbox to false when saving trk
arnaudbore Sep 19, 2024
424d2f1
Merge branch 'add_autoencoder_streamlines' of github.com:arnaudbore/d…
arnaudbore Sep 19, 2024
4f04fca
add jeremi comments
arnaudbore Sep 26, 2024
29d690f
Revert "add jeremi comments"
arnaudbore Sep 26, 2024
8511db6
make it a little bit prettier waiting for PR244 to be merged
arnaudbore Sep 26, 2024
f478333
Merge branch 'master' into add_autoencoder_streamlines
arnaudbore Sep 27, 2024
517a530
rename vis streamline to autoencode tractogram, add save tractogram o…
arnaudbore Sep 30, 2024
327b659
change permission script
arnaudbore Sep 30, 2024
3711e4d
remove scripts wait for Antoine script
arnaudbore Oct 1, 2024
7521801
fix hdf5
arnaudbore Oct 1, 2024
97d1169
answer em comments
arnaudbore Oct 1, 2024
765ad56
fix init modelAE
arnaudbore Oct 1, 2024
de40049
add unused nb_points for others models
arnaudbore Oct 1, 2024
e57153b
fix condition resampling and compressing
arnaudbore Oct 1, 2024
fc55338
fix pep8
arnaudbore Oct 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions dwi_ml/data/hdf5/hdf5_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _create_streamline_groups(self, ref, subj_input_dir, subj_id,
raise ValueError(
"The data_per_streamline key '{}' was not found in "
"the sft. Check your tractogram file.".format(dps_key))

logging.debug(" Include dps \"{}\" in the HDF5.".format(dps_key))
streamlines_group.create_dataset('dps_' + dps_key,
data=sft.data_per_streamline[dps_key])
Expand Down Expand Up @@ -669,8 +669,6 @@ def _process_one_streamline_group(
Reference used to load and send the streamlines in voxel space and
to create final merged SFT. If the file is a .trk, 'same' is used
instead.
remove_invalid : bool
If True, invalid streamlines will be removed

Returns
-------
Expand Down
189 changes: 189 additions & 0 deletions dwi_ml/models/projects/ae_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# -*- coding: utf-8 -*-
import logging
from typing import List

import torch
from torch.nn import functional as F

from dwi_ml.models.main_models import MainModelAbstract


class ModelAE(MainModelAbstract):
"""
Recurrent tracking model.

Composed of an embedding for the imaging data's input + for the previous
direction's input, an RNN model to process the sequences, and a direction
getter model to convert the RNN outputs to the right structure, e.g.
deterministic (3D vectors) or probabilistic (based on probability
distribution parameters).
"""
def __init__(self, kernel_size, latent_space_dims,
experiment_name: str,
# Target preprocessing params for the batch loader + tracker
step_size: float = None,
compress_lines: float = False,
# Other
log_level=logging.root.level):
super().__init__(experiment_name, step_size, compress_lines, log_level)

self.kernel_size = kernel_size
self.latent_space_dims = latent_space_dims

self.pad = torch.nn.ReflectionPad1d(1)

def pre_pad(m):
return torch.nn.Sequential(self.pad, m)

self.fc1 = torch.nn.Linear(8192,
self.latent_space_dims) # 8192 = 1024*8
self.fc2 = torch.nn.Linear(self.latent_space_dims, 8192)

"""
Encode convolutions
"""
self.encod_conv1 = pre_pad(
torch.nn.Conv1d(3, 32, self.kernel_size, stride=2, padding=0)
)
self.encod_conv2 = pre_pad(
torch.nn.Conv1d(32, 64, self.kernel_size, stride=2, padding=0)
)
self.encod_conv3 = pre_pad(
torch.nn.Conv1d(64, 128, self.kernel_size, stride=2, padding=0)
)
self.encod_conv4 = pre_pad(
torch.nn.Conv1d(128, 256, self.kernel_size, stride=2, padding=0)
)
self.encod_conv5 = pre_pad(
torch.nn.Conv1d(256, 512, self.kernel_size, stride=2, padding=0)
)
self.encod_conv6 = pre_pad(
torch.nn.Conv1d(512, 1024, self.kernel_size, stride=1, padding=0)
)

"""
Decode convolutions
"""
self.decod_conv1 = pre_pad(
torch.nn.Conv1d(1024, 512, self.kernel_size, stride=1, padding=0)
)
self.upsampl1 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv2 = pre_pad(
torch.nn.Conv1d(512, 256, self.kernel_size, stride=1, padding=0)
)
self.upsampl2 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv3 = pre_pad(
torch.nn.Conv1d(256, 128, self.kernel_size, stride=1, padding=0)
)
self.upsampl3 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv4 = pre_pad(
torch.nn.Conv1d(128, 64, self.kernel_size, stride=1, padding=0)
)
self.upsampl4 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv5 = pre_pad(
torch.nn.Conv1d(64, 32, self.kernel_size, stride=1, padding=0)
)
self.upsampl5 = torch.nn.Upsample(
scale_factor=2, mode="linear", align_corners=False
)
self.decod_conv6 = pre_pad(
torch.nn.Conv1d(32, 3, self.kernel_size, stride=1, padding=0)
)

@property
def params_for_checkpoint(self):
"""All parameters necessary to create again the same model. Will be
used in the trainer, when saving the checkpoint state. Params here
will be used to re-create the model when starting an experiment from
checkpoint. You should be able to re-create an instance of your
model with those params."""
# p = super().params_for_checkpoint()
p = {'kernel_size': self.kernel_size,
'latent_space_dims': self.latent_space_dims,
'experiment_name': self.experiment_name}
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
return p

@classmethod
def _load_params(cls, model_dir):
p = super()._load_params(model_dir)
p['kernel_size'] = 3
p['latent_space_dims'] = 32
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
return p

def forward(self,
input_streamlines: List[torch.tensor],
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
):
"""Run the model on a batch of sequences.

Parameters
----------
input_streamlines: List[torch.tensor],
Batch of streamlines. Only used if previous directions are added to
the model. Used to compute directions; its last point will not be
used.

Returns
-------
model_outputs : List[Tensor]
Output data, ready to be passed to either `compute_loss()` or
`get_tracking_directions()`.
"""

x = self.decode(self.encode(input_streamlines))
return x

def encode(self, x):
# x: list of tensors
x = torch.stack(x)
x = torch.swapaxes(x, 1, 2)

h1 = F.relu(self.encod_conv1(x))
h2 = F.relu(self.encod_conv2(h1))
h3 = F.relu(self.encod_conv3(h2))
h4 = F.relu(self.encod_conv4(h3))
h5 = F.relu(self.encod_conv5(h4))
h6 = self.encod_conv6(h5)

self.encoder_out_size = (h6.shape[1], h6.shape[2])

# Flatten
h7 = h6.view(-1, self.encoder_out_size[0] * self.encoder_out_size[1])

fc1 = self.fc1(h7)

return fc1

def decode(self, z):
fc = self.fc2(z)
fc_reshape = fc.view(
-1, self.encoder_out_size[0], self.encoder_out_size[1]
)
h1 = F.relu(self.decod_conv1(fc_reshape))
h2 = self.upsampl1(h1)
h3 = F.relu(self.decod_conv2(h2))
h4 = self.upsampl2(h3)
h5 = F.relu(self.decod_conv3(h4))
h6 = self.upsampl3(h5)
h7 = F.relu(self.decod_conv4(h6))
h8 = self.upsampl4(h7)
h9 = F.relu(self.decod_conv5(h8))
h10 = self.upsampl5(h9)
h11 = self.decod_conv6(h10)

return h11

def compute_loss(self, model_outputs, targets, average_results=True):

targets = torch.stack(targets)
targets = torch.swapaxes(targets, 1, 2)
reconstruction_loss = torch.nn.MSELoss(reduction="sum")
mse = reconstruction_loss(model_outputs, targets)
return mse, 1
11 changes: 8 additions & 3 deletions dwi_ml/training/batch_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
logger = logging.getLogger('batch_loader_logger')


class DWIMLAbstractBatchLoader:
class DWIMLStreamlinesBatchLoader:
def __init__(self, dataset: MultiSubjectDataset, model: MainModelAbstract,
streamline_group_name: str, rng: int,
split_ratio: float = 0.,
Expand Down Expand Up @@ -197,7 +197,11 @@ def _data_augmentation_sft(self, sft):
self.context_subset.compress == self.model.compress_lines:
logger.debug("Compression rate is the same as when creating "
"the hdf5 dataset. Not compressing again.")
else:
elif self.model.step_size is not None and \
self.model.compress_lines is not None:
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
logger.debug("Resample streamlines using: \n" +
"- step_size: {}\n".format(self.model.step_size) +
"- compress_lines: {}".format(self.model.compress_lines))
sft = resample_or_compress(sft, self.model.step_size,
self.model.nb_points,
self.model.compress_lines)
Expand Down Expand Up @@ -314,6 +318,7 @@ def load_batch_streamlines(
sft.to_vox()
sft.to_corner()
batch_streamlines.extend(sft.streamlines)

batch_streamlines = [torch.as_tensor(s) for s in batch_streamlines]

return batch_streamlines, final_s_ids_per_subj
Expand Down Expand Up @@ -351,7 +356,7 @@ def load_batch_connectivity_matrices(
connectivity_nb_blocs, connectivity_labels)


class DWIMLBatchLoaderOneInput(DWIMLAbstractBatchLoader):
class DWIMLBatchLoaderOneInput(DWIMLStreamlinesBatchLoader):
"""
Loads:
input = one volume group
Expand Down
54 changes: 49 additions & 5 deletions dwi_ml/training/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dwi_ml.models.main_models import (MainModelAbstract,
ModelWithDirectionGetter)
from dwi_ml.training.batch_loaders import (
DWIMLAbstractBatchLoader, DWIMLBatchLoaderOneInput)
DWIMLStreamlinesBatchLoader, DWIMLBatchLoaderOneInput)
from dwi_ml.training.batch_samplers import DWIMLBatchIDSampler
from dwi_ml.training.utils.gradient_norm import compute_gradient_norm
from dwi_ml.training.utils.monitoring import (
Expand Down Expand Up @@ -53,7 +53,7 @@ class DWIMLAbstractTrainer:
def __init__(self,
model: MainModelAbstract, experiments_path: str,
experiment_name: str, batch_sampler: DWIMLBatchIDSampler,
batch_loader: DWIMLAbstractBatchLoader,
batch_loader: DWIMLStreamlinesBatchLoader,
learning_rates: Union[List, float] = None,
weight_decay: float = 0.01,
optimizer: str = 'Adam', max_epochs: int = 10,
Expand All @@ -78,7 +78,7 @@ def __init__(self,
batch_sampler: DWIMLBatchIDSampler
Instantiated class used for sampling batches.
Data in batch_sampler.dataset must be already loaded.
batch_loader: DWIMLAbstractBatchLoader
batch_loader: DWIMLStreamlinesBatchLoader
Instantiated class with a load_batch method able to load data
associated to sampled batch ids. Data in batch_sampler.dataset must
be already loaded.
Expand Down Expand Up @@ -461,7 +461,7 @@ def _prepare_checkpoint_info(self) -> dict:
def init_from_checkpoint(
cls, model: MainModelAbstract, experiments_path, experiment_name,
batch_sampler: DWIMLBatchIDSampler,
batch_loader: DWIMLAbstractBatchLoader,
batch_loader: DWIMLStreamlinesBatchLoader,
checkpoint_state: dict, new_patience, new_max_epochs, log_level):
"""
Loads checkpoint information (parameters and states) to instantiate
Expand Down Expand Up @@ -1013,7 +1013,51 @@ def run_one_batch(self, data):
Any other data returned when computing loss. Not used in the
trainer, but could be useful anywhere else.
"""
raise NotImplementedError
# Data interpolation has not been done yet. GPU computations are done
# here in the main thread.
targets, ids_per_subj = data

# Dataloader always works on CPU. Sending to right device.
# (model is already moved).
targets = [s.to(self.device, non_blocking=True, dtype=torch.float)
for s in targets]

# Getting the inputs points from the volumes.
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
# Uses the model's method, with the batch_loader's data.
# Possibly skipping the last point if not useful.
streamlines_f = targets
if isinstance(self.model, ModelWithDirectionGetter) and \
arnaudbore marked this conversation as resolved.
Show resolved Hide resolved
not self.model.direction_getter.add_eos:
# No EOS = We don't use the last coord because it does not have an
# associated target direction.
streamlines_f = [s[:-1, :] for s in streamlines_f]

# Possibly add noise to inputs here.
logger.debug('*** Computing forward propagation')

# Now possibly add noise to streamlines (training / valid)
streamlines_f = self.batch_loader.add_noise_streamlines_forward(
streamlines_f, self.device)

# Possibly computing directions twice (during forward and loss)
# but ok, shouldn't be too heavy. Easier to deal with multiple
# projects' requirements by sending whole streamlines rather
# than only directions.
model_outputs = self.model(streamlines_f)
del streamlines_f

logger.debug('*** Computing loss')
targets = self.batch_loader.add_noise_streamlines_loss(
targets, self.device)

results = self.model.compute_loss(model_outputs, targets,
average_results=True)

if self.use_gpu:
log_gpu_memory_usage(logger)

# The mean tensor is a single value. Converting to float using item().
return results

def fix_parameters(self):
"""
Expand Down
Loading
Loading