Skip to content

Commit

Permalink
fixed bug in knn evaluation: probabilities didnt sum to 1
Browse files Browse the repository at this point in the history
  • Loading branch information
clemsgrs committed Mar 15, 2024
1 parent bf961ac commit 79e28ab
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 95 deletions.
179 changes: 93 additions & 86 deletions dinov2/eval/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,95 +3,52 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import os
import datetime
import argparse
from functools import partial
import json
import logging
import sys
from pathlib import Path
from typing import List, Optional
from typing import Optional

import torch
from torch.nn.functional import one_hot, softmax

import dinov2.distributed as distributed
from dinov2.data import SamplerType, make_data_loader
from dinov2.data import SamplerType, make_data_loader, make_dataset
from dinov2.eval.metrics import AccuracyAveraging, build_metric
from dinov2.eval.setup import get_args_parser as get_setup_args_parser
from dinov2.utils.utils import initialize_wandb
from dinov2.utils.config import setup, write_config
from dinov2.eval.setup import setup_and_build_model
from dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features
from dinov2.data.transforms import make_classification_eval_transform


logger = logging.getLogger("dinov2")


def get_args_parser(
description: Optional[str] = None,
parents: Optional[List[argparse.ArgumentParser]] = None,
add_help: bool = True,
):
parents = parents or []
setup_args_parser = get_setup_args_parser(parents=parents, add_help=False)
parents = [setup_args_parser]
parser = argparse.ArgumentParser(
description=description,
parents=parents,
add_help=add_help,
)
def get_args_parser(add_help: bool = True):
parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help)
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument(
"--query-dataset",
dest="query_dataset_str",
type=str,
help="Query dataset",
"opts",
help="""
Modify config options at the end of the command. For Yacs configs, use
space-separated "PATH.KEY VALUE" pairs.
For python-based LazyConfig, use "path.key=value".
""".strip(),
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument(
"--test-dataset",
dest="test_dataset_str",
"--output-dir",
"--output_dir",
default="./output",
type=str,
help="Test dataset",
)
parser.add_argument(
"--nb_knn",
nargs="+",
type=int,
help="Number of NN to use. 20 is usually working the best.",
)
parser.add_argument(
"--temperature",
type=float,
help="Temperature used in the voting coefficient",
)
parser.add_argument(
"--gather-on-cpu",
action="store_true",
help="Whether to gather the query features on cpu, slower"
"but useful to avoid OOM for large datasets (e.g. ImageNet22k).",
)
parser.add_argument(
"--batch-size",
type=int,
help="Batch size.",
)
parser.add_argument(
"--n-per-class-list",
nargs="+",
type=int,
help="Number to take per class",
)
parser.add_argument(
"--n-tries",
type=int,
help="Number of tries",
)
parser.set_defaults(
query_dataset_str="ImageNet:split=QUERY",
test_dataset_str="ImageNet:split=TEST",
nb_knn=[10, 20, 100, 200],
temperature=0.07,
batch_size=256,
n_per_class_list=[-1],
n_tries=1,
help="Output directory to save logs and checkpoints",
)

return parser


Expand Down Expand Up @@ -191,7 +148,8 @@ def __init__(self, keys):
def forward(self, features_dict, targets):
for k in self.keys:
features_dict = features_dict[k]
return {"preds": features_dict, "target": targets}
preds = features_dict / features_dict.sum(dim=-1).unsqueeze(-1)
return {"preds": preds, "target": targets}


def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, query_features, query_labels):
Expand Down Expand Up @@ -271,8 +229,9 @@ def eval_knn(
header="Query",
verbose=verbose,
)
# given model went through ModelWithNormalize, query_features are already normalized
if verbose:
logger.info(f"Query features created, shape {query_features.shape}.")
logger.info(f"Query features created, shape {tuple(query_features.shape)}.")

