Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ViT Speedup #399

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 45 additions & 34 deletions cyto_dl/nn/vits/blocks/patchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
import numpy as np
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from einops.layers.torch import Rearrange, Reduce
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.utils import take_indexes


def random_indexes(size: int):
forward_indexes = np.arange(size)
np.random.shuffle(forward_indexes)
backward_indexes = np.argsort(forward_indexes)
def random_indexes(size: int, device):
forward_indexes = torch.randperm(size, device=device, dtype=torch.long)
backward_indexes = torch.argsort(forward_indexes)
return forward_indexes, backward_indexes


Expand All @@ -33,11 +32,11 @@ def __init__(
Parameters
----------
patch_size: List[int]
Size of each patch
Size of each patch in pix (ZYX order for 3D, YX order for 2D)
emb_dim: int
Dimension of encoder
n_patches: List[int]
Number of patches in each spatial dimension
Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D)
spatial_dims: int
Number of spatial dimensions
context_pixels: List[int]
Expand Down Expand Up @@ -65,12 +64,24 @@ def __init__(
padding=context_pixels,
)
self.img2token = Rearrange("b c z y x -> (z y x) b c")
self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x",
n_patch_z=n_patches[0],
n_patch_y=n_patches[1],
n_patch_x=n_patches[2],
self.patch2img = torch.nn.Sequential(
*[
Rearrange(
"(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x",
n_patch_z=n_patches[0],
n_patch_y=n_patches[1],
n_patch_x=n_patches[2],
),
Reduce(
"b c n_patch_z n_patch_y n_patch_x -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
reduction="repeat",
patch_size_z=patch_size[0],
patch_size_y=patch_size[1],
patch_size_x=patch_size[2],
),
]
)

elif spatial_dims == 2:
self.conv = nn.Conv2d(
in_channels=input_channels,
Expand All @@ -80,12 +91,21 @@ def __init__(
padding=context_pixels,
)
self.img2token = Rearrange("b c y x -> (y x) b c")
self.patch2img = Rearrange(
"(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x",
n_patch_y=n_patches[0],
n_patch_x=n_patches[1],
self.patch2img = torch.nn.Sequential(
*[
Rearrange(
"(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x",
n_patch_y=n_patches[0],
n_patch_x=n_patches[1],
),
Reduce(
"b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
reduction="repeat",
patch_size_y=patch_size[0],
patch_size_x=patch_size[1],
),
]
)

self.task_embedding = torch.nn.ParameterDict(
{task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks}
)
Expand All @@ -99,27 +119,20 @@ def _init_weight(self):
def get_mask(self, img, n_visible_patches, num_patches):
B = img.shape[0]

indexes = [random_indexes(num_patches) for _ in range(B)]
indexes = [random_indexes(num_patches, img.device) for _ in range(B)]
# forward indexes : index in image -> shuffledpatch
forward_indexes = torch.as_tensor(
np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long
)
forward_indexes = torch.stack([i[0] for i in indexes], axis=-1)

# backward indexes : shuffled patch -> index in image
backward_indexes = torch.as_tensor(
np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long
)
backward_indexes = torch.stack([i[1] for i in indexes], axis=-1)

mask = torch.zeros(num_patches, B, 1)
mask = torch.zeros(num_patches, B, 1, device=img.device, dtype=torch.bool)
benjijamorris marked this conversation as resolved.
Show resolved Hide resolved
# visible patches are first
mask[:n_visible_patches] = 1
mask[:n_visible_patches] = True
mask = take_indexes(mask, backward_indexes)
mask = self.patch2img(mask)
# one pixel per masked patch, interpolate to size of input image
mask = torch.nn.functional.interpolate(
mask, img.shape[-self.spatial_dims :], mode="nearest"
)

return mask.to(img), forward_indexes, backward_indexes
return mask, forward_indexes, backward_indexes

def forward(self, img, mask_ratio, task=None):
# generate mask
Expand All @@ -141,6 +154,4 @@ def forward(self, img, mask_ratio, task=None):
tokens = tokens + self.task_embedding[task]

# mask is used above to mask out patches, we need to invert it for loss calculation
mask = (1 - mask).bool()

return tokens, mask, forward_indexes, backward_indexes
return tokens, ~mask, forward_indexes, backward_indexes
28 changes: 23 additions & 5 deletions cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,29 @@ def init_weight(self):

def forward(self, features, forward_indexes, backward_indexes):
# HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers
features = features.squeeze(0)
# features can be n t b c (if intermediate feature weighter used) or t b c if not
features = features[0] if len(features.shape) == 4 else features
T, B, C = features.shape
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now
features = self.projection_norm(self.projection(features))

# add cls token
backward_indexes = torch.cat(
[torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
[
torch.zeros(
1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long
benjijamorris marked this conversation as resolved.
Show resolved Hide resolved
),
backward_indexes + 1,
],
dim=0,
)
forward_indexes = torch.cat(
[torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1],
[
torch.zeros(
1, forward_indexes.shape[1], device=forward_indexes.device, dtype=torch.long
),
forward_indexes + 1,
],
dim=0,
)
# fill in masked regions
Expand Down Expand Up @@ -138,11 +149,18 @@ def forward(self, features, forward_indexes, backward_indexes):

# add back in visible/encoded tokens that we don't calculate loss on
patches = torch.cat(
[torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches],
[
torch.zeros(
(T - 1, B, patches.shape[-1]),
requires_grad=False,
device=patches.device,
dtype=patches.dtype,
),
patches,
],
dim=0,
)
patches = take_indexes(patches, backward_indexes[1:] - 1)
# patches to image
img = self.patch2img(patches)

return img
7 changes: 6 additions & 1 deletion cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,12 @@ def forward(self, features, forward_indexes, backward_indexes):
features = self.projection_norm(self.projection(features))

backward_indexes = torch.cat(
[torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
[
torch.zeros(
1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long
),
backward_indexes + 1,
],
dim=0,
)
# fill in masked regions
Expand Down
4 changes: 1 addition & 3 deletions cyto_dl/nn/vits/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,4 @@


def take_indexes(sequences, indexes):
return torch.gather(
sequences, 0, repeat(indexes.to(sequences.device), "t b -> t b c", c=sequences.shape[-1])
)
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))
Loading