Skip to content

Commit

Permalink
refactor dinov2 tests, check against official implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurent2916 committed Apr 2, 2024
1 parent 4f94dfb commit 1a8ea91
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 118 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ venv/

# tests' model weights
tests/weights/
tests/repos/

# ruff
.ruff_cache
Expand Down
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ Then, download and convert all the necessary weights. Be aware that this will us
python scripts/prepare_test_weights.py
```

Some tests require cloning the original implementation of the model as they use `torch.hub.load`:

```bash
git clone [email protected]:facebookresearch/dinov2.git tests/repos/dinov2
```

Finally, run the tests:

```bash
Expand Down
10 changes: 0 additions & 10 deletions scripts/prepare_test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,6 @@ def download_dinov2():
]
download_files(urls, weights_folder)

# For testing (note: versions with registers are not available yet on HuggingFace)
for repo in ["dinov2-small", "dinov2-base", "dinov2-large"]:
base_folder = os.path.join(test_weights_dir, "facebook", repo)
urls = [
f"https://huggingface.co/facebook/{repo}/raw/main/config.json",
f"https://huggingface.co/facebook/{repo}/raw/main/preprocessor_config.json",
f"https://huggingface.co/facebook/{repo}/resolve/main/pytorch_model.bin",
]
download_files(urls, base_folder)


def download_lcm_base():
base_folder = os.path.join(test_weights_dir, "latent-consistency/lcm-sdxl")
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ def test_weights_path() -> Path:
return Path(from_env) if from_env else PARENT_PATH / "weights"


@fixture(scope="session")
def test_repos_path() -> Path:
from_env = os.getenv("REFINERS_TEST_REPOS_DIR")
return Path(from_env) if from_env else PARENT_PATH / "repos"


@fixture(scope="session")
def test_e2e_path() -> Path:
return PARENT_PATH / "e2e"
Expand Down
215 changes: 107 additions & 108 deletions tests/foundationals/dinov2/test_dinov2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from math import isclose
from pathlib import Path
from typing import Any
from warnings import warn

import pytest
import torch
from transformers import AutoModel # type: ignore
from transformers.models.dinov2.modeling_dinov2 import Dinov2Model # type: ignore

