diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index 13b44ae3c..d75912225 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -176,7 +176,7 @@ def init_weights(self): nn.init.normal_(self.register_tokens, std=1e-6) named_apply(init_weights_vit_timm, self) - def interpolate_pos_encoding(self, x, w, h): + def interpolate_pos_encoding(self, x, h, w): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.pos_embed.shape[1] - 1 @@ -196,28 +196,28 @@ def interpolate_pos_encoding(self, x, w, h): # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors sx = float(w0 + self.interpolate_offset) / M sy = float(h0 + self.interpolate_offset) / M - kwargs["scale_factor"] = (sx, sy) + kwargs["scale_factor"] = (sy, sx) else: # Simply specify an output size instead of a scale factor - kwargs["size"] = (w0, h0) + kwargs["size"] = (h0, w0) patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), mode="bicubic", antialias=self.interpolate_antialias, **kwargs, ) - assert (w0, h0) == patch_pos_embed.shape[-2:] + assert (h0, w0) == patch_pos_embed.shape[-2:] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def prepare_tokens_with_masks(self, x, masks=None): - B, nc, w, h = x.shape + B, nc, h, w = x.shape x = self.patch_embed(x) if masks is not None: x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = x + self.interpolate_pos_encoding(x, w, h) + x = x + self.interpolate_pos_encoding(x, h, w) if self.register_tokens is not None: x = torch.cat(