Skip to content

Commit

Permalink
small tweaks to try to retrieve each step of inference
Browse files Browse the repository at this point in the history
  • Loading branch information
GeorgePearse committed Sep 7, 2024
1 parent 84835fb commit 31c7379
Showing 1 changed file with 69 additions and 1 deletion.
70 changes: 69 additions & 1 deletion mmdet/models/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,9 @@ class SwinTransformer(BaseModule):
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
Default: 4.
depths (tuple[int]): Depths of each Swin Transformer stage.
Default: (2, 2, 6, 2).
Default: (2, 2, 6, 2).
This means that the model has 4 "stages. <-- George Comment. The number of stages is retrieved by running stages = len(depths)
num_heads (tuple[int]): Parallel attention heads of each Swin
Transformer stage. Default: (3, 6, 12, 24).
strides (tuple[int]): The patch merging or patch embedding stride of
Expand Down Expand Up @@ -571,6 +573,10 @@ def __init__(self,

super(SwinTransformer, self).__init__(init_cfg=init_cfg)

# George comment
# default value of depths --> (2, 2, 6, 2)
# therefore num_layers = 4, which then ends up being
# the number of stages
num_layers = len(depths)
self.out_indices = out_indices
self.use_abs_pos_embed = use_abs_pos_embed
Expand Down Expand Up @@ -603,6 +609,7 @@ def __init__(self,

self.stages = ModuleList()
in_channels = embed_dims

for i in range(num_layers):
if i < num_layers - 1:
downsample = PatchMerging(
Expand All @@ -614,6 +621,11 @@ def __init__(self,
else:
downsample = None

# George comment:
# one stage for every layer
# very annoying terminology switch
# don't see why it wouldn't just be num_stages
# instead of num_layers
stage = SwinBlockSequence(
embed_dims=in_channels,
num_heads=num_heads[i],
Expand Down Expand Up @@ -817,3 +829,59 @@ def correct_unfold_norm_order(x):
new_ckpt['backbone.' + new_k] = new_v

return new_ckpt


@MODELS.register_module()
class SwinTransformerFirst3Stages(SwinTransformer):

def forward(self, x):
x, hw_shape = self.patch_embed(x)

if self.use_abs_pos_embed:
x = x + self.absolute_pos_embed
x = self.drop_after_pos(x)

outs = []

# Switch this so that it only runs for the number of
# frozen stages.
for i, stage in enumerate(self.stages):
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)

if i >= self.frozen_stages:
break

if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(out)
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()
outs.append(out)

return outs



@MODELS.register_module()
class SwinTransformerLastStage(SwinTransformer):

def forward(self, x):
stage = self.stages[:-1]

# Not yet sure if this is right
i = len(self.stages)
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)

# Switch this so that it only runs for the number of
# frozen stages.
norm_layer = getattr(self, f'norm{i}')
out = norm_layer(out)
out = out.view(-1, *out_hw_shape,
self.num_features[i]).permute(0, 3, 1,
2).contiguous()

# Return as list to keep the output format consistent
# You'd then feed this through to the rpn and roi_heads
return [out]

0 comments on commit 31c7379

Please sign in to comment.