Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(DINOv2) Implement positional encoding interpolation #343

Merged
merged 3 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ venv/

# tests' model weights
tests/weights/
tests/repos/

# ruff
.ruff_cache
Expand Down
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ Then, download and convert all the necessary weights. Be aware that this will us
python scripts/prepare_test_weights.py
```

Some tests require cloning the original implementation of the model as they use `torch.hub.load`:

```bash
git clone [email protected]:facebookresearch/dinov2.git tests/repos/dinov2
```

Finally, run the tests:

```bash
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ exclude_also = [

[tool.typos.default]
extend-words = { adaptee = "adaptee" }
extend-ignore-identifiers-re = ["NDArray*", "interm"]
extend-ignore-identifiers-re = ["NDArray*", "interm", "af000ded"]

[tool.pytest.ini_options]
filterwarnings = [
Expand Down
2 changes: 1 addition & 1 deletion scripts/conversion/convert_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def convert_dinov2_facebook(weights: dict[str, torch.Tensor]) -> None:

rename_keys: list[tuple[str, str]] = [
("cls_token", "Concatenate.ClassToken.Parameter.weight"),
("pos_embed", "PositionalEncoder.Parameter.weight"),
("pos_embed", "PositionalEncoder.PositionalEmbedding.Parameter.weight"),
("patch_embed.proj.weight", "Concatenate.PatchEncoder.Conv2d.weight"),
("patch_embed.proj.bias", "Concatenate.PatchEncoder.Conv2d.bias"),
("norm.weight", "LayerNorm.weight"),
Expand Down
22 changes: 6 additions & 16 deletions scripts/prepare_test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,6 @@ def download_dinov2():
]
download_files(urls, weights_folder)

# For testing (note: versions with registers are not available yet on HuggingFace)
for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]:
base_folder = os.path.join(test_weights_dir, "facebook", repo)
urls = [
f"https://huggingface.co/facebook/{repo}/raw/main/config.json",
f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json",
f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin",
]
download_files(urls, base_folder)


def download_lcm_base():
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
Expand Down Expand Up @@ -688,37 +678,37 @@ def convert_dinov2():
"convert_dinov2.py",
"tests/weights/dinov2_vits14_pretrain.pth",
"tests/weights/dinov2_vits14_pretrain.safetensors",
expected_hash="b7f9b294",
expected_hash="af000ded",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitb14_pretrain.pth",
"tests/weights/dinov2_vitb14_pretrain.safetensors",
expected_hash="d72c767b",
expected_hash="d6294087",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitl14_pretrain.pth",
"tests/weights/dinov2_vitl14_pretrain.safetensors",
expected_hash="71eb98d1",
expected_hash="ddd4819f",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vits14_reg4_pretrain.pth",
"tests/weights/dinov2_vits14_reg4_pretrain.safetensors",
expected_hash="89118b46",
expected_hash="080247c7",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitb14_reg4_pretrain.pth",
"tests/weights/dinov2_vitb14_reg4_pretrain.safetensors",
expected_hash="b0296f77",
expected_hash="5cd4d408",
)
run_conversion_script(
"convert_dinov2.py",
"tests/weights/dinov2_vitl14_reg4_pretrain.pth",
"tests/weights/dinov2_vitl14_reg4_pretrain.safetensors",
expected_hash="b3d877dc",
expected_hash="b1221702",
)


