Skip to content

Commit

Permalink
ViT Speedup (#399)
Browse files Browse the repository at this point in the history
* simplify vit indexing

* update with ritviks comments

---------

Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
3 people authored Jul 2, 2024
1 parent 4b4e37a commit 592177b
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 43 deletions.
79 changes: 45 additions & 34 deletions cyto_dl/nn/vits/blocks/patchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
import numpy as np
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from einops.layers.torch import Rearrange, Reduce
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.utils import take_indexes


def random_indexes(size: int):
forward_indexes = np.arange(size)
np.random.shuffle(forward_indexes)
backward_indexes = np.argsort(forward_indexes)
def random_indexes(size: int, device):
forward_indexes = torch.randperm(size, device=device, dtype=torch.long)
backward_indexes = torch.argsort(forward_indexes)
return forward_indexes, backward_indexes


Expand All @@ -33,11 +32,11 @@ def __init__(
Parameters
----------
patch_size: List[int]
Size of each patch
Size of each patch in pix (ZYX order for 3D, YX order for 2D)
emb_dim: int
Dimension of encoder
n_patches: List[int]
Number of patches in each spatial dimension
Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D)
spatial_dims: int
Number of spatial dimensions
context_pixels: List[int]
Expand Down Expand Up @@ -65,12 +64,24 @@ def __init__(
padding=context_pixels,
)
self.img2token = Rearrange("b c z y x -> (z y x) b c")
self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x",
n_patch_z=n_patches[0],
n_patch_y=n_patches[1],
n_patch_x=n_patches[2],
self.patch2img = torch.nn.Sequential(
*[
Rearrange(
"(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x",
n_patch_z=n_patches[0],
n_patch_y=n_patches[1],
n_patch_x=n_patches[2],
),
Reduce(
"b c n_patch_z n_patch_y n_patch_x -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
reduction="repeat",
patch_size_z=patch_size[0],
patch_size_y=patch_size[1],
patch_size_x=patch_size[2],
),
]
)

elif spatial_dims == 2:
self.conv = nn.Conv2d(
in_channels=input_channels,
Expand All @@ -80,12 +91,21 @@ def __init__(
padding=context_pixels,
)
self.img2token = Rearrange("b c y x -> (y x) b c")
self.patch2img = Rearrange(
"(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x",
n_patch_y=n_patches[0],
n_patch_x=n_patches[1],
self.patch2img = torch.nn.Sequential(
*[
Rearrange(
"(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x",
n_patch_y=n_patches[0],
n_patch_x=n_patches[1],
),
Reduce(
"b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
reduction="repeat",
patch_size_y=patch_size[0],
patch_size_x=patch_size[1],
),
]
)

self.task_embedding = torch.nn.ParameterDict(
{task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks}
)
Expand All @@ -99,27 +119,20 @@ def _init_weight(self):
def get_mask(self, img, n_visible_patches, num_patches):
B = img.shape[0]

indexes = [random_indexes(num_patches) for _ in range(B)]
indexes = [random_indexes(num_patches, img.device) for _ in range(B)]
# forward indexes : index in image -> shuffledpatch
forward_indexes = torch.as_tensor(
np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long
)
forward_indexes = torch.stack([i[0] for i in indexes], axis=-1)

# backward indexes : shuffled patch -> index in image
backward_indexes = torch.as_tensor(
np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long
)
backward_indexes = torch.stack([i[1] for i in indexes], axis=-1)

mask = torch.zeros(num_patches, B, 1)
mask = torch.zeros(num_patches, B, 1, device=img.device, dtype=torch.bool)
# visible patches are first
mask[:n_visible_patches] = 1
mask[:n_visible_patches] = True
mask = take_indexes(mask, backward_indexes)
mask = self.patch2img(mask)
# one pixel per masked patch, interpolate to size of input image
mask = torch.nn.functional.interpolate(
mask, img.shape[-self.spatial_dims :], mode="nearest"
)

return mask.to(img), forward_indexes, backward_indexes
return mask, forward_indexes, backward_indexes

def forward(self, img, mask_ratio, task=None):
# generate mask
Expand All @@ -141,6 +154,4 @@ def forward(self, img, mask_ratio, task=None):
tokens = tokens + self.task_embedding[task]

# mask is used above to mask out patches, we need to invert it for loss calculation
mask = (1 - mask).bool()

return tokens, mask, forward_indexes, backward_indexes
return tokens, ~mask, forward_indexes, backward_indexes
28 changes: 23 additions & 5 deletions cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,29 @@ def init_weight(self):

def forward(self, features, forward_indexes, backward_indexes):
# HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers
features = features.squeeze(0)
# features can be n t b c (if intermediate feature weighter used) or t b c if not
features = features[0] if len(features.shape) == 4 else features
T, B, C = features.shape
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now
features = self.projection_norm(self.projection(features))

# add cls token
backward_indexes = torch.cat(
[torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
[
torch.zeros(
1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long
),
backward_indexes + 1,
],
dim=0,
)
forward_indexes = torch.cat(
[torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1],
[
torch.zeros(
1, forward_indexes.shape[1], device=forward_indexes.device, dtype=torch.long
),
forward_indexes + 1,
],
dim=0,
)
# fill in masked regions
Expand Down Expand Up @@ -138,11 +149,18 @@ def forward(self, features, forward_indexes, backward_indexes):

# add back in visible/encoded tokens that we don't calculate loss on
patches = torch.cat(
[torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches],
[
torch.zeros(
(T - 1, B, patches.shape[-1]),
requires_grad=False,
device=patches.device,
dtype=patches.dtype,
),
patches,
],
dim=0,
)
patches = take_indexes(patches, backward_indexes[1:] - 1)
# patches to image
img = self.patch2img(patches)

return img
7 changes: 6 additions & 1 deletion cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,12 @@ def forward(self, features, forward_indexes, backward_indexes):
features = self.projection_norm(self.projection(features))

backward_indexes = torch.cat(
[torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
[
torch.zeros(
1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long
),
backward_indexes + 1,
],
dim=0,
)
# fill in masked regions
Expand Down
4 changes: 1 addition & 3 deletions cyto_dl/nn/vits/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,4 @@


def take_indexes(sequences, indexes):
return torch.gather(
sequences, 0, repeat(indexes.to(sequences.device), "t b -> t b c", c=sequences.shape[-1])
)
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))

0 comments on commit 592177b

Please sign in to comment.