From f131ad8a6589abac30c5d03cdaca77c5f8d685d3 Mon Sep 17 00:00:00 2001 From: ngbountos Date: Sat, 1 Jul 2023 03:54:45 +0300 Subject: [PATCH] add basic code for MAE --- main.py | 9 +- self_supervised/mae/mae_model.py | 332 +++++++++++++++++ self_supervised/mae/mae_scheduler.py | 15 + self_supervised/mae/mae_utils.py | 340 ++++++++++++++++++ .../{train_ssl.py => train_contrastive.py} | 8 +- training/train_mae.py | 286 +++++++++++++++ 6 files changed, 984 insertions(+), 6 deletions(-) create mode 100644 self_supervised/mae/mae_model.py create mode 100644 self_supervised/mae/mae_scheduler.py create mode 100644 self_supervised/mae/mae_utils.py rename training/{train_ssl.py => train_contrastive.py} (98%) create mode 100644 training/train_mae.py diff --git a/main.py b/main.py index bc187c9..046cb25 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,11 @@ import pyjson5 as json from self_supervised.mocov2 import builder +from self_supervised.mae import mae_model from utilities.utils import prepare_configuration -from training import train_ssl, train_supervised +from training import train_contrastive, train_supervised import argparse import pprint +import os if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -39,7 +41,10 @@ config["moco_t"], config["mlp"], ) + elif config['method'] == "mae": + raise NotImplementedError(f'{config["method"]} is not supported.') + else: raise NotImplementedError(f'{config["method"]} is not supported.') - train_ssl.exec_model(model, config) + train_contrastive.exec_model(model, config) diff --git a/self_supervised/mae/mae_model.py b/self_supervised/mae/mae_model.py new file mode 100644 index 0000000..dd777e2 --- /dev/null +++ b/self_supervised/mae/mae_model.py @@ -0,0 +1,332 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- + +from functools import partial + +import torch +import torch.nn as nn + +from timm.models.vision_transformer import PatchEmbed, Block + +# -------------------------------------------------------- +# 2D sine-cosine position embedding +# References: +# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py +# MoCo v3: https://github.com/facebookresearch/moco-v3 +# -------------------------------------------------------- +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +# -------------------------------------------------------- +# Interpolate position embeddings for high-resolution +# References: +# DeiT: https://github.com/facebookresearch/deit +# -------------------------------------------------------- +def interpolate_pos_embed(model, checkpoint_model): + if 'pos_embed' in checkpoint_model: + pos_embed_checkpoint = checkpoint_model['pos_embed'] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model['pos_embed'] = new_pos_embed + + +class MaskedAutoencoderViT(nn.Module): + """ Masked Autoencoder with VisionTransformer backbone + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, + embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): + super().__init__() + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding + + self.blocks = nn.ModuleList([ + Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) + # -------------------------------------------------------------------------- + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) + + self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) + + self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding + + self.decoder_blocks = nn.ModuleList([ + Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, qk_scale=None, norm_layer=norm_layer) + for i in range(decoder_depth)]) + + self.decoder_norm = norm_layer(decoder_embed_dim) + self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch + # -------------------------------------------------------------------------- + + self.norm_pix_loss = norm_pix_loss + + self.initialize_weights() + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) + self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + torch.nn.init.normal_(self.cls_token, std=.02) + torch.nn.init.normal_(self.mask_token, std=.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1]**.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore + + def forward_encoder(self, x, mask_ratio): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + x, mask, ids_restore = self.random_masking(x, mask_ratio) + + # append cls token + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + # remove cls token + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.e-6)**.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches + return loss + + def forward(self, imgs, mask_ratio=0.75): + latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) + pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] + loss = self.forward_loss(imgs, pred, mask) + return loss, pred, mask + + +def mae_vit_base_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=768, depth=12, num_heads=12, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_large_patch16_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +def mae_vit_huge_patch14_dec512d8b(**kwargs): + model = MaskedAutoencoderViT( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, + decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, + mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + return model + + +# set recommended archs +mae_vit_base_patch16 = mae_vit_base_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_large_patch16 = mae_vit_large_patch16_dec512d8b # decoder: 512 dim, 8 blocks +mae_vit_huge_patch14 = mae_vit_huge_patch14_dec512d8b # decoder: 512 dim, 8 blocks \ No newline at end of file diff --git a/self_supervised/mae/mae_scheduler.py b/self_supervised/mae/mae_scheduler.py new file mode 100644 index 0000000..5953846 --- /dev/null +++ b/self_supervised/mae/mae_scheduler.py @@ -0,0 +1,15 @@ +import math + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate with half-cycle cosine after warmup""" + if epoch < args.warmup_epochs: + lr = args.lr * epoch / args.warmup_epochs + else: + lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ + (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) + for param_group in optimizer.param_groups: + if "lr_scale" in param_group: + param_group["lr"] = lr * param_group["lr_scale"] + else: + param_group["lr"] = lr + return lr \ No newline at end of file diff --git a/self_supervised/mae/mae_utils.py b/self_supervised/mae/mae_utils.py new file mode 100644 index 0000000..ad9a786 --- /dev/null +++ b/self_supervised/mae/mae_utils.py @@ -0,0 +1,340 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# References: +# DeiT: https://github.com/facebookresearch/deit +# BEiT: https://github.com/microsoft/unilm/tree/master/beit +# -------------------------------------------------------- + +import builtins +import datetime +import os +import time +from collections import defaultdict, deque +from pathlib import Path + +import torch +import torch.distributed as dist +from torch._six import inf + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if v is None: + continue + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + log_msg = [ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ] + if torch.cuda.is_available(): + log_msg.append('max mem: {memory:.0f}') + log_msg = self.delimiter.join(log_msg) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + builtin_print = builtins.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + force = force or (get_world_size() > 8) + if is_master or force: + now = datetime.datetime.now().time() + builtin_print('[{}] '.format(now), end='') # print with time stamp + builtin_print(*args, **kwargs) + + builtins.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if args.dist_on_itp: + args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) + args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) + args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) + os.environ['LOCAL_RANK'] = str(args.gpu) + os.environ['RANK'] = str(args.rank) + os.environ['WORLD_SIZE'] = str(args.world_size) + # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] + elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + setup_for_distributed(is_master=True) # hack + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}, gpu {}'.format( + args.rank, args.dist_url, args.gpu), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +class NativeScalerWithGradNormCount: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) + else: + self._scaler.unscale_(optimizer) + norm = get_grad_norm_(parameters) + self._scaler.step(optimizer) + self._scaler.update() + else: + norm = None + return norm + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) + + +def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = [p for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(parameters) == 0: + return torch.tensor(0.) + device = parameters[0].grad.device + if norm_type == inf: + total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) + else: + total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) + return total_norm + + +def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): + output_dir = Path(args.output_dir) + epoch_name = str(epoch) + if loss_scaler is not None: + checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] + for checkpoint_path in checkpoint_paths: + to_save = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'epoch': epoch, + 'scaler': loss_scaler.state_dict(), + 'args': args, + } + + save_on_master(to_save, checkpoint_path) + else: + client_state = {'epoch': epoch} + model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) + + +def load_model(args, model_without_ddp, optimizer, loss_scaler): + if args.resume: + if args.resume.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + args.resume, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(args.resume, map_location='cpu') + model_without_ddp.load_state_dict(checkpoint['model']) + print("Resume checkpoint %s" % args.resume) + if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): + optimizer.load_state_dict(checkpoint['optimizer']) + args.start_epoch = checkpoint['epoch'] + 1 + if 'scaler' in checkpoint: + loss_scaler.load_state_dict(checkpoint['scaler']) + print("With optim & sched!") + + +def all_reduce_mean(x): + world_size = get_world_size() + if world_size > 1: + x_reduce = torch.tensor(x).cuda() + dist.all_reduce(x_reduce) + x_reduce /= world_size + return x_reduce.item() + else: + return x \ No newline at end of file diff --git a/training/train_ssl.py b/training/train_contrastive.py similarity index 98% rename from training/train_ssl.py rename to training/train_contrastive.py index 9d407da..bc56cb7 100644 --- a/training/train_ssl.py +++ b/training/train_contrastive.py @@ -162,15 +162,15 @@ def print_pass(*args): json.dump({"wandb_id": id}, open(args["checkpoint_path"] + "/id.json", "w")) wandb.watch(model) - print("=> creating model '{}'".format(config["architecture"])) + print("=> creating model '{}'".format(args["architecture"])) print(model) model.cuda() optimizer = torch.optim.SGD( model.parameters(), - config["lr"], - momentum=config["momentum"], - weight_decay=config["weight_decay"], + args["lr"], + momentum=args["momentum"], + weight_decay=args["weight_decay"], ) if args["resume_checkpoint"]: diff --git a/training/train_mae.py b/training/train_mae.py new file mode 100644 index 0000000..bba49ed --- /dev/null +++ b/training/train_mae.py @@ -0,0 +1,286 @@ +import builtins +import copy +import json +import math +import os +import random +import shutil +import time +import warnings + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms.functional as F +import wandb +import webdataset as wds +import mae_utils as misc +import mae_scheduler as lr_sched +from torchvision.transforms import ( + Compose, + Grayscale, + Normalize, + RandomCrop, + Resize, + ToTensor, +) +import sys +import dataset.Dataset as Dataset +from self_supervised.mocov2 import builder +from utilities.utils import prepare_configuration, is_distributed, is_global_master, world_info_from_env, save_checkpoint, load_checkpoint, AverageMeter, ProgressMeter, adjust_learning_rate, accuracy + + +def train(train_loader, model, criterion, optimizer, epoch, args,loss_scaler): + print("Training epoch: ", epoch) + metric_logger = misc.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + # switch to train mode + model.train() + accum_iter = args['accum_iter'] + end = time.time() + for i, (img1, _) in enumerate(train_loader): + + # we use a per iteration (instead of per epoch) lr scheduler + if i % accum_iter == 0: + lr_sched.adjust_learning_rate(optimizer, i / len(train_loader) + epoch, args) + img1 = img1.cuda(non_blocking=True) + + with torch.cuda.amp.autocast(): + loss, _, _ = model(img1, mask_ratio=args['mask_ratio']) + + loss_value = loss.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + sys.exit(1) + + loss /= accum_iter + loss_scaler(loss, optimizer, parameters=model.parameters(), + update_grad=(i + 1) % accum_iter == 0) + if (i + 1) % accum_iter == 0: + optimizer.zero_grad() + + torch.cuda.synchronize() + + metric_logger.update(loss=loss_value) + + lr = optimizer.param_groups[0]["lr"] + metric_logger.update(lr=lr) + + loss_value_reduce = misc.all_reduce_mean(loss_value) + if (i + 1) % accum_iter == 0: + """ We use epoch_1000x as the x-axis in tensorboard. + This calibrates different curves when batch size changes. + """ + epoch_1000x = int((i / len(train_loader) + epoch) * 1000) + #log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x) + #log_writer.add_scalar('lr', lr, epoch_1000x) + wandb.log({'train_loss':loss_value_reduce,'lr':lr,'epoch':epoch_1000x}) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def exec_model(model, args): + if args["seed"] is not None: + random.seed(args["seed"]) + torch.manual_seed(args["seed"]) + cudnn.deterministic = True + warnings.warn( + "You have chosen to seed training. " + "This will turn on the CUDNN deterministic setting, " + "which can slow down your training considerably! " + "You may see unexpected behavior when restarting " + "from checkpoints." + ) + + if is_distributed(): + if "SLURM_PROCID" in os.environ: + args["local_rank"], args["rank"], args["world_size"] = world_info_from_env() + args["num_workers"] = int(os.environ["SLURM_CPUS_PER_TASK"]) + os.environ["LOCAL_RANK"] = str(args["local_rank"]) + os.environ["RANK"] = str(args["rank"]) + os.environ["WORLD_SIZE"] = str(args["world_size"]) + dist.init_process_group( + backend="nccl", + world_size=args["world_size"], + rank=args["rank"], + ) + os.environ["WANDB_MODE"] = "offline" + else: + args["local_rank"], _, _ = world_info_from_env() + dist.init_process_group(backend="nccl") + args["world_size"] = dist.get_world_size() + args["rank"] = dist.get_rank() + + torch.cuda.set_device(args["local_rank"]) + + # suppress printing if not master + if not is_global_master(args): + + def print_pass(*args): + pass + + builtins.print = print_pass + else: + raise NotImplementedError("Only DistributedDataParallel is supported.") + + if is_global_master(args): + # Initialize wandb + print("Initializing Wandb") + if args["resume_wandb"]: + id_json = json.load(open(args["checkpoint_path"] + "/id.json")) + args["wandb_id"] = id_json["wandb_id"] + wandb.init( + project=args["wandb_project"], + entity=args["wandb_entity"], + id=args["wandb_id"], + resume=args["resume_wandb"], + ) + else: + id = wandb.sdk.lib.runid.generate_id() + args["wandb_id"] = id + wandb.init( + project=args["wandb_project"], + entity=args["wandb_entity"], + config=args, + id=id, + resume="allow", + ) + json.dump({"wandb_id": id}, open(args["checkpoint_path"] + "/id.json", "w")) + wandb.watch(model) + + print("=> creating model '{}'".format(args["architecture"])) + print(model) + + model.cuda() + optimizer = torch.optim.SGD( + model.parameters(), + args["lr"], + momentum=args["momentum"], + weight_decay=args["weight_decay"], + ) + + if args["resume_checkpoint"]: + load_checkpoint(model, optimizer, args) + + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args["local_rank"]] + ) + + # define loss function (criterion) + criterion = nn.CrossEntropyLoss().cuda() + + cudnn.benchmark = True + + print("Initializing Dataset") + if args["webdataset"]: + from utilities.augmentations import get_augmentations + + def stack_and_augment(src): + for sample in src: + cc, diff = sample + img1 = torch.cat((cc, diff), 0).permute(1, 2, 0).numpy() + img2 = copy.deepcopy(img1) + tranform = Compose( + [ + ToTensor(), + Normalize( + mean=[0.5472, 0.7416], + std=[0.4142, 0.2995], + ), + ] + ) + yield ( + tranform(augmentations(image=img1)["image"]), + tranform(augmentations(image=img2)["image"]), + ) + + base_transform = Compose( + [ + Resize(size=args["resolution"]), + RandomCrop(size=args["resolution"]), + Grayscale(), + ] + ) + + augmentations = get_augmentations(args) + + global_batch_size = args["batch_size"] * args["world_size"] + num_batches = math.floor(args["wds_size"] / global_batch_size) + num_worker_batches = math.floor(num_batches / args["num_workers"]) + args["num_batches"] = num_worker_batches * args["num_workers"] + + train_dataset = ( + wds.DataPipeline( + wds.ResampledShards(args["data_path"]), + wds.tarfile_to_samples(), + wds.shuffle(1000), + wds.decode("torch"), + wds.to_tuple("cc.png", "diff.png"), + wds.map_tuple(base_transform, base_transform), + stack_and_augment, + wds.batched(args["batch_size"], partial=False), + ) + .with_epoch(num_worker_batches) + .with_length(args["wds_size"]) + ) + + train_loader = wds.WebLoader( + train_dataset, + batch_size=None, + shuffle=False, + num_workers=args["num_workers"], + persistent_workers=True, + pin_memory=True, + ) + + else: + train_dataset = Dataset.Dataset(args) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + + train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args["batch_size"], + shuffle=(train_sampler is None), + num_workers=args["num_workers"], + pin_memory=True, + sampler=train_sampler, + drop_last=True, + ) + args["num_batches"] = len(train_loader) + + print("Dataset initialized. Size: ", len(train_dataset)) + + for epoch in range(args["start_epoch"], args["epochs"]): + if not args["webdataset"]: + train_sampler.set_epoch(epoch) + + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + if is_global_master(args): + save_checkpoint( + { + "epoch": epoch + 1, + "arch": args["architecture"], + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + }, + is_best=False, + filename=args["checkpoint_path"] + + "/checkpoint_{:04d}.pth.tar".format(epoch), + ) + if is_global_master(args): + wandb.finish()