Expand Down
15 changes: 13 additions & 2 deletions src/refiners/fluxion/layers/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,26 @@ class Interpolate(Module):
This layer wraps [`torch.nn.functional.interpolate`][torch.nn.functional.interpolate].
"""

def __init__(self) -> None:
def __init__(
self,
mode: str = "nearest",
antialias: bool = False,
) -> None:
super().__init__()
self.mode = mode
self.antialias = antialias

def forward(
self,
x: Tensor,
shape: Size,
) -> Tensor:
return interpolate(x, shape)
return interpolate(
x=x,
size=shape,
mode=self.mode,
antialias=self.antialias,
)


class Downsample(Chain):
Expand Down
18 changes: 12 additions & 6 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,18 @@ def pad(x: Tensor, pad: Iterable[int], value: float = 0.0, mode: str = "constant
return _pad(input=x, pad=pad, value=value, mode=mode) # type: ignore


def interpolate(x: Tensor, factor: float | torch.Size, mode: str = "nearest") -> Tensor:
return (
_interpolate(x, scale_factor=factor, mode=mode)
if isinstance(factor, float | int)
else _interpolate(x, size=factor, mode=mode)
) # type: ignore
def interpolate(
x: Tensor,
size: torch.Size,
mode: str = "nearest",
antialias: bool = False,
) -> Tensor:
return _interpolate( # type: ignore
input=x,
size=size,
mode=mode,
antialias=antialias,
)


# Adapted from https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py
Expand Down
7 changes: 7 additions & 0 deletions src/refiners/foundationals/dinov2/dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class DINOv2_small_reg(ViT):
num_layers (int): 12
num_heads (int): 6
num_registers (int): 4
interpolate_antialias (bool): True
"""

