diff --git a/cyto_dl/nn/vits/blocks/patchify.py b/cyto_dl/nn/vits/blocks/patchify.py index 620ad0721..ee5a2c563 100644 --- a/cyto_dl/nn/vits/blocks/patchify.py +++ b/cyto_dl/nn/vits/blocks/patchify.py @@ -3,16 +3,15 @@ import numpy as np import torch import torch.nn as nn -from einops.layers.torch import Rearrange +from einops.layers.torch import Rearrange, Reduce from timm.models.layers import trunc_normal_ from cyto_dl.nn.vits.utils import take_indexes -def random_indexes(size: int): - forward_indexes = np.arange(size) - np.random.shuffle(forward_indexes) - backward_indexes = np.argsort(forward_indexes) +def random_indexes(size: int, device): + forward_indexes = torch.randperm(size, device=device, dtype=torch.long) + backward_indexes = torch.argsort(forward_indexes) return forward_indexes, backward_indexes @@ -33,11 +32,11 @@ def __init__( Parameters ---------- patch_size: List[int] - Size of each patch + Size of each patch in pix (ZYX order for 3D, YX order for 2D) emb_dim: int Dimension of encoder n_patches: List[int] - Number of patches in each spatial dimension + Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) spatial_dims: int Number of spatial dimensions context_pixels: List[int] @@ -65,12 +64,24 @@ def __init__( padding=context_pixels, ) self.img2token = Rearrange("b c z y x -> (z y x) b c") - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", - n_patch_z=n_patches[0], - n_patch_y=n_patches[1], - n_patch_x=n_patches[2], + self.patch2img = torch.nn.Sequential( + *[ + Rearrange( + "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", + n_patch_z=n_patches[0], + n_patch_y=n_patches[1], + n_patch_x=n_patches[2], + ), + Reduce( + "b c n_patch_z n_patch_y n_patch_x -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + reduction="repeat", + patch_size_z=patch_size[0], + patch_size_y=patch_size[1], + patch_size_x=patch_size[2], + ), + ] ) + elif spatial_dims == 2: self.conv = nn.Conv2d( in_channels=input_channels, @@ -80,12 +91,21 @@ def __init__( padding=context_pixels, ) self.img2token = Rearrange("b c y x -> (y x) b c") - self.patch2img = Rearrange( - "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", - n_patch_y=n_patches[0], - n_patch_x=n_patches[1], + self.patch2img = torch.nn.Sequential( + *[ + Rearrange( + "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", + n_patch_y=n_patches[0], + n_patch_x=n_patches[1], + ), + Reduce( + "b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + reduction="repeat", + patch_size_y=patch_size[0], + patch_size_x=patch_size[1], + ), + ] ) - self.task_embedding = torch.nn.ParameterDict( {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} ) @@ -99,27 +119,20 @@ def _init_weight(self): def get_mask(self, img, n_visible_patches, num_patches): B = img.shape[0] - indexes = [random_indexes(num_patches) for _ in range(B)] + indexes = [random_indexes(num_patches, img.device) for _ in range(B)] # forward indexes : index in image -> shuffledpatch - forward_indexes = torch.as_tensor( - np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long - ) + forward_indexes = torch.stack([i[0] for i in indexes], axis=-1) + # backward indexes : shuffled patch -> index in image - backward_indexes = torch.as_tensor( - np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long - ) + backward_indexes = torch.stack([i[1] for i in indexes], axis=-1) - mask = torch.zeros(num_patches, B, 1) + mask = torch.zeros(num_patches, B, 1, device=img.device, dtype=torch.bool) # visible patches are first - mask[:n_visible_patches] = 1 + mask[:n_visible_patches] = True mask = take_indexes(mask, backward_indexes) mask = self.patch2img(mask) - # one pixel per masked patch, interpolate to size of input image - mask = torch.nn.functional.interpolate( - mask, img.shape[-self.spatial_dims :], mode="nearest" - ) - return mask.to(img), forward_indexes, backward_indexes + return mask, forward_indexes, backward_indexes def forward(self, img, mask_ratio, task=None): # generate mask @@ -141,6 +154,4 @@ def forward(self, img, mask_ratio, task=None): tokens = tokens + self.task_embedding[task] # mask is used above to mask out patches, we need to invert it for loss calculation - mask = (1 - mask).bool() - - return tokens, mask, forward_indexes, backward_indexes + return tokens, ~mask, forward_indexes, backward_indexes diff --git a/cyto_dl/nn/vits/cross_mae.py b/cyto_dl/nn/vits/cross_mae.py index c1ba65d9b..e8bde3f43 100644 --- a/cyto_dl/nn/vits/cross_mae.py +++ b/cyto_dl/nn/vits/cross_mae.py @@ -90,18 +90,29 @@ def init_weight(self): def forward(self, features, forward_indexes, backward_indexes): # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers - features = features.squeeze(0) + # features can be n t b c (if intermediate feature weighter used) or t b c if not + features = features[0] if len(features.shape) == 4 else features T, B, C = features.shape # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now features = self.projection_norm(self.projection(features)) # add cls token backward_indexes = torch.cat( - [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], + [ + torch.zeros( + 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long + ), + backward_indexes + 1, + ], dim=0, ) forward_indexes = torch.cat( - [torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1], + [ + torch.zeros( + 1, forward_indexes.shape[1], device=forward_indexes.device, dtype=torch.long + ), + forward_indexes + 1, + ], dim=0, ) # fill in masked regions @@ -138,11 +149,18 @@ def forward(self, features, forward_indexes, backward_indexes): # add back in visible/encoded tokens that we don't calculate loss on patches = torch.cat( - [torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches], + [ + torch.zeros( + (T - 1, B, patches.shape[-1]), + requires_grad=False, + device=patches.device, + dtype=patches.dtype, + ), + patches, + ], dim=0, ) patches = take_indexes(patches, backward_indexes[1:] - 1) # patches to image img = self.patch2img(patches) - return img diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index 320a6c659..1cc906dea 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -168,7 +168,12 @@ def forward(self, features, forward_indexes, backward_indexes): features = self.projection_norm(self.projection(features)) backward_indexes = torch.cat( - [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], + [ + torch.zeros( + 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long + ), + backward_indexes + 1, + ], dim=0, ) # fill in masked regions diff --git a/cyto_dl/nn/vits/utils.py b/cyto_dl/nn/vits/utils.py index 8c38261a9..61263ccd0 100644 --- a/cyto_dl/nn/vits/utils.py +++ b/cyto_dl/nn/vits/utils.py @@ -3,6 +3,4 @@ def take_indexes(sequences, indexes): - return torch.gather( - sequences, 0, repeat(indexes.to(sequences.device), "t b -> t b c", c=sequences.shape[-1]) - ) + return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))