From cd4c3599d1d4fa5c8c50d446486723d08d8aeb14 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Tue, 2 Jul 2024 14:54:10 -0700 Subject: [PATCH] update with ritviks comments --- cyto_dl/nn/vits/blocks/patchify.py | 12 ++++++------ cyto_dl/nn/vits/cross_mae.py | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cyto_dl/nn/vits/blocks/patchify.py b/cyto_dl/nn/vits/blocks/patchify.py index 5ec262ef..ee5a2c56 100644 --- a/cyto_dl/nn/vits/blocks/patchify.py +++ b/cyto_dl/nn/vits/blocks/patchify.py @@ -32,11 +32,11 @@ def __init__( Parameters ---------- patch_size: List[int] - Size of each patch + Size of each patch in pix (ZYX order for 3D, YX order for 2D) emb_dim: int Dimension of encoder n_patches: List[int] - Number of patches in each spatial dimension + Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) spatial_dims: int Number of spatial dimensions context_pixels: List[int] @@ -95,14 +95,14 @@ def __init__( *[ Rearrange( "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", - n_patch_y=n_patches[1], - n_patch_x=n_patches[2], + n_patch_y=n_patches[0], + n_patch_x=n_patches[1], ), Reduce( "b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", reduction="repeat", - patch_size_y=patch_size[1], - patch_size_x=patch_size[2], + patch_size_y=patch_size[0], + patch_size_x=patch_size[1], ), ] ) diff --git a/cyto_dl/nn/vits/cross_mae.py b/cyto_dl/nn/vits/cross_mae.py index 00fecce7..e8bde3f4 100644 --- a/cyto_dl/nn/vits/cross_mae.py +++ b/cyto_dl/nn/vits/cross_mae.py @@ -90,7 +90,8 @@ def init_weight(self): def forward(self, features, forward_indexes, backward_indexes): # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers - features = features[0] + # features can be n t b c (if intermediate feature weighter used) or t b c if not + features = features[0] if len(features.shape) == 4 else features T, B, C = features.shape # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now features = self.projection_norm(self.projection(features))