def __init__(
Expand All @@ -166,6 +167,7 @@ def __init__(
num_layers=12,
num_heads=6,
num_registers=4,
interpolate_antialias=True,
device=device,
dtype=dtype,
)
Expand All @@ -185,6 +187,7 @@ class DINOv2_base_reg(ViT):
num_layers (int): 12
num_heads (int): 12
num_registers (int): 4
interpolate_antialias (bool): True
"""

def __init__(
Expand All @@ -205,6 +208,7 @@ def __init__(
num_layers=12,
num_heads=12,
num_registers=4,
interpolate_antialias=True,
device=device,
dtype=dtype,
)
Expand All @@ -224,6 +228,7 @@ class DINOv2_large_reg(ViT):
num_layers (int): 24
num_heads (int): 16
num_registers (int): 4
interpolate_antialias (bool): True
"""

def __init__(
Expand All @@ -244,6 +249,7 @@ def __init__(
num_layers=24,
num_heads=16,
num_registers=4,
interpolate_antialias=True,
device=device,
dtype=dtype,
)
Expand All @@ -263,6 +269,7 @@ def __init__(
# num_layers=40,
# num_heads=24,
# num_registers=4,
# interpolate_antialias=True,
# device=device,
# dtype=dtype,
# )
101 changes: 92 additions & 9 deletions src/refiners/foundationals/dinov2/vit.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from math import sqrt
from typing import cast

import torch
from torch import Tensor

import refiners.fluxion.layers as fl
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.activations import Activation
from refiners.fluxion.utils import interpolate


class ClassToken(fl.Chain):
Expand All @@ -27,18 +30,20 @@ def __init__(
)


class PositionalEncoder(fl.Residual):
"""Encode the position of each patch in the input."""
class PositionalEmbedding(fl.Chain):
"""Learnable positional embedding."""

def __init__(
self,
sequence_length: int,
embedding_dim: int,
patch_size: int,
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
self.num_patches = sequence_length
self.sequence_length = sequence_length
self.embedding_dim = embedding_dim
self.patch_size = patch_size

super().__init__(
fl.Parameter(
Expand All @@ -49,6 +54,55 @@ def __init__(
)


class InterpolateEmbedding(fl.Module):
"""Interpolate the positional embeddings to match the input shape."""

def __init__(
self,
mode: str,
antialias: bool,
patch_size: int,
) -> None:
super().__init__()
self.mode = mode
self.antialias = antialias
self.patch_size = patch_size

def forward(
self,
x: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
cls_embed = x[:, :1, :] # -> (1, 1, D)
patch_embed = x[:, 1:, :] # -> (1, N, D)

N = patch_embed.shape[1]
D = patch_embed.shape[2]
M = int(sqrt(N))
W = input.shape[2]
H = input.shape[3]
assert M * M == N, "The sequence length must be a square number."

patch_embed = patch_embed.reshape(1, M, M, D) # -> (1, M, M, D)
patch_embed = patch_embed.permute(0, 3, 1, 2) # -> (1, D, M, M)
patch_embed = interpolate(
x=patch_embed.to(dtype=torch.float32),
mode=self.mode,
antialias=self.antialias,
size=torch.Size(
(
W // self.patch_size,
H // self.patch_size,
)
),
).to(dtype=cls_embed.dtype) # -> (1, D, w, h)
patch_embed = patch_embed.permute(0, 2, 3, 1) # -> (1, w, h, D)
patch_embed = patch_embed.reshape(1, -1, D) # -> (1, w*h, D)

x = torch.cat((cls_embed, patch_embed), dim=1) # -> (1, w*h+1, D)
return x


class LayerScale(fl.WeightedModule):
"""Scale the input tensor by a learnable parameter."""

Expand Down Expand Up @@ -125,6 +179,7 @@ def __init__(
self.patch_size = patch_size

super().__init__(
fl.SetContext(context="dinov2_vit", key="input"), # save the original input
fl.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
Expand Down Expand Up @@ -201,6 +256,10 @@ class Transformer(fl.Chain):
"""Alias for a Chain of TransformerLayer."""


class PositionalEncoder(fl.Residual):
"""Alias for a Residual."""


class Registers(fl.Concatenate):
"""Insert register tokens between CLS token and patches."""

Expand Down Expand Up @@ -243,6 +302,8 @@ def __init__(
norm_eps: float = 1e-6,
mlp_ratio: int = 4,
num_registers: int = 0,
interpolate_antialias: bool = False,
interpolate_mode: str = "bicubic",
device: torch.device | str | None = None,
dtype: torch.dtype | None = None,
) -> None:
Expand All @@ -257,6 +318,8 @@ def __init__(
norm_eps: The epsilon value for normalization.
mlp_ratio: The ratio for the multi-layer perceptron (MLP).
num_registers: The number of registers.
interpolate_antialias: Whether to use antialiasing for interpolation.
interpolate_mode: The interpolation mode.
device: The PyTorch device to use.
dtype: The PyTorch data type to use.
"""
Expand Down Expand Up @@ -286,19 +349,32 @@ def __init__(
),
dim=1,
),
# TODO: support https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
PositionalEncoder(
sequence_length=num_patches**2 + 1,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
PositionalEmbedding(
sequence_length=num_patches**2 + 1,
embedding_dim=embedding_dim,
patch_size=patch_size,
device=device,
dtype=dtype,
),
fl.Chain(
fl.Parallel(
fl.Identity(),
fl.UseContext(context="dinov2_vit", key="input"),
),
InterpolateEmbedding(
mode=interpolate_mode,
antialias=interpolate_antialias,
patch_size=patch_size,
),
),
),
Transformer(
TransformerLayer(
embedding_dim=embedding_dim,
num_heads=num_heads,
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
norm_eps=norm_eps,
device=device,
dtype=dtype,
)
Expand All @@ -320,3 +396,10 @@ def __init__(
dtype=dtype,
)
self.insert_before_type(Transformer, registers)

def init_context(self) -> Contexts:
return {
"dinov2_vit": {
"input": None,
},
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def set_inpainting_conditions(

mask_tensor = torch.tensor(data=np.array(object=mask).astype(dtype=np.float32) / 255.0).to(device=self.device)
mask_tensor = (mask_tensor > 0.5).unsqueeze(dim=0).unsqueeze(dim=0).to(dtype=self.dtype)
self.mask_latents = interpolate(x=mask_tensor, factor=torch.Size(latents_size))
self.mask_latents = interpolate(x=mask_tensor, size=torch.Size(latents_size))

init_image_tensor = image_to_tensor(image=target_image, device=self.device, dtype=self.dtype) * 2 - 1
masked_init_image = init_image_tensor * (1 - mask_tensor)
Expand Down
Loading