test_dataloader = make_data_loader(
dataset=test_dataset,
Expand Down Expand Up @@ -376,9 +335,9 @@ def eval_knn_with_model(
results_dict[f"{k} Accuracy"] = acc
results_dict[f"{k} AUC"] = auc
if model_name and verbose:
logger.info(f"{model_name.title()} | {k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.2f}")
logger.info(f"{model_name.title()} | {k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.5f}")
elif verbose:
logger.info(f"{k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.2f}")
logger.info(f"{k}-NN classifier result: Accuracy: {acc:.2f} | AUC: {auc:.5f}")

metrics_file_path = Path(output_dir, "results_eval_knn.json")
with open(metrics_file_path, "a") as f:
Expand All @@ -394,29 +353,77 @@ def eval_knn_with_model(


def main(args):
model, autocast_dtype = setup_and_build_model(args)
cfg = setup(args)

run_distributed = torch.cuda.device_count() > 1
if run_distributed:
gpu_id = int(os.environ["LOCAL_RANK"])
else:
gpu_id = -1

if distributed.is_main_process():
print(f"torch.cuda.device_count(): {torch.cuda.device_count()}")
run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M")
# set up wandb
if cfg.wandb.enable:
key = os.environ.get("WANDB_API_KEY")
wandb_run = initialize_wandb(cfg, key=key)
run_id = wandb_run.id
else:
run_id = ""

if run_distributed:
obj = [run_id]
torch.distributed.broadcast_object_list(obj, 0, device=torch.device(f"cuda:{gpu_id}"))
run_id = obj[0]

output_dir = Path(cfg.train.output_dir, run_id)
if distributed.is_main_process():
output_dir.mkdir(exist_ok=True, parents=True)
cfg.train.output_dir = str(output_dir)

if distributed.is_main_process():
write_config(cfg, cfg.train.output_dir)

model, autocast_dtype = setup_and_build_model(cfg)

transform = make_classification_eval_transform()
query_dataset_str = cfg.data.query_dataset
test_dataset_str = cfg.data.test_dataset
query_dataset = make_dataset(
dataset_str=query_dataset_str,
transform=transform,
)
test_dataset = make_dataset(
dataset_str=test_dataset_str,
transform=transform,
)

eval_knn_with_model(
model=model,
output_dir=args.output_dir,
query_dataset_str=args.query_dataset_str,
test_dataset_str=args.test_dataset_str,
nb_knn=args.nb_knn,
temperature=args.temperature,
output_dir=cfg.train.output_dir,
query_dataset=query_dataset,
test_dataset=test_dataset,
nb_knn=cfg.knn.nb_knn,
temperature=cfg.knn.temperature,
autocast_dtype=autocast_dtype,
accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY,
gpu_id=-1,
transform=None,
gather_on_cpu=args.gather_on_cpu,
batch_size=args.batch_size,
num_workers=5,
n_per_class_list=args.n_per_class_list,
n_tries=args.n_tries,
gather_on_cpu=cfg.speed.gather_on_cpu,
batch_size=cfg.data.batch_size,
num_workers=cfg.speed.num_workers,
n_per_class_list=cfg.knn.n_per_class_list,
n_tries=cfg.knn.n_tries,
verbose=True,
)

return 0


if __name__ == "__main__":
description = "DINOv2 k-NN evaluation"
args_parser = get_args_parser(description=description)
args = args_parser.parse_args()
sys.exit(main(args))
import warnings

warnings.filterwarnings("ignore", category=UserWarning, message="TypedStorage is deprecated")

args = get_args_parser(add_help=True).parse_args()
main(args)
10 changes: 4 additions & 6 deletions dinov2/eval/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import torch.backends.cudnn as cudnn

from dinov2.models import build_model_from_cfg
from dinov2.utils.config import setup
import dinov2.utils.utils as dinov2_utils


Expand Down Expand Up @@ -59,17 +58,16 @@ def get_autocast_dtype(config):
return torch.float


def build_model_for_eval(config, pretrained_weights):
def build_model_for_eval(config):
model, _ = build_model_from_cfg(config, only_teacher=True)
dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher")
dinov2_utils.load_pretrained_weights(model, config.student.pretrained_weights, "teacher")
model.eval()
model.cuda()
return model


def setup_and_build_model(args) -> Tuple[Any, torch.dtype]:
def setup_and_build_model(config) -> Tuple[Any, torch.dtype]:
cudnn.benchmark = True
config = setup(args)
model = build_model_for_eval(config, args.pretrained_weights)
model = build_model_for_eval(config)
autocast_dtype = get_autocast_dtype(config)
return model, autocast_dtype
3 changes: 1 addition & 2 deletions dinov2/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def evaluate(
header = "Test"

for samples, targets, *_ in metric_logger.log_every(data_loader, 10, device, header):
# given model went through ModelWithNormalize, outputs are already normalized
outputs = model(samples.to(device))
targets = targets.to(device)
one_hot_targets = one_hot(targets, num_classes=num_classes)
Expand Down Expand Up @@ -139,8 +140,6 @@ def extract_features_with_dataloader(
labels_shape = list(labels_rank.shape)
labels_shape[0] = sample_count
all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device)
if verbose:
logger.info(f"Storing features into tensor of shape {features.shape}")

# share indexes, features and labels between processes
index_all = all_gather_and_flatten(index).to(gather_device)
Expand Down
2 changes: 1 addition & 1 deletion dinov2/inference/extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def main(args):
filenames.append(fname)
feature_paths.append(feature_path)
if cfg.wandb.enable and not run_distributed:
wandb.log({"processed": i + 1})
wandb.log({"processed": i + imgs.shape[0]})

features_df = pd.DataFrame.from_dict(
{
Expand Down

0 comments on commit 79e28ab

Please sign in to comment.