Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] Fix interpolation of positional embeddings #378

Merged
merged 2 commits into from
Feb 22, 2024

Conversation

patricklabatut
Copy link
Contributor

Use size instead of scale factor to specify the output size of nn.interpolate(): this avoids any rounding issue leading to mismatching output size and consistently generate the same output size as with the previous kludge (from facebookresearch/dino#8).

Test:

Before (using scale_factor without interpolate offset):

Simulating the computation of the output size: each mismatching line reports [N_image_pixels] N_image_patches / sqrt(N_total_model_patches) * sqrt(N_total_model_patches != interpolate_output_size

size: 224
  patch size: 14
  patch size: 16
    [976] (61.0 / 14.0) x 14.0 != 60
    [1840] (115.0 / 14.0) x 14.0 != 114
    [1952] (122.0 / 14.0) x 14.0 != 121
size: 518
  patch size: 14
    [546] (39.0 / 37.0) x 37.0 != 38
    [602] (43.0 / 37.0) x 37.0 != 42
    [1092] (78.0 / 37.0) x 37.0 != 77
    [1204] (86.0 / 37.0) x 37.0 != 85
    [1610] (115.0 / 37.0) x 37.0 != 114
    [1722] (123.0 / 37.0) x 37.0 != 122

(i.e. 3 failing image sizes for 224/14 models, 6 failing images size for 518/14)

Forward (with DINOv2 ViT-S/14 w/ registers) on a 1x3xHxW tensor for H = W = one of the offending sizes:

Traceback (most recent call last):
  File "/Users/plabatut/github/facebookresearchdinov2/./test_interp_pos_embed.py", line 76, in <module>
    y = model(x)
  File "/Users/plabatut/opt/micromamba/envs/dinov2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/plabatut/opt/micromamba/envs/dinov2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/plabatut/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 321, in forward
    ret = self.forward_features(*args, **kwargs)
  File "/Users/plabatut/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 254, in forward_features
    x = self.prepare_tokens_with_masks(x, masks)
  File "/Users/plabatut/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 216, in prepare_tokens_with_masks
    x = x + self.interpolate_pos_encoding(x, w, h)
  File "/Users/plabatut/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 204, in interpolate_pos_encoding
    assert int(w0) == patch_pos_embed.shape[-2]
AssertionError

After (using size without interpolate offset):

Simulating the computation of the output size: no reported mismatch.

size: 224
  patch size: 14
  patch size: 16
size: 518
  patch size: 14

(i.e. no failing image size)

Also checked that the output size is always matching the one with the kludge

Forward (with DINOv2 ViT-S/14 w/ registers) on a 1x3xHxW tensor for H = W = one of the offending sizes for the before case: no error.

Forward (with DINOv2 ViT-S/14 w/ registers) on a 1x3xHxW tensor for H = W = some size: same result as with the kludge.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 22, 2024
@patricklabatut patricklabatut linked an issue Feb 22, 2024 that may be closed by this pull request
Copy link
Contributor

@MichaelRamamonjisoa MichaelRamamonjisoa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, you can merge it after linting, thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Positional encoding fails with rectangular input
3 participants