From 082523b2b10e5fc96a4ec158342b8a61509d6a02 Mon Sep 17 00:00:00 2001 From: clemsgrs Date: Tue, 20 Aug 2024 13:34:05 +0200 Subject: [PATCH] improved config setup & readability of train.py --- dinov2/configs/train/vit_tiny_14.yaml | 4 +- dinov2/eval/knn.py | 2 +- dinov2/train/train.py | 59 +++++---------------------- dinov2/utils/config.py | 22 ++++++++-- 4 files changed, 33 insertions(+), 54 deletions(-) diff --git a/dinov2/configs/train/vit_tiny_14.yaml b/dinov2/configs/train/vit_tiny_14.yaml index 05a08dda4..23535120b 100644 --- a/dinov2/configs/train/vit_tiny_14.yaml +++ b/dinov2/configs/train/vit_tiny_14.yaml @@ -3,7 +3,7 @@ dino: ibot: separate_head: true train: - batch_size_per_gpu: 128 + batch_size_per_gpu: 32 dataset_path: PathologyFoundation:root=/root/data centering: sinkhorn_knopp num_workers: 8 @@ -34,7 +34,7 @@ optim: crops: local_crops_size: 98 wandb: - enable: true + enable: false project: 'dinov2' username: 'vlfm' exp_name: 'profiling' diff --git a/dinov2/eval/knn.py b/dinov2/eval/knn.py index a7d2723e9..5f9b18836 100644 --- a/dinov2/eval/knn.py +++ b/dinov2/eval/knn.py @@ -243,7 +243,7 @@ def eval_knn( persistent_workers=persistent_workers, verbose=verbose, ) - num_classes = query_labels.max() + 1 + num_classes = len(torch.unique(query_labels)) metric_collection = build_metric(num_classes=num_classes, average_type=accuracy_averaging) device = torch.cuda.current_device() diff --git a/dinov2/train/train.py b/dinov2/train/train.py index d5c753eb4..8496e1ebf 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -1,8 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the Apache License, Version 2.0 -# found in the LICENSE file in the root directory of this source tree. - import argparse import logging import math @@ -11,7 +6,6 @@ import json import wandb import tqdm -import datetime from functools import partial from typing import Optional from pathlib import Path @@ -27,7 +21,7 @@ from dinov2.fsdp import FSDPCheckpointer from dinov2.logging import MetricLogger, SmoothedValue from dinov2.utils.config import setup, write_config -from dinov2.utils.utils import CosineScheduler, initialize_wandb, load_weights +from dinov2.utils.utils import CosineScheduler, load_weights from dinov2.models import build_model_from_cfg from dinov2.eval.knn import eval_knn_with_model from dinov2.eval.setup import get_autocast_dtype @@ -64,7 +58,7 @@ def get_args_parser(add_help: bool = True): parser.add_argument( "--output-dir", "--output_dir", - default="", + default="output", type=str, help="Output directory to save logs and checkpoints", ) @@ -161,7 +155,6 @@ def do_tune( query_dataset, test_dataset, output_dir, - gpu_id, verbose: bool = True, ): # in DINOv2, they have on SSLMetaArch class @@ -189,8 +182,8 @@ def do_tune( student = student.to(torch.device("cuda")) teacher = teacher.to(torch.device("cuda")) - # student = student.to(torch.device(f"cuda:{gpu_id}")) - # teacher = teacher.to(torch.device(f"cuda:{gpu_id}")) + # student = student.to(torch.device(f"cuda:{distributed.get_global_rank()}")) + # teacher = teacher.to(torch.device(f"cuda:{distributed.get_global_rank()}")) if verbose: tqdm.tqdm.write(f"Loading epoch {epoch} weights...") student_weights = model.student.state_dict() @@ -217,7 +210,7 @@ def do_tune( temperature=cfg.tune.knn.temperature, autocast_dtype=autocast_dtype, accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, - gpu_id=gpu_id, + gpu_id=distributed.get_global_rank(), gather_on_cpu=cfg.tune.knn.gather_on_cpu, batch_size=cfg.tune.knn.batch_size, num_workers=0, @@ -237,7 +230,7 @@ def do_tune( temperature=cfg.tune.knn.temperature, autocast_dtype=autocast_dtype, accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, - gpu_id=gpu_id, + gpu_id=distributed.get_global_rank(), gather_on_cpu=cfg.tune.knn.gather_on_cpu, batch_size=cfg.tune.knn.batch_size, num_workers=0, @@ -263,7 +256,7 @@ def do_tune( return results -def do_train(cfg, model, gpu_id, run_distributed, resume=False): +def do_train(cfg, model, resume=False): model.train() inputs_dtype = torch.half fp16_scaler = model.fp16_scaler # for mixed precision training @@ -396,7 +389,7 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): for data in metric_logger.log_every( data_loader, - gpu_id, + distributed.get_global_rank(), log_freq, header, max_iter, @@ -493,7 +486,6 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): query_dataset, test_dataset, results_save_dir, - gpu_id, verbose=False, ) @@ -502,7 +494,7 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): for name, value in metrics_dict.items(): update_log_dict(log_dict, f"tune/{model_name}.{name}", value, step="epoch") - early_stopper(epoch, tune_results, periodic_checkpointer, run_distributed, iteration) + early_stopper(epoch, tune_results, periodic_checkpointer, distributed.is_enabled(), iteration) if early_stopper.early_stop and cfg.tune.early_stopping.enable: stop = True @@ -523,7 +515,7 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): do_test(cfg, model, f"training_{iteration}") torch.cuda.synchronize() - periodic_checkpointer.step(iteration, run_distributed=run_distributed) + periodic_checkpointer.step(iteration, run_distributed=distributed.is_enabled()) iteration = iteration + 1 @@ -535,35 +527,6 @@ def do_train(cfg, model, gpu_id, run_distributed, resume=False): def main(args): cfg = setup(args) - - run_distributed = torch.cuda.device_count() > 1 - if run_distributed: - gpu_id = int(os.environ["LOCAL_RANK"]) - else: - gpu_id = -1 - - if distributed.is_main_process(): - print(f"torch.cuda.device_count(): {torch.cuda.device_count()}") - run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M") - # set up wandb - if cfg.wandb.enable: - key = os.environ.get("WANDB_API_KEY") - wandb_run = initialize_wandb(cfg, key=key) - wandb_run.define_metric("epoch", summary="max") - run_id = wandb_run.id - else: - run_id = "" - - if run_distributed: - obj = [run_id] - torch.distributed.broadcast_object_list(obj, 0, device=torch.device(f"cuda:{gpu_id}")) - run_id = obj[0] - - output_dir = Path(cfg.train.output_dir, run_id) - if distributed.is_main_process(): - output_dir.mkdir(exist_ok=True, parents=True) - cfg.train.output_dir = str(output_dir) - if distributed.is_main_process(): write_config(cfg, cfg.train.output_dir) @@ -580,7 +543,7 @@ def main(args): ) return do_test(cfg, model, f"manual_{iteration}") - do_train(cfg, model, gpu_id, run_distributed, resume=not args.no_resume) + do_train(cfg, model, resume=not args.no_resume) if __name__ == "__main__": diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py index 78fa03670..411ea1633 100644 --- a/dinov2/utils/config.py +++ b/dinov2/utils/config.py @@ -6,7 +6,9 @@ import math import logging import os +import datetime +from pathlib import Path from omegaconf import OmegaConf import dinov2.distributed as distributed @@ -47,18 +49,32 @@ def get_cfg_from_args(args): return cfg -def default_setup(args): +def default_setup(args, cfg): + run_id = datetime.datetime.now().strftime("%Y-%m-%d_%H_%M") + # set up wandb + if cfg.wandb.enable: + key = os.environ.get("WANDB_API_KEY") + wandb_run = utils.initialize_wandb(cfg, key=key) + wandb_run.define_metric("epoch", summary="max") + run_id = wandb_run.id + + output_dir = Path(cfg.train.output_dir, run_id) + if distributed.is_main_process(): + output_dir.mkdir(exist_ok=True, parents=True) + cfg.train.output_dir = str(output_dir) + distributed.enable(overwrite=True) seed = getattr(args, "seed", 0) rank = distributed.get_global_rank() global logger - setup_logging(output=args.output_dir, level=logging.INFO) + setup_logging(output=cfg.train.output_dir, level=logging.INFO) logger = logging.getLogger("dinov2") utils.fix_random_seeds(seed + rank) logger.info("git:\n {}\n".format(utils.get_sha())) logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + return cfg def setup(args): @@ -66,6 +82,6 @@ def setup(args): Create configs and perform basic setups. """ cfg = get_cfg_from_args(args) - default_setup(args) + cfg = default_setup(args, cfg) apply_scaling_rules_to_cfg(cfg) return cfg