From 14517126935705adafc5e1cad972f7ac8b75ae95 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Sat, 25 Nov 2023 23:49:43 +0100 Subject: [PATCH 1/2] Fix interpolation of positional embeddings --- dinov2/models/vision_transformer.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index 4926108b3..7b937afff 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -188,21 +188,25 @@ def interpolate_pos_encoding(self, x, w, h): dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset - - sqrt_N = math.sqrt(N) - sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), - scale_factor=(sx, sy), + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), mode="bicubic", antialias=self.interpolate_antialias, + **kwargs, ) - - assert int(w0) == patch_pos_embed.shape[-2] - assert int(h0) == patch_pos_embed.shape[-1] + assert (w0, h0) == patch_pos_embed.shape[-2:] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) From f7425b9a30bb83a897cab0c6ecc2237b2520d23a Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Thu, 22 Feb 2024 18:12:15 +0100 Subject: [PATCH 2/2] Lint --- dinov2/models/vision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dinov2/models/vision_transformer.py b/dinov2/models/vision_transformer.py index 7b937afff..13b44ae3c 100644 --- a/dinov2/models/vision_transformer.py +++ b/dinov2/models/vision_transformer.py @@ -310,7 +310,7 @@ def get_intermediate_layers( if norm: outputs = [self.norm(out) for out in outputs] class_tokens = [out[:, 0] for out in outputs] - outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] if reshape: B, _, w, h = x.shape outputs = [