diff --git a/modules/inpaint/lama/__init__.py b/modules/inpaint/lama/__init__.py index b521546..98149b4 100644 --- a/modules/inpaint/lama/__init__.py +++ b/modules/inpaint/lama/__init__.py @@ -1,6 +1,5 @@ # https://github.com/advimman/lama import os -import yaml import torch import torch.nn.functional as F @@ -8,7 +7,6 @@ import comfy.model_management as model_management from ...model_utils import download_model -from ...utils import ensure_package lama = None @@ -35,11 +33,6 @@ def pad_tensor_to_modulo(img, mod): def load_model(): global lama if lama is None: - ensure_package("omegaconf") - - from omegaconf import OmegaConf - from .saicinpainting.training.trainers import load_checkpoint - files = download_model( model_path=model_dir, model_url=model_url, @@ -47,13 +40,8 @@ def load_model(): download_name="big-lama.pt", ) - cfg = yaml.safe_load(open(config_path, "rt")) - cfg = OmegaConf.create(cfg) - cfg.training_model.predict_only = True - cfg.visualizer.kind = "noop" - - lama = load_checkpoint(cfg, files[0], strict=False, map_location="cpu") - lama.freeze() + lama = torch.jit.load(files[0], map_location="cpu") + lama.eval() return lama @@ -98,13 +86,12 @@ def lama_inpaint( msk = (msk > 0) * 1.0 msk = msk.unsqueeze(0).unsqueeze(0) - batch = {} - batch["image"] = pad_tensor_to_modulo(img, 8).to(device) - batch["mask"] = pad_tensor_to_modulo(msk, 8).to(device) + + src_image = pad_tensor_to_modulo(img, 8).to(device) + src_mask = pad_tensor_to_modulo(msk, 8).to(device) - res = model(batch) - res = batch["inpainted"][0].permute(1, 2, 0) - res = res.detach().cpu() + res = model(src_image, src_mask) + res = res[0].permute(1, 2, 0).detach().cpu() res = res[:orig_h, :orig_w] inpainted.append(res) diff --git a/modules/inpaint/lama/config.yaml b/modules/inpaint/lama/config.yaml deleted file mode 100644 index 55fd91b..0000000 --- a/modules/inpaint/lama/config.yaml +++ /dev/null @@ -1,157 +0,0 @@ -run_title: b18_ffc075_batch8x15 -training_model: - kind: default - visualize_each_iters: 1000 - concat_mask: true - store_discr_outputs_for_vis: true -losses: - l1: - weight_missing: 0 - weight_known: 10 - perceptual: - weight: 0 - adversarial: - kind: r1 - weight: 10 - gp_coef: 0.001 - mask_as_fake_target: true - allow_scale_mask: true - feature_matching: - weight: 100 - resnet_pl: - weight: 30 - weights_path: ${env:TORCH_HOME} - -optimizers: - generator: - kind: adam - lr: 0.001 - discriminator: - kind: adam - lr: 0.0001 -visualizer: - key_order: - - image - - predicted_image - - discr_output_fake - - discr_output_real - - inpainted - rescale_keys: - - discr_output_fake - - discr_output_real - kind: directory - outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples -location: - data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large - out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments - tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs -data: - batch_size: 15 - val_batch_size: 2 - num_workers: 3 - train: - indir: ${location.data_root_dir}/train - out_size: 256 - mask_gen_kwargs: - irregular_proba: 1 - irregular_kwargs: - max_angle: 4 - max_len: 200 - max_width: 100 - max_times: 5 - min_times: 1 - box_proba: 1 - box_kwargs: - margin: 10 - bbox_min_size: 30 - bbox_max_size: 150 - max_times: 3 - min_times: 1 - segm_proba: 0 - segm_kwargs: - confidence_threshold: 0.5 - max_object_area: 0.5 - min_mask_area: 0.07 - downsample_levels: 6 - num_variants_per_mask: 1 - rigidness_mode: 1 - max_foreground_coverage: 0.3 - max_foreground_intersection: 0.7 - max_mask_intersection: 0.1 - max_hidden_area: 0.1 - max_scale_change: 0.25 - horizontal_flip: true - max_vertical_shift: 0.2 - position_shuffle: true - transform_variant: distortions - dataloader_kwargs: - batch_size: ${data.batch_size} - shuffle: true - num_workers: ${data.num_workers} - val: - indir: ${location.data_root_dir}/val - img_suffix: .png - dataloader_kwargs: - batch_size: ${data.val_batch_size} - shuffle: false - num_workers: ${data.num_workers} - visual_test: - indir: ${location.data_root_dir}/korean_test - img_suffix: _input.png - pad_out_to_modulo: 32 - dataloader_kwargs: - batch_size: 1 - shuffle: false - num_workers: ${data.num_workers} -generator: - kind: ffc_resnet - input_nc: 4 - output_nc: 3 - ngf: 64 - n_downsampling: 3 - n_blocks: 18 - add_out_act: sigmoid - init_conv_kwargs: - ratio_gin: 0 - ratio_gout: 0 - enable_lfu: false - downsample_conv_kwargs: - ratio_gin: ${generator.init_conv_kwargs.ratio_gout} - ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin} - enable_lfu: false - resnet_conv_kwargs: - ratio_gin: 0.75 - ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin} - enable_lfu: false -discriminator: - kind: pix2pixhd_nlayer - input_nc: 3 - ndf: 64 - n_layers: 4 -evaluator: - kind: default - inpainted_key: inpainted - integral_kind: ssim_fid100_f1 -trainer: - kwargs: - gpus: -1 - accelerator: ddp - max_epochs: 200 - gradient_clip_val: 1 - log_gpu_memory: None - limit_train_batches: 25000 - val_check_interval: ${trainer.kwargs.limit_train_batches} - log_every_n_steps: 1000 - precision: 32 - terminate_on_nan: false - check_val_every_n_epoch: 1 - num_sanity_val_steps: 8 - limit_val_batches: 1000 - replace_sampler_ddp: false - checkpoint_kwargs: - verbose: true - save_top_k: 5 - save_last: true - period: 1 - monitor: val_ssim_fid100_f1_total_mean - mode: max diff --git a/modules/inpaint/lama/saicinpainting/__init__.py b/modules/inpaint/lama/saicinpainting/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/modules/inpaint/lama/saicinpainting/training/__init__.py b/modules/inpaint/lama/saicinpainting/training/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/modules/inpaint/lama/saicinpainting/training/data/__init__.py b/modules/inpaint/lama/saicinpainting/training/data/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/modules/inpaint/lama/saicinpainting/training/data/masks.py b/modules/inpaint/lama/saicinpainting/training/data/masks.py deleted file mode 100644 index 588efdf..0000000 --- a/modules/inpaint/lama/saicinpainting/training/data/masks.py +++ /dev/null @@ -1,367 +0,0 @@ -import math -import random -import hashlib -import logging -from enum import Enum - -import cv2 -import numpy as np - -# from ..evaluation.masks.mask import SegmentationMask -from ...utils import LinearRamp - -LOGGER = logging.getLogger(__name__) - - -class DrawMethod(Enum): - LINE = "line" - CIRCLE = "circle" - SQUARE = "square" - - -def make_random_irregular_mask( - shape, max_angle=4, max_len=60, max_width=20, min_times=0, max_times=10, draw_method=DrawMethod.LINE -): - draw_method = DrawMethod(draw_method) - - height, width = shape - mask = np.zeros((height, width), np.float32) - times = np.random.randint(min_times, max_times + 1) - for i in range(times): - start_x = np.random.randint(width) - start_y = np.random.randint(height) - for j in range(1 + np.random.randint(5)): - angle = 0.01 + np.random.randint(max_angle) - if i % 2 == 0: - angle = 2 * 3.1415926 - angle - length = 10 + np.random.randint(max_len) - brush_w = 5 + np.random.randint(max_width) - end_x = np.clip((start_x + length * np.sin(angle)).astype(np.int32), 0, width) - end_y = np.clip((start_y + length * np.cos(angle)).astype(np.int32), 0, height) - if draw_method == DrawMethod.LINE: - cv2.line(mask, (start_x, start_y), (end_x, end_y), 1.0, brush_w) - elif draw_method == DrawMethod.CIRCLE: - cv2.circle(mask, (start_x, start_y), radius=brush_w, color=1.0, thickness=-1) - elif draw_method == DrawMethod.SQUARE: - radius = brush_w // 2 - mask[start_y - radius : start_y + radius, start_x - radius : start_x + radius] = 1 - start_x, start_y = end_x, end_y - return mask[None, ...] - - -class RandomIrregularMaskGenerator: - def __init__( - self, - max_angle=4, - max_len=60, - max_width=20, - min_times=0, - max_times=10, - ramp_kwargs=None, - draw_method=DrawMethod.LINE, - ): - self.max_angle = max_angle - self.max_len = max_len - self.max_width = max_width - self.min_times = min_times - self.max_times = max_times - self.draw_method = draw_method - self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None - - def __call__(self, img, iter_i=None, raw_image=None): - coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1 - cur_max_len = int(max(1, self.max_len * coef)) - cur_max_width = int(max(1, self.max_width * coef)) - cur_max_times = int(self.min_times + 1 + (self.max_times - self.min_times) * coef) - return make_random_irregular_mask( - img.shape[1:], - max_angle=self.max_angle, - max_len=cur_max_len, - max_width=cur_max_width, - min_times=self.min_times, - max_times=cur_max_times, - draw_method=self.draw_method, - ) - - -def make_random_rectangle_mask(shape, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3): - height, width = shape - mask = np.zeros((height, width), np.float32) - bbox_max_size = min(bbox_max_size, height - margin * 2, width - margin * 2) - times = np.random.randint(min_times, max_times + 1) - for i in range(times): - box_width = np.random.randint(bbox_min_size, bbox_max_size) - box_height = np.random.randint(bbox_min_size, bbox_max_size) - start_x = np.random.randint(margin, width - margin - box_width + 1) - start_y = np.random.randint(margin, height - margin - box_height + 1) - mask[start_y : start_y + box_height, start_x : start_x + box_width] = 1 - return mask[None, ...] - - -class RandomRectangleMaskGenerator: - def __init__(self, margin=10, bbox_min_size=30, bbox_max_size=100, min_times=0, max_times=3, ramp_kwargs=None): - self.margin = margin - self.bbox_min_size = bbox_min_size - self.bbox_max_size = bbox_max_size - self.min_times = min_times - self.max_times = max_times - self.ramp = LinearRamp(**ramp_kwargs) if ramp_kwargs is not None else None - - def __call__(self, img, iter_i=None, raw_image=None): - coef = self.ramp(iter_i) if (self.ramp is not None) and (iter_i is not None) else 1 - cur_bbox_max_size = int(self.bbox_min_size + 1 + (self.bbox_max_size - self.bbox_min_size) * coef) - cur_max_times = int(self.min_times + (self.max_times - self.min_times) * coef) - return make_random_rectangle_mask( - img.shape[1:], - margin=self.margin, - bbox_min_size=self.bbox_min_size, - bbox_max_size=cur_bbox_max_size, - min_times=self.min_times, - max_times=cur_max_times, - ) - - -class RandomSegmentationMaskGenerator: - def __init__(self, **kwargs): - self.impl = None # will be instantiated in first call (effectively in subprocess) - self.kwargs = kwargs - - def __call__(self, img, iter_i=None, raw_image=None): - if self.impl is None: - self.impl = SegmentationMask(**self.kwargs) - - masks = self.impl.get_masks(np.transpose(img, (1, 2, 0))) - masks = [m for m in masks if len(np.unique(m)) > 1] - return np.random.choice(masks) - - -def make_random_superres_mask(shape, min_step=2, max_step=4, min_width=1, max_width=3): - height, width = shape - mask = np.zeros((height, width), np.float32) - step_x = np.random.randint(min_step, max_step + 1) - width_x = np.random.randint(min_width, min(step_x, max_width + 1)) - offset_x = np.random.randint(0, step_x) - - step_y = np.random.randint(min_step, max_step + 1) - width_y = np.random.randint(min_width, min(step_y, max_width + 1)) - offset_y = np.random.randint(0, step_y) - - for dy in range(width_y): - mask[offset_y + dy :: step_y] = 1 - for dx in range(width_x): - mask[:, offset_x + dx :: step_x] = 1 - return mask[None, ...] - - -class RandomSuperresMaskGenerator: - def __init__(self, **kwargs): - self.kwargs = kwargs - - def __call__(self, img, iter_i=None): - return make_random_superres_mask(img.shape[1:], **self.kwargs) - - -class DumbAreaMaskGenerator: - min_ratio = 0.1 - max_ratio = 0.35 - default_ratio = 0.225 - - def __init__(self, is_training): - # Parameters: - # is_training(bool): If true - random rectangular mask, if false - central square mask - self.is_training = is_training - - def _random_vector(self, dimension): - if self.is_training: - lower_limit = math.sqrt(self.min_ratio) - upper_limit = math.sqrt(self.max_ratio) - mask_side = round((random.random() * (upper_limit - lower_limit) + lower_limit) * dimension) - u = random.randint(0, dimension - mask_side - 1) - v = u + mask_side - else: - margin = (math.sqrt(self.default_ratio) / 2) * dimension - u = round(dimension / 2 - margin) - v = round(dimension / 2 + margin) - return u, v - - def __call__(self, img, iter_i=None, raw_image=None): - c, height, width = img.shape - mask = np.zeros((height, width), np.float32) - x1, x2 = self._random_vector(width) - y1, y2 = self._random_vector(height) - mask[x1:x2, y1:y2] = 1 - return mask[None, ...] - - -class OutpaintingMaskGenerator: - def __init__( - self, - min_padding_percent: float = 0.04, - max_padding_percent: int = 0.25, - left_padding_prob: float = 0.5, - top_padding_prob: float = 0.5, - right_padding_prob: float = 0.5, - bottom_padding_prob: float = 0.5, - is_fixed_randomness: bool = False, - ): - """ - is_fixed_randomness - get identical paddings for the same image if args are the same - """ - self.min_padding_percent = min_padding_percent - self.max_padding_percent = max_padding_percent - self.probs = [left_padding_prob, top_padding_prob, right_padding_prob, bottom_padding_prob] - self.is_fixed_randomness = is_fixed_randomness - - assert self.min_padding_percent <= self.max_padding_percent - assert self.max_padding_percent > 0 - assert ( - len([x for x in [self.min_padding_percent, self.max_padding_percent] if (x >= 0 and x <= 1)]) == 2 - ), f"Padding percentage should be in [0,1]" - assert sum(self.probs) > 0, f"At least one of the padding probs should be greater than 0 - {self.probs}" - assert ( - len([x for x in self.probs if (x >= 0) and (x <= 1)]) == 4 - ), f"At least one of padding probs is not in [0,1] - {self.probs}" - if len([x for x in self.probs if x > 0]) == 1: - LOGGER.warning( - f"Only one padding prob is greater than zero - {self.probs}. That means that the outpainting masks will be always on the same side" - ) - - def apply_padding(self, mask, coord): - mask[ - int(coord[0][0] * self.img_h) : int(coord[1][0] * self.img_h), - int(coord[0][1] * self.img_w) : int(coord[1][1] * self.img_w), - ] = 1 - return mask - - def get_padding(self, size): - n1 = int(self.min_padding_percent * size) - n2 = int(self.max_padding_percent * size) - return self.rnd.randint(n1, n2) / size - - @staticmethod - def _img2rs(img): - arr = np.ascontiguousarray(img.astype(np.uint8)) - str_hash = hashlib.sha1(arr).hexdigest() - res = hash(str_hash) % (2**32) - return res - - def __call__(self, img, iter_i=None, raw_image=None): - c, self.img_h, self.img_w = img.shape - mask = np.zeros((self.img_h, self.img_w), np.float32) - at_least_one_mask_applied = False - - if self.is_fixed_randomness: - assert raw_image is not None, f"Cant calculate hash on raw_image=None" - rs = self._img2rs(raw_image) - self.rnd = np.random.RandomState(rs) - else: - self.rnd = np.random - - coords = [ - [(0, 0), (1, self.get_padding(size=self.img_h))], - [(0, 0), (self.get_padding(size=self.img_w), 1)], - [(0, 1 - self.get_padding(size=self.img_h)), (1, 1)], - [(1 - self.get_padding(size=self.img_w), 0), (1, 1)], - ] - - for pp, coord in zip(self.probs, coords): - if self.rnd.random() < pp: - at_least_one_mask_applied = True - mask = self.apply_padding(mask=mask, coord=coord) - - if not at_least_one_mask_applied: - idx = self.rnd.choice(range(len(coords)), p=np.array(self.probs) / sum(self.probs)) - mask = self.apply_padding(mask=mask, coord=coords[idx]) - return mask[None, ...] - - -class MixedMaskGenerator: - def __init__( - self, - irregular_proba=1 / 3, - irregular_kwargs=None, - box_proba=1 / 3, - box_kwargs=None, - segm_proba=1 / 3, - segm_kwargs=None, - squares_proba=0, - squares_kwargs=None, - superres_proba=0, - superres_kwargs=None, - outpainting_proba=0, - outpainting_kwargs=None, - invert_proba=0, - ): - self.probas = [] - self.gens = [] - - if irregular_proba > 0: - self.probas.append(irregular_proba) - if irregular_kwargs is None: - irregular_kwargs = {} - else: - irregular_kwargs = dict(irregular_kwargs) - irregular_kwargs["draw_method"] = DrawMethod.LINE - self.gens.append(RandomIrregularMaskGenerator(**irregular_kwargs)) - - if box_proba > 0: - self.probas.append(box_proba) - if box_kwargs is None: - box_kwargs = {} - self.gens.append(RandomRectangleMaskGenerator(**box_kwargs)) - - if segm_proba > 0: - self.probas.append(segm_proba) - if segm_kwargs is None: - segm_kwargs = {} - self.gens.append(RandomSegmentationMaskGenerator(**segm_kwargs)) - - if squares_proba > 0: - self.probas.append(squares_proba) - if squares_kwargs is None: - squares_kwargs = {} - else: - squares_kwargs = dict(squares_kwargs) - squares_kwargs["draw_method"] = DrawMethod.SQUARE - self.gens.append(RandomIrregularMaskGenerator(**squares_kwargs)) - - if superres_proba > 0: - self.probas.append(superres_proba) - if superres_kwargs is None: - superres_kwargs = {} - self.gens.append(RandomSuperresMaskGenerator(**superres_kwargs)) - - if outpainting_proba > 0: - self.probas.append(outpainting_proba) - if outpainting_kwargs is None: - outpainting_kwargs = {} - self.gens.append(OutpaintingMaskGenerator(**outpainting_kwargs)) - - self.probas = np.array(self.probas, dtype="float32") - self.probas /= self.probas.sum() - self.invert_proba = invert_proba - - def __call__(self, img, iter_i=None, raw_image=None): - kind = np.random.choice(len(self.probas), p=self.probas) - gen = self.gens[kind] - result = gen(img, iter_i=iter_i, raw_image=raw_image) - if self.invert_proba > 0 and random.random() < self.invert_proba: - result = 1 - result - return result - - -def get_mask_generator(kind, kwargs): - if kind is None: - kind = "mixed" - if kwargs is None: - kwargs = {} - - if kind == "mixed": - cl = MixedMaskGenerator - elif kind == "outpainting": - cl = OutpaintingMaskGenerator - elif kind == "dumb": - cl = DumbAreaMaskGenerator - else: - raise NotImplementedError(f"No such generator kind = {kind}") - return cl(**kwargs) diff --git a/modules/inpaint/lama/saicinpainting/training/losses/__init__.py b/modules/inpaint/lama/saicinpainting/training/losses/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/modules/inpaint/lama/saicinpainting/training/losses/adversarial.py b/modules/inpaint/lama/saicinpainting/training/losses/adversarial.py deleted file mode 100644 index ccd3f17..0000000 --- a/modules/inpaint/lama/saicinpainting/training/losses/adversarial.py +++ /dev/null @@ -1,204 +0,0 @@ -from typing import Tuple, Dict, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class BaseAdversarialLoss: - def pre_generator_step( - self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module - ): - """ - Prepare for generator step - :param real_batch: Tensor, a batch of real samples - :param fake_batch: Tensor, a batch of samples produced by generator - :param generator: - :param discriminator: - :return: None - """ - - def pre_discriminator_step( - self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module - ): - """ - Prepare for discriminator step - :param real_batch: Tensor, a batch of real samples - :param fake_batch: Tensor, a batch of samples produced by generator - :param generator: - :param discriminator: - :return: None - """ - - def generator_loss( - self, - real_batch: torch.Tensor, - fake_batch: torch.Tensor, - discr_real_pred: torch.Tensor, - discr_fake_pred: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """ - Calculate generator loss - :param real_batch: Tensor, a batch of real samples - :param fake_batch: Tensor, a batch of samples produced by generator - :param discr_real_pred: Tensor, discriminator output for real_batch - :param discr_fake_pred: Tensor, discriminator output for fake_batch - :param mask: Tensor, actual mask, which was at input of generator when making fake_batch - :return: total generator loss along with some values that might be interesting to log - """ - raise NotImplemented() - - def discriminator_loss( - self, - real_batch: torch.Tensor, - fake_batch: torch.Tensor, - discr_real_pred: torch.Tensor, - discr_fake_pred: torch.Tensor, - mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """ - Calculate discriminator loss and call .backward() on it - :param real_batch: Tensor, a batch of real samples - :param fake_batch: Tensor, a batch of samples produced by generator - :param discr_real_pred: Tensor, discriminator output for real_batch - :param discr_fake_pred: Tensor, discriminator output for fake_batch - :param mask: Tensor, actual mask, which was at input of generator when making fake_batch - :return: total discriminator loss along with some values that might be interesting to log - """ - raise NotImplemented() - - def interpolate_mask(self, mask, shape): - assert mask is not None - assert self.allow_scale_mask or shape == mask.shape[-2:] - if shape != mask.shape[-2:] and self.allow_scale_mask: - if self.mask_scale_mode == "maxpool": - mask = F.adaptive_max_pool2d(mask, shape) - else: - mask = F.interpolate(mask, size=shape, mode=self.mask_scale_mode) - return mask - - -def make_r1_gp(discr_real_pred, real_batch): - if torch.is_grad_enabled(): - grad_real = torch.autograd.grad(outputs=discr_real_pred.sum(), inputs=real_batch, create_graph=True)[0] - grad_penalty = (grad_real.view(grad_real.shape[0], -1).norm(2, dim=1) ** 2).mean() - else: - grad_penalty = 0 - real_batch.requires_grad = False - - return grad_penalty - - -class NonSaturatingWithR1(BaseAdversarialLoss): - def __init__( - self, - gp_coef=5, - weight=1, - mask_as_fake_target=False, - allow_scale_mask=False, - mask_scale_mode="nearest", - extra_mask_weight_for_gen=0, - use_unmasked_for_gen=True, - use_unmasked_for_discr=True, - ): - self.gp_coef = gp_coef - self.weight = weight - # use for discr => use for gen; - # otherwise we teach only the discr to pay attention to very small difference - assert use_unmasked_for_gen or (not use_unmasked_for_discr) - # mask as target => use unmasked for discr: - # if we don't care about unmasked regions at all - # then it doesn't matter if the value of mask_as_fake_target is true or false - assert use_unmasked_for_discr or (not mask_as_fake_target) - self.use_unmasked_for_gen = use_unmasked_for_gen - self.use_unmasked_for_discr = use_unmasked_for_discr - self.mask_as_fake_target = mask_as_fake_target - self.allow_scale_mask = allow_scale_mask - self.mask_scale_mode = mask_scale_mode - self.extra_mask_weight_for_gen = extra_mask_weight_for_gen - - def generator_loss( - self, - real_batch: torch.Tensor, - fake_batch: torch.Tensor, - discr_real_pred: torch.Tensor, - discr_fake_pred: torch.Tensor, - mask=None, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - fake_loss = F.softplus(-discr_fake_pred) - if ( - self.mask_as_fake_target and self.extra_mask_weight_for_gen > 0 - ) or not self.use_unmasked_for_gen: # == if masked region should be treated differently - mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) - if not self.use_unmasked_for_gen: - fake_loss = fake_loss * mask - else: - pixel_weights = 1 + mask * self.extra_mask_weight_for_gen - fake_loss = fake_loss * pixel_weights - - return fake_loss.mean() * self.weight, dict() - - def pre_discriminator_step( - self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module - ): - real_batch.requires_grad = True - - def discriminator_loss( - self, - real_batch: torch.Tensor, - fake_batch: torch.Tensor, - discr_real_pred: torch.Tensor, - discr_fake_pred: torch.Tensor, - mask=None, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - real_loss = F.softplus(-discr_real_pred) - grad_penalty = make_r1_gp(discr_real_pred, real_batch) * self.gp_coef - fake_loss = F.softplus(discr_fake_pred) - - if not self.use_unmasked_for_discr or self.mask_as_fake_target: - # == if masked region should be treated differently - mask = self.interpolate_mask(mask, discr_fake_pred.shape[-2:]) - # use_unmasked_for_discr=False only makes sense for fakes; - # for reals there is no difference beetween two regions - fake_loss = fake_loss * mask - if self.mask_as_fake_target: - fake_loss = fake_loss + (1 - mask) * F.softplus(-discr_fake_pred) - - sum_discr_loss = real_loss + grad_penalty + fake_loss - metrics = dict( - discr_real_out=discr_real_pred.mean(), discr_fake_out=discr_fake_pred.mean(), discr_real_gp=grad_penalty - ) - return sum_discr_loss.mean(), metrics - - -class BCELoss(BaseAdversarialLoss): - def __init__(self, weight): - self.weight = weight - self.bce_loss = nn.BCEWithLogitsLoss() - - def generator_loss(self, discr_fake_pred: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - real_mask_gt = torch.zeros(discr_fake_pred.shape).to(discr_fake_pred.device) - fake_loss = self.bce_loss(discr_fake_pred, real_mask_gt) * self.weight - return fake_loss, dict() - - def pre_discriminator_step( - self, real_batch: torch.Tensor, fake_batch: torch.Tensor, generator: nn.Module, discriminator: nn.Module - ): - real_batch.requires_grad = True - - def discriminator_loss( - self, mask: torch.Tensor, discr_real_pred: torch.Tensor, discr_fake_pred: torch.Tensor - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - real_mask_gt = torch.zeros(discr_real_pred.shape).to(discr_real_pred.device) - sum_discr_loss = (self.bce_loss(discr_real_pred, real_mask_gt) + self.bce_loss(discr_fake_pred, mask)) / 2 - metrics = dict(discr_real_out=discr_real_pred.mean(), discr_fake_out=discr_fake_pred.mean(), discr_real_gp=0) - return sum_discr_loss, metrics - - -def make_discrim_loss(kind, **kwargs): - if kind == "r1": - return NonSaturatingWithR1(**kwargs) - elif kind == "bce": - return BCELoss(**kwargs) - raise ValueError(f"Unknown adversarial loss kind {kind}") diff --git a/modules/inpaint/lama/saicinpainting/training/losses/constants.py b/modules/inpaint/lama/saicinpainting/training/losses/constants.py deleted file mode 100644 index 7797a21..0000000 --- a/modules/inpaint/lama/saicinpainting/training/losses/constants.py +++ /dev/null @@ -1,154 +0,0 @@ -weights = { - "ade20k": [ - 6.34517766497462, - 9.328358208955224, - 11.389521640091116, - 16.10305958132045, - 20.833333333333332, - 22.22222222222222, - 25.125628140703515, - 43.29004329004329, - 50.5050505050505, - 54.6448087431694, - 55.24861878453038, - 60.24096385542168, - 62.5, - 66.2251655629139, - 84.74576271186442, - 90.90909090909092, - 91.74311926605505, - 96.15384615384616, - 96.15384615384616, - 97.08737864077669, - 102.04081632653062, - 135.13513513513513, - 149.2537313432836, - 153.84615384615384, - 163.93442622950818, - 166.66666666666666, - 188.67924528301887, - 192.30769230769232, - 217.3913043478261, - 227.27272727272725, - 227.27272727272725, - 227.27272727272725, - 303.03030303030306, - 322.5806451612903, - 333.3333333333333, - 370.3703703703703, - 384.61538461538464, - 416.6666666666667, - 416.6666666666667, - 434.7826086956522, - 434.7826086956522, - 454.5454545454545, - 454.5454545454545, - 500.0, - 526.3157894736842, - 526.3157894736842, - 555.5555555555555, - 555.5555555555555, - 555.5555555555555, - 555.5555555555555, - 555.5555555555555, - 555.5555555555555, - 555.5555555555555, - 588.2352941176471, - 588.2352941176471, - 588.2352941176471, - 588.2352941176471, - 588.2352941176471, - 666.6666666666666, - 666.6666666666666, - 666.6666666666666, - 666.6666666666666, - 714.2857142857143, - 714.2857142857143, - 714.2857142857143, - 714.2857142857143, - 714.2857142857143, - 769.2307692307693, - 769.2307692307693, - 769.2307692307693, - 833.3333333333334, - 833.3333333333334, - 833.3333333333334, - 833.3333333333334, - 909.090909090909, - 1000.0, - 1111.111111111111, - 1111.111111111111, - 1111.111111111111, - 1111.111111111111, - 1111.111111111111, - 1250.0, - 1250.0, - 1250.0, - 1250.0, - 1250.0, - 1428.5714285714287, - 1428.5714285714287, - 1428.5714285714287, - 1428.5714285714287, - 1428.5714285714287, - 1428.5714285714287, - 1428.5714285714287, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 1666.6666666666667, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2000.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 2500.0, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 3333.3333333333335, - 5000.0, - 5000.0, - 5000.0, - ] -} diff --git a/modules/inpaint/lama/saicinpainting/training/losses/distance_weighting.py b/modules/inpaint/lama/saicinpainting/training/losses/distance_weighting.py deleted file mode 100644 index bca6c5a..0000000 --- a/modules/inpaint/lama/saicinpainting/training/losses/distance_weighting.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from .perceptual import IMAGENET_STD, IMAGENET_MEAN - - -def dummy_distance_weighter(real_img, pred_img, mask): - return mask - - -def get_gauss_kernel(kernel_size, width_factor=1): - coords = torch.stack(torch.meshgrid(torch.arange(kernel_size), torch.arange(kernel_size)), dim=0).float() - diff = torch.exp(-((coords - kernel_size // 2) ** 2).sum(0) / kernel_size / width_factor) - diff /= diff.sum() - return diff - - -class BlurMask(nn.Module): - def __init__(self, kernel_size=5, width_factor=1): - super().__init__() - self.filter = nn.Conv2d(1, 1, kernel_size, padding=kernel_size // 2, padding_mode="replicate", bias=False) - self.filter.weight.data.copy_(get_gauss_kernel(kernel_size, width_factor=width_factor)) - - def forward(self, real_img, pred_img, mask): - with torch.no_grad(): - result = self.filter(mask) * mask - return result - - -class EmulatedEDTMask(nn.Module): - def __init__(self, dilate_kernel_size=5, blur_kernel_size=5, width_factor=1): - super().__init__() - self.dilate_filter = nn.Conv2d( - 1, 1, dilate_kernel_size, padding=dilate_kernel_size // 2, padding_mode="replicate", bias=False - ) - self.dilate_filter.weight.data.copy_( - torch.ones(1, 1, dilate_kernel_size, dilate_kernel_size, dtype=torch.float) - ) - self.blur_filter = nn.Conv2d( - 1, 1, blur_kernel_size, padding=blur_kernel_size // 2, padding_mode="replicate", bias=False - ) - self.blur_filter.weight.data.copy_(get_gauss_kernel(blur_kernel_size, width_factor=width_factor)) - - def forward(self, real_img, pred_img, mask): - with torch.no_grad(): - known_mask = 1 - mask - dilated_known_mask = (self.dilate_filter(known_mask) > 1).float() - result = self.blur_filter(1 - dilated_known_mask) * mask - return result - - -class PropagatePerceptualSim(nn.Module): - def __init__(self, level=2, max_iters=10, temperature=500, erode_mask_size=3): - super().__init__() - vgg = torchvision.models.vgg19(pretrained=True).features - vgg_avg_pooling = [] - - for weights in vgg.parameters(): - weights.requires_grad = False - - cur_level_i = 0 - for module in vgg.modules(): - if module.__class__.__name__ == "Sequential": - continue - elif module.__class__.__name__ == "MaxPool2d": - vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) - else: - vgg_avg_pooling.append(module) - if module.__class__.__name__ == "ReLU": - cur_level_i += 1 - if cur_level_i == level: - break - - self.features = nn.Sequential(*vgg_avg_pooling) - - self.max_iters = max_iters - self.temperature = temperature - self.do_erode = erode_mask_size > 0 - if self.do_erode: - self.erode_mask = nn.Conv2d(1, 1, erode_mask_size, padding=erode_mask_size // 2, bias=False) - self.erode_mask.weight.data.fill_(1) - - def forward(self, real_img, pred_img, mask): - with torch.no_grad(): - real_img = (real_img - IMAGENET_MEAN.to(real_img)) / IMAGENET_STD.to(real_img) - real_feats = self.features(real_img) - - vertical_sim = torch.exp( - -(real_feats[:, :, 1:] - real_feats[:, :, :-1]).pow(2).sum(1, keepdim=True) / self.temperature - ) - horizontal_sim = torch.exp( - -(real_feats[:, :, :, 1:] - real_feats[:, :, :, :-1]).pow(2).sum(1, keepdim=True) / self.temperature - ) - - mask_scaled = F.interpolate(mask, size=real_feats.shape[-2:], mode="bilinear", align_corners=False) - if self.do_erode: - mask_scaled = (self.erode_mask(mask_scaled) > 1).float() - - cur_knowness = 1 - mask_scaled - - for iter_i in range(self.max_iters): - new_top_knowness = F.pad(cur_knowness[:, :, :-1] * vertical_sim, (0, 0, 1, 0), mode="replicate") - new_bottom_knowness = F.pad(cur_knowness[:, :, 1:] * vertical_sim, (0, 0, 0, 1), mode="replicate") - - new_left_knowness = F.pad(cur_knowness[:, :, :, :-1] * horizontal_sim, (1, 0, 0, 0), mode="replicate") - new_right_knowness = F.pad(cur_knowness[:, :, :, 1:] * horizontal_sim, (0, 1, 0, 0), mode="replicate") - - new_knowness = ( - torch.stack([new_top_knowness, new_bottom_knowness, new_left_knowness, new_right_knowness], dim=0) - .max(0) - .values - ) - - cur_knowness = torch.max(cur_knowness, new_knowness) - - cur_knowness = F.interpolate(cur_knowness, size=mask.shape[-2:], mode="bilinear") - result = torch.min(mask, 1 - cur_knowness) - - return result - - -def make_mask_distance_weighter(kind="none", **kwargs): - if kind == "none": - return dummy_distance_weighter - if kind == "blur": - return BlurMask(**kwargs) - if kind == "edt": - return EmulatedEDTMask(**kwargs) - if kind == "pps": - return PropagatePerceptualSim(**kwargs) - raise ValueError(f"Unknown mask distance weighter kind {kind}") diff --git a/modules/inpaint/lama/saicinpainting/training/losses/feature_matching.py b/modules/inpaint/lama/saicinpainting/training/losses/feature_matching.py deleted file mode 100644 index c019895..0000000 --- a/modules/inpaint/lama/saicinpainting/training/losses/feature_matching.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import List - -import torch -import torch.nn.functional as F - - -def masked_l2_loss(pred, target, mask, weight_known, weight_missing): - per_pixel_l2 = F.mse_loss(pred, target, reduction='none') - pixel_weights = mask * weight_missing + (1 - mask) * weight_known - return (pixel_weights * per_pixel_l2).mean() - - -def masked_l1_loss(pred, target, mask, weight_known, weight_missing): - per_pixel_l1 = F.l1_loss(pred, target, reduction='none') - pixel_weights = mask * weight_missing + (1 - mask) * weight_known - return (pixel_weights * per_pixel_l1).mean() - - -def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): - if mask is None: - res = torch.stack([F.mse_loss(fake_feat, target_feat) - for fake_feat, target_feat in zip(fake_features, target_features)]).mean() - else: - res = 0 - norm = 0 - for fake_feat, target_feat in zip(fake_features, target_features): - cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) - error_weights = 1 - cur_mask - cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() - res = res + cur_val - norm += 1 - res = res / norm - return res diff --git a/modules/inpaint/lama/saicinpainting/training/losses/perceptual.py b/modules/inpaint/lama/saicinpainting/training/losses/perceptual.py deleted file mode 100644 index 235d320..0000000 --- a/modules/inpaint/lama/saicinpainting/training/losses/perceptual.py +++ /dev/null @@ -1,84 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision - -from ...utils import check_and_warn_input_range - - -IMAGENET_MEAN = torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None] -IMAGENET_STD = torch.FloatTensor([0.229, 0.224, 0.225])[None, :, None, None] - - -class PerceptualLoss(nn.Module): - def __init__(self, normalize_inputs=True): - super(PerceptualLoss, self).__init__() - - self.normalize_inputs = normalize_inputs - self.mean_ = IMAGENET_MEAN - self.std_ = IMAGENET_STD - - vgg = torchvision.models.vgg19(pretrained=True).features - vgg_avg_pooling = [] - - for weights in vgg.parameters(): - weights.requires_grad = False - - for module in vgg.modules(): - if module.__class__.__name__ == "Sequential": - continue - elif module.__class__.__name__ == "MaxPool2d": - vgg_avg_pooling.append(nn.AvgPool2d(kernel_size=2, stride=2, padding=0)) - else: - vgg_avg_pooling.append(module) - - self.vgg = nn.Sequential(*vgg_avg_pooling) - - def do_normalize_inputs(self, x): - return (x - self.mean_.to(x.device)) / self.std_.to(x.device) - - def partial_losses(self, input, target, mask=None): - check_and_warn_input_range(target, 0, 1, "PerceptualLoss target in partial_losses") - - # we expect input and target to be in [0, 1] range - losses = [] - - if self.normalize_inputs: - features_input = self.do_normalize_inputs(input) - features_target = self.do_normalize_inputs(target) - else: - features_input = input - features_target = target - - for layer in self.vgg[:30]: - features_input = layer(features_input) - features_target = layer(features_target) - - if layer.__class__.__name__ == "ReLU": - loss = F.mse_loss(features_input, features_target, reduction="none") - - if mask is not None: - cur_mask = F.interpolate( - mask, size=features_input.shape[-2:], mode="bilinear", align_corners=False - ) - loss = loss * (1 - cur_mask) - - loss = loss.mean(dim=tuple(range(1, len(loss.shape)))) - losses.append(loss) - - return losses - - def forward(self, input, target, mask=None): - losses = self.partial_losses(input, target, mask=mask) - return torch.stack(losses).sum(dim=0) - - def get_global_features(self, input): - check_and_warn_input_range(input, 0, 1, "PerceptualLoss input in get_global_features") - - if self.normalize_inputs: - features_input = self.do_normalize_inputs(input) - else: - features_input = input - - features_input = self.vgg(features_input) - return features_input diff --git a/modules/inpaint/lama/saicinpainting/training/losses/segmentation.py b/modules/inpaint/lama/saicinpainting/training/losses/segmentation.py deleted file mode 100644 index f9bb4c9..0000000 --- a/modules/inpaint/lama/saicinpainting/training/losses/segmentation.py +++ /dev/null @@ -1,43 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .constants import weights as constant_weights - - -class CrossEntropy2d(nn.Module): - def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs): - """ - weight (Tensor, optional): a manual rescaling weight given to each class. - If given, has to be a Tensor of size "nclasses" - """ - super(CrossEntropy2d, self).__init__() - self.reduction = reduction - self.ignore_label = ignore_label - self.weights = weights - if self.weights is not None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.weights = torch.FloatTensor(constant_weights[weights]).to(device) - - def forward(self, predict, target): - """ - Args: - predict:(n, c, h, w) - target:(n, 1, h, w) - """ - target = target.long() - assert not target.requires_grad - assert predict.dim() == 4, "{0}".format(predict.size()) - assert target.dim() == 4, "{0}".format(target.size()) - assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) - assert target.size(1) == 1, "{0}".format(target.size(1)) - assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2)) - assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3)) - target = target.squeeze(1) - n, c, h, w = predict.size() - target_mask = (target >= 0) * (target != self.ignore_label) - target = target[target_mask] - predict = predict.transpose(1, 2).transpose(2, 3).contiguous() - predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) - loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction) - return loss diff --git a/modules/inpaint/lama/saicinpainting/training/losses/style_loss.py b/modules/inpaint/lama/saicinpainting/training/losses/style_loss.py deleted file mode 100644 index 06d0db7..0000000 --- a/modules/inpaint/lama/saicinpainting/training/losses/style_loss.py +++ /dev/null @@ -1,150 +0,0 @@ -import torch -import torch.nn as nn -import torchvision.models as models - - -class PerceptualLoss(nn.Module): - r""" - Perceptual loss, VGG-based - https://arxiv.org/abs/1603.08155 - https://github.com/dxyang/StyleTransfer/blob/master/utils.py - """ - - def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): - super(PerceptualLoss, self).__init__() - self.add_module("vgg", VGG19()) - self.criterion = torch.nn.L1Loss() - self.weights = weights - - def __call__(self, x, y): - # Compute features - x_vgg, y_vgg = self.vgg(x), self.vgg(y) - - content_loss = 0.0 - content_loss += self.weights[0] * self.criterion(x_vgg["relu1_1"], y_vgg["relu1_1"]) - content_loss += self.weights[1] * self.criterion(x_vgg["relu2_1"], y_vgg["relu2_1"]) - content_loss += self.weights[2] * self.criterion(x_vgg["relu3_1"], y_vgg["relu3_1"]) - content_loss += self.weights[3] * self.criterion(x_vgg["relu4_1"], y_vgg["relu4_1"]) - content_loss += self.weights[4] * self.criterion(x_vgg["relu5_1"], y_vgg["relu5_1"]) - - return content_loss - - -class VGG19(torch.nn.Module): - def __init__(self): - super(VGG19, self).__init__() - features = models.vgg19(pretrained=True).features - self.relu1_1 = torch.nn.Sequential() - self.relu1_2 = torch.nn.Sequential() - - self.relu2_1 = torch.nn.Sequential() - self.relu2_2 = torch.nn.Sequential() - - self.relu3_1 = torch.nn.Sequential() - self.relu3_2 = torch.nn.Sequential() - self.relu3_3 = torch.nn.Sequential() - self.relu3_4 = torch.nn.Sequential() - - self.relu4_1 = torch.nn.Sequential() - self.relu4_2 = torch.nn.Sequential() - self.relu4_3 = torch.nn.Sequential() - self.relu4_4 = torch.nn.Sequential() - - self.relu5_1 = torch.nn.Sequential() - self.relu5_2 = torch.nn.Sequential() - self.relu5_3 = torch.nn.Sequential() - self.relu5_4 = torch.nn.Sequential() - - for x in range(2): - self.relu1_1.add_module(str(x), features[x]) - - for x in range(2, 4): - self.relu1_2.add_module(str(x), features[x]) - - for x in range(4, 7): - self.relu2_1.add_module(str(x), features[x]) - - for x in range(7, 9): - self.relu2_2.add_module(str(x), features[x]) - - for x in range(9, 12): - self.relu3_1.add_module(str(x), features[x]) - - for x in range(12, 14): - self.relu3_2.add_module(str(x), features[x]) - - for x in range(14, 16): - self.relu3_2.add_module(str(x), features[x]) - - for x in range(16, 18): - self.relu3_4.add_module(str(x), features[x]) - - for x in range(18, 21): - self.relu4_1.add_module(str(x), features[x]) - - for x in range(21, 23): - self.relu4_2.add_module(str(x), features[x]) - - for x in range(23, 25): - self.relu4_3.add_module(str(x), features[x]) - - for x in range(25, 27): - self.relu4_4.add_module(str(x), features[x]) - - for x in range(27, 30): - self.relu5_1.add_module(str(x), features[x]) - - for x in range(30, 32): - self.relu5_2.add_module(str(x), features[x]) - - for x in range(32, 34): - self.relu5_3.add_module(str(x), features[x]) - - for x in range(34, 36): - self.relu5_4.add_module(str(x), features[x]) - - # don't need the gradients, just want the features - for param in self.parameters(): - param.requires_grad = False - - def forward(self, x): - relu1_1 = self.relu1_1(x) - relu1_2 = self.relu1_2(relu1_1) - - relu2_1 = self.relu2_1(relu1_2) - relu2_2 = self.relu2_2(relu2_1) - - relu3_1 = self.relu3_1(relu2_2) - relu3_2 = self.relu3_2(relu3_1) - relu3_3 = self.relu3_3(relu3_2) - relu3_4 = self.relu3_4(relu3_3) - - relu4_1 = self.relu4_1(relu3_4) - relu4_2 = self.relu4_2(relu4_1) - relu4_3 = self.relu4_3(relu4_2) - relu4_4 = self.relu4_4(relu4_3) - - relu5_1 = self.relu5_1(relu4_4) - relu5_2 = self.relu5_2(relu5_1) - relu5_3 = self.relu5_3(relu5_2) - relu5_4 = self.relu5_4(relu5_3) - - out = { - "relu1_1": relu1_1, - "relu1_2": relu1_2, - "relu2_1": relu2_1, - "relu2_2": relu2_2, - "relu3_1": relu3_1, - "relu3_2": relu3_2, - "relu3_3": relu3_3, - "relu3_4": relu3_4, - "relu4_1": relu4_1, - "relu4_2": relu4_2, - "relu4_3": relu4_3, - "relu4_4": relu4_4, - "relu5_1": relu5_1, - "relu5_2": relu5_2, - "relu5_3": relu5_3, - "relu5_4": relu5_4, - } - return out diff --git a/modules/inpaint/lama/saicinpainting/training/modules/__init__.py b/modules/inpaint/lama/saicinpainting/training/modules/__init__.py deleted file mode 100644 index c919374..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/__init__.py +++ /dev/null @@ -1,36 +0,0 @@ -import logging - -from ..modules.ffc import FFCResNetGenerator -from ..modules.pix2pixhd import ( - GlobalGenerator, - MultiDilatedGlobalGenerator, - NLayerDiscriminator, - MultidilatedNLayerDiscriminator, -) - - -def make_generator(config, kind, **kwargs): - logging.info(f"Make generator {kind}") - - if kind == "pix2pixhd_multidilated": - return MultiDilatedGlobalGenerator(**kwargs) - - if kind == "pix2pixhd_global": - return GlobalGenerator(**kwargs) - - if kind == "ffc_resnet": - return FFCResNetGenerator(**kwargs) - - raise ValueError(f"Unknown generator kind {kind}") - - -def make_discriminator(kind, **kwargs): - logging.info(f"Make discriminator {kind}") - - if kind == "pix2pixhd_nlayer_multidilated": - return MultidilatedNLayerDiscriminator(**kwargs) - - if kind == "pix2pixhd_nlayer": - return NLayerDiscriminator(**kwargs) - - raise ValueError(f"Unknown discriminator kind {kind}") diff --git a/modules/inpaint/lama/saicinpainting/training/modules/base.py b/modules/inpaint/lama/saicinpainting/training/modules/base.py deleted file mode 100644 index 921fc1b..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/base.py +++ /dev/null @@ -1,96 +0,0 @@ -import abc -from typing import Tuple, List - -import torch -import torch.nn as nn - -from .depthwise_sep_conv import DepthWiseSeperableConv -from .multidilated_conv import MultidilatedConv - - -class BaseDiscriminator(nn.Module): - @abc.abstractmethod - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: - """ - Predict scores and get intermediate activations. Useful for feature matching loss - :return tuple (scores, list of intermediate activations) - """ - raise NotImplemented() - - -def get_conv_block_ctor(kind="default"): - if not isinstance(kind, str): - return kind - if kind == "default": - return nn.Conv2d - if kind == "depthwise": - return DepthWiseSeperableConv - if kind == "multidilated": - return MultidilatedConv - raise ValueError(f"Unknown convolutional block kind {kind}") - - -def get_norm_layer(kind="bn"): - if not isinstance(kind, str): - return kind - if kind == "bn": - return nn.BatchNorm2d - if kind == "in": - return nn.InstanceNorm2d - raise ValueError(f"Unknown norm block kind {kind}") - - -def get_activation(kind="tanh"): - if kind == "tanh": - return nn.Tanh() - if kind == "sigmoid": - return nn.Sigmoid() - if kind is False: - return nn.Identity() - raise ValueError(f"Unknown activation kind {kind}") - - -class SimpleMultiStepGenerator(nn.Module): - def __init__(self, steps: List[nn.Module]): - super().__init__() - self.steps = nn.ModuleList(steps) - - def forward(self, x): - cur_in = x - outs = [] - for step in self.steps: - cur_out = step(cur_in) - outs.append(cur_out) - cur_in = torch.cat((cur_in, cur_out), dim=1) - return torch.cat(outs[::-1], dim=1) - - -def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): - if kind == "convtranspose": - return [ - nn.ConvTranspose2d( - min(max_features, ngf * mult), - min(max_features, int(ngf * mult / 2)), - kernel_size=3, - stride=2, - padding=1, - output_padding=1, - ), - norm_layer(min(max_features, int(ngf * mult / 2))), - activation, - ] - elif kind == "bilinear": - return [ - nn.Upsample(scale_factor=2, mode="bilinear"), - DepthWiseSeperableConv( - min(max_features, ngf * mult), - min(max_features, int(ngf * mult / 2)), - kernel_size=3, - stride=1, - padding=1, - ), - norm_layer(min(max_features, int(ngf * mult / 2))), - activation, - ] - else: - raise Exception(f"Invalid deconv kind: {kind}") diff --git a/modules/inpaint/lama/saicinpainting/training/modules/depthwise_sep_conv.py b/modules/inpaint/lama/saicinpainting/training/modules/depthwise_sep_conv.py deleted file mode 100644 index ce1b13a..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/depthwise_sep_conv.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch -import torch.nn as nn - - -class DepthWiseSeperableConv(nn.Module): - def __init__(self, in_dim, out_dim, *args, **kwargs): - super().__init__() - if "groups" in kwargs: - # ignoring groups for Depthwise Sep Conv - del kwargs["groups"] - - self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) - self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) - - def forward(self, x): - out = self.depthwise(x) - out = self.pointwise(out) - return out diff --git a/modules/inpaint/lama/saicinpainting/training/modules/fake_fakes.py b/modules/inpaint/lama/saicinpainting/training/modules/fake_fakes.py deleted file mode 100644 index a5e2881..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/fake_fakes.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -from kornia.constants import SamplePadding -from kornia.augmentation import RandomAffine, CenterCrop - - -class FakeFakesGenerator: - def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2): - self.grad_aug = RandomAffine( - degrees=360, translate=0.2, padding_mode=SamplePadding.REFLECTION, keepdim=False, p=1 - ) - self.img_aug = RandomAffine( - degrees=img_aug_degree, - translate=img_aug_translate, - padding_mode=SamplePadding.REFLECTION, - keepdim=True, - p=1, - ) - self.aug_proba = aug_proba - - def __call__(self, input_images, masks): - blend_masks = self._fill_masks_with_gradient(masks) - blend_target = self._make_blend_target(input_images) - result = input_images * (1 - blend_masks) + blend_target * blend_masks - return result, blend_masks - - def _make_blend_target(self, input_images): - batch_size = input_images.shape[0] - permuted = input_images[torch.randperm(batch_size)] - augmented = self.img_aug(input_images) - is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float() - result = augmented * is_aug + permuted * (1 - is_aug) - return result - - def _fill_masks_with_gradient(self, masks): - batch_size, _, height, width = masks.shape - grad = ( - torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) - .view(1, 1, 1, -1) - .expand(batch_size, 1, height * 2, width * 2) - ) - grad = self.grad_aug(grad) - grad = CenterCrop((height, width))(grad) - grad *= masks - - grad_for_min = grad + (1 - masks) * 10 - grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None] - grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6 - grad.clamp_(min=0, max=1) - - return grad diff --git a/modules/inpaint/lama/saicinpainting/training/modules/ffc.py b/modules/inpaint/lama/saicinpainting/training/modules/ffc.py deleted file mode 100644 index b796e59..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/ffc.py +++ /dev/null @@ -1,589 +0,0 @@ -# Fast Fourier Convolution NeurIPS 2020 -# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py -# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .base import get_activation, BaseDiscriminator -from .spatial_transform import LearnableSpatialTransformWrapper -from .squeeze_excitation import SELayer - - -class FFCSE_block(nn.Module): - def __init__(self, channels, ratio_g): - super(FFCSE_block, self).__init__() - in_cg = int(channels * ratio_g) - in_cl = channels - in_cg - r = 16 - - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.conv1 = nn.Conv2d(channels, channels // r, kernel_size=1, bias=True) - self.relu1 = nn.ReLU(inplace=True) - self.conv_a2l = None if in_cl == 0 else nn.Conv2d(channels // r, in_cl, kernel_size=1, bias=True) - self.conv_a2g = None if in_cg == 0 else nn.Conv2d(channels // r, in_cg, kernel_size=1, bias=True) - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - x = x if type(x) is tuple else (x, 0) - id_l, id_g = x - - x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1) - x = self.avgpool(x) - x = self.relu1(self.conv1(x)) - - x_l = 0 if self.conv_a2l is None else id_l * self.sigmoid(self.conv_a2l(x)) - x_g = 0 if self.conv_a2g is None else id_g * self.sigmoid(self.conv_a2g(x)) - return x_l, x_g - - -class FourierUnit(nn.Module): - def __init__( - self, - in_channels, - out_channels, - groups=1, - spatial_scale_factor=None, - spatial_scale_mode="bilinear", - spectral_pos_encoding=False, - use_se=False, - se_kwargs=None, - ffc3d=False, - fft_norm="ortho", - ): - # bn_layer not used - super(FourierUnit, self).__init__() - self.groups = groups - - self.conv_layer = torch.nn.Conv2d( - in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), - out_channels=out_channels * 2, - kernel_size=1, - stride=1, - padding=0, - groups=self.groups, - bias=False, - ) - self.bn = torch.nn.BatchNorm2d(out_channels * 2) - self.relu = torch.nn.ReLU(inplace=True) - - # squeeze and excitation block - self.use_se = use_se - if use_se: - if se_kwargs is None: - se_kwargs = {} - self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) - - self.spatial_scale_factor = spatial_scale_factor - self.spatial_scale_mode = spatial_scale_mode - self.spectral_pos_encoding = spectral_pos_encoding - self.ffc3d = ffc3d - self.fft_norm = fft_norm - - def forward(self, x): - batch = x.shape[0] - - if self.spatial_scale_factor is not None: - orig_size = x.shape[-2:] - x = F.interpolate( - x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False - ) - - r_size = x.size() - # (batch, c, h, w/2+1, 2) - fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) - ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) - ffted = torch.stack((ffted.real, ffted.imag), dim=-1) - ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) - ffted = ffted.view( - ( - batch, - -1, - ) - + ffted.size()[3:] - ) - - if self.spectral_pos_encoding: - height, width = ffted.shape[-2:] - coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) - coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) - ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) - - if self.use_se: - ffted = self.se(ffted) - - ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) - ffted = self.relu(self.bn(ffted)) - - ffted = ( - ffted.view( - ( - batch, - -1, - 2, - ) - + ffted.size()[2:] - ) - .permute(0, 1, 3, 4, 2) - .contiguous() - ) # (batch,c, t, h, w/2+1, 2) - ffted = torch.complex(ffted[..., 0], ffted[..., 1]) - - ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] - output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) - - if self.spatial_scale_factor is not None: - output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) - - return output - - -class SpectralTransform(nn.Module): - def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): - # bn_layer not used - super(SpectralTransform, self).__init__() - self.enable_lfu = enable_lfu - if stride == 2: - self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) - else: - self.downsample = nn.Identity() - - self.stride = stride - self.conv1 = nn.Sequential( - nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False), - nn.BatchNorm2d(out_channels // 2), - nn.ReLU(inplace=True), - ) - self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups, **fu_kwargs) - if self.enable_lfu: - self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups) - self.conv2 = torch.nn.Conv2d(out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) - - def forward(self, x): - x = self.downsample(x) - x = self.conv1(x) - output = self.fu(x) - - if self.enable_lfu: - n, c, h, w = x.shape - split_no = 2 - split_s = h // split_no - xs = torch.cat(torch.split(x[:, : c // 4], split_s, dim=-2), dim=1).contiguous() - xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous() - xs = self.lfu(xs) - xs = xs.repeat(1, 1, split_no, split_no).contiguous() - else: - xs = 0 - - output = self.conv2(x + output + xs) - - return output - - -class FFC(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - ratio_gin, - ratio_gout, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=False, - enable_lfu=True, - padding_type="reflect", - gated=False, - **spectral_kwargs - ): - super(FFC, self).__init__() - - assert stride == 1 or stride == 2, "Stride should be 1 or 2." - self.stride = stride - - in_cg = int(in_channels * ratio_gin) - in_cl = in_channels - in_cg - out_cg = int(out_channels * ratio_gout) - out_cl = out_channels - out_cg - # groups_g = 1 if groups == 1 else int(groups * ratio_gout) - # groups_l = 1 if groups == 1 else groups - groups_g - - self.ratio_gin = ratio_gin - self.ratio_gout = ratio_gout - self.global_in_num = in_cg - - module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d - self.convl2l = module( - in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type - ) - module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d - self.convl2g = module( - in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type - ) - module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d - self.convg2l = module( - in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type - ) - module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform - self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) - - self.gated = gated - module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d - self.gate = module(in_channels, 2, 1) - - def forward(self, x): - x_l, x_g = x if type(x) is tuple else (x, 0) - out_xl, out_xg = 0, 0 - - if self.gated: - total_input_parts = [x_l] - if torch.is_tensor(x_g): - total_input_parts.append(x_g) - total_input = torch.cat(total_input_parts, dim=1) - - gates = torch.sigmoid(self.gate(total_input)) - g2l_gate, l2g_gate = gates.chunk(2, dim=1) - else: - g2l_gate, l2g_gate = 1, 1 - - if self.ratio_gout != 1: - out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate - if self.ratio_gout != 0: - out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) - - return out_xl, out_xg - - -class FFC_BN_ACT(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - ratio_gin, - ratio_gout, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=False, - norm_layer=nn.BatchNorm2d, - activation_layer=nn.Identity, - padding_type="reflect", - enable_lfu=True, - **kwargs - ): - super(FFC_BN_ACT, self).__init__() - self.ffc = FFC( - in_channels, - out_channels, - kernel_size, - ratio_gin, - ratio_gout, - stride, - padding, - dilation, - groups, - bias, - enable_lfu, - padding_type=padding_type, - **kwargs - ) - lnorm = nn.Identity if ratio_gout == 1 else norm_layer - gnorm = nn.Identity if ratio_gout == 0 else norm_layer - global_channels = int(out_channels * ratio_gout) - self.bn_l = lnorm(out_channels - global_channels) - self.bn_g = gnorm(global_channels) - - lact = nn.Identity if ratio_gout == 1 else activation_layer - gact = nn.Identity if ratio_gout == 0 else activation_layer - self.act_l = lact(inplace=True) - self.act_g = gact(inplace=True) - - def forward(self, x): - x_l, x_g = self.ffc(x) - x_l = self.act_l(self.bn_l(x_l)) - x_g = self.act_g(self.bn_g(x_g)) - return x_l, x_g - - -class FFCResnetBlock(nn.Module): - def __init__( - self, - dim, - padding_type, - norm_layer, - activation_layer=nn.ReLU, - dilation=1, - spatial_transform_kwargs=None, - inline=False, - **conv_kwargs - ): - super().__init__() - self.conv1 = FFC_BN_ACT( - dim, - dim, - kernel_size=3, - padding=dilation, - dilation=dilation, - norm_layer=norm_layer, - activation_layer=activation_layer, - padding_type=padding_type, - **conv_kwargs - ) - self.conv2 = FFC_BN_ACT( - dim, - dim, - kernel_size=3, - padding=dilation, - dilation=dilation, - norm_layer=norm_layer, - activation_layer=activation_layer, - padding_type=padding_type, - **conv_kwargs - ) - if spatial_transform_kwargs is not None: - self.conv1 = LearnableSpatialTransformWrapper(self.conv1, **spatial_transform_kwargs) - self.conv2 = LearnableSpatialTransformWrapper(self.conv2, **spatial_transform_kwargs) - self.inline = inline - - def forward(self, x): - if self.inline: - x_l, x_g = x[:, : -self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num :] - else: - x_l, x_g = x if type(x) is tuple else (x, 0) - - id_l, id_g = x_l, x_g - - x_l, x_g = self.conv1((x_l, x_g)) - x_l, x_g = self.conv2((x_l, x_g)) - - x_l, x_g = id_l + x_l, id_g + x_g - out = x_l, x_g - if self.inline: - out = torch.cat(out, dim=1) - return out - - -class ConcatTupleLayer(nn.Module): - def forward(self, x): - assert isinstance(x, tuple) - x_l, x_g = x - assert torch.is_tensor(x_l) or torch.is_tensor(x_g) - if not torch.is_tensor(x_g): - return x_l - return torch.cat(x, dim=1) - - -class FFCResNetGenerator(nn.Module): - def __init__( - self, - input_nc, - output_nc, - ngf=64, - n_downsampling=3, - n_blocks=9, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - activation_layer=nn.ReLU, - up_norm_layer=nn.BatchNorm2d, - up_activation=nn.ReLU(True), - init_conv_kwargs={}, - downsample_conv_kwargs={}, - resnet_conv_kwargs={}, - spatial_transform_layers=None, - spatial_transform_kwargs={}, - add_out_act=True, - max_features=1024, - out_ffc=False, - out_ffc_kwargs={}, - ): - assert n_blocks >= 0 - super().__init__() - - model = [ - nn.ReflectionPad2d(3), - FFC_BN_ACT( - input_nc, - ngf, - kernel_size=7, - padding=0, - norm_layer=norm_layer, - activation_layer=activation_layer, - **init_conv_kwargs - ), - ] - - ### downsample - for i in range(n_downsampling): - mult = 2**i - if i == n_downsampling - 1: - cur_conv_kwargs = dict(downsample_conv_kwargs) - cur_conv_kwargs["ratio_gout"] = resnet_conv_kwargs.get("ratio_gin", 0) - else: - cur_conv_kwargs = downsample_conv_kwargs - model += [ - FFC_BN_ACT( - min(max_features, ngf * mult), - min(max_features, ngf * mult * 2), - kernel_size=3, - stride=2, - padding=1, - norm_layer=norm_layer, - activation_layer=activation_layer, - **cur_conv_kwargs - ) - ] - - mult = 2**n_downsampling - feats_num_bottleneck = min(max_features, ngf * mult) - - ### resnet blocks - for i in range(n_blocks): - cur_resblock = FFCResnetBlock( - feats_num_bottleneck, - padding_type=padding_type, - activation_layer=activation_layer, - norm_layer=norm_layer, - **resnet_conv_kwargs - ) - if spatial_transform_layers is not None and i in spatial_transform_layers: - cur_resblock = LearnableSpatialTransformWrapper(cur_resblock, **spatial_transform_kwargs) - model += [cur_resblock] - - model += [ConcatTupleLayer()] - - ### upsample - for i in range(n_downsampling): - mult = 2 ** (n_downsampling - i) - model += [ - nn.ConvTranspose2d( - min(max_features, ngf * mult), - min(max_features, int(ngf * mult / 2)), - kernel_size=3, - stride=2, - padding=1, - output_padding=1, - ), - up_norm_layer(min(max_features, int(ngf * mult / 2))), - up_activation, - ] - - if out_ffc: - model += [ - FFCResnetBlock( - ngf, - padding_type=padding_type, - activation_layer=activation_layer, - norm_layer=norm_layer, - inline=True, - **out_ffc_kwargs - ) - ] - - model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] - if add_out_act: - model.append(get_activation("tanh" if add_out_act is True else add_out_act)) - self.model = nn.Sequential(*model) - - def forward(self, input): - return self.model(input) - - -class FFCNLayerDiscriminator(BaseDiscriminator): - def __init__( - self, - input_nc, - ndf=64, - n_layers=3, - norm_layer=nn.BatchNorm2d, - max_features=512, - init_conv_kwargs={}, - conv_kwargs={}, - ): - super().__init__() - self.n_layers = n_layers - - def _act_ctor(inplace=True): - return nn.LeakyReLU(negative_slope=0.2, inplace=inplace) - - kw = 3 - padw = int(np.ceil((kw - 1.0) / 2)) - sequence = [ - [ - FFC_BN_ACT( - input_nc, - ndf, - kernel_size=kw, - padding=padw, - norm_layer=norm_layer, - activation_layer=_act_ctor, - **init_conv_kwargs - ) - ] - ] - - nf = ndf - for n in range(1, n_layers): - nf_prev = nf - nf = min(nf * 2, max_features) - - cur_model = [ - FFC_BN_ACT( - nf_prev, - nf, - kernel_size=kw, - stride=2, - padding=padw, - norm_layer=norm_layer, - activation_layer=_act_ctor, - **conv_kwargs - ) - ] - sequence.append(cur_model) - - nf_prev = nf - nf = min(nf * 2, 512) - - cur_model = [ - FFC_BN_ACT( - nf_prev, - nf, - kernel_size=kw, - stride=1, - padding=padw, - norm_layer=norm_layer, - activation_layer=lambda *args, **kwargs: nn.LeakyReLU(*args, negative_slope=0.2, **kwargs), - **conv_kwargs - ), - ConcatTupleLayer(), - ] - sequence.append(cur_model) - - sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] - - for n in range(len(sequence)): - setattr(self, "model" + str(n), nn.Sequential(*sequence[n])) - - def get_all_activations(self, x): - res = [x] - for n in range(self.n_layers + 2): - model = getattr(self, "model" + str(n)) - res.append(model(res[-1])) - return res[1:] - - def forward(self, x): - act = self.get_all_activations(x) - feats = [] - for out in act[:-1]: - if isinstance(out, tuple): - if torch.is_tensor(out[1]): - out = torch.cat(out, dim=1) - else: - out = out[0] - feats.append(out) - return act[-1], feats diff --git a/modules/inpaint/lama/saicinpainting/training/modules/multidilated_conv.py b/modules/inpaint/lama/saicinpainting/training/modules/multidilated_conv.py deleted file mode 100644 index c0e91ae..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/multidilated_conv.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -import torch.nn as nn -import random - -from .depthwise_sep_conv import DepthWiseSeperableConv - - -class MultidilatedConv(nn.Module): - def __init__( - self, - in_dim, - out_dim, - kernel_size, - dilation_num=3, - comb_mode="sum", - equal_dim=True, - shared_weights=False, - padding=1, - min_dilation=1, - shuffle_in_channels=False, - use_depthwise=False, - **kwargs - ): - super().__init__() - convs = [] - self.equal_dim = equal_dim - assert comb_mode in ("cat_out", "sum", "cat_in", "cat_both"), comb_mode - if comb_mode in ("cat_out", "cat_both"): - self.cat_out = True - if equal_dim: - assert out_dim % dilation_num == 0 - out_dims = [out_dim // dilation_num] * dilation_num - self.index = sum( - [[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], [] - ) - else: - out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] - out_dims.append(out_dim - sum(out_dims)) - index = [] - starts = [0] + out_dims[:-1] - lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] - for i in range(out_dims[-1]): - for j in range(dilation_num): - index += list(range(starts[j], starts[j] + lengths[j])) - starts[j] += lengths[j] - self.index = index - assert len(index) == out_dim - self.out_dims = out_dims - else: - self.cat_out = False - self.out_dims = [out_dim] * dilation_num - - if comb_mode in ("cat_in", "cat_both"): - if equal_dim: - assert in_dim % dilation_num == 0 - in_dims = [in_dim // dilation_num] * dilation_num - else: - in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] - in_dims.append(in_dim - sum(in_dims)) - self.in_dims = in_dims - self.cat_in = True - else: - self.cat_in = False - self.in_dims = [in_dim] * dilation_num - - conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d - dilation = min_dilation - for i in range(dilation_num): - if isinstance(padding, int): - cur_padding = padding * dilation - else: - cur_padding = padding[i] - convs.append( - conv_type( - self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs - ) - ) - if i > 0 and shared_weights: - convs[-1].weight = convs[0].weight - convs[-1].bias = convs[0].bias - dilation *= 2 - self.convs = nn.ModuleList(convs) - - self.shuffle_in_channels = shuffle_in_channels - if self.shuffle_in_channels: - # shuffle list as shuffling of tensors is nondeterministic - in_channels_permute = list(range(in_dim)) - random.shuffle(in_channels_permute) - # save as buffer so it is saved and loaded with checkpoint - self.register_buffer("in_channels_permute", torch.tensor(in_channels_permute)) - - def forward(self, x): - if self.shuffle_in_channels: - x = x[:, self.in_channels_permute] - - outs = [] - if self.cat_in: - if self.equal_dim: - x = x.chunk(len(self.convs), dim=1) - else: - new_x = [] - start = 0 - for dim in self.in_dims: - new_x.append(x[:, start : start + dim]) - start += dim - x = new_x - for i, conv in enumerate(self.convs): - if self.cat_in: - input = x[i] - else: - input = x - outs.append(conv(input)) - if self.cat_out: - out = torch.cat(outs, dim=1)[:, self.index] - else: - out = sum(outs) - return out diff --git a/modules/inpaint/lama/saicinpainting/training/modules/multiscale.py b/modules/inpaint/lama/saicinpainting/training/modules/multiscale.py deleted file mode 100644 index a9dbdd3..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/multiscale.py +++ /dev/null @@ -1,338 +0,0 @@ -from typing import List, Tuple, Union, Optional - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .base import get_conv_block_ctor, get_activation -from .pix2pixhd import ResnetBlock - - -class ResNetHead(nn.Module): - def __init__( - self, - input_nc, - ngf=64, - n_downsampling=3, - n_blocks=9, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - conv_kind="default", - activation=nn.ReLU(True), - ): - assert n_blocks >= 0 - super(ResNetHead, self).__init__() - - conv_layer = get_conv_block_ctor(conv_kind) - - model = [ - nn.ReflectionPad2d(3), - conv_layer(input_nc, ngf, kernel_size=7, padding=0), - norm_layer(ngf), - activation, - ] - - ### downsample - for i in range(n_downsampling): - mult = 2**i - model += [ - conv_layer(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), - norm_layer(ngf * mult * 2), - activation, - ] - - mult = 2**n_downsampling - - ### resnet blocks - for i in range(n_blocks): - model += [ - ResnetBlock( - ngf * mult, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - conv_kind=conv_kind, - ) - ] - - self.model = nn.Sequential(*model) - - def forward(self, input): - return self.model(input) - - -class ResNetTail(nn.Module): - def __init__( - self, - output_nc, - ngf=64, - n_downsampling=3, - n_blocks=9, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - conv_kind="default", - activation=nn.ReLU(True), - up_norm_layer=nn.BatchNorm2d, - up_activation=nn.ReLU(True), - add_out_act=False, - out_extra_layers_n=0, - add_in_proj=None, - ): - assert n_blocks >= 0 - super(ResNetTail, self).__init__() - - mult = 2**n_downsampling - - model = [] - - if add_in_proj is not None: - model.append(nn.Conv2d(add_in_proj, ngf * mult, kernel_size=1)) - - ### resnet blocks - for i in range(n_blocks): - model += [ - ResnetBlock( - ngf * mult, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - conv_kind=conv_kind, - ) - ] - - ### upsample - for i in range(n_downsampling): - mult = 2 ** (n_downsampling - i) - model += [ - nn.ConvTranspose2d( - ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1 - ), - up_norm_layer(int(ngf * mult / 2)), - up_activation, - ] - self.model = nn.Sequential(*model) - - out_layers = [] - for _ in range(out_extra_layers_n): - out_layers += [nn.Conv2d(ngf, ngf, kernel_size=1, padding=0), up_norm_layer(ngf), up_activation] - out_layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] - - if add_out_act: - out_layers.append(get_activation("tanh" if add_out_act is True else add_out_act)) - - self.out_proj = nn.Sequential(*out_layers) - - def forward(self, input, return_last_act=False): - features = self.model(input) - out = self.out_proj(features) - if return_last_act: - return out, features - else: - return out - - -class MultiscaleResNet(nn.Module): - def __init__( - self, - input_nc, - output_nc, - ngf=64, - n_downsampling=2, - n_blocks_head=2, - n_blocks_tail=6, - n_scales=3, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - conv_kind="default", - activation=nn.ReLU(True), - up_norm_layer=nn.BatchNorm2d, - up_activation=nn.ReLU(True), - add_out_act=False, - out_extra_layers_n=0, - out_cumulative=False, - return_only_hr=False, - ): - super().__init__() - - self.heads = nn.ModuleList( - [ - ResNetHead( - input_nc, - ngf=ngf, - n_downsampling=n_downsampling, - n_blocks=n_blocks_head, - norm_layer=norm_layer, - padding_type=padding_type, - conv_kind=conv_kind, - activation=activation, - ) - for i in range(n_scales) - ] - ) - tail_in_feats = ngf * (2**n_downsampling) + ngf - self.tails = nn.ModuleList( - [ - ResNetTail( - output_nc, - ngf=ngf, - n_downsampling=n_downsampling, - n_blocks=n_blocks_tail, - norm_layer=norm_layer, - padding_type=padding_type, - conv_kind=conv_kind, - activation=activation, - up_norm_layer=up_norm_layer, - up_activation=up_activation, - add_out_act=add_out_act, - out_extra_layers_n=out_extra_layers_n, - add_in_proj=None if (i == n_scales - 1) else tail_in_feats, - ) - for i in range(n_scales) - ] - ) - - self.out_cumulative = out_cumulative - self.return_only_hr = return_only_hr - - @property - def num_scales(self): - return len(self.heads) - - def forward( - self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None - ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ - :param ms_inputs: List of inputs of different resolutions from HR to LR - :param smallest_scales_num: int or None, number of smallest scales to take at input - :return: Depending on return_only_hr: - True: Only the most HR output - False: List of outputs of different resolutions from HR to LR - """ - if smallest_scales_num is None: - assert len(self.heads) == len(ms_inputs), (len(self.heads), len(ms_inputs), smallest_scales_num) - smallest_scales_num = len(self.heads) - else: - assert smallest_scales_num == len(ms_inputs) <= len(self.heads), ( - len(self.heads), - len(ms_inputs), - smallest_scales_num, - ) - - cur_heads = self.heads[-smallest_scales_num:] - ms_features = [cur_head(cur_inp) for cur_head, cur_inp in zip(cur_heads, ms_inputs)] - - all_outputs = [] - prev_tail_features = None - for i in range(len(ms_features)): - scale_i = -i - 1 - - cur_tail_input = ms_features[-i - 1] - if prev_tail_features is not None: - if prev_tail_features.shape != cur_tail_input.shape: - prev_tail_features = F.interpolate( - prev_tail_features, size=cur_tail_input.shape[2:], mode="bilinear", align_corners=False - ) - cur_tail_input = torch.cat((cur_tail_input, prev_tail_features), dim=1) - - cur_out, cur_tail_feats = self.tails[scale_i](cur_tail_input, return_last_act=True) - - prev_tail_features = cur_tail_feats - all_outputs.append(cur_out) - - if self.out_cumulative: - all_outputs_cum = [all_outputs[0]] - for i in range(1, len(ms_features)): - cur_out = all_outputs[i] - cur_out_cum = cur_out + F.interpolate( - all_outputs_cum[-1], size=cur_out.shape[2:], mode="bilinear", align_corners=False - ) - all_outputs_cum.append(cur_out_cum) - all_outputs = all_outputs_cum - - if self.return_only_hr: - return all_outputs[-1] - else: - return all_outputs[::-1] - - -class MultiscaleDiscriminatorSimple(nn.Module): - def __init__(self, ms_impl): - super().__init__() - self.ms_impl = nn.ModuleList(ms_impl) - - @property - def num_scales(self): - return len(self.ms_impl) - - def forward( - self, ms_inputs: List[torch.Tensor], smallest_scales_num: Optional[int] = None - ) -> List[Tuple[torch.Tensor, List[torch.Tensor]]]: - """ - :param ms_inputs: List of inputs of different resolutions from HR to LR - :param smallest_scales_num: int or None, number of smallest scales to take at input - :return: List of pairs (prediction, features) for different resolutions from HR to LR - """ - if smallest_scales_num is None: - assert len(self.ms_impl) == len(ms_inputs), (len(self.ms_impl), len(ms_inputs), smallest_scales_num) - smallest_scales_num = len(self.heads) - else: - assert smallest_scales_num == len(ms_inputs) <= len(self.ms_impl), ( - len(self.ms_impl), - len(ms_inputs), - smallest_scales_num, - ) - - return [cur_discr(cur_input) for cur_discr, cur_input in zip(self.ms_impl[-smallest_scales_num:], ms_inputs)] - - -class SingleToMultiScaleInputMixin: - def forward(self, x: torch.Tensor) -> List: - orig_height, orig_width = x.shape[2:] - factors = [2**i for i in range(self.num_scales)] - ms_inputs = [ - F.interpolate(x, size=(orig_height // f, orig_width // f), mode="bilinear", align_corners=False) - for f in factors - ] - return super().forward(ms_inputs) - - -class GeneratorMultiToSingleOutputMixin: - def forward(self, x): - return super().forward(x)[0] - - -class DiscriminatorMultiToSingleOutputMixin: - def forward(self, x): - out_feat_tuples = super().forward(x) - return out_feat_tuples[0][0], [f for _, flist in out_feat_tuples for f in flist] - - -class DiscriminatorMultiToSingleOutputStackedMixin: - def __init__(self, *args, return_feats_only_levels=None, **kwargs): - super().__init__(*args, **kwargs) - self.return_feats_only_levels = return_feats_only_levels - - def forward(self, x): - out_feat_tuples = super().forward(x) - outs = [out for out, _ in out_feat_tuples] - scaled_outs = [outs[0]] + [ - F.interpolate(cur_out, size=outs[0].shape[-2:], mode="bilinear", align_corners=False) - for cur_out in outs[1:] - ] - out = torch.cat(scaled_outs, dim=1) - if self.return_feats_only_levels is not None: - feat_lists = [out_feat_tuples[i][1] for i in self.return_feats_only_levels] - else: - feat_lists = [flist for _, flist in out_feat_tuples] - feats = [f for flist in feat_lists for f in flist] - return out, feats - - -class MultiscaleDiscrSingleInput( - SingleToMultiScaleInputMixin, DiscriminatorMultiToSingleOutputStackedMixin, MultiscaleDiscriminatorSimple -): - pass - - -class MultiscaleResNetSingle(GeneratorMultiToSingleOutputMixin, SingleToMultiScaleInputMixin, MultiscaleResNet): - pass diff --git a/modules/inpaint/lama/saicinpainting/training/modules/pix2pixhd.py b/modules/inpaint/lama/saicinpainting/training/modules/pix2pixhd.py deleted file mode 100644 index 394ed6c..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/pix2pixhd.py +++ /dev/null @@ -1,893 +0,0 @@ -# original: https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py -import collections -from functools import partial -import functools -import logging -from collections import defaultdict - -import numpy as np -import torch.nn as nn - -from .base import BaseDiscriminator, deconv_factory, get_conv_block_ctor, get_norm_layer, get_activation -from .ffc import FFCResnetBlock -from .multidilated_conv import MultidilatedConv - - -class DotDict(defaultdict): - # https://stackoverflow.com/questions/2352181/how-to-use-a-dot-to-access-members-of-dictionary - """dot.notation access to dictionary attributes""" - __getattr__ = defaultdict.get - __setattr__ = defaultdict.__setitem__ - __delattr__ = defaultdict.__delitem__ - - -class Identity(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return x - - -class ResnetBlock(nn.Module): - def __init__( - self, - dim, - padding_type, - norm_layer, - activation=nn.ReLU(True), - use_dropout=False, - conv_kind="default", - dilation=1, - in_dim=None, - groups=1, - second_dilation=None, - ): - super(ResnetBlock, self).__init__() - self.in_dim = in_dim - self.dim = dim - if second_dilation is None: - second_dilation = dilation - self.conv_block = self.build_conv_block( - dim, - padding_type, - norm_layer, - activation, - use_dropout, - conv_kind=conv_kind, - dilation=dilation, - in_dim=in_dim, - groups=groups, - second_dilation=second_dilation, - ) - - if self.in_dim is not None: - self.input_conv = nn.Conv2d(in_dim, dim, 1) - - self.out_channnels = dim - - def build_conv_block( - self, - dim, - padding_type, - norm_layer, - activation, - use_dropout, - conv_kind="default", - dilation=1, - in_dim=None, - groups=1, - second_dilation=1, - ): - conv_layer = get_conv_block_ctor(conv_kind) - - conv_block = [] - p = 0 - if padding_type == "reflect": - conv_block += [nn.ReflectionPad2d(dilation)] - elif padding_type == "replicate": - conv_block += [nn.ReplicationPad2d(dilation)] - elif padding_type == "zero": - p = dilation - else: - raise NotImplementedError("padding [%s] is not implemented" % padding_type) - - if in_dim is None: - in_dim = dim - - conv_block += [ - conv_layer(in_dim, dim, kernel_size=3, padding=p, dilation=dilation), - norm_layer(dim), - activation, - ] - if use_dropout: - conv_block += [nn.Dropout(0.5)] - - p = 0 - if padding_type == "reflect": - conv_block += [nn.ReflectionPad2d(second_dilation)] - elif padding_type == "replicate": - conv_block += [nn.ReplicationPad2d(second_dilation)] - elif padding_type == "zero": - p = second_dilation - else: - raise NotImplementedError("padding [%s] is not implemented" % padding_type) - conv_block += [ - conv_layer(dim, dim, kernel_size=3, padding=p, dilation=second_dilation, groups=groups), - norm_layer(dim), - ] - - return nn.Sequential(*conv_block) - - def forward(self, x): - x_before = x - if self.in_dim is not None: - x = self.input_conv(x) - out = x + self.conv_block(x_before) - return out - - -class ResnetBlock5x5(nn.Module): - def __init__( - self, - dim, - padding_type, - norm_layer, - activation=nn.ReLU(True), - use_dropout=False, - conv_kind="default", - dilation=1, - in_dim=None, - groups=1, - second_dilation=None, - ): - super(ResnetBlock5x5, self).__init__() - self.in_dim = in_dim - self.dim = dim - if second_dilation is None: - second_dilation = dilation - self.conv_block = self.build_conv_block( - dim, - padding_type, - norm_layer, - activation, - use_dropout, - conv_kind=conv_kind, - dilation=dilation, - in_dim=in_dim, - groups=groups, - second_dilation=second_dilation, - ) - - if self.in_dim is not None: - self.input_conv = nn.Conv2d(in_dim, dim, 1) - - self.out_channnels = dim - - def build_conv_block( - self, - dim, - padding_type, - norm_layer, - activation, - use_dropout, - conv_kind="default", - dilation=1, - in_dim=None, - groups=1, - second_dilation=1, - ): - conv_layer = get_conv_block_ctor(conv_kind) - - conv_block = [] - p = 0 - if padding_type == "reflect": - conv_block += [nn.ReflectionPad2d(dilation * 2)] - elif padding_type == "replicate": - conv_block += [nn.ReplicationPad2d(dilation * 2)] - elif padding_type == "zero": - p = dilation * 2 - else: - raise NotImplementedError("padding [%s] is not implemented" % padding_type) - - if in_dim is None: - in_dim = dim - - conv_block += [ - conv_layer(in_dim, dim, kernel_size=5, padding=p, dilation=dilation), - norm_layer(dim), - activation, - ] - if use_dropout: - conv_block += [nn.Dropout(0.5)] - - p = 0 - if padding_type == "reflect": - conv_block += [nn.ReflectionPad2d(second_dilation * 2)] - elif padding_type == "replicate": - conv_block += [nn.ReplicationPad2d(second_dilation * 2)] - elif padding_type == "zero": - p = second_dilation * 2 - else: - raise NotImplementedError("padding [%s] is not implemented" % padding_type) - conv_block += [ - conv_layer(dim, dim, kernel_size=5, padding=p, dilation=second_dilation, groups=groups), - norm_layer(dim), - ] - - return nn.Sequential(*conv_block) - - def forward(self, x): - x_before = x - if self.in_dim is not None: - x = self.input_conv(x) - out = x + self.conv_block(x_before) - return out - - -class MultidilatedResnetBlock(nn.Module): - def __init__(self, dim, padding_type, conv_layer, norm_layer, activation=nn.ReLU(True), use_dropout=False): - super().__init__() - self.conv_block = self.build_conv_block(dim, padding_type, conv_layer, norm_layer, activation, use_dropout) - - def build_conv_block(self, dim, padding_type, conv_layer, norm_layer, activation, use_dropout, dilation=1): - conv_block = [] - conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), norm_layer(dim), activation] - if use_dropout: - conv_block += [nn.Dropout(0.5)] - - conv_block += [conv_layer(dim, dim, kernel_size=3, padding_mode=padding_type), norm_layer(dim)] - - return nn.Sequential(*conv_block) - - def forward(self, x): - out = x + self.conv_block(x) - return out - - -class MultiDilatedGlobalGenerator(nn.Module): - def __init__( - self, - input_nc, - output_nc, - ngf=64, - n_downsampling=3, - n_blocks=3, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - conv_kind="default", - deconv_kind="convtranspose", - activation=nn.ReLU(True), - up_norm_layer=nn.BatchNorm2d, - affine=None, - up_activation=nn.ReLU(True), - add_out_act=True, - max_features=1024, - multidilation_kwargs={}, - ffc_positions=None, - ffc_kwargs={}, - ): - assert n_blocks >= 0 - super().__init__() - - conv_layer = get_conv_block_ctor(conv_kind) - resnet_conv_layer = functools.partial(get_conv_block_ctor("multidilated"), **multidilation_kwargs) - norm_layer = get_norm_layer(norm_layer) - if affine is not None: - norm_layer = partial(norm_layer, affine=affine) - up_norm_layer = get_norm_layer(up_norm_layer) - if affine is not None: - up_norm_layer = partial(up_norm_layer, affine=affine) - - model = [ - nn.ReflectionPad2d(3), - conv_layer(input_nc, ngf, kernel_size=7, padding=0), - norm_layer(ngf), - activation, - ] - - identity = Identity() - ### downsample - for i in range(n_downsampling): - mult = 2**i - - model += [ - conv_layer( - min(max_features, ngf * mult), - min(max_features, ngf * mult * 2), - kernel_size=3, - stride=2, - padding=1, - ), - norm_layer(min(max_features, ngf * mult * 2)), - activation, - ] - - mult = 2**n_downsampling - feats_num_bottleneck = min(max_features, ngf * mult) - - ### resnet blocks - for i in range(n_blocks): - if ffc_positions is not None and i in ffc_positions: - model += [ - FFCResnetBlock( - feats_num_bottleneck, - padding_type, - norm_layer, - activation_layer=nn.ReLU, - inline=True, - **ffc_kwargs, - ) - ] - model += [ - MultidilatedResnetBlock( - feats_num_bottleneck, - padding_type=padding_type, - conv_layer=resnet_conv_layer, - activation=activation, - norm_layer=norm_layer, - ) - ] - - ### upsample - for i in range(n_downsampling): - mult = 2 ** (n_downsampling - i) - model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) - model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] - if add_out_act: - model.append(get_activation("tanh" if add_out_act is True else add_out_act)) - self.model = nn.Sequential(*model) - - def forward(self, input): - return self.model(input) - - -class ConfigGlobalGenerator(nn.Module): - def __init__( - self, - input_nc, - output_nc, - ngf=64, - n_downsampling=3, - n_blocks=3, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - conv_kind="default", - deconv_kind="convtranspose", - activation=nn.ReLU(True), - up_norm_layer=nn.BatchNorm2d, - affine=None, - up_activation=nn.ReLU(True), - add_out_act=True, - max_features=1024, - manual_block_spec=[], - resnet_block_kind="multidilatedresnetblock", - resnet_conv_kind="multidilated", - resnet_dilation=1, - multidilation_kwargs={}, - ): - assert n_blocks >= 0 - super().__init__() - - conv_layer = get_conv_block_ctor(conv_kind) - resnet_conv_layer = functools.partial(get_conv_block_ctor(resnet_conv_kind), **multidilation_kwargs) - norm_layer = get_norm_layer(norm_layer) - if affine is not None: - norm_layer = partial(norm_layer, affine=affine) - up_norm_layer = get_norm_layer(up_norm_layer) - if affine is not None: - up_norm_layer = partial(up_norm_layer, affine=affine) - - model = [ - nn.ReflectionPad2d(3), - conv_layer(input_nc, ngf, kernel_size=7, padding=0), - norm_layer(ngf), - activation, - ] - - identity = Identity() - - ### downsample - for i in range(n_downsampling): - mult = 2**i - model += [ - conv_layer( - min(max_features, ngf * mult), - min(max_features, ngf * mult * 2), - kernel_size=3, - stride=2, - padding=1, - ), - norm_layer(min(max_features, ngf * mult * 2)), - activation, - ] - - mult = 2**n_downsampling - feats_num_bottleneck = min(max_features, ngf * mult) - - if len(manual_block_spec) == 0: - manual_block_spec = [DotDict(lambda: None, {"n_blocks": n_blocks, "use_default": True})] - - ### resnet blocks - for block_spec in manual_block_spec: - - def make_and_add_blocks(model, block_spec): - block_spec = DotDict(lambda: None, block_spec) - if not block_spec.use_default: - resnet_conv_layer = functools.partial( - get_conv_block_ctor(block_spec.resnet_conv_kind), **block_spec.multidilation_kwargs - ) - resnet_conv_kind = block_spec.resnet_conv_kind - resnet_block_kind = block_spec.resnet_block_kind - if block_spec.resnet_dilation is not None: - resnet_dilation = block_spec.resnet_dilation - for i in range(block_spec.n_blocks): - if resnet_block_kind == "multidilatedresnetblock": - model += [ - MultidilatedResnetBlock( - feats_num_bottleneck, - padding_type=padding_type, - conv_layer=resnet_conv_layer, - activation=activation, - norm_layer=norm_layer, - ) - ] - if resnet_block_kind == "resnetblock": - model += [ - ResnetBlock( - ngf * mult, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - conv_kind=resnet_conv_kind, - ) - ] - if resnet_block_kind == "resnetblock5x5": - model += [ - ResnetBlock5x5( - ngf * mult, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - conv_kind=resnet_conv_kind, - ) - ] - if resnet_block_kind == "resnetblockdwdil": - model += [ - ResnetBlock( - ngf * mult, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - conv_kind=resnet_conv_kind, - dilation=resnet_dilation, - second_dilation=resnet_dilation, - ) - ] - - make_and_add_blocks(model, block_spec) - - ### upsample - for i in range(n_downsampling): - mult = 2 ** (n_downsampling - i) - model += deconv_factory(deconv_kind, ngf, mult, up_norm_layer, up_activation, max_features) - model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] - if add_out_act: - model.append(get_activation("tanh" if add_out_act is True else add_out_act)) - self.model = nn.Sequential(*model) - - def forward(self, input): - return self.model(input) - - -def make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs): - blocks = [] - for i in range(dilated_blocks_n): - if dilation_block_kind == "simple": - blocks.append(ResnetBlock(**dilated_block_kwargs, dilation=2 ** (i + 1))) - elif dilation_block_kind == "multi": - blocks.append(MultidilatedResnetBlock(**dilated_block_kwargs)) - else: - raise ValueError(f'dilation_block_kind could not be "{dilation_block_kind}"') - return blocks - - -class GlobalGenerator(nn.Module): - def __init__( - self, - input_nc, - output_nc, - ngf=64, - n_downsampling=3, - n_blocks=9, - norm_layer=nn.BatchNorm2d, - padding_type="reflect", - conv_kind="default", - activation=nn.ReLU(True), - up_norm_layer=nn.BatchNorm2d, - affine=None, - up_activation=nn.ReLU(True), - dilated_blocks_n=0, - dilated_blocks_n_start=0, - dilated_blocks_n_middle=0, - add_out_act=True, - max_features=1024, - is_resblock_depthwise=False, - ffc_positions=None, - ffc_kwargs={}, - dilation=1, - second_dilation=None, - dilation_block_kind="simple", - multidilation_kwargs={}, - ): - assert n_blocks >= 0 - super().__init__() - - conv_layer = get_conv_block_ctor(conv_kind) - norm_layer = get_norm_layer(norm_layer) - if affine is not None: - norm_layer = partial(norm_layer, affine=affine) - up_norm_layer = get_norm_layer(up_norm_layer) - if affine is not None: - up_norm_layer = partial(up_norm_layer, affine=affine) - - if ffc_positions is not None: - ffc_positions = collections.Counter(ffc_positions) - - model = [ - nn.ReflectionPad2d(3), - conv_layer(input_nc, ngf, kernel_size=7, padding=0), - norm_layer(ngf), - activation, - ] - - identity = Identity() - ### downsample - for i in range(n_downsampling): - mult = 2**i - - model += [ - conv_layer( - min(max_features, ngf * mult), - min(max_features, ngf * mult * 2), - kernel_size=3, - stride=2, - padding=1, - ), - norm_layer(min(max_features, ngf * mult * 2)), - activation, - ] - - mult = 2**n_downsampling - feats_num_bottleneck = min(max_features, ngf * mult) - - dilated_block_kwargs = dict( - dim=feats_num_bottleneck, padding_type=padding_type, activation=activation, norm_layer=norm_layer - ) - if dilation_block_kind == "simple": - dilated_block_kwargs["conv_kind"] = conv_kind - elif dilation_block_kind == "multi": - dilated_block_kwargs["conv_layer"] = functools.partial( - get_conv_block_ctor("multidilated"), **multidilation_kwargs - ) - - # dilated blocks at the start of the bottleneck sausage - if dilated_blocks_n_start is not None and dilated_blocks_n_start > 0: - model += make_dil_blocks(dilated_blocks_n_start, dilation_block_kind, dilated_block_kwargs) - - # resnet blocks - for i in range(n_blocks): - # dilated blocks at the middle of the bottleneck sausage - if i == n_blocks // 2 and dilated_blocks_n_middle is not None and dilated_blocks_n_middle > 0: - model += make_dil_blocks(dilated_blocks_n_middle, dilation_block_kind, dilated_block_kwargs) - - if ffc_positions is not None and i in ffc_positions: - for _ in range(ffc_positions[i]): # same position can occur more than once - model += [ - FFCResnetBlock( - feats_num_bottleneck, - padding_type, - norm_layer, - activation_layer=nn.ReLU, - inline=True, - **ffc_kwargs, - ) - ] - - if is_resblock_depthwise: - resblock_groups = feats_num_bottleneck - else: - resblock_groups = 1 - - model += [ - ResnetBlock( - feats_num_bottleneck, - padding_type=padding_type, - activation=activation, - norm_layer=norm_layer, - conv_kind=conv_kind, - groups=resblock_groups, - dilation=dilation, - second_dilation=second_dilation, - ) - ] - - # dilated blocks at the end of the bottleneck sausage - if dilated_blocks_n is not None and dilated_blocks_n > 0: - model += make_dil_blocks(dilated_blocks_n, dilation_block_kind, dilated_block_kwargs) - - # upsample - for i in range(n_downsampling): - mult = 2 ** (n_downsampling - i) - model += [ - nn.ConvTranspose2d( - min(max_features, ngf * mult), - min(max_features, int(ngf * mult / 2)), - kernel_size=3, - stride=2, - padding=1, - output_padding=1, - ), - up_norm_layer(min(max_features, int(ngf * mult / 2))), - up_activation, - ] - model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] - if add_out_act: - model.append(get_activation("tanh" if add_out_act is True else add_out_act)) - self.model = nn.Sequential(*model) - - def forward(self, input): - return self.model(input) - - -class GlobalGeneratorGated(GlobalGenerator): - def __init__(self, *args, **kwargs): - real_kwargs = dict(conv_kind="gated_bn_relu", activation=nn.Identity(), norm_layer=nn.Identity) - real_kwargs.update(kwargs) - super().__init__(*args, **real_kwargs) - - -class GlobalGeneratorFromSuperChannels(nn.Module): - def __init__( - self, - input_nc, - output_nc, - n_downsampling, - n_blocks, - super_channels, - norm_layer="bn", - padding_type="reflect", - add_out_act=True, - ): - super().__init__() - self.n_downsampling = n_downsampling - norm_layer = get_norm_layer(norm_layer) - if type(norm_layer) == functools.partial: - use_bias = norm_layer.func == nn.InstanceNorm2d - else: - use_bias = norm_layer == nn.InstanceNorm2d - - channels = self.convert_super_channels(super_channels) - self.channels = channels - - model = [ - nn.ReflectionPad2d(3), - nn.Conv2d(input_nc, channels[0], kernel_size=7, padding=0, bias=use_bias), - norm_layer(channels[0]), - nn.ReLU(True), - ] - - for i in range(n_downsampling): # add downsampling layers - mult = 2**i - model += [ - nn.Conv2d(channels[0 + i], channels[1 + i], kernel_size=3, stride=2, padding=1, bias=use_bias), - norm_layer(channels[1 + i]), - nn.ReLU(True), - ] - - mult = 2**n_downsampling - - n_blocks1 = n_blocks // 3 - n_blocks2 = n_blocks1 - n_blocks3 = n_blocks - n_blocks1 - n_blocks2 - - for i in range(n_blocks1): - c = n_downsampling - dim = channels[c] - model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer)] - - for i in range(n_blocks2): - c = n_downsampling + 1 - dim = channels[c] - kwargs = {} - if i == 0: - kwargs = {"in_dim": channels[c - 1]} - model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] - - for i in range(n_blocks3): - c = n_downsampling + 2 - dim = channels[c] - kwargs = {} - if i == 0: - kwargs = {"in_dim": channels[c - 1]} - model += [ResnetBlock(dim, padding_type=padding_type, norm_layer=norm_layer, **kwargs)] - - for i in range(n_downsampling): # add upsampling layers - mult = 2 ** (n_downsampling - i) - model += [ - nn.ConvTranspose2d( - channels[n_downsampling + 3 + i], - channels[n_downsampling + 3 + i + 1], - kernel_size=3, - stride=2, - padding=1, - output_padding=1, - bias=use_bias, - ), - norm_layer(channels[n_downsampling + 3 + i + 1]), - nn.ReLU(True), - ] - model += [nn.ReflectionPad2d(3)] - model += [nn.Conv2d(channels[2 * n_downsampling + 3], output_nc, kernel_size=7, padding=0)] - - if add_out_act: - model.append(get_activation("tanh" if add_out_act is True else add_out_act)) - self.model = nn.Sequential(*model) - - def convert_super_channels(self, super_channels): - n_downsampling = self.n_downsampling - result = [] - cnt = 0 - - if n_downsampling == 2: - N1 = 10 - elif n_downsampling == 3: - N1 = 13 - else: - raise NotImplementedError - - for i in range(0, N1): - if i in [1, 4, 7, 10]: - channel = super_channels[cnt] * (2**cnt) - config = {"channel": channel} - result.append(channel) - logging.info(f"Downsample channels {result[-1]}") - cnt += 1 - - for i in range(3): - for counter, j in enumerate(range(N1 + i * 3, N1 + 3 + i * 3)): - if len(super_channels) == 6: - channel = super_channels[3] * 4 - else: - channel = super_channels[i + 3] * 4 - config = {"channel": channel} - if counter == 0: - result.append(channel) - logging.info(f"Bottleneck channels {result[-1]}") - cnt = 2 - - for i in range(N1 + 9, N1 + 21): - if i in [22, 25, 28]: - cnt -= 1 - if len(super_channels) == 6: - channel = super_channels[5 - cnt] * (2**cnt) - else: - channel = super_channels[7 - cnt] * (2**cnt) - result.append(int(channel)) - logging.info(f"Upsample channels {result[-1]}") - return result - - def forward(self, input): - return self.model(input) - - -# Defines the PatchGAN discriminator with the specified arguments. -class NLayerDiscriminator(BaseDiscriminator): - def __init__( - self, - input_nc, - ndf=64, - n_layers=3, - norm_layer=nn.BatchNorm2d, - ): - super().__init__() - self.n_layers = n_layers - - kw = 4 - padw = int(np.ceil((kw - 1.0) / 2)) - sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] - - nf = ndf - for n in range(1, n_layers): - nf_prev = nf - nf = min(nf * 2, 512) - - cur_model = [] - cur_model += [ - nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), - norm_layer(nf), - nn.LeakyReLU(0.2, True), - ] - sequence.append(cur_model) - - nf_prev = nf - nf = min(nf * 2, 512) - - cur_model = [] - cur_model += [ - nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), - norm_layer(nf), - nn.LeakyReLU(0.2, True), - ] - sequence.append(cur_model) - - sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] - - for n in range(len(sequence)): - setattr(self, "model" + str(n), nn.Sequential(*sequence[n])) - - def get_all_activations(self, x): - res = [x] - for n in range(self.n_layers + 2): - model = getattr(self, "model" + str(n)) - res.append(model(res[-1])) - return res[1:] - - def forward(self, x): - act = self.get_all_activations(x) - return act[-1], act[:-1] - - -class MultidilatedNLayerDiscriminator(BaseDiscriminator): - def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, multidilation_kwargs={}): - super().__init__() - self.n_layers = n_layers - - kw = 4 - padw = int(np.ceil((kw - 1.0) / 2)) - sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] - - nf = ndf - for n in range(1, n_layers): - nf_prev = nf - nf = min(nf * 2, 512) - - cur_model = [] - cur_model += [ - MultidilatedConv(nf_prev, nf, kernel_size=kw, stride=2, padding=[2, 3], **multidilation_kwargs), - norm_layer(nf), - nn.LeakyReLU(0.2, True), - ] - sequence.append(cur_model) - - nf_prev = nf - nf = min(nf * 2, 512) - - cur_model = [] - cur_model += [ - nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), - norm_layer(nf), - nn.LeakyReLU(0.2, True), - ] - sequence.append(cur_model) - - sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] - - for n in range(len(sequence)): - setattr(self, "model" + str(n), nn.Sequential(*sequence[n])) - - def get_all_activations(self, x): - res = [x] - for n in range(self.n_layers + 2): - model = getattr(self, "model" + str(n)) - res.append(model(res[-1])) - return res[1:] - - def forward(self, x): - act = self.get_all_activations(x) - return act[-1], act[:-1] - - -class NLayerDiscriminatorAsGen(NLayerDiscriminator): - def forward(self, x): - return super().forward(x)[0] diff --git a/modules/inpaint/lama/saicinpainting/training/modules/spatial_transform.py b/modules/inpaint/lama/saicinpainting/training/modules/spatial_transform.py deleted file mode 100644 index 37cd292..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/spatial_transform.py +++ /dev/null @@ -1,49 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from kornia.geometry.transform import rotate - - -class LearnableSpatialTransformWrapper(nn.Module): - def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True): - super().__init__() - self.impl = impl - self.angle = torch.rand(1) * angle_init_range - if train_angle: - self.angle = nn.Parameter(self.angle, requires_grad=True) - self.pad_coef = pad_coef - - def forward(self, x): - if torch.is_tensor(x): - return self.inverse_transform(self.impl(self.transform(x)), x) - elif isinstance(x, tuple): - x_trans = tuple(self.transform(elem) for elem in x) - y_trans = self.impl(x_trans) - return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)) - else: - raise ValueError(f"Unexpected input type {type(x)}") - - def transform(self, x): - height, width = x.shape[2:] - pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) - x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode="reflect") - x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) - return x_padded_rotated - - def inverse_transform(self, y_padded_rotated, orig_x): - height, width = orig_x.shape[2:] - pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) - - y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) - y_height, y_width = y_padded.shape[2:] - y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w] - return y - - -if __name__ == "__main__": - layer = LearnableSpatialTransformWrapper(nn.Identity()) - x = torch.arange(2 * 3 * 15 * 15).view(2, 3, 15, 15).float() - y = layer(x) - assert x.shape == y.shape - assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1]) - print("all ok") diff --git a/modules/inpaint/lama/saicinpainting/training/modules/squeeze_excitation.py b/modules/inpaint/lama/saicinpainting/training/modules/squeeze_excitation.py deleted file mode 100644 index 330bf6f..0000000 --- a/modules/inpaint/lama/saicinpainting/training/modules/squeeze_excitation.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch.nn as nn - - -class SELayer(nn.Module): - def __init__(self, channel, reduction=16): - super(SELayer, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc = nn.Sequential( - nn.Linear(channel, channel // reduction, bias=False), - nn.ReLU(inplace=True), - nn.Linear(channel // reduction, channel, bias=False), - nn.Sigmoid(), - ) - - def forward(self, x): - b, c, _, _ = x.size() - y = self.avg_pool(x).view(b, c) - y = self.fc(y).view(b, c, 1, 1) - res = x * y.expand_as(x) - return res diff --git a/modules/inpaint/lama/saicinpainting/training/trainers/__init__.py b/modules/inpaint/lama/saicinpainting/training/trainers/__init__.py deleted file mode 100644 index 952f439..0000000 --- a/modules/inpaint/lama/saicinpainting/training/trainers/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging -import torch - -from .default import DefaultInpaintingTrainingModule - - -def get_training_model_class(kind): - if kind == "default": - return DefaultInpaintingTrainingModule - - raise ValueError(f"Unknown trainer module {kind}") - - -def make_training_model(config): - kind = config.training_model.kind - kwargs = dict(config.training_model) - kwargs.pop("kind") - kwargs["use_ddp"] = config.trainer.kwargs.get("accelerator", None) == "ddp" - - logging.info(f"Make training model {kind}") - - cls = get_training_model_class(kind) - return cls(config, **kwargs) - - -def load_checkpoint(train_config, path, map_location="cuda", strict=True): - model: torch.nn.Module = make_training_model(train_config) - state = torch.load(path, map_location=map_location) - model.load_state_dict(state["state_dict"], strict=strict) - model.on_load_checkpoint(state) - return model diff --git a/modules/inpaint/lama/saicinpainting/training/trainers/base.py b/modules/inpaint/lama/saicinpainting/training/trainers/base.py deleted file mode 100644 index 5c692ca..0000000 --- a/modules/inpaint/lama/saicinpainting/training/trainers/base.py +++ /dev/null @@ -1,316 +0,0 @@ -import copy -import logging -from typing import Dict, Tuple - -import pandas as pd -import pytorch_lightning as ptl -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DistributedSampler - -# from saicinpainting.evaluation import make_evaluator -# from saicinpainting.training.data.datasets import make_default_train_dataloader, make_default_val_dataloader -# from saicinpainting.training.losses.adversarial import make_discrim_loss -# from saicinpainting.training.losses.perceptual import PerceptualLoss, ResNetPL -from ..modules import make_generator # , make_discriminator - -# from saicinpainting.training.visualizers import make_visualizer -from ...utils import add_prefix_to_keys, average_dicts, set_requires_grad, flatten_dict, get_has_ddp_rank - -LOGGER = logging.getLogger(__name__) - - -def make_optimizer(parameters, kind="adamw", **kwargs): - if kind == "adam": - optimizer_class = torch.optim.Adam - elif kind == "adamw": - optimizer_class = torch.optim.AdamW - else: - raise ValueError(f"Unknown optimizer kind {kind}") - return optimizer_class(parameters, **kwargs) - - -def update_running_average(result: nn.Module, new_iterate_model: nn.Module, decay=0.999): - with torch.no_grad(): - res_params = dict(result.named_parameters()) - new_params = dict(new_iterate_model.named_parameters()) - - for k in res_params.keys(): - res_params[k].data.mul_(decay).add_(new_params[k].data, alpha=1 - decay) - - -def make_multiscale_noise(base_tensor, scales=6, scale_mode="bilinear"): - batch_size, _, height, width = base_tensor.shape - cur_height, cur_width = height, width - result = [] - align_corners = False if scale_mode in ("bilinear", "bicubic") else None - for _ in range(scales): - cur_sample = torch.randn(batch_size, 1, cur_height, cur_width, device=base_tensor.device) - cur_sample_scaled = F.interpolate( - cur_sample, size=(height, width), mode=scale_mode, align_corners=align_corners - ) - result.append(cur_sample_scaled) - cur_height //= 2 - cur_width //= 2 - return torch.cat(result, dim=1) - - -class BaseInpaintingTrainingModule(ptl.LightningModule): - def __init__( - self, - config, - use_ddp, - *args, - predict_only=False, - visualize_each_iters=100, - average_generator=False, - generator_avg_beta=0.999, - average_generator_start_step=30000, - average_generator_period=10, - store_discr_outputs_for_vis=False, - **kwargs, - ): - super().__init__(*args, **kwargs) - LOGGER.info("BaseInpaintingTrainingModule init called") - - self.config = config - - self.generator = make_generator(config, **self.config.generator) - self.use_ddp = use_ddp - - # if not get_has_ddp_rank(): - # LOGGER.info(f"Generator\n{self.generator}") - - # if not predict_only: - # self.save_hyperparameters(self.config) - # self.discriminator = make_discriminator(**self.config.discriminator) - # self.adversarial_loss = make_discrim_loss(**self.config.losses.adversarial) - # self.visualizer = make_visualizer(**self.config.visualizer) - # self.val_evaluator = make_evaluator(**self.config.evaluator) - # self.test_evaluator = make_evaluator(**self.config.evaluator) - - # if not get_has_ddp_rank(): - # LOGGER.info(f"Discriminator\n{self.discriminator}") - - # extra_val = self.config.data.get("extra_val", ()) - # if extra_val: - # self.extra_val_titles = list(extra_val) - # self.extra_evaluators = nn.ModuleDict({k: make_evaluator(**self.config.evaluator) for k in extra_val}) - # else: - # self.extra_evaluators = {} - - # self.average_generator = average_generator - # self.generator_avg_beta = generator_avg_beta - # self.average_generator_start_step = average_generator_start_step - # self.average_generator_period = average_generator_period - # self.generator_average = None - # self.last_generator_averaging_step = -1 - # self.store_discr_outputs_for_vis = store_discr_outputs_for_vis - - # if self.config.losses.get("l1", {"weight_known": 0})["weight_known"] > 0: - # self.loss_l1 = nn.L1Loss(reduction="none") - - # if self.config.losses.get("mse", {"weight": 0})["weight"] > 0: - # self.loss_mse = nn.MSELoss(reduction="none") - - # if self.config.losses.perceptual.weight > 0: - # self.loss_pl = PerceptualLoss() - - # if self.config.losses.get("resnet_pl", {"weight": 0})["weight"] > 0: - # self.loss_resnet_pl = ResNetPL(**self.config.losses.resnet_pl) - # else: - # self.loss_resnet_pl = None - - self.visualize_each_iters = visualize_each_iters - LOGGER.info("BaseInpaintingTrainingModule init done") - - def configure_optimizers(self): - discriminator_params = list(self.discriminator.parameters()) - return [ - dict(optimizer=make_optimizer(self.generator.parameters(), **self.config.optimizers.generator)), - dict(optimizer=make_optimizer(discriminator_params, **self.config.optimizers.discriminator)), - ] - - def train_dataloader(self): - kwargs = dict(self.config.data.train) - if self.use_ddp: - kwargs["ddp_kwargs"] = dict( - num_replicas=self.trainer.num_nodes * self.trainer.num_processes, - rank=self.trainer.global_rank, - shuffle=True, - ) - dataloader = make_default_train_dataloader(**self.config.data.train) - return dataloader - - def val_dataloader(self): - res = [make_default_val_dataloader(**self.config.data.val)] - - if self.config.data.visual_test is not None: - res = res + [make_default_val_dataloader(**self.config.data.visual_test)] - else: - res = res + res - - extra_val = self.config.data.get("extra_val", ()) - if extra_val: - res += [make_default_val_dataloader(**extra_val[k]) for k in self.extra_val_titles] - - return res - - def training_step(self, batch, batch_idx, optimizer_idx=None): - self._is_training_step = True - return self._do_step(batch, batch_idx, mode="train", optimizer_idx=optimizer_idx) - - def validation_step(self, batch, batch_idx, dataloader_idx): - extra_val_key = None - if dataloader_idx == 0: - mode = "val" - elif dataloader_idx == 1: - mode = "test" - else: - mode = "extra_val" - extra_val_key = self.extra_val_titles[dataloader_idx - 2] - self._is_training_step = False - return self._do_step(batch, batch_idx, mode=mode, extra_val_key=extra_val_key) - - def training_step_end(self, batch_parts_outputs): - if ( - self.training - and self.average_generator - and self.global_step >= self.average_generator_start_step - and self.global_step >= self.last_generator_averaging_step + self.average_generator_period - ): - if self.generator_average is None: - self.generator_average = copy.deepcopy(self.generator) - else: - update_running_average(self.generator_average, self.generator, decay=self.generator_avg_beta) - self.last_generator_averaging_step = self.global_step - - full_loss = ( - batch_parts_outputs["loss"].mean() - if torch.is_tensor(batch_parts_outputs["loss"]) # loss is not tensor when no discriminator used - else torch.tensor(batch_parts_outputs["loss"]).float().requires_grad_(True) - ) - log_info = {k: v.mean() for k, v in batch_parts_outputs["log_info"].items()} - self.log_dict(log_info, on_step=True, on_epoch=False) - return full_loss - - def validation_epoch_end(self, outputs): - outputs = [step_out for out_group in outputs for step_out in out_group] - averaged_logs = average_dicts(step_out["log_info"] for step_out in outputs) - self.log_dict({k: v.mean() for k, v in averaged_logs.items()}) - - pd.set_option("display.max_columns", 500) - pd.set_option("display.width", 1000) - - # standard validation - val_evaluator_states = [s["val_evaluator_state"] for s in outputs if "val_evaluator_state" in s] - val_evaluator_res = self.val_evaluator.evaluation_end(states=val_evaluator_states) - val_evaluator_res_df = pd.DataFrame(val_evaluator_res).stack(1).unstack(0) - val_evaluator_res_df.dropna(axis=1, how="all", inplace=True) - LOGGER.info( - f"Validation metrics after epoch #{self.current_epoch}, " - f"total {self.global_step} iterations:\n{val_evaluator_res_df}" - ) - - for k, v in flatten_dict(val_evaluator_res).items(): - self.log(f"val_{k}", v) - - # standard visual test - test_evaluator_states = [s["test_evaluator_state"] for s in outputs if "test_evaluator_state" in s] - test_evaluator_res = self.test_evaluator.evaluation_end(states=test_evaluator_states) - test_evaluator_res_df = pd.DataFrame(test_evaluator_res).stack(1).unstack(0) - test_evaluator_res_df.dropna(axis=1, how="all", inplace=True) - LOGGER.info( - f"Test metrics after epoch #{self.current_epoch}, " - f"total {self.global_step} iterations:\n{test_evaluator_res_df}" - ) - - for k, v in flatten_dict(test_evaluator_res).items(): - self.log(f"test_{k}", v) - - # extra validations - if self.extra_evaluators: - for cur_eval_title, cur_evaluator in self.extra_evaluators.items(): - cur_state_key = f"extra_val_{cur_eval_title}_evaluator_state" - cur_states = [s[cur_state_key] for s in outputs if cur_state_key in s] - cur_evaluator_res = cur_evaluator.evaluation_end(states=cur_states) - cur_evaluator_res_df = pd.DataFrame(cur_evaluator_res).stack(1).unstack(0) - cur_evaluator_res_df.dropna(axis=1, how="all", inplace=True) - LOGGER.info( - f"Extra val {cur_eval_title} metrics after epoch #{self.current_epoch}, " - f"total {self.global_step} iterations:\n{cur_evaluator_res_df}" - ) - for k, v in flatten_dict(cur_evaluator_res).items(): - self.log(f"extra_val_{cur_eval_title}_{k}", v) - - def _do_step(self, batch, batch_idx, mode="train", optimizer_idx=None, extra_val_key=None): - if optimizer_idx == 0: # step for generator - set_requires_grad(self.generator, True) - set_requires_grad(self.discriminator, False) - elif optimizer_idx == 1: # step for discriminator - set_requires_grad(self.generator, False) - set_requires_grad(self.discriminator, True) - - batch = self(batch) - - total_loss = 0 - metrics = {} - - if optimizer_idx is None or optimizer_idx == 0: # step for generator - total_loss, metrics = self.generator_loss(batch) - - elif optimizer_idx is None or optimizer_idx == 1: # step for discriminator - if self.config.losses.adversarial.weight > 0: - total_loss, metrics = self.discriminator_loss(batch) - - if self.get_ddp_rank() in (None, 0) and (batch_idx % self.visualize_each_iters == 0 or mode == "test"): - if self.config.losses.adversarial.weight > 0: - if self.store_discr_outputs_for_vis: - with torch.no_grad(): - self.store_discr_outputs(batch) - vis_suffix = f"_{mode}" - if mode == "extra_val": - vis_suffix += f"_{extra_val_key}" - self.visualizer(self.current_epoch, batch_idx, batch, suffix=vis_suffix) - - metrics_prefix = f"{mode}_" - if mode == "extra_val": - metrics_prefix += f"{extra_val_key}_" - result = dict(loss=total_loss, log_info=add_prefix_to_keys(metrics, metrics_prefix)) - if mode == "val": - result["val_evaluator_state"] = self.val_evaluator.process_batch(batch) - elif mode == "test": - result["test_evaluator_state"] = self.test_evaluator.process_batch(batch) - elif mode == "extra_val": - result[f"extra_val_{extra_val_key}_evaluator_state"] = self.extra_evaluators[extra_val_key].process_batch( - batch - ) - - return result - - def get_current_generator(self, no_average=False): - if not no_average and not self.training and self.average_generator and self.generator_average is not None: - return self.generator_average - return self.generator - - def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """Pass data through generator and obtain at leas 'predicted_image' and 'inpainted' keys""" - raise NotImplementedError() - - def generator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - raise NotImplementedError() - - def discriminator_loss(self, batch) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - raise NotImplementedError() - - def store_discr_outputs(self, batch): - out_size = batch["image"].shape[2:] - discr_real_out, _ = self.discriminator(batch["image"]) - discr_fake_out, _ = self.discriminator(batch["predicted_image"]) - batch["discr_output_real"] = F.interpolate(discr_real_out, size=out_size, mode="nearest") - batch["discr_output_fake"] = F.interpolate(discr_fake_out, size=out_size, mode="nearest") - batch["discr_output_diff"] = batch["discr_output_real"] - batch["discr_output_fake"] - - def get_ddp_rank(self): - return self.trainer.global_rank if (self.trainer.num_nodes * self.trainer.num_processes) > 1 else None diff --git a/modules/inpaint/lama/saicinpainting/training/trainers/default.py b/modules/inpaint/lama/saicinpainting/training/trainers/default.py deleted file mode 100644 index 7320402..0000000 --- a/modules/inpaint/lama/saicinpainting/training/trainers/default.py +++ /dev/null @@ -1,230 +0,0 @@ -import logging -import random -import torch -import torch.nn.functional as F -from omegaconf import OmegaConf - -from ..losses.distance_weighting import make_mask_distance_weighter -from ..losses.feature_matching import feature_matching_loss, masked_l1_loss -from ..modules.fake_fakes import FakeFakesGenerator -from .base import BaseInpaintingTrainingModule, make_multiscale_noise -from ...utils import add_prefix_to_keys, get_ramp - -LOGGER = logging.getLogger(__name__) - - -def ceil_modulo(x, mod): - if x % mod == 0: - return x - return (x // mod + 1) * mod - - -def make_constant_area_crop_params(img_height, img_width, min_size=128, max_size=512, area=256 * 256, round_to_mod=16): - min_size = min(img_height, img_width, min_size) - max_size = min(img_height, img_width, max_size) - if random.random() < 0.5: - out_height = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) - out_width = min(max_size, ceil_modulo(area // out_height, round_to_mod)) - else: - out_width = min(max_size, ceil_modulo(random.randint(min_size, max_size), round_to_mod)) - out_height = min(max_size, ceil_modulo(area // out_width, round_to_mod)) - - start_y = random.randint(0, img_height - out_height) - start_x = random.randint(0, img_width - out_width) - return (start_y, start_x, out_height, out_width) - - -def make_constant_area_crop_batch(batch, **kwargs): - crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params( - img_height=batch["image"].shape[2], img_width=batch["image"].shape[3], **kwargs - ) - batch["image"] = batch["image"][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width] - batch["mask"] = batch["mask"][:, :, crop_y : crop_y + crop_height, crop_x : crop_x + crop_width] - return batch - - -class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule): - def __init__( - self, - *args, - concat_mask=True, - rescale_scheduler_kwargs=None, - image_to_discriminator="predicted_image", - add_noise_kwargs=None, - noise_fill_hole=False, - const_area_crop_kwargs=None, - distance_weighter_kwargs=None, - distance_weighted_mask_for_discr=False, - fake_fakes_proba=0, - fake_fakes_generator_kwargs=None, - **kwargs - ): - super().__init__(*args, **kwargs) - self.concat_mask = concat_mask - self.rescale_size_getter = ( - get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None - ) - self.image_to_discriminator = image_to_discriminator - self.add_noise_kwargs = add_noise_kwargs - self.noise_fill_hole = noise_fill_hole - self.const_area_crop_kwargs = const_area_crop_kwargs - self.refine_mask_for_losses = ( - make_mask_distance_weighter(**distance_weighter_kwargs) if distance_weighter_kwargs is not None else None - ) - self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr - - self.fake_fakes_proba = fake_fakes_proba - if self.fake_fakes_proba > 1e-3: - self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {})) - - def forward(self, batch): - if self.training and self.rescale_size_getter is not None: - cur_size = self.rescale_size_getter(self.global_step) - batch["image"] = F.interpolate(batch["image"], size=cur_size, mode="bilinear", align_corners=False) - batch["mask"] = F.interpolate(batch["mask"], size=cur_size, mode="nearest") - - if self.training and self.const_area_crop_kwargs is not None: - batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs) - - img = batch["image"] - mask = batch["mask"] - - masked_img = img * (1 - mask) - - if self.add_noise_kwargs is not None: - noise = make_multiscale_noise(masked_img, **self.add_noise_kwargs) - if self.noise_fill_hole: - masked_img = masked_img + mask * noise[:, : masked_img.shape[1]] - masked_img = torch.cat([masked_img, noise], dim=1) - - if self.concat_mask: - masked_img = torch.cat([masked_img, mask], dim=1) - - batch["predicted_image"] = self.generator(masked_img) - batch["inpainted"] = mask * batch["predicted_image"] + (1 - mask) * batch["image"] - - if self.fake_fakes_proba > 1e-3: - if self.training and torch.rand(1).item() < self.fake_fakes_proba: - batch["fake_fakes"], batch["fake_fakes_masks"] = self.fake_fakes_gen(img, mask) - batch["use_fake_fakes"] = True - else: - batch["fake_fakes"] = torch.zeros_like(img) - batch["fake_fakes_masks"] = torch.zeros_like(mask) - batch["use_fake_fakes"] = False - - batch["mask_for_losses"] = ( - self.refine_mask_for_losses(img, batch["predicted_image"], mask) - if self.refine_mask_for_losses is not None and self.training - else mask - ) - - return batch - - def generator_loss(self, batch): - img = batch["image"] - predicted_img = batch[self.image_to_discriminator] - original_mask = batch["mask"] - supervised_mask = batch["mask_for_losses"] - - # L1 - l1_value = masked_l1_loss( - predicted_img, - img, - supervised_mask, - self.config.losses.l1.weight_known, - self.config.losses.l1.weight_missing, - ) - - total_loss = l1_value - metrics = dict(gen_l1=l1_value) - - # vgg-based perceptual loss - if self.config.losses.perceptual.weight > 0: - pl_value = ( - self.loss_pl(predicted_img, img, mask=supervised_mask).sum() * self.config.losses.perceptual.weight - ) - total_loss = total_loss + pl_value - metrics["gen_pl"] = pl_value - - # discriminator - # adversarial_loss calls backward by itself - mask_for_discr = supervised_mask if self.distance_weighted_mask_for_discr else original_mask - self.adversarial_loss.pre_generator_step( - real_batch=img, fake_batch=predicted_img, generator=self.generator, discriminator=self.discriminator - ) - discr_real_pred, discr_real_features = self.discriminator(img) - discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) - adv_gen_loss, adv_metrics = self.adversarial_loss.generator_loss( - real_batch=img, - fake_batch=predicted_img, - discr_real_pred=discr_real_pred, - discr_fake_pred=discr_fake_pred, - mask=mask_for_discr, - ) - total_loss = total_loss + adv_gen_loss - metrics["gen_adv"] = adv_gen_loss - metrics.update(add_prefix_to_keys(adv_metrics, "adv_")) - - # feature matching - if self.config.losses.feature_matching.weight > 0: - need_mask_in_fm = OmegaConf.to_container(self.config.losses.feature_matching).get("pass_mask", False) - mask_for_fm = supervised_mask if need_mask_in_fm else None - fm_value = ( - feature_matching_loss(discr_fake_features, discr_real_features, mask=mask_for_fm) - * self.config.losses.feature_matching.weight - ) - total_loss = total_loss + fm_value - metrics["gen_fm"] = fm_value - - if self.loss_resnet_pl is not None: - resnet_pl_value = self.loss_resnet_pl(predicted_img, img) - total_loss = total_loss + resnet_pl_value - metrics["gen_resnet_pl"] = resnet_pl_value - - return total_loss, metrics - - def discriminator_loss(self, batch): - total_loss = 0 - metrics = {} - - predicted_img = batch[self.image_to_discriminator].detach() - self.adversarial_loss.pre_discriminator_step( - real_batch=batch["image"], - fake_batch=predicted_img, - generator=self.generator, - discriminator=self.discriminator, - ) - discr_real_pred, discr_real_features = self.discriminator(batch["image"]) - discr_fake_pred, discr_fake_features = self.discriminator(predicted_img) - adv_discr_loss, adv_metrics = self.adversarial_loss.discriminator_loss( - real_batch=batch["image"], - fake_batch=predicted_img, - discr_real_pred=discr_real_pred, - discr_fake_pred=discr_fake_pred, - mask=batch["mask"], - ) - total_loss = total_loss + adv_discr_loss - metrics["discr_adv"] = adv_discr_loss - metrics.update(add_prefix_to_keys(adv_metrics, "adv_")) - - if batch.get("use_fake_fakes", False): - fake_fakes = batch["fake_fakes"] - self.adversarial_loss.pre_discriminator_step( - real_batch=batch["image"], - fake_batch=fake_fakes, - generator=self.generator, - discriminator=self.discriminator, - ) - discr_fake_fakes_pred, _ = self.discriminator(fake_fakes) - fake_fakes_adv_discr_loss, fake_fakes_adv_metrics = self.adversarial_loss.discriminator_loss( - real_batch=batch["image"], - fake_batch=fake_fakes, - discr_real_pred=discr_real_pred, - discr_fake_pred=discr_fake_fakes_pred, - mask=batch["mask"], - ) - total_loss = total_loss + fake_fakes_adv_discr_loss - metrics["discr_adv_fake_fakes"] = fake_fakes_adv_discr_loss - metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, "adv_")) - - return total_loss, metrics diff --git a/modules/inpaint/lama/saicinpainting/training/visualizers/__init__.py b/modules/inpaint/lama/saicinpainting/training/visualizers/__init__.py deleted file mode 100644 index 4770d1f..0000000 --- a/modules/inpaint/lama/saicinpainting/training/visualizers/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -import logging - -from saicinpainting.training.visualizers.directory import DirectoryVisualizer -from saicinpainting.training.visualizers.noop import NoopVisualizer - - -def make_visualizer(kind, **kwargs): - logging.info(f'Make visualizer {kind}') - - if kind == 'directory': - return DirectoryVisualizer(**kwargs) - if kind == 'noop': - return NoopVisualizer() - - raise ValueError(f'Unknown visualizer kind {kind}') diff --git a/modules/inpaint/lama/saicinpainting/training/visualizers/base.py b/modules/inpaint/lama/saicinpainting/training/visualizers/base.py deleted file mode 100644 index 01978f1..0000000 --- a/modules/inpaint/lama/saicinpainting/training/visualizers/base.py +++ /dev/null @@ -1,75 +0,0 @@ -import abc -from typing import Dict, List - -import numpy as np -import torch -import skimage.color as color -from skimage.segmentation import mark_boundaries - -from . import colors - -COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation - - -class BaseVisualizer: - @abc.abstractmethod - def __call__(self, epoch_i, batch_i, batch, suffix="", rank=None): - """ - Take a batch, make an image from it and visualize - """ - raise NotImplementedError() - - -def visualize_mask_and_images( - images_dict: Dict[str, np.ndarray], - keys: List[str], - last_without_mask=True, - rescale_keys=None, - mask_only_first=None, - black_mask=False, -) -> np.ndarray: - mask = images_dict["mask"] > 0.5 - result = [] - for i, k in enumerate(keys): - img = images_dict[k] - img = np.transpose(img, (1, 2, 0)) - - if rescale_keys is not None and k in rescale_keys: - img = img - img.min() - img /= img.max() + 1e-5 - if len(img.shape) == 2: - img = np.expand_dims(img, 2) - - if img.shape[2] == 1: - img = np.repeat(img, 3, axis=2) - elif img.shape[2] > 3: - img_classes = img.argmax(2) - img = color.label2rgb(img_classes, colors=COLORS) - - if mask_only_first: - need_mark_boundaries = i == 0 - else: - need_mark_boundaries = i < len(keys) - 1 or not last_without_mask - - if need_mark_boundaries: - if black_mask: - img = img * (1 - mask[0][..., None]) - img = mark_boundaries(img, mask[0], color=(1.0, 0.0, 0.0), outline_color=(1.0, 1.0, 1.0), mode="thick") - result.append(img) - return np.concatenate(result, axis=1) - - -def visualize_mask_and_images_batch( - batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, last_without_mask=True, rescale_keys=None -) -> np.ndarray: - batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() if k in keys or k == "mask"} - - batch_size = next(iter(batch.values())).shape[0] - items_to_vis = min(batch_size, max_items) - result = [] - for i in range(items_to_vis): - cur_dct = {k: tens[i] for k, tens in batch.items()} - result.append( - visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, rescale_keys=rescale_keys) - ) - return np.concatenate(result, axis=0) diff --git a/modules/inpaint/lama/saicinpainting/training/visualizers/colors.py b/modules/inpaint/lama/saicinpainting/training/visualizers/colors.py deleted file mode 100644 index 740940c..0000000 --- a/modules/inpaint/lama/saicinpainting/training/visualizers/colors.py +++ /dev/null @@ -1,95 +0,0 @@ -import random -import colorsys - -import numpy as np -import matplotlib - -matplotlib.use("agg") -import matplotlib.pyplot as plt -from matplotlib.colors import LinearSegmentedColormap - - -def generate_colors(nlabels, type="bright", first_color_black=False, last_color_black=True, verbose=False): - # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib - """ - Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks - :param nlabels: Number of labels (size of colormap) - :param type: 'bright' for strong colors, 'soft' for pastel colors - :param first_color_black: Option to use first color as black, True or False - :param last_color_black: Option to use last color as black, True or False - :param verbose: Prints the number of labels and shows the colormap. True or False - :return: colormap for matplotlib - """ - if type not in ("bright", "soft"): - print('Please choose "bright" or "soft" for type') - return - - if verbose: - print("Number of labels: " + str(nlabels)) - - # Generate color map for bright colors, based on hsv - if type == "bright": - randHSVcolors = [ - ( - np.random.uniform(low=0.0, high=1), - np.random.uniform(low=0.2, high=1), - np.random.uniform(low=0.9, high=1), - ) - for i in range(nlabels) - ] - - # Convert HSV list to RGB - randRGBcolors = [] - for HSVcolor in randHSVcolors: - randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) - - if first_color_black: - randRGBcolors[0] = [0, 0, 0] - - if last_color_black: - randRGBcolors[-1] = [0, 0, 0] - - random_colormap = LinearSegmentedColormap.from_list("new_map", randRGBcolors, N=nlabels) - - # Generate soft pastel colors, by limiting the RGB spectrum - if type == "soft": - low = 0.6 - high = 0.95 - randRGBcolors = [ - ( - np.random.uniform(low=low, high=high), - np.random.uniform(low=low, high=high), - np.random.uniform(low=low, high=high), - ) - for i in range(nlabels) - ] - - if first_color_black: - randRGBcolors[0] = [0, 0, 0] - - if last_color_black: - randRGBcolors[-1] = [0, 0, 0] - random_colormap = LinearSegmentedColormap.from_list("new_map", randRGBcolors, N=nlabels) - - # Display colorbar - if verbose: - from matplotlib import colors, colorbar - from matplotlib import pyplot as plt - - fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) - - bounds = np.linspace(0, nlabels, nlabels + 1) - norm = colors.BoundaryNorm(bounds, nlabels) - - cb = colorbar.ColorbarBase( - ax, - cmap=random_colormap, - norm=norm, - spacing="proportional", - ticks=None, - boundaries=bounds, - format="%1i", - orientation="horizontal", - ) - - return randRGBcolors, random_colormap diff --git a/modules/inpaint/lama/saicinpainting/training/visualizers/directory.py b/modules/inpaint/lama/saicinpainting/training/visualizers/directory.py deleted file mode 100644 index 79b2af8..0000000 --- a/modules/inpaint/lama/saicinpainting/training/visualizers/directory.py +++ /dev/null @@ -1,41 +0,0 @@ -import os - -import cv2 -import numpy as np - -from .base import BaseVisualizer, visualize_mask_and_images_batch -from ...utils import check_and_warn_input_range - - -class DirectoryVisualizer(BaseVisualizer): - DEFAULT_KEY_ORDER = "image predicted_image inpainted".split(" ") - - def __init__( - self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10, last_without_mask=True, rescale_keys=None - ): - self.outdir = outdir - os.makedirs(self.outdir, exist_ok=True) - self.key_order = key_order - self.max_items_in_batch = max_items_in_batch - self.last_without_mask = last_without_mask - self.rescale_keys = rescale_keys - - def __call__(self, epoch_i, batch_i, batch, suffix="", rank=None): - check_and_warn_input_range(batch["image"], 0, 1, "DirectoryVisualizer target image") - vis_img = visualize_mask_and_images_batch( - batch, - self.key_order, - max_items=self.max_items_in_batch, - last_without_mask=self.last_without_mask, - rescale_keys=self.rescale_keys, - ) - - vis_img = np.clip(vis_img * 255, 0, 255).astype("uint8") - - curoutdir = os.path.join(self.outdir, f"epoch{epoch_i:04d}{suffix}") - os.makedirs(curoutdir, exist_ok=True) - rank_suffix = f"_r{rank}" if rank is not None else "" - out_fname = os.path.join(curoutdir, f"batch{batch_i:07d}{rank_suffix}.jpg") - - vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) - cv2.imwrite(out_fname, vis_img) diff --git a/modules/inpaint/lama/saicinpainting/training/visualizers/noop.py b/modules/inpaint/lama/saicinpainting/training/visualizers/noop.py deleted file mode 100644 index 763939b..0000000 --- a/modules/inpaint/lama/saicinpainting/training/visualizers/noop.py +++ /dev/null @@ -1,9 +0,0 @@ -from .base import BaseVisualizer - - -class NoopVisualizer(BaseVisualizer): - def __init__(self, *args, **kwargs): - pass - - def __call__(self, epoch_i, batch_i, batch, suffix="", rank=None): - pass diff --git a/modules/inpaint/lama/saicinpainting/utils.py b/modules/inpaint/lama/saicinpainting/utils.py deleted file mode 100644 index 8b18b49..0000000 --- a/modules/inpaint/lama/saicinpainting/utils.py +++ /dev/null @@ -1,172 +0,0 @@ -import bisect -import functools -import logging -import numbers -import os -import sys -import traceback -import warnings - -import torch -from pytorch_lightning import seed_everything - -LOGGER = logging.getLogger(__name__) - - -def check_and_warn_input_range(tensor, min_value, max_value, name): - actual_min = tensor.min() - actual_max = tensor.max() - if actual_min < min_value or actual_max > max_value: - warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}") - - -def sum_dict_with_prefix(target, cur_dict, prefix, default=0): - for k, v in cur_dict.items(): - target_key = prefix + k - target[target_key] = target.get(target_key, default) + v - - -def average_dicts(dict_list): - result = {} - norm = 1e-3 - for dct in dict_list: - sum_dict_with_prefix(result, dct, "") - norm += 1 - for k in list(result): - result[k] /= norm - return result - - -def add_prefix_to_keys(dct, prefix): - return {prefix + k: v for k, v in dct.items()} - - -def set_requires_grad(module, value): - for param in module.parameters(): - param.requires_grad = value - - -def flatten_dict(dct): - result = {} - for k, v in dct.items(): - if isinstance(k, tuple): - k = "_".join(k) - if isinstance(v, dict): - for sub_k, sub_v in flatten_dict(v).items(): - result[f"{k}_{sub_k}"] = sub_v - else: - result[k] = v - return result - - -class LinearRamp: - def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0): - self.start_value = start_value - self.end_value = end_value - self.start_iter = start_iter - self.end_iter = end_iter - - def __call__(self, i): - if i < self.start_iter: - return self.start_value - if i >= self.end_iter: - return self.end_value - part = (i - self.start_iter) / (self.end_iter - self.start_iter) - return self.start_value * (1 - part) + self.end_value * part - - -class LadderRamp: - def __init__(self, start_iters, values): - self.start_iters = start_iters - self.values = values - assert len(values) == len(start_iters) + 1, (len(values), len(start_iters)) - - def __call__(self, i): - segment_i = bisect.bisect_right(self.start_iters, i) - return self.values[segment_i] - - -def get_ramp(kind="ladder", **kwargs): - if kind == "linear": - return LinearRamp(**kwargs) - if kind == "ladder": - return LadderRamp(**kwargs) - raise ValueError(f"Unexpected ramp kind: {kind}") - - -def print_traceback_handler(sig, frame): - LOGGER.warning(f"Received signal {sig}") - bt = "".join(traceback.format_stack()) - LOGGER.warning(f"Requested stack trace:\n{bt}") - - -def handle_deterministic_config(config): - seed = dict(config).get("seed", None) - if seed is None: - return False - - seed_everything(seed) - return True - - -def get_shape(t): - if torch.is_tensor(t): - return tuple(t.shape) - elif isinstance(t, dict): - return {n: get_shape(q) for n, q in t.items()} - elif isinstance(t, (list, tuple)): - return [get_shape(q) for q in t] - elif isinstance(t, numbers.Number): - return type(t) - else: - raise ValueError("unexpected type {}".format(type(t))) - - -def get_has_ddp_rank(): - master_port = os.environ.get("MASTER_PORT", None) - node_rank = os.environ.get("NODE_RANK", None) - local_rank = os.environ.get("LOCAL_RANK", None) - world_size = os.environ.get("WORLD_SIZE", None) - has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None - return has_rank - - -def handle_ddp_subprocess(): - def main_decorator(main_func): - @functools.wraps(main_func) - def new_main(*args, **kwargs): - # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE - parent_cwd = os.environ.get("TRAINING_PARENT_WORK_DIR", None) - has_parent = parent_cwd is not None - has_rank = get_has_ddp_rank() - assert has_parent == has_rank, f"Inconsistent state: has_parent={has_parent}, has_rank={has_rank}" - - if has_parent: - # we are in the worker - sys.argv.extend( - [ - f"hydra.run.dir={parent_cwd}", - # 'hydra/hydra_logging=disabled', - # 'hydra/job_logging=disabled' - ] - ) - # do nothing if this is a top-level process - # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization - - main_func(*args, **kwargs) - - return new_main - - return main_decorator - - -def handle_ddp_parent_process(): - parent_cwd = os.environ.get("TRAINING_PARENT_WORK_DIR", None) - has_parent = parent_cwd is not None - has_rank = get_has_ddp_rank() - assert has_parent == has_rank, f"Inconsistent state: has_parent={has_parent}, has_rank={has_rank}" - - if parent_cwd is None: - os.environ["TRAINING_PARENT_WORK_DIR"] = os.getcwd() - - return has_parent diff --git a/pyproject.toml b/pyproject.toml index 506d51e..7db53e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [project] name = "comfyui-art-venture" description = "A comprehensive set of custom nodes for ComfyUI, focusing on utilities for image processing, JSON manipulation, model operations and working with object via URLs" -version = "1.0.0" +version = "1.0.1" license = "LICENSE" -dependencies = ["timm==0.6.13", "transformers", "fairscale", "pycocoevalcap", "opencv-python", "qrcode[pil]", "pytorch_lightning", "kornia", "pydantic", "segment_anything", "omegaconf", "boto3>=1.34.101"] +dependencies = ["timm==0.6.13", "transformers", "fairscale", "pycocoevalcap", "opencv-python", "qrcode[pil]", "pytorch_lightning", "kornia", "pydantic", "segment_anything", "boto3>=1.34.101"] [project.urls] Repository = "https://github.com/sipherxyz/comfyui-art-venture"