diff --git a/mask2former/modeling/transformer_decoder/position_encoding.py b/mask2former/modeling/transformer_decoder/position_encoding.py index f32532e0..336dda4b 100644 --- a/mask2former/modeling/transformer_decoder/position_encoding.py +++ b/mask2former/modeling/transformer_decoder/position_encoding.py @@ -38,7 +38,7 @@ def forward(self, x, mask=None): x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='trunc') / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t