From 8f0bd34091b03cb6c8dd73e4b70ecc04df22b041 Mon Sep 17 00:00:00 2001 From: kian-kd Date: Wed, 10 Apr 2024 14:25:36 -0400 Subject: [PATCH 1/2] dump code --- README.md | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 0c9e19e..8e16503 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,17 @@ # Masked Autoencoders are Scalable Learners of Cellular Morphology -Official repo for Recursion's accepted spotlight paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio). - -Paper: https://arxiv.org/abs/2309.16064 +Official repo for Recursion's two recently accepted papers: +- Spotlight full-length paper at [CVPR 2024](https://cvpr.thecvf.com/Conferences/2024/AcceptedPapers) -- Masked Autoencoders for Microscopy are Scalable Learners of Cellular Biology + - Paper: link to be shared soon! +- Spotlight workshop paper at [NeurIPS 2023 Generative AI & Biology workshop](https://openreview.net/group?id=NeurIPS.cc/2023/Workshop/GenBio) + - Paper: https://arxiv.org/abs/2309.16064 ![vit_diff_mask_ratios](https://github.com/recursionpharma/maes_microscopy/assets/109550980/c15f46b1-cdb9-41a7-a4af-bdc9684a971d) ## Provided code -The baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm: +See the repo for ingredients required for defining our MAEs. Users seeking to re-implement training will need to stitch together the Encoder and Decoder modules according to their usecase. + +Furthermore the baseline Vision Transformer architecture backbone used in this work can be built with the following code snippet from Timm: ``` import timm.models.vision_transformer as vit @@ -29,7 +33,7 @@ def vit_base_patch16_256(**kwargs): return vit.vit_base_patch16_224(**default_kwargs) ``` -Additional code will be released as the date of the workshop gets closer. - ## Provided models -Stay tuned... +A publicly available model for research can be found via Nvidia's BioNemo platform, which handles inference and auto-scaling for you: https://www.rxrx.ai/phenom + +We are not able to release model weights at this time. From 373b8b8d61269c88b9d3020d8c728be5dea4ed21 Mon Sep 17 00:00:00 2001 From: kian-kd Date: Wed, 10 Apr 2024 14:25:49 -0400 Subject: [PATCH 2/2] dump --- config.yaml | 15 +++ loss.py | 50 +++++++++ mae_modules.py | 272 ++++++++++++++++++++++++++++++++++++++++++++++ mae_utils.py | 64 +++++++++++ masking.py | 46 ++++++++ vit.py | 284 +++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 731 insertions(+) create mode 100644 config.yaml create mode 100644 loss.py create mode 100644 mae_modules.py create mode 100644 mae_utils.py create mode 100644 masking.py create mode 100644 vit.py diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..8fea17a --- /dev/null +++ b/config.yaml @@ -0,0 +1,15 @@ +loss: + _target_: torch.nn.MSELoss # combine with fourier loss weighted at 0.01 mixing factor for best results + reduction: none +optimizer: + _target_: timm.optim.lion.Lion + _partial_: true + lr: *lr 1e-4 # 1e-4 for <= ViT-B, and 3e-5 for ViT-L + weight_decay: 0.05 + betas: [0.9, 0.95] +lr_scheduler: + _target_: torch.optim.lr_scheduler.OneCycleLR + _partial_: true + max_lr: @lr + pct_start: 0.1 + anneal_strategy: cos \ No newline at end of file diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..59865d5 --- /dev/null +++ b/loss.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + + +class FourierLoss(nn.Module): + def __init__( + self, + use_l1_loss: bool = True, + num_multimodal_modalities: int = 1, # set to 1 for vanilla MAE, 6 for channel-agnostic MAE + ) -> None: + """ + Fourier transform loss is only sound when using L1 or L2 loss to compare the frequency domains + between the images / their radial histograms. + + We will always set `reduction="none"` and enforce that the computation of any reductions from the + output of this loss be managed by the model under question. + """ + super().__init__() + self.loss = nn.L1Loss(reduction="none") if use_l1_loss else nn.MSELoss(reduction="none") + self.num_modalities = num_multimodal_modalities + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + # input = reconstructed image, target = original image + # flattened images from MAE are (B, H*W, C), so, here we convert to B x C x H x W (note we assume H == W) + flattened_images = len(input.shape) == len(target.shape) == 3 + if flattened_images: + B, H_W, C = input.shape + H_W = H_W // self.num_modalities + four_d_shape = (B, C * self.num_modalities, int(H_W**0.5), int(H_W**0.5)) + input = input.view(*four_d_shape) + target = target.view(*four_d_shape) + else: + B, C, h, w = input.shape + H_W = h * w + + if len(input.shape) != len(target.shape) != 4: + raise ValueError(f"Invalid input shape: got {input.shape} and {target.shape}.") + + fft_reconstructed = torch.fft.fft2(input) + fft_original = torch.fft.fft2(target) + + magnitude_reconstructed = torch.abs(fft_reconstructed) + magnitude_original = torch.abs(fft_original) + + loss_tensor: torch.Tensor = self.loss(magnitude_reconstructed, magnitude_original) + + if flattened_images and not self.num_bins: # then output loss should be reshaped + loss_tensor = loss_tensor.reshape(B, H_W * self.num_modalities, C) + + return loss_tensor diff --git a/mae_modules.py b/mae_modules.py new file mode 100644 index 0000000..08de947 --- /dev/null +++ b/mae_modules.py @@ -0,0 +1,272 @@ +from functools import partial +from typing import Tuple, Union + +import torch +import torch.nn as nn +from timm.models.helpers import checkpoint_seq +from timm.models.vision_transformer import Block, Mlp, VisionTransformer + +from .masking import transformer_random_masking +from .vit import channel_agnostic_vit + +# If interested in training new MAEs, combine an encoder and decoder into a new module, and you should +# leverage the flattening and unflattening utilities as needed from mae_utils.py. +# Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions. +# As described in the paper, images are self-standardized at the start. + + +class SelfStandardize(nn.Module): + def __init__(self) -> None: + super().__init__() + self.self_standardize = nn.LazyInstanceNorm2d( + affine=False, track_running_stats=False + ) + + def forward(self, pixels: torch.Tensor) -> torch.Tensor: + x = pixels.float() / 255.0 + return self.self_standardize(x) + + +class MAEEncoder(nn.Module): + def __init__( + self, + vit_backbone: VisionTransformer, + max_in_chans: int = 6, + channel_agnostic: bool = False, + ) -> None: + super().__init__() + if channel_agnostic: + self.vit_backbone = channel_agnostic_vit( + vit_backbone, max_in_chans=max_in_chans + ) + else: + self.vit_backbone = vit_backbone + self.max_in_chans = max_in_chans + self.channel_agnostic = channel_agnostic + + @property + def embed_dim(self) -> int: + return int(self.vit_backbone.embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.vit_backbone.forward_features(x) + x = self.vit_backbone.forward_head(x) + return x # type: ignore[no-any-return] + + def forward_masked( + self, + x: torch.Tensor, + mask_ratio: float, + constant_noise: Union[torch.Tensor, None] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = self.vit_backbone.patch_embed(x) + x = self.vit_backbone._pos_embed(x) # adds class token + x_ = x[:, 1:, :] # no class token + x_, mask, ind_restore = transformer_random_masking( + x_, mask_ratio, constant_noise + ) + x = torch.cat([x[:, :1, :], x_], dim=1) # add class token + x = self.vit_backbone.norm_pre(x) + + if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.vit_backbone.blocks, x) + else: + x = self.vit_backbone.blocks(x) + x = self.vit_backbone.norm(x) + return x, mask, ind_restore + + +class MAEDecoder(nn.Module): + def __init__( + self, + embed_dim: int = 512, + depth: int = 8, + num_heads: int = 16, + mlp_ratio: float = 4, + qkv_bias: bool = True, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment] + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.pos_embeddings = None # to be overwritten by MAE class + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.blocks = nn.Sequential( + *[ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + self.norm = norm_layer(embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.pos_embeddings + x = self.blocks(x) + x = self.norm(x) + return x # type: ignore[no-any-return] + + def forward_masked( + self, x: torch.Tensor, ind_restore: torch.Tensor + ) -> torch.Tensor: + mask_tokens = self.mask_token.repeat( + x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 + ) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token + x_ = torch.gather( + x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # add class token + + x = x + self.pos_embeddings + x = self.blocks(x) + x = self.norm(x) + return x # type: ignore[no-any-return] + + +class CrossAttention(nn.Module): + def __init__( + self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0 + ): + super().__init__() + self.num_heads = num_heads + head_dim = embed_dim // num_heads + self.scale = head_dim**-0.5 + + self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) + self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dim, embed_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, context): + B, N, C = x.shape + _, M, _ = context.shape + + q = ( + self.q(x) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) + kv = ( + self.kv(context) + .reshape(B, M, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CAMAEDecoder(nn.Module): + def __init__( + self, + num_modalities: int = 6, + tokens_per_modality: int = 256, + embed_dim: int = 256, + depth: int = 2, + num_heads: int = 16, + mlp_ratio: float = 4, + qkv_bias: bool = True, + norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment] + ) -> None: + super().__init__() + self.num_modalities = num_modalities + self.tokens_per_modality = tokens_per_modality + self.embed_dim = embed_dim + self.pos_embeddings = None # to be overwritten by MAE class + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.placeholder = nn.Parameter( + torch.zeros(1, 1, embed_dim), requires_grad=False + ) + self.modality_tokens = nn.ParameterList( + [ + nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + for modality in range(self.num_modalities) + ] + ) + + self.cross_attention = CrossAttention(embed_dim=self.embed_dim) + self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio)) + + self.decoders = nn.ModuleList( + [ + nn.Sequential( + *[ + Block( + embed_dim, + num_heads, + mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + for modality in range(self.num_modalities) + ] + ) + # self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm + self.context_norm = norm_layer(embed_dim) + self.query_norm = norm_layer(embed_dim) + self.out_norm = norm_layer(embed_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_m_s = [] + + modality_tokens_concat = torch.cat( + [ + self.placeholder, + ] # placeholder for class token + + [ + m_t.repeat(1, self.tokens_per_modality, 1) + for m_t in self.modality_tokens + ], + dim=1, + ) + + x = ( + x + self.pos_embeddings + modality_tokens_concat + ) # add pos and tiled modality tokens + x_ = x[:, 1:, :] # no class token + for m, decoder in enumerate( + self.decoders + ): # iterate through modalities and decoders + x_m = x_[ + :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, : + ] + x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_)) + x_m = x_m + self.mlp(self.out_norm(x_m)) + x_m = decoder(x_m) + x_m_s.append(x_m) + x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens + # x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm + x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token + + return x_m_s + + def forward_masked( + self, x: torch.Tensor, ind_restore: torch.Tensor + ) -> torch.Tensor: + mask_tokens = self.mask_token.repeat( + x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 + ) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token + x_ = torch.gather( + x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # add class token + x = self.forward(x) + return x diff --git a/mae_utils.py b/mae_utils.py new file mode 100644 index 0000000..0bf7018 --- /dev/null +++ b/mae_utils.py @@ -0,0 +1,64 @@ +import math + +import torch + + +def flatten_images(img: torch.Tensor, patch_size: int, channel_agnostic: bool = False) -> torch.Tensor: + """ + Flattens 2D images into tokens with the same pixel values + + Parameters + ---------- + img : input image tensor (N, C, H, W) + + Returns + ------- + flattened_img: flattened image tensor (N, L, patch_size**2 * C) + """ + + if (img.shape[2] != img.shape[3]) or (img.shape[2] % patch_size != 0): + raise ValueError("image H must equal image W and be divisible by patch_size") + in_chans = img.shape[1] + + h = w = int(img.shape[2] // patch_size) + x = img.reshape(shape=(img.shape[0], in_chans, h, patch_size, w, patch_size)) + + if channel_agnostic: + x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHPWQ -> NCHWPQ + x = x.reshape(shape=(img.shape[0], in_chans * h * w, int(patch_size**2))) + else: + x = torch.permute(x, (0, 2, 4, 3, 5, 1)) # NCHPWQ -> NHWPQC + x = x.reshape(shape=(img.shape[0], h * w, int(patch_size**2 * in_chans))) + return x + + +def unflatten_tokens( + tokens: torch.Tensor, patch_size: int, num_modalities: int = 1, channel_agnostic: bool = False +) -> torch.Tensor: + """ + Unflattens tokens (N,L,patch_size**2 * C) into image tensor (N,C,H,W) with the pixel values + + Parameters + ---------- + tokens : input token tensor (N,L,patch_size**2 * C) + + Returns + ------- + img: image tensor (N,C,H,W) + """ + if num_modalities > 1 and not channel_agnostic: + raise ValueError("Multiple modalities requires channel agnostic unflattening.") + + h = w = int(math.sqrt(tokens.shape[1] // num_modalities)) + if h * w != (tokens.shape[1] // num_modalities): + raise ValueError("sqrt of number of tokens not integer") + + if channel_agnostic: + x = tokens.reshape(shape=(tokens.shape[0], -1, h, w, patch_size, patch_size)) + x = torch.permute(x, (0, 1, 2, 4, 3, 5)) # NCHWPQ -> NCHPWQ + else: + x = tokens.reshape(shape=(tokens.shape[0], h, w, patch_size, patch_size, -1)) + x = torch.permute(x, (0, 5, 1, 3, 2, 4)) # NHWPQC -> NCHPWQ + img = x.reshape(shape=(x.shape[0], -1, h * patch_size, h * patch_size)) + + return img diff --git a/masking.py b/masking.py new file mode 100644 index 0000000..1ddb788 --- /dev/null +++ b/masking.py @@ -0,0 +1,46 @@ +from typing import Tuple, Union + +import torch + + +def transformer_random_masking( + x: torch.Tensor, mask_ratio: float, constant_noise: Union[torch.Tensor, None] = None +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Random mask patches per sample + + Parameters + ---------- + x : token tensor (N, L, D) + mask_ratio: float - ratio of image to mask + constant_noise: None, if provided should be a tensor of shape (N, L) to produce consistent masks + + Returns + ------- + x_masked : sub-sampled version of x ( int(mask_ratio * N), L, D) + mask : binary mask indicated masked tokens (1 where masked) (N, L) + ind_restore : locations of masked tokens, needed for decoder + """ + + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + # use random noise to generate batch based random masks + if constant_noise is not None: + noise = constant_noise + else: + noise = torch.rand(N, L, device=x.device) + + shuffled_tokens = torch.argsort(noise, dim=1) # shuffled index + ind_restore = torch.argsort(shuffled_tokens, dim=1) # unshuffled index + + # get masked input + tokens_to_keep = shuffled_tokens[:, :len_keep] # keep the first len_keep indices + x_masked = torch.gather(x, dim=1, index=tokens_to_keep.unsqueeze(-1).repeat(1, 1, D)) + + # get binary mask used for loss masking: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + mask = torch.gather(mask, dim=1, index=ind_restore) # unshuffle to get the binary mask + + return x_masked, mask, ind_restore diff --git a/vit.py b/vit.py new file mode 100644 index 0000000..79ad568 --- /dev/null +++ b/vit.py @@ -0,0 +1,284 @@ +import timm.models.vision_transformer as vit +import torch + + +def generate_2d_sincos_pos_embeddings( + embedding_dim: int, length: int, scale: float = 10000.0, use_class_token: bool = True, num_modality: int = 1 +) -> torch.nn.Parameter: + """ + Generate 2Dimensional sin/cosine positional embeddings + + Parameters + ---------- + embedding_dim : int + embedding dimension used in vit + length : int + number of tokens along height or width of image after patching (assuming square) + scale : float + scale for sin/cos functions + use_class_token : bool + True - add zero vector to be added to class_token, False - no vector added + num_modality: number of modalities. If 0, a single modality is assumed. + Otherwise one-hot modality encoding is added and sincos encoding size is appropriately reduced. + + Returns + ------- + positional_encoding : torch.Tensor + positional encoding to add to vit patch encodings + [num_modality*length*length, embedding_dim] or [1+num_modality*length*length, embedding_dim] + (w/ or w/o cls_token) + """ + + linear_positions = torch.arange(length, dtype=torch.float32) + height_mesh, width_mesh = torch.meshgrid(linear_positions, linear_positions, indexing="ij") + positional_dim = embedding_dim // 4 # accomodate h and w x cos and sin embeddings + positional_weights = torch.arange(positional_dim, dtype=torch.float32) / positional_dim + positional_weights = 1.0 / (scale**positional_weights) + + height_weights = torch.outer(height_mesh.flatten(), positional_weights) + width_weights = torch.outer(width_mesh.flatten(), positional_weights) + + positional_encoding = torch.cat( + [torch.sin(height_weights), torch.cos(height_weights), torch.sin(width_weights), torch.cos(width_weights)], + dim=1, + )[None, :, :] + + # repeat positional encoding for multiple channel modalities + positional_encoding = positional_encoding.repeat(1, num_modality, 1) + + if use_class_token: + class_token = torch.zeros([1, 1, embedding_dim], dtype=torch.float32) + positional_encoding = torch.cat([class_token, positional_encoding], dim=1) + + positional_encoding = torch.nn.Parameter(positional_encoding, requires_grad=False) + + return positional_encoding + + +class ChannelAgnosticPatchEmbed(vit.PatchEmbed): # type: ignore[misc] + def __init__( + self, + img_size: int, + patch_size: int, + embed_dim: int, + bias: bool = True, + ) -> None: + super().__init__( + img_size=img_size, + patch_size=patch_size, + in_chans=1, # in_chans is used by self.proj, which we override anyway + embed_dim=embed_dim, + norm_layer=None, + flatten=False, + bias=bias, + ) + # channel-agnostic MAE has a single projection for all chans + self.proj = torch.nn.Conv2d(1, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + in_chans = x.shape[1] + x = torch.stack([self.proj(x[:, i : i + 1]) for i in range(in_chans)], dim=2) # single project for all chans + x = x.flatten(2).transpose(1, 2) # BCMHW -> BNC + return x + + +class ChannelAgnosticViT(vit.VisionTransformer): # type: ignore[misc] + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: + # rewrite https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L586 + to_cat = [] + if self.cls_token is not None: + to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) + + # TODO: upgrade timm to get access to register tokens + # if self.vit_backbone.reg_token is not None: + # to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) + + # MAIN DIFFERENCE with Timm - we DYNAMICALLY ADDING POS EMBEDDINGS based on shape of inputs + # this supports having CA-MAEs actually be channel-agnostic at inference time + if self.no_embed_class: + x = x + self.pos_embed[:, : x.shape[1]] + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + else: + if to_cat: + x = torch.cat(to_cat + [x], dim=1) + x = x + self.pos_embed[:, : x.shape[1]] + return self.pos_drop(x) # type: ignore[no-any-return] + + +def channel_agnostic_vit(vit_backbone: vit.VisionTransformer, max_in_chans: int) -> vit.VisionTransformer: + # replace patch embedding with channel-agnostic version + vit_backbone.patch_embed = ChannelAgnosticPatchEmbed( + img_size=vit_backbone.patch_embed.img_size[0], + patch_size=vit_backbone.patch_embed.patch_size[0], + embed_dim=vit_backbone.embed_dim, + ) + + # replace positional embedding with channel-agnostic version + vit_backbone.pos_embed = generate_2d_sincos_pos_embeddings( + embedding_dim=vit_backbone.embed_dim, + length=vit_backbone.patch_embed.grid_size[0], + use_class_token=vit_backbone.cls_token is not None, + num_modality=max_in_chans, + ) + + # change the class to be ChannelAgnostic so that it actually uses the new _pos_embed + vit_backbone.__class__ = ChannelAgnosticViT + return vit_backbone + + +def sincos_positional_encoding_vit( + vit_backbone: vit.VisionTransformer, scale: float = 10000.0 +) -> vit.VisionTransformer: + """Attaches no-grad sin-cos positional embeddings to a pre-constructed ViT backbone model. + + Parameters + ---------- + vit_backbone : timm.models.vision_transformer.VisionTransformer + the constructed vision transformer from timm + scale : float (default 10000.0) + hyperparameter for sincos positional embeddings, recommend keeping at 10,000 + + Returns + ------- + timm.models.vision_transformer.VisionTransformer + the same ViT but with fixed no-grad positional encodings to add to vit patch encodings + """ + # length: number of tokens along height or width of image after patching (assuming square) + length = vit_backbone.patch_embed.img_size[0] // vit_backbone.patch_embed.patch_size[0] + pos_embeddings = generate_2d_sincos_pos_embeddings( + vit_backbone.embed_dim, length=length, scale=scale, use_class_token=vit_backbone.cls_token is not None + ) + # note, if the model had weight_init == 'skip', this might get overwritten + vit_backbone.pos_embed = pos_embeddings + return vit_backbone + + +def vit_small_patch16_256(**kwargs): + default_kwargs = dict( + img_size=256, + in_chans=6, + num_classes=0, + fc_norm=None, + class_token=True, + drop_path_rate=0.1, + init_values=0.0001, + block_fn=vit.ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + for k, v in kwargs.items(): + default_kwargs[k] = v + return vit.vit_small_patch16_224(**default_kwargs) + + +def vit_small_patch32_512(**kwargs): + default_kwargs = dict( + img_size=512, + in_chans=6, + num_classes=0, + fc_norm=None, + class_token=True, + drop_path_rate=0.1, + init_values=0.0001, + block_fn=vit.ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + for k, v in kwargs.items(): + default_kwargs[k] = v + return vit.vit_small_patch32_384(**default_kwargs) + + +def vit_base_patch8_256(**kwargs): + default_kwargs = dict( + img_size=256, + in_chans=6, + num_classes=0, + fc_norm=None, + class_token=True, + drop_path_rate=0.1, + init_values=0.0001, + block_fn=vit.ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + for k, v in kwargs.items(): + default_kwargs[k] = v + return vit.vit_base_patch8_224(**default_kwargs) + + +def vit_base_patch16_256(**kwargs): + default_kwargs = dict( + img_size=256, + in_chans=6, + num_classes=0, + fc_norm=None, + class_token=True, + drop_path_rate=0.1, + init_values=0.0001, + block_fn=vit.ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + for k, v in kwargs.items(): + default_kwargs[k] = v + return vit.vit_base_patch16_224(**default_kwargs) + + +def vit_base_patch32_512(**kwargs): + default_kwargs = dict( + img_size=512, + in_chans=6, + num_classes=0, + fc_norm=None, + class_token=True, + drop_path_rate=0.1, + init_values=0.0001, + block_fn=vit.ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + for k, v in kwargs.items(): + default_kwargs[k] = v + return vit.vit_base_patch32_384(**default_kwargs) + + +def vit_large_patch8_256(**kwargs): + default_kwargs = dict( + img_size=256, + in_chans=6, + num_classes=0, + fc_norm=None, + class_token=True, + patch_size=8, + embed_dim=1024, + depth=24, + num_heads=16, + drop_path_rate=0.3, + init_values=0.0001, + block_fn=vit.ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + for k, v in kwargs.items(): + default_kwargs[k] = v + return vit.VisionTransformer(**default_kwargs) + + +def vit_large_patch16_256(**kwargs): + default_kwargs = dict( + img_size=256, + in_chans=6, + num_classes=0, + fc_norm=None, + class_token=True, + drop_path_rate=0.3, + init_values=0.0001, + block_fn=vit.ParallelScalingBlock, + qkv_bias=False, + qk_norm=True, + ) + for k, v in kwargs.items(): + default_kwargs[k] = v + return vit.vit_large_patch16_384(**default_kwargs)