From fd2a520065e8a25485de3f5712f675f9a5e73921 Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Tue, 2 Jul 2024 22:02:15 +0000 Subject: [PATCH] =?UTF-8?q?Deploying=20to=20gh-pages=20from=20main=20@=205?= =?UTF-8?q?92177bab4df09337e98d1d0837082a034a4facf=20=F0=9F=9A=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- _modules/cyto_dl/nn/vits/blocks/patchify.html | 79 +++++++++++-------- _modules/cyto_dl/nn/vits/cross_mae.html | 28 +++++-- _modules/cyto_dl/nn/vits/mae.html | 7 +- _modules/cyto_dl/nn/vits/utils.html | 4 +- cyto_dl.nn.vits.blocks.patchify.html | 6 +- searchindex.js | 2 +- 6 files changed, 79 insertions(+), 47 deletions(-) diff --git a/_modules/cyto_dl/nn/vits/blocks/patchify.html b/_modules/cyto_dl/nn/vits/blocks/patchify.html index 23a2a820..88068ae4 100644 --- a/_modules/cyto_dl/nn/vits/blocks/patchify.html +++ b/_modules/cyto_dl/nn/vits/blocks/patchify.html @@ -418,16 +418,15 @@
[docs] def forward(self, features, forward_indexes, backward_indexes):
# HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers
- features = features.squeeze(0)
+ # features can be n t b c (if intermediate feature weighter used) or t b c if not
+ features = features[0] if len(features.shape) == 4 else features
T, B, C = features.shape
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now
features = self.projection_norm(self.projection(features))
# add cls token
backward_indexes = torch.cat(
- [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
+ [
+ torch.zeros(
+ 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long
+ ),
+ backward_indexes + 1,
+ ],
dim=0,
)
forward_indexes = torch.cat(
- [torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1],
+ [
+ torch.zeros(
+ 1, forward_indexes.shape[1], device=forward_indexes.device, dtype=torch.long
+ ),
+ forward_indexes + 1,
+ ],
dim=0,
)
# fill in masked regions
@@ -553,13 +564,20 @@ Source code for cyto_dl.nn.vits.cross_mae
# add back in visible/encoded tokens that we don't calculate loss on
patches = torch.cat(
- [torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches],
+ [
+ torch.zeros(
+ (T - 1, B, patches.shape[-1]),
+ requires_grad=False,
+ device=patches.device,
+ dtype=patches.dtype,
+ ),
+ patches,
+ ],
dim=0,
)
patches = take_indexes(patches, backward_indexes[1:] - 1)
# patches to image
img = self.patch2img(patches)
-
return img
features = self.projection_norm(self.projection(features))
backward_indexes = torch.cat(
- [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
+ [
+ torch.zeros(
+ 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long
+ ),
+ backward_indexes + 1,
+ ],
dim=0,
)
# fill in masked regions
diff --git a/_modules/cyto_dl/nn/vits/utils.html b/_modules/cyto_dl/nn/vits/utils.html
index 01bb1259..e6e2540f 100644
--- a/_modules/cyto_dl/nn/vits/utils.html
+++ b/_modules/cyto_dl/nn/vits/utils.html
@@ -418,9 +418,7 @@ Source code for cyto_dl.nn.vits.utils
[docs]def take_indexes(sequences, indexes):
- return torch.gather(
- sequences, 0, repeat(indexes.to(sequences.device), "t b -> t b c", c=sequences.shape[-1])
- )
+ return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))
patch_size (List[int]) – Size of each patch
patch_size (List[int]) – Size of each patch in pix (ZYX order for 3D, YX order for 2D)
emb_dim (int) – Dimension of encoder
n_patches (List[int]) – Number of patches in each spatial dimension
n_patches (List[int]) – Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D)
spatial_dims (int) – Number of spatial dimensions
context_pixels (List[int]) – Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.
input_channels (int) – Number of input channels