From fd2a520065e8a25485de3f5712f675f9a5e73921 Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Tue, 2 Jul 2024 22:02:15 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20main=20@=205?= =?UTF-8?q?92177bab4df09337e98d1d0837082a034a4facf=20=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _modules/cyto_dl/nn/vits/blocks/patchify.html | 79 +++++++++++-------- _modules/cyto_dl/nn/vits/cross_mae.html | 28 +++++-- _modules/cyto_dl/nn/vits/mae.html | 7 +- _modules/cyto_dl/nn/vits/utils.html | 4 +- cyto_dl.nn.vits.blocks.patchify.html | 6 +- searchindex.js | 2 +- 6 files changed, 79 insertions(+), 47 deletions(-) diff --git a/_modules/cyto_dl/nn/vits/blocks/patchify.html b/_modules/cyto_dl/nn/vits/blocks/patchify.html index 23a2a820..88068ae4 100644 --- a/_modules/cyto_dl/nn/vits/blocks/patchify.html +++ b/_modules/cyto_dl/nn/vits/blocks/patchify.html @@ -418,16 +418,15 @@

Source code for cyto_dl.nn.vits.blocks.patchify

< import numpy as np import torch import torch.nn as nn -from einops.layers.torch import Rearrange +from einops.layers.torch import Rearrange, Reduce from timm.models.layers import trunc_normal_ from cyto_dl.nn.vits.utils import take_indexes -
[docs]def random_indexes(size: int): - forward_indexes = np.arange(size) - np.random.shuffle(forward_indexes) - backward_indexes = np.argsort(forward_indexes) +
[docs]def random_indexes(size: int, device): + forward_indexes = torch.randperm(size, device=device, dtype=torch.long) + backward_indexes = torch.argsort(forward_indexes) return forward_indexes, backward_indexes
@@ -448,11 +447,11 @@

Source code for cyto_dl.nn.vits.blocks.patchify

< Parameters ---------- patch_size: List[int] - Size of each patch + Size of each patch in pix (ZYX order for 3D, YX order for 2D) emb_dim: int Dimension of encoder n_patches: List[int] - Number of patches in each spatial dimension + Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) spatial_dims: int Number of spatial dimensions context_pixels: List[int] @@ -480,12 +479,24 @@

Source code for cyto_dl.nn.vits.blocks.patchify

