Skip to content

Commit

Permalink
update with ritviks comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Morris committed Jul 2, 2024
1 parent e17c540 commit cd4c359
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
12 changes: 6 additions & 6 deletions cyto_dl/nn/vits/blocks/patchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def __init__(
Parameters
----------
patch_size: List[int]
Size of each patch
Size of each patch in pix (ZYX order for 3D, YX order for 2D)
emb_dim: int
Dimension of encoder
n_patches: List[int]
Number of patches in each spatial dimension
Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D)
spatial_dims: int
Number of spatial dimensions
context_pixels: List[int]
Expand Down Expand Up @@ -95,14 +95,14 @@ def __init__(
*[
Rearrange(
"(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x",
n_patch_y=n_patches[1],
n_patch_x=n_patches[2],
n_patch_y=n_patches[0],
n_patch_x=n_patches[1],
),
Reduce(
"b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
reduction="repeat",
patch_size_y=patch_size[1],
patch_size_x=patch_size[2],
patch_size_y=patch_size[0],
patch_size_x=patch_size[1],
),
]
)
Expand Down
3 changes: 2 additions & 1 deletion cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def init_weight(self):

def forward(self, features, forward_indexes, backward_indexes):
# HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers
features = features[0]
# features can be n t b c (if intermediate feature weighter used) or t b c if not
features = features[0] if len(features.shape) == 4 else features
T, B, C = features.shape
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now
features = self.projection_norm(self.projection(features))
Expand Down

0 comments on commit cd4c359

Please sign in to comment.