from refiners.fluxion.utils import load_from_safetensors, manual_seed, no_grad
from refiners.foundationals.dinov2 import (
from refiners.fluxion.utils import load_from_safetensors, load_tensors, manual_seed, no_grad
from refiners.foundationals.dinov2.dinov2 import (
DINOv2_base,
DINOv2_base_reg,
DINOv2_large,
Expand All @@ -18,130 +16,131 @@
)
from refiners.foundationals.dinov2.vit import ViT

FLAVORS = [
"dinov2_vits14",
"dinov2_vitb14",
"dinov2_vitl14",
"dinov2_vits14_reg4",
"dinov2_vitb14_reg4",
"dinov2_vitl14_reg4",
]
FLAVORS_MAP = {
"dinov2_vits14": DINOv2_small,
"dinov2_vits14_reg": DINOv2_small_reg,
"dinov2_vitb14": DINOv2_base,
"dinov2_vitb14_reg": DINOv2_base_reg,
"dinov2_vitl14": DINOv2_large,
"dinov2_vitl14_reg": DINOv2_large_reg,
# TODO: support giant flavors
# "dinov2_vitg14": DINOv2_giant,
# "dinov2_vitg14_reg": DINOv2_giant_reg,
}


@pytest.fixture(scope="module", params=[224, 518])
def resolution(request: pytest.FixtureRequest) -> int:
return request.param


@pytest.fixture(scope="module", params=FLAVORS)
@pytest.fixture(scope="module", params=FLAVORS_MAP.keys())
def flavor(request: pytest.FixtureRequest) -> str:
return request.param


# Temporary: see comments in `test_encoder_only`
@pytest.fixture(scope="module")
def seed_expected_norm(flavor: str) -> tuple[int, float]:
match flavor:
case "dinov2_vits14":
return (42, 1977.9213867)
case "dinov2_vitb14":
return (42, 1902.6384277)
case "dinov2_vitl14":
return (42, 1763.9187011)
case "dinov2_vits14_reg4":
return (42, 989.2380981)
case "dinov2_vitb14_reg4":
return (42, 974.4362182)
case "dinov2_vitl14_reg4":
return (42, 924.8797607)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
def dinov2_repo_path(test_repos_path: Path) -> Path:
repo = test_repos_path / "dinov2"
if not repo.exists():
warn(f"could not find DINOv2 GitHub repo at {repo}, skipping")
pytest.skip(allow_module_level=True)
return repo


@pytest.fixture(scope="module")
def our_backbone(test_weights_path: Path, flavor: str, test_device: torch.device) -> ViT:
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
def ref_model(
flavor: str,
dinov2_repo_path: Path,
test_weights_path: Path,
test_device: torch.device,
) -> torch.nn.Module:
kwargs: dict[str, Any] = {}
if "reg" not in flavor:
kwargs["interpolate_offset"] = 0.0

model = torch.hub.load( # type: ignore
model=flavor,
repo_or_dir=str(dinov2_repo_path),
source="local",
pretrained=False, # to turn off automatic weights download (see load_state_dict below)
**kwargs,
).to(device=test_device)

flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.pth"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
backbone = DINOv2_small(device=test_device)
case "dinov2_vitb14":
backbone = DINOv2_base(device=test_device)
case "dinov2_vitl14":
backbone = DINOv2_large(device=test_device)
case "dinov2_vits14_reg4":
backbone = DINOv2_small_reg(device=test_device)
case "dinov2_vitb14_reg4":
backbone = DINOv2_base_reg(device=test_device)
case "dinov2_vitl14_reg4":
backbone = DINOv2_large_reg(device=test_device)
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
tensors = load_from_safetensors(weights)
backbone.load_state_dict(tensors)
return backbone
model.load_state_dict(load_tensors(weights, device=test_device))


@pytest.fixture(scope="module")
def dinov2_weights_path(test_weights_path: Path, flavor: str):
# TODO: At the time of writing, those are not yet supported in transformers
# (https://github.com/huggingface/transformers/issues/27379). Alternatively, it is also possible to use
# facebookresearch/dinov2 directly (https://github.com/finegrain-ai/refiners/pull/132).
if flavor.endswith("_reg4"):
warn(f"DINOv2 with registers are not yet supported in Hugging Face, skipping")
pytest.skip(allow_module_level=True)
match flavor:
case "dinov2_vits14":
name = "dinov2-small"
case "dinov2_vitb14":
name = "dinov2-base"
case "dinov2_vitl14":
name = "dinov2-large"
case _:
raise ValueError(f"Unexpected DINOv2 flavor: {flavor}")
r = test_weights_path / "facebook" / name
if not r.is_dir():
warn(f"could not find DINOv2 weights at {r}, skipping")
pytest.skip(allow_module_level=True)
return r
assert isinstance(model, torch.nn.Module)
return model


@pytest.fixture(scope="module")
def ref_backbone(dinov2_weights_path: Path, test_device: torch.device) -> Dinov2Model:
backbone = AutoModel.from_pretrained(dinov2_weights_path) # type: ignore
assert isinstance(backbone, Dinov2Model)
return backbone.to(test_device) # type: ignore


def test_encoder(
ref_backbone: Dinov2Model,
our_backbone: ViT,
def our_model(
test_weights_path: Path,
flavor: str,
test_device: torch.device,
):
manual_seed(42)
) -> ViT:
model = FLAVORS_MAP[flavor](device=test_device)

# Position encoding interpolation [1] at runtime is not supported yet. So stick to the default image resolution
# e.g. using (224, 224) pixels as input would give a runtime error (sequence size mismatch)
# [1]: https://github.com/facebookresearch/dinov2/blob/2302b6b/dinov2/models/vision_transformer.py#L179
assert our_backbone.image_size == 518

x = torch.randn(1, 3, 518, 518).to(test_device)
flavor = flavor.replace("_reg", "_reg4")
weights = test_weights_path / f"{flavor}_pretrain.safetensors"
if not weights.is_file():
warn(f"could not find weights at {weights}, skipping")
pytest.skip(allow_module_level=True)

with no_grad():
ref_features = ref_backbone(x).last_hidden_state
our_features = our_backbone(x)
tensors = load_from_safetensors(weights)
model.load_state_dict(tensors)

assert (our_features - ref_features).abs().max() < 1e-3
return model


# Mainly for DINOv2 + registers coverage (this test can be removed once `test_encoder` supports all flavors)
def test_encoder_only(
our_backbone: ViT,
seed_expected_norm: tuple[int, float],
@no_grad()
def test_dinov2_facebook_weights(
ref_model: torch.nn.Module,
our_model: ViT,
resolution: int,
test_device: torch.device,
):
seed, expected_norm = seed_expected_norm
manual_seed(seed)

x = torch.randn(1, 3, 518, 518).to(test_device)

our_features = our_backbone(x)

assert isclose(our_features.norm().item(), expected_norm, rel_tol=1e-04)
) -> None:
manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
)

ref_output = ref_model(input_data, is_training=True)
ref_cls = ref_output["x_norm_clstoken"]
ref_reg = ref_output["x_norm_regtokens"]
ref_patch = ref_output["x_norm_patchtokens"]

our_output = our_model(input_data)
our_cls = our_output[:, 0]
our_reg = our_output[:, 1 : our_model.num_registers + 1]
our_patch = our_output[:, our_model.num_registers + 1 :]

assert torch.allclose(ref_cls, our_cls, atol=1e-4)
assert torch.allclose(ref_reg, our_reg, atol=1e-4)
assert torch.allclose(ref_patch, our_patch, atol=3e-3)


@no_grad()
def test_dinov2_float16(
resolution: int,
test_device: torch.device,
) -> None:
model = DINOv2_small(device=test_device, dtype=torch.float16)

manual_seed(2)
input_data = torch.randn(
(1, 3, resolution, resolution),
device=test_device,
dtype=torch.float16,
)

output = model(input_data)
sequence_length = (resolution // model.patch_size) ** 2 + 1
assert output.shape == (1, sequence_length, model.embedding_dim)
assert output.dtype == torch.float16

0 comments on commit 1a8ea91

Please sign in to comment.