< padding=context_pixels, ) self.img2token = Rearrange("b c z y x -> (z y x) b c") - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", - n_patch_z=n_patches[0], - n_patch_y=n_patches[1], - n_patch_x=n_patches[2], + self.patch2img = torch.nn.Sequential( + *[ + Rearrange( + "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", + n_patch_z=n_patches[0], + n_patch_y=n_patches[1], + n_patch_x=n_patches[2], + ), + Reduce( + "b c n_patch_z n_patch_y n_patch_x -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + reduction="repeat", + patch_size_z=patch_size[0], + patch_size_y=patch_size[1], + patch_size_x=patch_size[2], + ), + ] ) + elif spatial_dims == 2: self.conv = nn.Conv2d( in_channels=input_channels, @@ -495,12 +506,21 @@

Source code for cyto_dl.nn.vits.blocks.patchify

< padding=context_pixels, ) self.img2token = Rearrange("b c y x -> (y x) b c") - self.patch2img = Rearrange( - "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", - n_patch_y=n_patches[0], - n_patch_x=n_patches[1], + self.patch2img = torch.nn.Sequential( + *[ + Rearrange( + "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", + n_patch_y=n_patches[0], + n_patch_x=n_patches[1], + ), + Reduce( + "b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + reduction="repeat", + patch_size_y=patch_size[0], + patch_size_x=patch_size[1], + ), + ] ) - self.task_embedding = torch.nn.ParameterDict( {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} ) @@ -514,27 +534,20 @@

Source code for cyto_dl.nn.vits.blocks.patchify

<
[docs] def get_mask(self, img, n_visible_patches, num_patches): B = img.shape[0] - indexes = [random_indexes(num_patches) for _ in range(B)] + indexes = [random_indexes(num_patches, img.device) for _ in range(B)] # forward indexes : index in image -> shuffledpatch - forward_indexes = torch.as_tensor( - np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long - ) + forward_indexes = torch.stack([i[0] for i in indexes], axis=-1) + # backward indexes : shuffled patch -> index in image - backward_indexes = torch.as_tensor( - np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long - ) + backward_indexes = torch.stack([i[1] for i in indexes], axis=-1) - mask = torch.zeros(num_patches, B, 1) + mask = torch.zeros(num_patches, B, 1, device=img.device, dtype=torch.bool) # visible patches are first - mask[:n_visible_patches] = 1 + mask[:n_visible_patches] = True mask = take_indexes(mask, backward_indexes) mask = self.patch2img(mask) - # one pixel per masked patch, interpolate to size of input image - mask = torch.nn.functional.interpolate( - mask, img.shape[-self.spatial_dims :], mode="nearest" - ) - return mask.to(img), forward_indexes, backward_indexes
+ return mask, forward_indexes, backward_indexes
[docs] def forward(self, img, mask_ratio, task=None): # generate mask @@ -556,9 +569,7 @@

Source code for cyto_dl.nn.vits.blocks.patchify

< tokens = tokens + self.task_embedding[task] # mask is used above to mask out patches, we need to invert it for loss calculation - mask = (1 - mask).bool() - - return tokens, mask, forward_indexes, backward_indexes
+ return tokens, ~mask, forward_indexes, backward_indexes
diff --git a/_modules/cyto_dl/nn/vits/cross_mae.html b/_modules/cyto_dl/nn/vits/cross_mae.html index a9c52b69..45a8c87a 100644 --- a/_modules/cyto_dl/nn/vits/cross_mae.html +++ b/_modules/cyto_dl/nn/vits/cross_mae.html @@ -505,18 +505,29 @@

Source code for cyto_dl.nn.vits.cross_mae

 
 
[docs] def forward(self, features, forward_indexes, backward_indexes): # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers - features = features.squeeze(0) + # features can be n t b c (if intermediate feature weighter used) or t b c if not + features = features[0] if len(features.shape) == 4 else features T, B, C = features.shape # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now features = self.projection_norm(self.projection(features)) # add cls token backward_indexes = torch.cat( - [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], + [ + torch.zeros( + 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long + ), + backward_indexes + 1, + ], dim=0, ) forward_indexes = torch.cat( - [torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1], + [ + torch.zeros( + 1, forward_indexes.shape[1], device=forward_indexes.device, dtype=torch.long + ), + forward_indexes + 1, + ], dim=0, ) # fill in masked regions @@ -553,13 +564,20 @@

Source code for cyto_dl.nn.vits.cross_mae

 
         # add back in visible/encoded tokens that we don't calculate loss on
         patches = torch.cat(
-            [torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches],
+            [
+                torch.zeros(
+                    (T - 1, B, patches.shape[-1]),
+                    requires_grad=False,
+                    device=patches.device,
+                    dtype=patches.dtype,
+                ),
+                patches,
+            ],
             dim=0,
         )
         patches = take_indexes(patches, backward_indexes[1:] - 1)
         # patches to image
         img = self.patch2img(patches)
-
         return img
diff --git a/_modules/cyto_dl/nn/vits/mae.html b/_modules/cyto_dl/nn/vits/mae.html index 017e885b..7c8b8272 100644 --- a/_modules/cyto_dl/nn/vits/mae.html +++ b/_modules/cyto_dl/nn/vits/mae.html @@ -583,7 +583,12 @@

Source code for cyto_dl.nn.vits.mae

         features = self.projection_norm(self.projection(features))
 
         backward_indexes = torch.cat(
-            [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
+            [
+                torch.zeros(
+                    1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long
+                ),
+                backward_indexes + 1,
+            ],
             dim=0,
         )
         # fill in masked regions
diff --git a/_modules/cyto_dl/nn/vits/utils.html b/_modules/cyto_dl/nn/vits/utils.html
index 01bb1259..e6e2540f 100644
--- a/_modules/cyto_dl/nn/vits/utils.html
+++ b/_modules/cyto_dl/nn/vits/utils.html
@@ -418,9 +418,7 @@ 

Source code for cyto_dl.nn.vits.utils

 
 
 
[docs]def take_indexes(sequences, indexes): - return torch.gather( - sequences, 0, repeat(indexes.to(sequences.device), "t b -> t b c", c=sequences.shape[-1]) - )
+ return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))
diff --git a/cyto_dl.nn.vits.blocks.patchify.html b/cyto_dl.nn.vits.blocks.patchify.html index 56ead891..5db2bfbf 100644 --- a/cyto_dl.nn.vits.blocks.patchify.html +++ b/cyto_dl.nn.vits.blocks.patchify.html @@ -424,9 +424,9 @@
Parameters: