diff --git a/.gitignore b/.gitignore index 362e9c257..481271499 100644 --- a/.gitignore +++ b/.gitignore @@ -43,7 +43,6 @@ cache !package.json # all dynamic stuff -/repositories/**/* /extensions/**/* /outputs/**/* /embeddings/**/* diff --git a/.gitmodules b/.gitmodules index 000ad0e26..58a8cd7d7 100644 --- a/.gitmodules +++ b/.gitmodules @@ -35,3 +35,4 @@ [submodule "modules/k-diffusion"] path = modules/k-diffusion url = https://github.com/crowsonkb/k-diffusion + ignore = dirty diff --git a/configs/v2-1-stable-unclip-h-inference.yaml b/configs/v2-1-stable-unclip-h-inference.yaml new file mode 100644 index 000000000..1bd0c64d3 --- /dev/null +++ b/configs/v2-1-stable-unclip-h-inference.yaml @@ -0,0 +1,80 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion + params: + embedding_dropout: 0.25 + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 96 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn-adm + scale_factor: 0.18215 + monitor: val/loss_simple_ema + use_ema: False + + embedder_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder + + noise_aug_config: + target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation + params: + timestep_dim: 1024 + noise_schedule_config: + timesteps: 1000 + beta_schedule: squaredcos_cap_v2 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + num_classes: "sequential" + adm_in_channels: 2048 + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" diff --git a/configs/v2-1-stable-unclip-l-inference.yaml b/configs/v2-1-stable-unclip-l-inference.yaml new file mode 100644 index 000000000..335fd61f3 --- /dev/null +++ b/configs/v2-1-stable-unclip-l-inference.yaml @@ -0,0 +1,83 @@ +model: + base_learning_rate: 1.0e-04 + target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion + params: + embedding_dropout: 0.25 + parameterization: "v" + linear_start: 0.00085 + linear_end: 0.0120 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 96 + channels: 4 + cond_stage_trainable: false + conditioning_key: crossattn-adm + scale_factor: 0.18215 + monitor: val/loss_simple_ema + use_ema: False + + embedder_config: + target: ldm.modules.encoders.modules.ClipImageEmbedder + params: + model: "ViT-L/14" + + noise_aug_config: + target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation + params: + clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th" + timestep_dim: 768 + noise_schedule_config: + timesteps: 1000 + beta_schedule: squaredcos_cap_v2 + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + num_classes: "sequential" + adm_in_channels: 1536 + use_checkpoint: True + image_size: 32 # unused + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" \ No newline at end of file diff --git a/configs/v2-midas-inference.yaml b/configs/v2-midas-inference.yaml new file mode 100644 index 000000000..f20c30f61 --- /dev/null +++ b/configs/v2-midas-inference.yaml @@ -0,0 +1,74 @@ +model: + base_learning_rate: 5.0e-07 + target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion + params: + linear_start: 0.00085 + linear_end: 0.0120 + num_timesteps_cond: 1 + log_every_t: 200 + timesteps: 1000 + first_stage_key: "jpg" + cond_stage_key: "txt" + image_size: 64 + channels: 4 + cond_stage_trainable: false + conditioning_key: hybrid + scale_factor: 0.18215 + monitor: val/loss_simple_ema + finetune_keys: null + use_ema: False + + depth_stage_config: + target: ldm.modules.midas.api.MiDaSInference + params: + model_type: "dpt_hybrid" + + unet_config: + target: ldm.modules.diffusionmodules.openaimodel.UNetModel + params: + use_checkpoint: True + image_size: 32 # unused + in_channels: 5 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2, 1 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4, 4 ] + num_head_channels: 64 # need to fix for flash-attn + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: 1 + context_dim: 1024 + legacy: False + + first_stage_config: + target: ldm.models.autoencoder.AutoencoderKL + params: + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + #attn_type: "vanilla-xformers" + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + cond_stage_config: + target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder + params: + freeze: True + layer: "penultimate" + + diff --git a/installer.py b/installer.py index 205f16cee..b5fcd8335 100644 --- a/installer.py +++ b/installer.py @@ -591,6 +591,7 @@ def install_packages(): # clone required repositories def install_repositories(): + """ if args.profile: pr = cProfile.Profile() pr.enable() @@ -615,6 +616,7 @@ def d(name): clone(blip_repo, d('BLIP'), blip_commit) if args.profile: print_profile(pr, 'Repositories') + """ # run extension installer diff --git a/modules/paths.py b/modules/paths.py index 0fd7b87b5..d3cce3f4a 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -17,22 +17,13 @@ # data_path = cmd_opts_pre.data sys.path.insert(0, script_path) -# search for directory of stable diffusion in following places -sd_path = None -possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)] -for possible_sd_path in possible_sd_paths: - if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): - sd_path = os.path.abspath(possible_sd_path) - break - -assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" - +sd_path = os.path.join(script_path, 'repositories') path_dirs = [ - (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), - (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), - (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), - (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), + (sd_path, 'ldm', 'ldm', []), + (sd_path, 'taming', 'Taming Transformers', []), + (os.path.join(sd_path, 'blip'), 'models/blip.py', 'BLIP', []), + (os.path.join(sd_path, 'codeformer'), 'inference_codeformer.py', 'CodeFormer', []), + (os.path.join('modules', 'k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), ] paths = {} diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 163390ca6..3df76caa5 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -4,10 +4,10 @@ from modules import paths, sd_disable_initialization, devices -sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +sd_repo_configs_path = 'configs' config_default = paths.sd_default_config -config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") -config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") +config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference-512-base.yaml") +config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-768-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") diff --git a/repositories/blip b/repositories/blip new file mode 160000 index 000000000..3a29b7410 --- /dev/null +++ b/repositories/blip @@ -0,0 +1 @@ +Subproject commit 3a29b7410476bf5f2ba0955827390eb6ea1f4f9d diff --git a/repositories/codeformer/.gitignore b/repositories/codeformer/.gitignore new file mode 100644 index 000000000..15f690d16 --- /dev/null +++ b/repositories/codeformer/.gitignore @@ -0,0 +1,128 @@ +.vscode + +# ignored files +version.py + +# ignored files with suffix +*.html +# *.png +# *.jpeg +# *.jpg +*.pt +*.gif +*.pth +*.dat +*.zip + +# template + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# project +results/ +dlib/ +*_old* + diff --git a/repositories/codeformer/LICENSE b/repositories/codeformer/LICENSE new file mode 100644 index 000000000..44bf750a2 --- /dev/null +++ b/repositories/codeformer/LICENSE @@ -0,0 +1,35 @@ +S-Lab License 1.0 + +Copyright 2022 S-Lab + +Redistribution and use for non-commercial purpose in source and +binary forms, with or without modification, are permitted provided +that the following conditions are met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +In the event that redistribution and/or use for commercial purpose in +source or binary forms, with or without modification is required, +please contact the contributor(s) of the work. \ No newline at end of file diff --git a/repositories/codeformer/README.md b/repositories/codeformer/README.md new file mode 100644 index 000000000..f75785393 --- /dev/null +++ b/repositories/codeformer/README.md @@ -0,0 +1,149 @@ +

+ +

+ +## Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022) + +[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI) + + +google colab logo [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) ![visitors](https://visitor-badge-sczhou.glitch.me/badge?page_id=sczhou/CodeFormer) + + + +[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/) + +S-Lab, Nanyang Technological University + + + + +:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs: + +**[News]**: :whale: *Due to copyright issues, we have to delay the release of the training code (expected by the end of this year). Please star and stay tuned for our future updates!* +### Update +- **2022.10.05**: Support video input `--input_path [YOUR_VIDOE.mp4]`. Try it to enhance your videos! :clapper: +- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) +- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) +- **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement. +- **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement. +- **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. +- **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`. +- **2022.07.17**: Add Colab demo of CodeFormer. google colab logo +- **2022.07.16**: Release inference code for face restoration. :blush: +- **2022.06.21**: This repo is created. + +### TODO +- [ ] Add checkpoint for face inpainting +- [ ] Add checkpoint for face colorization +- [ ] Add training code and config files +- [x] ~~Add background image enhancement~~ + +#### :panda_face: Try Enhancing Old Photos / Fixing AI-arts +[](https://imgsli.com/MTI3NTE2) [](https://imgsli.com/MTI3NTE1) [](https://imgsli.com/MTI3NTIw) + +#### Face Restoration + + + + +#### Face Color Enhancement and Restoration + + + +#### Face Inpainting + + + + + +### Dependencies and Installation + +- Pytorch >= 1.7.1 +- CUDA >= 10.1 +- Other required packages in `requirements.txt` +``` +# git clone this repository +git clone https://github.com/sczhou/CodeFormer +cd CodeFormer + +# create new anaconda env +conda create -n codeformer python=3.8 -y +conda activate codeformer + +# install python dependencies +pip3 install -r requirements.txt +python basicsr/setup.py develop +``` + + +### Quick Inference + +#### Download Pre-trained Models: +Download the facelib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command. +``` +python scripts/download_pretrained_models.py facelib +``` + +Download the CodeFormer pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command. +``` +python scripts/download_pretrained_models.py CodeFormer +``` + +#### Prepare Testing Data: +You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder. + + +#### Testing on Face Restoration: +[Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison. + +🧑🏻 Face Restoration (cropped and aligned face) +``` +# For cropped and aligned faces +python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path] +``` + +:framed_picture: Whole Image Enhancement +``` +# For whole image +# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN +# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN +python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path] +``` + +:clapper: Video Enhancement +``` +# For Windows/Mac users, please install ffmpeg first +conda install -c conda-forge ffmpeg +``` +``` +# For video clips +# video path should end with '.mp4'|'.mov'|'.avi' +python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path] +``` + + +Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result. + +The results will be saved in the `results` folder. + +### Citation +If our work is useful for your research, please consider citing: + + @inproceedings{zhou2022codeformer, + author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change}, + title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer}, + booktitle = {NeurIPS}, + year = {2022} + } + +### License + +This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license. + +### Acknowledgement + +This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works. + +### Contact +If you have any question, please feel free to reach me out at `shangchenzhou@gmail.com`. diff --git a/repositories/codeformer/basicsr/VERSION b/repositories/codeformer/basicsr/VERSION new file mode 100644 index 000000000..1892b9267 --- /dev/null +++ b/repositories/codeformer/basicsr/VERSION @@ -0,0 +1 @@ +1.3.2 diff --git a/repositories/codeformer/basicsr/__init__.py b/repositories/codeformer/basicsr/__init__.py new file mode 100644 index 000000000..c7ffcccd7 --- /dev/null +++ b/repositories/codeformer/basicsr/__init__.py @@ -0,0 +1,11 @@ +# https://github.com/xinntao/BasicSR +# flake8: noqa +from .archs import * +from .data import * +from .losses import * +from .metrics import * +from .models import * +from .ops import * +from .train import * +from .utils import * +from .version import __gitsha__, __version__ diff --git a/repositories/codeformer/basicsr/archs/__init__.py b/repositories/codeformer/basicsr/archs/__init__.py new file mode 100644 index 000000000..cfb1e4d7b --- /dev/null +++ b/repositories/codeformer/basicsr/archs/__init__.py @@ -0,0 +1,25 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import ARCH_REGISTRY + +__all__ = ['build_network'] + +# automatically scan and import arch modules for registry +# scan all the files under the 'archs' folder and collect files ending with +# '_arch.py' +arch_folder = osp.dirname(osp.abspath(__file__)) +arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] +# import all the arch modules +_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] + + +def build_network(opt): + opt = deepcopy(opt) + network_type = opt.pop('type') + net = ARCH_REGISTRY.get(network_type)(**opt) + logger = get_root_logger() + logger.info(f'Network [{net.__class__.__name__}] is created.') + return net diff --git a/repositories/codeformer/basicsr/archs/arcface_arch.py b/repositories/codeformer/basicsr/archs/arcface_arch.py new file mode 100644 index 000000000..fe5afb7bd --- /dev/null +++ b/repositories/codeformer/basicsr/archs/arcface_arch.py @@ -0,0 +1,245 @@ +import torch.nn as nn +from basicsr.utils.registry import ARCH_REGISTRY + + +def conv3x3(inplanes, outplanes, stride=1): + """A simple wrapper for 3x3 convolution with padding. + + Args: + inplanes (int): Channel number of inputs. + outplanes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + """ + return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + """Basic residual block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class IRBlock(nn.Module): + """Improved residual block (IR Block) used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + expansion = 1 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): + super(IRBlock, self).__init__() + self.bn0 = nn.BatchNorm2d(inplanes) + self.conv1 = conv3x3(inplanes, inplanes) + self.bn1 = nn.BatchNorm2d(inplanes) + self.prelu = nn.PReLU() + self.conv2 = conv3x3(inplanes, planes, stride) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + self.use_se = use_se + if self.use_se: + self.se = SEBlock(planes) + + def forward(self, x): + residual = x + out = self.bn0(x) + out = self.conv1(out) + out = self.bn1(out) + out = self.prelu(out) + + out = self.conv2(out) + out = self.bn2(out) + if self.use_se: + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.prelu(out) + + return out + + +class Bottleneck(nn.Module): + """Bottleneck block used in the ResNetArcFace architecture. + + Args: + inplanes (int): Channel number of inputs. + planes (int): Channel number of outputs. + stride (int): Stride in convolution. Default: 1. + downsample (nn.Module): The downsample module. Default: None. + """ + expansion = 4 # output channel expansion ratio + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBlock(nn.Module): + """The squeeze-and-excitation block (SEBlock) used in the IRBlock. + + Args: + channel (int): Channel number of inputs. + reduction (int): Channel reduction ration. Default: 16. + """ + + def __init__(self, channel, reduction=16): + super(SEBlock, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), + nn.Sigmoid()) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + return x * y + + +@ARCH_REGISTRY.register() +class ResNetArcFace(nn.Module): + """ArcFace with ResNet architectures. + + Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. + + Args: + block (str): Block used in the ArcFace architecture. + layers (tuple(int)): Block numbers in each layer. + use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. + """ + + def __init__(self, block, layers, use_se=True): + if block == 'IRBlock': + block = IRBlock + self.inplanes = 64 + self.use_se = use_se + super(ResNetArcFace, self).__init__() + + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.prelu = nn.PReLU() + self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.bn4 = nn.BatchNorm2d(512) + self.dropout = nn.Dropout() + self.fc5 = nn.Linear(512 * 8 * 8, 512) + self.bn5 = nn.BatchNorm1d(512) + + # initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.xavier_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, num_blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) + self.inplanes = planes + for _ in range(1, num_blocks): + layers.append(block(self.inplanes, planes, use_se=self.use_se)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn4(x) + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.fc5(x) + x = self.bn5(x) + + return x \ No newline at end of file diff --git a/repositories/codeformer/basicsr/archs/arch_util.py b/repositories/codeformer/basicsr/archs/arch_util.py new file mode 100644 index 000000000..bad45ab34 --- /dev/null +++ b/repositories/codeformer/basicsr/archs/arch_util.py @@ -0,0 +1,318 @@ +import collections.abc +import math +import torch +import torchvision +import warnings +from distutils.version import LooseVersion +from itertools import repeat +from torch import nn as nn +from torch.nn import functional as F +from torch.nn import init as init +from torch.nn.modules.batchnorm import _BatchNorm + +from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv +from basicsr.utils import get_root_logger + + +@torch.no_grad() +def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): + """Initialize network weights. + + Args: + module_list (list[nn.Module] | nn.Module): Modules to be initialized. + scale (float): Scale initialized weights, especially for residual + blocks. Default: 1. + bias_fill (float): The value to fill bias. Default: 0 + kwargs (dict): Other arguments for initialization function. + """ + if not isinstance(module_list, list): + module_list = [module_list] + for module in module_list: + for m in module.modules(): + if isinstance(m, nn.Conv2d): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, nn.Linear): + init.kaiming_normal_(m.weight, **kwargs) + m.weight.data *= scale + if m.bias is not None: + m.bias.data.fill_(bias_fill) + elif isinstance(m, _BatchNorm): + init.constant_(m.weight, 1) + if m.bias is not None: + m.bias.data.fill_(bias_fill) + + +def make_layer(basic_block, num_basic_block, **kwarg): + """Make layers by stacking the same blocks. + + Args: + basic_block (nn.module): nn.module class for basic block. + num_basic_block (int): number of blocks. + + Returns: + nn.Sequential: Stacked blocks in nn.Sequential. + """ + layers = [] + for _ in range(num_basic_block): + layers.append(basic_block(**kwarg)) + return nn.Sequential(*layers) + + +class ResidualBlockNoBN(nn.Module): + """Residual block without BN. + + It has a style of: + ---Conv-ReLU-Conv-+- + |________________| + + Args: + num_feat (int): Channel number of intermediate features. + Default: 64. + res_scale (float): Residual scale. Default: 1. + pytorch_init (bool): If set to True, use pytorch default init, + otherwise, use default_init_weights. Default: False. + """ + + def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): + super(ResidualBlockNoBN, self).__init__() + self.res_scale = res_scale + self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) + self.relu = nn.ReLU(inplace=True) + + if not pytorch_init: + default_init_weights([self.conv1, self.conv2], 0.1) + + def forward(self, x): + identity = x + out = self.conv2(self.relu(self.conv1(x))) + return identity + out * self.res_scale + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): + """Warp an image or feature map with optical flow. + + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2), normal value. + interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. + padding_mode (str): 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Before pytorch 1.3, the default value is + align_corners=True. After pytorch 1.3, the default value is + align_corners=False. Here, we use the True as default. + + Returns: + Tensor: Warped image or feature map. + """ + assert x.size()[-2:] == flow.size()[1:3] + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) + grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 + grid.requires_grad = False + + vgrid = grid + flow + # scale grid to [-1,1] + vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 + vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 + vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) + output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) + + # TODO, what if align_corners=False + return output + + +def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): + """Resize a flow according to ratio or shape. + + Args: + flow (Tensor): Precomputed flow. shape [N, 2, H, W]. + size_type (str): 'ratio' or 'shape'. + sizes (list[int | float]): the ratio for resizing or the final output + shape. + 1) The order of ratio should be [ratio_h, ratio_w]. For + downsampling, the ratio should be smaller than 1.0 (i.e., ratio + < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., + ratio > 1.0). + 2) The order of output_size should be [out_h, out_w]. + interp_mode (str): The mode of interpolation for resizing. + Default: 'bilinear'. + align_corners (bool): Whether align corners. Default: False. + + Returns: + Tensor: Resized flow. + """ + _, _, flow_h, flow_w = flow.size() + if size_type == 'ratio': + output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) + elif size_type == 'shape': + output_h, output_w = sizes[0], sizes[1] + else: + raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') + + input_flow = flow.clone() + ratio_h = output_h / flow_h + ratio_w = output_w / flow_w + input_flow[:, 0, :, :] *= ratio_w + input_flow[:, 1, :, :] *= ratio_h + resized_flow = F.interpolate( + input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) + return resized_flow + + +# TODO: may write a cpp file +def pixel_unshuffle(x, scale): + """ Pixel unshuffle. + + Args: + x (Tensor): Input feature with shape (b, c, hh, hw). + scale (int): Downsample ratio. + + Returns: + Tensor: the pixel unshuffled feature. + """ + b, c, hh, hw = x.size() + out_channel = c * (scale**2) + assert hh % scale == 0 and hw % scale == 0 + h = hh // scale + w = hw // scale + x_view = x.view(b, c, h, scale, w, scale) + return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) + + +class DCNv2Pack(ModulatedDeformConvPack): + """Modulated deformable conv for deformable alignment. + + Different from the official DCNv2Pack, which generates offsets and masks + from the preceding features, this DCNv2Pack takes another different + features to generate offsets and masks. + + Ref: + Delving Deep into Deformable Alignment in Video Super-Resolution. + """ + + def forward(self, x, feat): + out = self.conv_offset(feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + offset_absmean = torch.mean(torch.abs(offset)) + if offset_absmean > 50: + logger = get_root_logger() + logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') + + if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, + self.dilation, mask) + else: + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, + self.dilation, self.groups, self.deformable_groups) + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. ' + 'The distribution of values may be incorrect.', + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + up = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [low, up], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * up - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py + + The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +# From PyTorch +def _ntuple(n): + + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple \ No newline at end of file diff --git a/repositories/codeformer/basicsr/archs/codeformer_arch.py b/repositories/codeformer/basicsr/archs/codeformer_arch.py new file mode 100644 index 000000000..4d0d8027c --- /dev/null +++ b/repositories/codeformer/basicsr/archs/codeformer_arch.py @@ -0,0 +1,276 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from basicsr.archs.vqgan_arch import * +from basicsr.utils import get_root_logger +from basicsr.utils.registry import ARCH_REGISTRY + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator']): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat \ No newline at end of file diff --git a/repositories/codeformer/basicsr/archs/rrdbnet_arch.py b/repositories/codeformer/basicsr/archs/rrdbnet_arch.py new file mode 100644 index 000000000..49a2d6c20 --- /dev/null +++ b/repositories/codeformer/basicsr/archs/rrdbnet_arch.py @@ -0,0 +1,119 @@ +import torch +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.utils.registry import ARCH_REGISTRY +from .arch_util import default_init_weights, make_layer, pixel_unshuffle + + +class ResidualDenseBlock(nn.Module): + """Residual Dense Block. + + Used in RRDB block in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat=64, num_grow_ch=32): + super(ResidualDenseBlock, self).__init__() + self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) + self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) + self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + # initialization + default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) + + def forward(self, x): + x1 = self.lrelu(self.conv1(x)) + x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) + x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) + x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) + x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) + # Emperically, we use 0.2 to scale the residual for better performance + return x5 * 0.2 + x + + +class RRDB(nn.Module): + """Residual in Residual Dense Block. + + Used in RRDB-Net in ESRGAN. + + Args: + num_feat (int): Channel number of intermediate features. + num_grow_ch (int): Channels for each growth. + """ + + def __init__(self, num_feat, num_grow_ch=32): + super(RRDB, self).__init__() + self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) + self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) + + def forward(self, x): + out = self.rdb1(x) + out = self.rdb2(out) + out = self.rdb3(out) + # Emperically, we use 0.2 to scale the residual for better performance + return out * 0.2 + x + + +@ARCH_REGISTRY.register() +class RRDBNet(nn.Module): + """Networks consisting of Residual in Residual Dense Block, which is used + in ESRGAN. + + ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. + + We extend ESRGAN for scale x2 and scale x1. + Note: This is one option for scale 1, scale 2 in RRDBNet. + We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size + and enlarge the channel size before feeding inputs into the main ESRGAN architecture. + + Args: + num_in_ch (int): Channel number of inputs. + num_out_ch (int): Channel number of outputs. + num_feat (int): Channel number of intermediate features. + Default: 64 + num_block (int): Block number in the trunk network. Defaults: 23 + num_grow_ch (int): Channels for each growth. Default: 32. + """ + + def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32): + super(RRDBNet, self).__init__() + self.scale = scale + if scale == 2: + num_in_ch = num_in_ch * 4 + elif scale == 1: + num_in_ch = num_in_ch * 16 + self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch) + self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + # upsample + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x): + if self.scale == 2: + feat = pixel_unshuffle(x, scale=2) + elif self.scale == 1: + feat = pixel_unshuffle(x, scale=4) + else: + feat = x + feat = self.conv_first(feat) + body_feat = self.conv_body(self.body(feat)) + feat = feat + body_feat + # upsample + feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest'))) + feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest'))) + out = self.conv_last(self.lrelu(self.conv_hr(feat))) + return out \ No newline at end of file diff --git a/repositories/codeformer/basicsr/archs/vgg_arch.py b/repositories/codeformer/basicsr/archs/vgg_arch.py new file mode 100644 index 000000000..23bb0103c --- /dev/null +++ b/repositories/codeformer/basicsr/archs/vgg_arch.py @@ -0,0 +1,161 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +from basicsr.utils.registry import ARCH_REGISTRY + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + + +@ARCH_REGISTRY.register() +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/repositories/codeformer/basicsr/archs/vqgan_arch.py b/repositories/codeformer/basicsr/archs/vqgan_arch.py new file mode 100644 index 000000000..5ac692633 --- /dev/null +++ b/repositories/codeformer/basicsr/archs/vqgan_arch.py @@ -0,0 +1,434 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from basicsr.utils import get_root_logger +from basicsr.utils.registry import ARCH_REGISTRY + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + # min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) \ No newline at end of file diff --git a/repositories/codeformer/basicsr/data/__init__.py b/repositories/codeformer/basicsr/data/__init__.py new file mode 100644 index 000000000..c6adb4bb6 --- /dev/null +++ b/repositories/codeformer/basicsr/data/__init__.py @@ -0,0 +1,100 @@ +import importlib +import numpy as np +import random +import torch +import torch.utils.data +from copy import deepcopy +from functools import partial +from os import path as osp + +from basicsr.data.prefetch_dataloader import PrefetchDataLoader +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.dist_util import get_dist_info +from basicsr.utils.registry import DATASET_REGISTRY + +__all__ = ['build_dataset', 'build_dataloader'] + +# automatically scan and import dataset modules for registry +# scan all the files under the data folder with '_dataset' in file names +data_folder = osp.dirname(osp.abspath(__file__)) +dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] +# import all the dataset modules +_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] + + +def build_dataset(dataset_opt): + """Build dataset from options. + + Args: + dataset_opt (dict): Configuration for dataset. It must constain: + name (str): Dataset name. + type (str): Dataset type. + """ + dataset_opt = deepcopy(dataset_opt) + dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) + logger = get_root_logger() + logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') + return dataset + + +def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): + """Build dataloader. + + Args: + dataset (torch.utils.data.Dataset): Dataset. + dataset_opt (dict): Dataset options. It contains the following keys: + phase (str): 'train' or 'val'. + num_worker_per_gpu (int): Number of workers for each GPU. + batch_size_per_gpu (int): Training batch size for each GPU. + num_gpu (int): Number of GPUs. Used only in the train phase. + Default: 1. + dist (bool): Whether in distributed training. Used only in the train + phase. Default: False. + sampler (torch.utils.data.sampler): Data sampler. Default: None. + seed (int | None): Seed. Default: None + """ + phase = dataset_opt['phase'] + rank, _ = get_dist_info() + if phase == 'train': + if dist: # distributed training + batch_size = dataset_opt['batch_size_per_gpu'] + num_workers = dataset_opt['num_worker_per_gpu'] + else: # non-distributed training + multiplier = 1 if num_gpu == 0 else num_gpu + batch_size = dataset_opt['batch_size_per_gpu'] * multiplier + num_workers = dataset_opt['num_worker_per_gpu'] * multiplier + dataloader_args = dict( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=sampler, + drop_last=True) + if sampler is None: + dataloader_args['shuffle'] = True + dataloader_args['worker_init_fn'] = partial( + worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None + elif phase in ['val', 'test']: # validation + dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) + else: + raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") + + dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) + + prefetch_mode = dataset_opt.get('prefetch_mode') + if prefetch_mode == 'cpu': # CPUPrefetcher + num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) + logger = get_root_logger() + logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}') + return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) + else: + # prefetch_mode=None: Normal dataloader + # prefetch_mode='cuda': dataloader for CUDAPrefetcher + return torch.utils.data.DataLoader(**dataloader_args) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # Set the worker seed to num_workers * rank + worker_id + seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) diff --git a/repositories/codeformer/basicsr/data/data_sampler.py b/repositories/codeformer/basicsr/data/data_sampler.py new file mode 100644 index 000000000..575452d9f --- /dev/null +++ b/repositories/codeformer/basicsr/data/data_sampler.py @@ -0,0 +1,48 @@ +import math +import torch +from torch.utils.data.sampler import Sampler + + +class EnlargedSampler(Sampler): + """Sampler that restricts data loading to a subset of the dataset. + + Modified from torch.utils.data.distributed.DistributedSampler + Support enlarging the dataset for iteration-based training, for saving + time when restart the dataloader after each epoch + + Args: + dataset (torch.utils.data.Dataset): Dataset used for sampling. + num_replicas (int | None): Number of processes participating in + the training. It is usually the world_size. + rank (int | None): Rank of the current process within num_replicas. + ratio (int): Enlarging ratio. Default: 1. + """ + + def __init__(self, dataset, num_replicas, rank, ratio=1): + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(self.total_size, generator=g).tolist() + + dataset_size = len(self.dataset) + indices = [v % dataset_size for v in indices] + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/repositories/codeformer/basicsr/data/data_util.py b/repositories/codeformer/basicsr/data/data_util.py new file mode 100644 index 000000000..63b1bce8e --- /dev/null +++ b/repositories/codeformer/basicsr/data/data_util.py @@ -0,0 +1,305 @@ +import cv2 +import numpy as np +import torch +from os import path as osp +from torch.nn import functional as F + +from basicsr.data.transforms import mod_crop +from basicsr.utils import img2tensor, scandir + + +def read_img_seq(path, require_mod_crop=False, scale=1): + """Read a sequence of images from a given folder path. + + Args: + path (list[str] | str): List of image paths or image folder path. + require_mod_crop (bool): Require mod crop for each image. + Default: False. + scale (int): Scale factor for mod_crop. Default: 1. + + Returns: + Tensor: size (t, c, h, w), RGB, [0, 1]. + """ + if isinstance(path, list): + img_paths = path + else: + img_paths = sorted(list(scandir(path, full_path=True))) + imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths] + if require_mod_crop: + imgs = [mod_crop(img, scale) for img in imgs] + imgs = img2tensor(imgs, bgr2rgb=True, float32=True) + imgs = torch.stack(imgs, dim=0) + return imgs + + +def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'): + """Generate an index list for reading `num_frames` frames from a sequence + of images. + + Args: + crt_idx (int): Current center index. + max_frame_num (int): Max number of the sequence of images (from 1). + num_frames (int): Reading num_frames frames. + padding (str): Padding mode, one of + 'replicate' | 'reflection' | 'reflection_circle' | 'circle' + Examples: current_idx = 0, num_frames = 5 + The generated frame indices under different padding mode: + replicate: [0, 0, 0, 1, 2] + reflection: [2, 1, 0, 1, 2] + reflection_circle: [4, 3, 0, 1, 2] + circle: [3, 4, 0, 1, 2] + + Returns: + list[int]: A list of indices. + """ + assert num_frames % 2 == 1, 'num_frames should be an odd number.' + assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.' + + max_frame_num = max_frame_num - 1 # start from 0 + num_pad = num_frames // 2 + + indices = [] + for i in range(crt_idx - num_pad, crt_idx + num_pad + 1): + if i < 0: + if padding == 'replicate': + pad_idx = 0 + elif padding == 'reflection': + pad_idx = -i + elif padding == 'reflection_circle': + pad_idx = crt_idx + num_pad - i + else: + pad_idx = num_frames + i + elif i > max_frame_num: + if padding == 'replicate': + pad_idx = max_frame_num + elif padding == 'reflection': + pad_idx = max_frame_num * 2 - i + elif padding == 'reflection_circle': + pad_idx = (crt_idx - num_pad) - (i - max_frame_num) + else: + pad_idx = i - num_frames + else: + pad_idx = i + indices.append(pad_idx) + return indices + + +def paired_paths_from_lmdb(folders, keys): + """Generate paired paths from lmdb files. + + Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is: + + lq.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records + 1)image name (with extension), + 2)image shape, + 3)compression level, separated by a white space. + Example: `baboon.png (120,125,3) 1` + + We use the image name without extension as the lmdb key. + Note that we use the same key for the corresponding lq and gt images. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + Note that this key is different from lmdb keys. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')): + raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb ' + f'formats. But received {input_key}: {input_folder}; ' + f'{gt_key}: {gt_folder}') + # ensure that the two meta_info files are the same + with open(osp.join(input_folder, 'meta_info.txt')) as fin: + input_lmdb_keys = [line.split('.')[0] for line in fin] + with open(osp.join(gt_folder, 'meta_info.txt')) as fin: + gt_lmdb_keys = [line.split('.')[0] for line in fin] + if set(input_lmdb_keys) != set(gt_lmdb_keys): + raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.') + else: + paths = [] + for lmdb_key in sorted(input_lmdb_keys): + paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)])) + return paths + + +def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl): + """Generate paired paths from an meta information file. + + Each line in the meta information file contains the image names and + image shape (usually for gt), separated by a white space. + + Example of an meta information file: + ``` + 0001_s001.png (480,480,3) + 0001_s002.png (480,480,3) + ``` + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + meta_info_file (str): Path to the meta information file. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + with open(meta_info_file, 'r') as fin: + gt_names = [line.split(' ')[0] for line in fin] + + paths = [] + for gt_name in gt_names: + basename, ext = osp.splitext(osp.basename(gt_name)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + gt_path = osp.join(gt_folder, gt_name) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paired_paths_from_folder(folders, keys, filename_tmpl): + """Generate paired paths from folders. + + Args: + folders (list[str]): A list of folder path. The order of list should + be [input_folder, gt_folder]. + keys (list[str]): A list of keys identifying folders. The order should + be in consistent with folders, e.g., ['lq', 'gt']. + filename_tmpl (str): Template for each filename. Note that the + template excludes the file extension. Usually the filename_tmpl is + for files in the input folder. + + Returns: + list[str]: Returned path list. + """ + assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. ' + f'But got {len(folders)}') + assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}') + input_folder, gt_folder = folders + input_key, gt_key = keys + + input_paths = list(scandir(input_folder)) + gt_paths = list(scandir(gt_folder)) + assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: ' + f'{len(input_paths)}, {len(gt_paths)}.') + paths = [] + for gt_path in gt_paths: + basename, ext = osp.splitext(osp.basename(gt_path)) + input_name = f'{filename_tmpl.format(basename)}{ext}' + input_path = osp.join(input_folder, input_name) + assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.') + gt_path = osp.join(gt_folder, gt_path) + paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)])) + return paths + + +def paths_from_folder(folder): + """Generate paths from folder. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + + paths = list(scandir(folder)) + paths = [osp.join(folder, path) for path in paths] + return paths + + +def paths_from_lmdb(folder): + """Generate paths from lmdb. + + Args: + folder (str): Folder path. + + Returns: + list[str]: Returned path list. + """ + if not folder.endswith('.lmdb'): + raise ValueError(f'Folder {folder}folder should in lmdb format.') + with open(osp.join(folder, 'meta_info.txt')) as fin: + paths = [line.split('.')[0] for line in fin] + return paths + + +def generate_gaussian_kernel(kernel_size=13, sigma=1.6): + """Generate Gaussian kernel used in `duf_downsample`. + + Args: + kernel_size (int): Kernel size. Default: 13. + sigma (float): Sigma of the Gaussian kernel. Default: 1.6. + + Returns: + np.array: The Gaussian kernel. + """ + from scipy.ndimage import filters as filters + kernel = np.zeros((kernel_size, kernel_size)) + # set element at the middle to one, a dirac delta + kernel[kernel_size // 2, kernel_size // 2] = 1 + # gaussian-smooth the dirac, resulting in a gaussian filter + return filters.gaussian_filter(kernel, sigma) + + +def duf_downsample(x, kernel_size=13, scale=4): + """Downsamping with Gaussian kernel used in the DUF official code. + + Args: + x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w). + kernel_size (int): Kernel size. Default: 13. + scale (int): Downsampling factor. Supported scale: (2, 3, 4). + Default: 4. + + Returns: + Tensor: DUF downsampled frames. + """ + assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.' + + squeeze_flag = False + if x.ndim == 4: + squeeze_flag = True + x = x.unsqueeze(0) + b, t, c, h, w = x.size() + x = x.view(-1, 1, h, w) + pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2 + x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect') + + gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale) + gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0) + x = F.conv2d(x, gaussian_filter, stride=scale) + x = x[:, :, 2:-2, 2:-2] + x = x.view(b, t, c, x.size(2), x.size(3)) + if squeeze_flag: + x = x.squeeze(0) + return x diff --git a/repositories/codeformer/basicsr/data/prefetch_dataloader.py b/repositories/codeformer/basicsr/data/prefetch_dataloader.py new file mode 100644 index 000000000..508842505 --- /dev/null +++ b/repositories/codeformer/basicsr/data/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/repositories/codeformer/basicsr/data/transforms.py b/repositories/codeformer/basicsr/data/transforms.py new file mode 100644 index 000000000..aead9dc73 --- /dev/null +++ b/repositories/codeformer/basicsr/data/transforms.py @@ -0,0 +1,165 @@ +import cv2 +import random + + +def mod_crop(img, scale): + """Mod crop images, used during testing. + + Args: + img (ndarray): Input image. + scale (int): Scale factor. + + Returns: + ndarray: Result image. + """ + img = img.copy() + if img.ndim in (2, 3): + h, w = img.shape[0], img.shape[1] + h_remainder, w_remainder = h % scale, w % scale + img = img[:h - h_remainder, :w - w_remainder, ...] + else: + raise ValueError(f'Wrong img ndim: {img.ndim}.') + return img + + +def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path): + """Paired random crop. + + It crops lists of lq and gt images with corresponding locations. + + Args: + img_gts (list[ndarray] | ndarray): GT images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + img_lqs (list[ndarray] | ndarray): LQ images. Note that all images + should have the same shape. If the input is an ndarray, it will + be transformed to a list containing itself. + gt_patch_size (int): GT patch size. + scale (int): Scale factor. + gt_path (str): Path to ground-truth. + + Returns: + list[ndarray] | ndarray: GT images and LQ images. If returned results + only have one element, just return ndarray. + """ + + if not isinstance(img_gts, list): + img_gts = [img_gts] + if not isinstance(img_lqs, list): + img_lqs = [img_lqs] + + h_lq, w_lq, _ = img_lqs[0].shape + h_gt, w_gt, _ = img_gts[0].shape + lq_patch_size = gt_patch_size // scale + + if h_gt != h_lq * scale or w_gt != w_lq * scale: + raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', + f'multiplication of LQ ({h_lq}, {w_lq}).') + if h_lq < lq_patch_size or w_lq < lq_patch_size: + raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' + f'({lq_patch_size}, {lq_patch_size}). ' + f'Please remove {gt_path}.') + + # randomly choose top and left coordinates for lq patch + top = random.randint(0, h_lq - lq_patch_size) + left = random.randint(0, w_lq - lq_patch_size) + + # crop lq patch + img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] + + # crop corresponding gt patch + top_gt, left_gt = int(top * scale), int(left * scale) + img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] + if len(img_gts) == 1: + img_gts = img_gts[0] + if len(img_lqs) == 1: + img_lqs = img_lqs[0] + return img_gts, img_lqs + + +def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): + """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). + + We use vertical flip and transpose for rotation implementation. + All the images in the list use the same augmentation. + + Args: + imgs (list[ndarray] | ndarray): Images to be augmented. If the input + is an ndarray, it will be transformed to a list. + hflip (bool): Horizontal flip. Default: True. + rotation (bool): Ratotation. Default: True. + flows (list[ndarray]: Flows to be augmented. If the input is an + ndarray, it will be transformed to a list. + Dimension is (h, w, 2). Default: None. + return_status (bool): Return the status of flip and rotation. + Default: False. + + Returns: + list[ndarray] | ndarray: Augmented images and flows. If returned + results only have one element, just return ndarray. + + """ + hflip = hflip and random.random() < 0.5 + vflip = rotation and random.random() < 0.5 + rot90 = rotation and random.random() < 0.5 + + def _augment(img): + if hflip: # horizontal + cv2.flip(img, 1, img) + if vflip: # vertical + cv2.flip(img, 0, img) + if rot90: + img = img.transpose(1, 0, 2) + return img + + def _augment_flow(flow): + if hflip: # horizontal + cv2.flip(flow, 1, flow) + flow[:, :, 0] *= -1 + if vflip: # vertical + cv2.flip(flow, 0, flow) + flow[:, :, 1] *= -1 + if rot90: + flow = flow.transpose(1, 0, 2) + flow = flow[:, :, [1, 0]] + return flow + + if not isinstance(imgs, list): + imgs = [imgs] + imgs = [_augment(img) for img in imgs] + if len(imgs) == 1: + imgs = imgs[0] + + if flows is not None: + if not isinstance(flows, list): + flows = [flows] + flows = [_augment_flow(flow) for flow in flows] + if len(flows) == 1: + flows = flows[0] + return imgs, flows + else: + if return_status: + return imgs, (hflip, vflip, rot90) + else: + return imgs + + +def img_rotate(img, angle, center=None, scale=1.0): + """Rotate image. + + Args: + img (ndarray): Image to be rotated. + angle (float): Rotation angle in degrees. Positive values mean + counter-clockwise rotation. + center (tuple[int]): Rotation center. If the center is None, + initialize it as the center of the image. Default: None. + scale (float): Isotropic scale factor. Default: 1.0. + """ + (h, w) = img.shape[:2] + + if center is None: + center = (w // 2, h // 2) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + rotated_img = cv2.warpAffine(img, matrix, (w, h)) + return rotated_img diff --git a/repositories/codeformer/basicsr/losses/__init__.py b/repositories/codeformer/basicsr/losses/__init__.py new file mode 100644 index 000000000..2b184e74c --- /dev/null +++ b/repositories/codeformer/basicsr/losses/__init__.py @@ -0,0 +1,26 @@ +from copy import deepcopy + +from basicsr.utils import get_root_logger +from basicsr.utils.registry import LOSS_REGISTRY +from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize, + gradient_penalty_loss, r1_penalty) + +__all__ = [ + 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', + 'r1_penalty', 'g_path_regularize' +] + + +def build_loss(opt): + """Build loss from options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + loss_type = opt.pop('type') + loss = LOSS_REGISTRY.get(loss_type)(**opt) + logger = get_root_logger() + logger.info(f'Loss [{loss.__class__.__name__}] is created.') + return loss diff --git a/repositories/codeformer/basicsr/losses/loss_util.py b/repositories/codeformer/basicsr/losses/loss_util.py new file mode 100644 index 000000000..744eeb46d --- /dev/null +++ b/repositories/codeformer/basicsr/losses/loss_util.py @@ -0,0 +1,95 @@ +import functools +from torch.nn import functional as F + + +def reduce_loss(loss, reduction): + """Reduce loss as specified. + + Args: + loss (Tensor): Elementwise loss tensor. + reduction (str): Options are 'none', 'mean' and 'sum'. + + Returns: + Tensor: Reduced loss tensor. + """ + reduction_enum = F._Reduction.get_enum(reduction) + # none: 0, elementwise_mean:1, sum: 2 + if reduction_enum == 0: + return loss + elif reduction_enum == 1: + return loss.mean() + else: + return loss.sum() + + +def weight_reduce_loss(loss, weight=None, reduction='mean'): + """Apply element-wise weight and reduce loss. + + Args: + loss (Tensor): Element-wise loss. + weight (Tensor): Element-wise weights. Default: None. + reduction (str): Same as built-in losses of PyTorch. Options are + 'none', 'mean' and 'sum'. Default: 'mean'. + + Returns: + Tensor: Loss values. + """ + # if weight is specified, apply element-wise weight + if weight is not None: + assert weight.dim() == loss.dim() + assert weight.size(1) == 1 or weight.size(1) == loss.size(1) + loss = loss * weight + + # if weight is not specified or reduction is sum, just reduce the loss + if weight is None or reduction == 'sum': + loss = reduce_loss(loss, reduction) + # if reduction is mean, then compute mean over weight region + elif reduction == 'mean': + if weight.size(1) > 1: + weight = weight.sum() + else: + weight = weight.sum() * loss.size(1) + loss = loss.sum() / weight + + return loss + + +def weighted_loss(loss_func): + """Create a weighted version of a given loss function. + + To use this decorator, the loss function must have the signature like + `loss_func(pred, target, **kwargs)`. The function only needs to compute + element-wise loss without any reduction. This decorator will add weight + and reduction arguments to the function. The decorated function will have + the signature like `loss_func(pred, target, weight=None, reduction='mean', + **kwargs)`. + + :Example: + + >>> import torch + >>> @weighted_loss + >>> def l1_loss(pred, target): + >>> return (pred - target).abs() + + >>> pred = torch.Tensor([0, 2, 3]) + >>> target = torch.Tensor([1, 1, 1]) + >>> weight = torch.Tensor([1, 0, 1]) + + >>> l1_loss(pred, target) + tensor(1.3333) + >>> l1_loss(pred, target, weight) + tensor(1.5000) + >>> l1_loss(pred, target, reduction='none') + tensor([1., 1., 2.]) + >>> l1_loss(pred, target, weight, reduction='sum') + tensor(3.) + """ + + @functools.wraps(loss_func) + def wrapper(pred, target, weight=None, reduction='mean', **kwargs): + # get element-wise loss + loss = loss_func(pred, target, **kwargs) + loss = weight_reduce_loss(loss, weight, reduction) + return loss + + return wrapper diff --git a/repositories/codeformer/basicsr/losses/losses.py b/repositories/codeformer/basicsr/losses/losses.py new file mode 100644 index 000000000..1bcf272cf --- /dev/null +++ b/repositories/codeformer/basicsr/losses/losses.py @@ -0,0 +1,455 @@ +import math +import lpips +import torch +from torch import autograd as autograd +from torch import nn as nn +from torch.nn import functional as F + +from basicsr.archs.vgg_arch import VGGFeatureExtractor +from basicsr.utils.registry import LOSS_REGISTRY +from .loss_util import weighted_loss + +_reduction_modes = ['none', 'mean', 'sum'] + + +@weighted_loss +def l1_loss(pred, target): + return F.l1_loss(pred, target, reduction='none') + + +@weighted_loss +def mse_loss(pred, target): + return F.mse_loss(pred, target, reduction='none') + + +@weighted_loss +def charbonnier_loss(pred, target, eps=1e-12): + return torch.sqrt((pred - target)**2 + eps) + + +@LOSS_REGISTRY.register() +class L1Loss(nn.Module): + """L1 (mean absolute error, MAE) loss. + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(L1Loss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class MSELoss(nn.Module): + """MSE (L2) loss. + + Args: + loss_weight (float): Loss weight for MSE loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + """ + + def __init__(self, loss_weight=1.0, reduction='mean'): + super(MSELoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class CharbonnierLoss(nn.Module): + """Charbonnier loss (one variant of Robust L1Loss, a differentiable + variant of L1Loss). + + Described in "Deep Laplacian Pyramid Networks for Fast and Accurate + Super-Resolution". + + Args: + loss_weight (float): Loss weight for L1 loss. Default: 1.0. + reduction (str): Specifies the reduction to apply to the output. + Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. + eps (float): A value used to control the curvature near zero. + Default: 1e-12. + """ + + def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): + super(CharbonnierLoss, self).__init__() + if reduction not in ['none', 'mean', 'sum']: + raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') + + self.loss_weight = loss_weight + self.reduction = reduction + self.eps = eps + + def forward(self, pred, target, weight=None, **kwargs): + """ + Args: + pred (Tensor): of shape (N, C, H, W). Predicted tensor. + target (Tensor): of shape (N, C, H, W). Ground truth tensor. + weight (Tensor, optional): of shape (N, C, H, W). Element-wise + weights. Default: None. + """ + return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) + + +@LOSS_REGISTRY.register() +class WeightedTVLoss(L1Loss): + """Weighted TV loss. + + Args: + loss_weight (float): Loss weight. Default: 1.0. + """ + + def __init__(self, loss_weight=1.0): + super(WeightedTVLoss, self).__init__(loss_weight=loss_weight) + + def forward(self, pred, weight=None): + y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :]) + x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1]) + + loss = x_diff + y_diff + + return loss + + +@LOSS_REGISTRY.register() +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'mse': + self.criterion = torch.nn.MSELoss(reduction='mean') + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + + +@LOSS_REGISTRY.register() +class LPIPSLoss(nn.Module): + def __init__(self, + loss_weight=1.0, + use_input_norm=True, + range_norm=False,): + super(LPIPSLoss, self).__init__() + self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() + self.loss_weight = loss_weight + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, pred, target): + if self.range_norm: + pred = (pred + 1) / 2 + target = (target + 1) / 2 + if self.use_input_norm: + pred = (pred - self.mean) / self.std + target = (target - self.mean) / self.std + lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) + return self.loss_weight * lpips_loss.mean() + + +@LOSS_REGISTRY.register() +class GANLoss(nn.Module): + """Define GAN loss. + + Args: + gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. + real_label_val (float): The value for real label. Default: 1.0. + fake_label_val (float): The value for fake label. Default: 0.0. + loss_weight (float): Loss weight. Default: 1.0. + Note that loss_weight is only for generators; and it is always 1.0 + for discriminators. + """ + + def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): + super(GANLoss, self).__init__() + self.gan_type = gan_type + self.loss_weight = loss_weight + self.real_label_val = real_label_val + self.fake_label_val = fake_label_val + + if self.gan_type == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif self.gan_type == 'lsgan': + self.loss = nn.MSELoss() + elif self.gan_type == 'wgan': + self.loss = self._wgan_loss + elif self.gan_type == 'wgan_softplus': + self.loss = self._wgan_softplus_loss + elif self.gan_type == 'hinge': + self.loss = nn.ReLU() + else: + raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') + + def _wgan_loss(self, input, target): + """wgan loss. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return -input.mean() if target else input.mean() + + def _wgan_softplus_loss(self, input, target): + """wgan loss with soft plus. softplus is a smooth approximation to the + ReLU function. + + In StyleGAN2, it is called: + Logistic loss for discriminator; + Non-saturating loss for generator. + + Args: + input (Tensor): Input tensor. + target (bool): Target label. + + Returns: + Tensor: wgan loss. + """ + return F.softplus(-input).mean() if target else F.softplus(input).mean() + + def get_target_label(self, input, target_is_real): + """Get target label. + + Args: + input (Tensor): Input tensor. + target_is_real (bool): Whether the target is real or fake. + + Returns: + (bool | Tensor): Target tensor. Return bool for wgan, otherwise, + return Tensor. + """ + + if self.gan_type in ['wgan', 'wgan_softplus']: + return target_is_real + target_val = (self.real_label_val if target_is_real else self.fake_label_val) + return input.new_ones(input.size()) * target_val + + def forward(self, input, target_is_real, is_disc=False): + """ + Args: + input (Tensor): The input for the loss module, i.e., the network + prediction. + target_is_real (bool): Whether the targe is real or fake. + is_disc (bool): Whether the loss for discriminators or not. + Default: False. + + Returns: + Tensor: GAN loss value. + """ + if self.gan_type == 'hinge': + if is_disc: # for discriminators in hinge-gan + input = -input if target_is_real else input + loss = self.loss(1 + input).mean() + else: # for generators in hinge-gan + loss = -input.mean() + else: # other gan types + target_label = self.get_target_label(input, target_is_real) + loss = self.loss(input, target_label) + + # loss_weight is always 1.0 for discriminators + return loss if is_disc else loss * self.loss_weight + + +def r1_penalty(real_pred, real_img): + """R1 regularization for discriminator. The core idea is to + penalize the gradient on real data alone: when the + generator distribution produces the true data distribution + and the discriminator is equal to 0 on the data manifold, the + gradient penalty ensures that the discriminator cannot create + a non-zero gradient orthogonal to the data manifold without + suffering a loss in the GAN game. + + Ref: + Eq. 9 in Which training methods for GANs do actually converge. + """ + grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] + grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() + return grad_penalty + + +def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): + noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) + grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] + path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) + + path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) + + path_penalty = (path_lengths - path_mean).pow(2).mean() + + return path_penalty, path_lengths.detach().mean(), path_mean.detach() + + +def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): + """Calculate gradient penalty for wgan-gp. + + Args: + discriminator (nn.Module): Network for the discriminator. + real_data (Tensor): Real input data. + fake_data (Tensor): Fake input data. + weight (Tensor): Weight tensor. Default: None. + + Returns: + Tensor: A tensor for gradient penalty. + """ + + batch_size = real_data.size(0) + alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) + + # interpolate between real_data and fake_data + interpolates = alpha * real_data + (1. - alpha) * fake_data + interpolates = autograd.Variable(interpolates, requires_grad=True) + + disc_interpolates = discriminator(interpolates) + gradients = autograd.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones_like(disc_interpolates), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + if weight is not None: + gradients = gradients * weight + + gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() + if weight is not None: + gradients_penalty /= torch.mean(weight) + + return gradients_penalty diff --git a/repositories/codeformer/basicsr/metrics/__init__.py b/repositories/codeformer/basicsr/metrics/__init__.py new file mode 100644 index 000000000..19d55cc83 --- /dev/null +++ b/repositories/codeformer/basicsr/metrics/__init__.py @@ -0,0 +1,19 @@ +from copy import deepcopy + +from basicsr.utils.registry import METRIC_REGISTRY +from .psnr_ssim import calculate_psnr, calculate_ssim + +__all__ = ['calculate_psnr', 'calculate_ssim'] + + +def calculate_metric(data, opt): + """Calculate metric from data and options. + + Args: + opt (dict): Configuration. It must constain: + type (str): Model type. + """ + opt = deepcopy(opt) + metric_type = opt.pop('type') + metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) + return metric diff --git a/repositories/codeformer/basicsr/metrics/metric_util.py b/repositories/codeformer/basicsr/metrics/metric_util.py new file mode 100644 index 000000000..4d18f0f78 --- /dev/null +++ b/repositories/codeformer/basicsr/metrics/metric_util.py @@ -0,0 +1,45 @@ +import numpy as np + +from basicsr.utils.matlab_functions import bgr2ycbcr + + +def reorder_image(img, input_order='HWC'): + """Reorder images to 'HWC' order. + + If the input_order is (h, w), return (h, w, 1); + If the input_order is (c, h, w), return (h, w, c); + If the input_order is (h, w, c), return as it is. + + Args: + img (ndarray): Input image. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + If the input image shape is (h, w), input_order will not have + effects. Default: 'HWC'. + + Returns: + ndarray: reordered image. + """ + + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") + if len(img.shape) == 2: + img = img[..., None] + if input_order == 'CHW': + img = img.transpose(1, 2, 0) + return img + + +def to_y_channel(img): + """Change to Y channel of YCbCr. + + Args: + img (ndarray): Images with range [0, 255]. + + Returns: + (ndarray): Images with range [0, 255] (float type) without round. + """ + img = img.astype(np.float32) / 255. + if img.ndim == 3 and img.shape[2] == 3: + img = bgr2ycbcr(img, y_only=True) + img = img[..., None] + return img * 255. diff --git a/repositories/codeformer/basicsr/metrics/psnr_ssim.py b/repositories/codeformer/basicsr/metrics/psnr_ssim.py new file mode 100644 index 000000000..bbd950699 --- /dev/null +++ b/repositories/codeformer/basicsr/metrics/psnr_ssim.py @@ -0,0 +1,128 @@ +import cv2 +import numpy as np + +from basicsr.metrics.metric_util import reorder_image, to_y_channel +from basicsr.utils.registry import METRIC_REGISTRY + + +@METRIC_REGISTRY.register() +def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the PSNR calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def _ssim(img1, img2): + """Calculate SSIM (structural similarity) for one channel images. + + It is called by func:`calculate_ssim`. + + Args: + img1 (ndarray): Images with range [0, 255] with order 'HWC'. + img2 (ndarray): Images with range [0, 255] with order 'HWC'. + + Returns: + float: ssim result. + """ + + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +@METRIC_REGISTRY.register() +def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False): + """Calculate SSIM (structural similarity). + + Ref: + Image quality assessment: From error visibility to structural similarity + + The results are the same as that of the official released MATLAB code in + https://ece.uwaterloo.ca/~z70wang/research/ssim/. + + For three-channel images, SSIM is calculated for each channel and then + averaged. + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edge of an image. These + pixels are not involved in the SSIM calculation. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + test_y_channel (bool): Test on Y channel of YCbCr. Default: False. + + Returns: + float: ssim result. + """ + + assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + if input_order not in ['HWC', 'CHW']: + raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') + img1 = reorder_image(img1, input_order=input_order) + img2 = reorder_image(img2, input_order=input_order) + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + if crop_border != 0: + img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] + img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] + + if test_y_channel: + img1 = to_y_channel(img1) + img2 = to_y_channel(img2) + + ssims = [] + for i in range(img1.shape[2]): + ssims.append(_ssim(img1[..., i], img2[..., i])) + return np.array(ssims).mean() diff --git a/repositories/codeformer/basicsr/models/__init__.py b/repositories/codeformer/basicsr/models/__init__.py new file mode 100644 index 000000000..00bde45f0 --- /dev/null +++ b/repositories/codeformer/basicsr/models/__init__.py @@ -0,0 +1,30 @@ +import importlib +from copy import deepcopy +from os import path as osp + +from basicsr.utils import get_root_logger, scandir +from basicsr.utils.registry import MODEL_REGISTRY + +__all__ = ['build_model'] + +# automatically scan and import model modules for registry +# scan all the files under the 'models' folder and collect files ending with +# '_model.py' +model_folder = osp.dirname(osp.abspath(__file__)) +model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] +# import all the model modules +_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] + + +def build_model(opt): + """Build model from options. + + Args: + opt (dict): Configuration. It must constain: + model_type (str): Model type. + """ + opt = deepcopy(opt) + model = MODEL_REGISTRY.get(opt['model_type'])(opt) + logger = get_root_logger() + logger.info(f'Model [{model.__class__.__name__}] is created.') + return model diff --git a/repositories/.placeholder b/repositories/codeformer/basicsr/ops/__init__.py similarity index 100% rename from repositories/.placeholder rename to repositories/codeformer/basicsr/ops/__init__.py diff --git a/repositories/codeformer/basicsr/ops/dcn/__init__.py b/repositories/codeformer/basicsr/ops/dcn/__init__.py new file mode 100644 index 000000000..32e3592f8 --- /dev/null +++ b/repositories/codeformer/basicsr/ops/dcn/__init__.py @@ -0,0 +1,7 @@ +from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, + modulated_deform_conv) + +__all__ = [ + 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', + 'modulated_deform_conv' +] diff --git a/repositories/codeformer/basicsr/ops/dcn/deform_conv.py b/repositories/codeformer/basicsr/ops/dcn/deform_conv.py new file mode 100644 index 000000000..734154f9e --- /dev/null +++ b/repositories/codeformer/basicsr/ops/dcn/deform_conv.py @@ -0,0 +1,377 @@ +import math +import torch +from torch import nn as nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn import functional as F +from torch.nn.modules.utils import _pair, _single + +try: + from . import deform_conv_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + deform_conv_ext = load( + 'deform_conv', + sources=[ + os.path.join(module_path, 'src', 'deform_conv_ext.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'), + os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'), + ], + ) + + +class DeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64): + if input is not None and input.dim() != 4: + raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.') + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + deform_conv_ext.deform_conv_forward(input, weight, + offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError + else: + cur_im2col_step = min(ctx.im2col_step, input.shape[0]) + assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize' + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input, + grad_offset, weight, ctx.bufs_[0], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1], + ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups, + ctx.deformable_groups, cur_im2col_step) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight, + ctx.bufs_[0], ctx.bufs_[1], weight.size(3), + weight.size(2), ctx.stride[1], ctx.stride[0], + ctx.padding[1], ctx.padding[0], ctx.dilation[1], + ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1, + cur_im2col_step) + + return (grad_input, grad_offset, grad_weight, None, None, None, None, None) + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})') + return output_size + + +class ModulatedDeformConvFunction(Function): + + @staticmethod + def forward(ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad \ + or input.requires_grad: + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output, + ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1], + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, + grad_output, weight.shape[2], weight.shape[3], ctx.stride, + ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation, + ctx.groups, ctx.deformable_groups, ctx.with_bias) + if not ctx.with_bias: + grad_bias = None + + return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1 + width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = DeformConvFunction.apply +modulated_deform_conv = ModulatedDeformConvFunction.apply + + +class DeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False): + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, \ + f'in_channels {in_channels} is not divisible by groups {groups}' + assert out_channels % groups == 0, \ + f'out_channels {out_channels} is not divisible ' \ + f'by groups {groups}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)) + + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + + def forward(self, x, offset): + # To fix an assert error in deform_conv_cuda.cpp:128 + # input image is smaller than kernel + input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]) + if input_pad: + pad_h = max(self.kernel_size[0] - x.size(2), 0) + pad_w = max(self.kernel_size[1] - x.size(3), 0) + x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous() + out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + if input_pad: + out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous() + return out + + +class DeformConvPack(DeformConv): + """A Deformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(DeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + offset = self.conv_offset(x) + return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, + self.deformable_groups) + + +class ModulatedDeformConv(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True): + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x, offset, mask): + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) + + +class ModulatedDeformConvPack(ModulatedDeformConv): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int or tuple[int]): Same as nn.Conv2d. + padding (int or tuple[int]): Same as nn.Conv2d. + dilation (int or tuple[int]): Same as nn.Conv2d. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super(ModulatedDeformConvPack, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=_pair(self.stride), + padding=_pair(self.padding), + dilation=_pair(self.dilation), + bias=True) + self.init_weights() + + def init_weights(self): + super(ModulatedDeformConvPack, self).init_weights() + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x): + out = self.conv_offset(x) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, + self.groups, self.deformable_groups) diff --git a/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp new file mode 100644 index 000000000..5d9424908 --- /dev/null +++ b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp @@ -0,0 +1,685 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor data_col); + +void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset, + const int channels, const int height, const int width, + const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const int channels, const int height, + const int width, const int ksize_h, const int ksize_w, const int pad_h, + const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, + const at::Tensor data_mask, const int batch_size, const int channels, + const int height_im, const int width_im, const int height_col, + const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, + const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, + const int width_im, const int height_col, const int width_col, + const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, const int dilation_h, + const int dilation_w, const int deformable_group, at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput, + at::Tensor weight, int kH, int kW, int dH, int dW, int padH, + int padW, int dilationH, int dilationW, int group, + int deformable_group) { + TORCH_CHECK(weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK(kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", kH, + kW); + + TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH, + kW, weight.size(2), weight.size(3)); + + TORCH_CHECK(dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", dH, dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK(nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight, + outputWidth); + + TORCH_CHECK(input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, input.size(1)); + + TORCH_CHECK((inputHeight >= kH && inputWidth >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, outputWidth, offset.size(2), offset.size(3)); + + TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, gradOutput->size(dimf)); + + TORCH_CHECK((gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, outputWidth, gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane, + outputHeight, outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + at::Tensor output_buffer = + at::zeros({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), group, output_buffer.size(1) / group, + output_buffer.size(2), output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), output_buffer.size(4)}); + + output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step, outputHeight, outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW, + dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, + outputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), group, gradOutput.size(1) / group, + gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)}); + + deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane, + inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, + dilationH, dilationW, im2col_step, deformable_group, + gradOffset[elt]); + + deformable_col2im(columns, offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH, + padW, dilationH, dilationW, group, deformable_group); + at::DeviceGuard guard(input.device()); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step, + nOutputPlane, outputHeight, outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step, + outputHeight, outputWidth}); + gradOutputBuffer.copy_(gradOutput); + gradOutputBuffer = + gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, + im2col_step * outputHeight, outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize / im2col_step, im2col_step, nInputPlane, + inputHeight, inputWidth}); + offset = + offset.view({batchSize / im2col_step, im2col_step, + deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight, + inputWidth, kH, kW, padH, padW, dH, dW, dilationH, + dilationW, im2col_step, deformable_group, columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = + gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_(gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), 1.0, scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = + at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view({output.size(0), group, output.size(1) / group, + output.size(2), output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + // divide into group + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view({output.size(0), output.size(1) * output.size(2), + output.size(3), output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = + grad_output.view({grad_output.size(0), group, grad_output.size(1) / group, + grad_output.size(2), grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view({group, weight.size(0) / group, weight.size(1), + weight.size(2), weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_(weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), 0.0f, 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view({weight.size(0) * weight.size(1), weight.size(2), + weight.size(3), weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, input[b], offset[b], mask[b], 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, + stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], offset[b], mask[b], 1, channels, height, width, height_out, + width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view({group, grad_weight.size(0) / group, + grad_weight.size(1), grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1), + grad_output.size(2), grad_output.size(3), + grad_output.size(4)}); +} diff --git a/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu new file mode 100644 index 000000000..98752dccf --- /dev/null +++ b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu @@ -0,0 +1,867 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) +{ + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +template +__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const scalar_t map_h = i * dilation_h + offset_h; + //const scalar_t map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + +void deformable_im2col( + const at::Tensor data_im, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, const int ksize_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, const int parallel_imgs, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + channel_per_deformable_group, parallel_imgs, channels, deformable_group, + height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, const scalar_t *data_col, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +void deformable_col2im( + const at::Tensor data_col, const at::Tensor data_offset, const int channels, + const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, + at::Tensor grad_im) +{ + + // todo: make sure parallel_imgs is passed in correctly + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, channels, height, width, ksize_h, + ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + +template +__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col, + const scalar_t *data_im, const scalar_t *data_offset, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, scalar_t *grad_offset) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * + batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * + kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + +void deformable_col2im_coord( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, + const int channels, const int height, const int width, const int ksize_h, + const int ksize_w, const int pad_h, const int pad_w, const int stride_h, + const int stride_w, const int dilation_h, const int dilation_w, + const int parallel_imgs, const int deformable_group, at::Tensor grad_offset) +{ + + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs; + int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, channels, height, width, + ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group, + height_col, width_col, grad_offset_); + })); +} + +template +__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width, + const int height, const int width, scalar_t h, scalar_t w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w, + const int height, const int width, const scalar_t *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const scalar_t *data_col, const scalar_t *data_im, + const scalar_t *data_offset, const scalar_t *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + scalar_t *grad_offset, scalar_t *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, at::Tensor grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, at::Tensor grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t *data_col_ = data_col.data_ptr(); + const scalar_t *data_im_ = data_im.data_ptr(); + const scalar_t *data_offset_ = data_offset.data_ptr(); + const scalar_t *data_mask_ = data_mask.data_ptr(); + scalar_t *grad_offset_ = grad_offset.data_ptr(); + scalar_t *grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset_, grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} diff --git a/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp new file mode 100644 index 000000000..41c6df6f7 --- /dev/null +++ b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp @@ -0,0 +1,164 @@ +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c + +#include +#include + +#include +#include + +#define WITH_CUDA // always use cuda +#ifdef WITH_CUDA +int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias); +#endif + +int deform_conv_forward(at::Tensor input, at::Tensor weight, + at::Tensor offset, at::Tensor output, + at::Tensor columns, at::Tensor ones, int kW, + int kH, int dW, int dH, int padW, int padH, + int dilationW, int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_forward_cuda(input, weight, offset, output, columns, + ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, + deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_input(at::Tensor input, at::Tensor offset, + at::Tensor gradOutput, at::Tensor gradInput, + at::Tensor gradOffset, at::Tensor weight, + at::Tensor columns, int kW, int kH, int dW, + int dH, int padW, int padH, int dilationW, + int dilationH, int group, + int deformable_group, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_input_cuda(input, offset, gradOutput, + gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, + dilationW, dilationH, group, deformable_group, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +int deform_conv_backward_parameters( + at::Tensor input, at::Tensor offset, at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, + int padW, int padH, int dilationW, int dilationH, int group, + int deformable_group, float scale, int im2col_step) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return deform_conv_backward_parameters_cuda(input, offset, gradOutput, + gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, + dilationH, group, deformable_group, scale, im2col_step); +#else + AT_ERROR("deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_forward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, + int kernel_h, int kernel_w, const int stride_h, const int stride_w, + const int pad_h, const int pad_w, const int dilation_h, + const int dilation_w, const int group, const int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_forward(input, weight, bias, ones, + offset, mask, output, columns, kernel_h, kernel_w, stride_h, + stride_w, pad_h, pad_w, dilation_h, dilation_w, group, + deformable_group, with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + +void modulated_deform_conv_backward( + at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, + at::Tensor offset, at::Tensor mask, at::Tensor columns, + at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, + at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, + int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, + int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, + const bool with_bias) { + if (input.device().is_cuda()) { +#ifdef WITH_CUDA + return modulated_deform_conv_cuda_backward(input, weight, bias, ones, + offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, + grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, + pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, + with_bias); +#else + AT_ERROR("modulated deform conv is not compiled with GPU support"); +#endif + } + AT_ERROR("modulated deform conv is not implemented on CPU"); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("deform_conv_forward", &deform_conv_forward, + "deform forward"); + m.def("deform_conv_backward_input", &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def("deform_conv_backward_parameters", + &deform_conv_backward_parameters, + "deform_conv_backward_parameters"); + m.def("modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated deform conv forward"); + m.def("modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated deform conv backward"); +} diff --git a/repositories/codeformer/basicsr/ops/fused_act/__init__.py b/repositories/codeformer/basicsr/ops/fused_act/__init__.py new file mode 100644 index 000000000..241dc0754 --- /dev/null +++ b/repositories/codeformer/basicsr/ops/fused_act/__init__.py @@ -0,0 +1,3 @@ +from .fused_act import FusedLeakyReLU, fused_leaky_relu + +__all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] diff --git a/repositories/codeformer/basicsr/ops/fused_act/fused_act.py b/repositories/codeformer/basicsr/ops/fused_act/fused_act.py new file mode 100644 index 000000000..588f815e5 --- /dev/null +++ b/repositories/codeformer/basicsr/ops/fused_act/fused_act.py @@ -0,0 +1,89 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 + +import torch +from torch import nn +from torch.autograd import Function + +try: + from . import fused_act_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + fused_act_ext = load( + 'fused', + sources=[ + os.path.join(module_path, 'src', 'fused_bias_act.cpp'), + os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), + ], + ) + + +class FusedLeakyReLUFunctionBackward(Function): + + @staticmethod + def forward(ctx, grad_output, out, negative_slope, scale): + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + empty = grad_output.new_empty(0) + + grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) + + dim = [0] + + if grad_input.ndim > 2: + dim += list(range(2, grad_input.ndim)) + + grad_bias = grad_input.sum(dim).detach() + + return grad_input, grad_bias + + @staticmethod + def backward(ctx, gradgrad_input, gradgrad_bias): + out, = ctx.saved_tensors + gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, + ctx.scale) + + return gradgrad_out, None, None, None + + +class FusedLeakyReLUFunction(Function): + + @staticmethod + def forward(ctx, input, bias, negative_slope, scale): + empty = input.new_empty(0) + out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) + ctx.save_for_backward(out) + ctx.negative_slope = negative_slope + ctx.scale = scale + + return out + + @staticmethod + def backward(ctx, grad_output): + out, = ctx.saved_tensors + + grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) + + return grad_input, grad_bias, None, None + + +class FusedLeakyReLU(nn.Module): + + def __init__(self, channel, negative_slope=0.2, scale=2**0.5): + super().__init__() + + self.bias = nn.Parameter(torch.zeros(channel)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp new file mode 100644 index 000000000..85ed0a79f --- /dev/null +++ b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp @@ -0,0 +1,26 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp +#include + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor fused_bias_act(const torch::Tensor& input, + const torch::Tensor& bias, + const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + CHECK_CUDA(input); + CHECK_CUDA(bias); + + return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); +} diff --git a/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu new file mode 100644 index 000000000..54c7ff53c --- /dev/null +++ b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu @@ -0,0 +1,100 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + + +template +static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, + int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { + int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; + + scalar_t zero = 0.0; + + for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { + scalar_t x = p_x[xi]; + + if (use_bias) { + x += p_b[(xi / step_b) % size_b]; + } + + scalar_t ref = use_ref ? p_ref[xi] : zero; + + scalar_t y; + + switch (act * 10 + grad) { + default: + case 10: y = x; break; + case 11: y = x; break; + case 12: y = 0.0; break; + + case 30: y = (x > 0.0) ? x : x * alpha; break; + case 31: y = (ref > 0.0) ? x : x * alpha; break; + case 32: y = 0.0; break; + } + + out[xi] = y * scale; + } +} + + +torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, + int act, int grad, float alpha, float scale) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + auto x = input.contiguous(); + auto b = bias.contiguous(); + auto ref = refer.contiguous(); + + int use_bias = b.numel() ? 1 : 0; + int use_ref = ref.numel() ? 1 : 0; + + int size_x = x.numel(); + int size_b = b.numel(); + int step_b = 1; + + for (int i = 1 + 1; i < x.dim(); i++) { + step_b *= x.size(i); + } + + int loop_x = 4; + int block_size = 4 * 32; + int grid_size = (size_x - 1) / (loop_x * block_size) + 1; + + auto y = torch::empty_like(x); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { + fused_bias_act_kernel<<>>( + y.data_ptr(), + x.data_ptr(), + b.data_ptr(), + ref.data_ptr(), + act, + grad, + alpha, + scale, + loop_x, + size_x, + step_b, + size_b, + use_bias, + use_ref + ); + }); + + return y; +} diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/__init__.py b/repositories/codeformer/basicsr/ops/upfirdn2d/__init__.py new file mode 100644 index 000000000..397e85bea --- /dev/null +++ b/repositories/codeformer/basicsr/ops/upfirdn2d/__init__.py @@ -0,0 +1,3 @@ +from .upfirdn2d import upfirdn2d + +__all__ = ['upfirdn2d'] diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp new file mode 100644 index 000000000..43d0b6783 --- /dev/null +++ b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp @@ -0,0 +1,24 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp +#include + + +torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1); + +#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, + int up_x, int up_y, int down_x, int down_y, + int pad_x0, int pad_x1, int pad_y0, int pad_y1) { + CHECK_CUDA(input); + CHECK_CUDA(kernel); + + return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); +} diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu new file mode 100644 index 000000000..8870063ba --- /dev/null +++ b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu @@ -0,0 +1,370 @@ +// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu +// Copyright (c) 2019, NVIDIA Corporation. All rights reserved. +// +// This work is made available under the Nvidia Source Code License-NC. +// To view a copy of this license, visit +// https://nvlabs.github.io/stylegan2/license.html + +#include + +#include +#include +#include +#include + +#include +#include + +static __host__ __device__ __forceinline__ int floor_div(int a, int b) { + int c = a / b; + + if (c * b > a) { + c--; + } + + return c; +} + +struct UpFirDn2DKernelParams { + int up_x; + int up_y; + int down_x; + int down_y; + int pad_x0; + int pad_x1; + int pad_y0; + int pad_y1; + + int major_dim; + int in_h; + int in_w; + int minor_dim; + int kernel_h; + int kernel_w; + int out_h; + int out_w; + int loop_major; + int loop_x; +}; + +template +__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; + int out_y = minor_idx / p.minor_dim; + minor_idx -= out_y * p.minor_dim; + int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; + int major_idx_base = blockIdx.z * p.loop_major; + + if (out_x_base >= p.out_w || out_y >= p.out_h || + major_idx_base >= p.major_dim) { + return; + } + + int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; + int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); + int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; + int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major && major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, out_x = out_x_base; + loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { + int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; + int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); + int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; + int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; + + const scalar_t *x_p = + &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + + minor_idx]; + const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; + int x_px = p.minor_dim; + int k_px = -p.up_x; + int x_py = p.in_w * p.minor_dim; + int k_py = -p.up_y * p.kernel_w; + + scalar_t v = 0.0f; + + for (int y = 0; y < h; y++) { + for (int x = 0; x < w; x++) { + v += static_cast(*x_p) * static_cast(*k_p); + x_p += x_px; + k_p += k_px; + } + + x_p += x_py - w * x_px; + k_p += k_py - w * k_px; + } + + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } +} + +template +__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, + const scalar_t *kernel, + const UpFirDn2DKernelParams p) { + const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; + const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; + + __shared__ volatile float sk[kernel_h][kernel_w]; + __shared__ volatile float sx[tile_in_h][tile_in_w]; + + int minor_idx = blockIdx.x; + int tile_out_y = minor_idx / p.minor_dim; + minor_idx -= tile_out_y * p.minor_dim; + tile_out_y *= tile_out_h; + int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; + int major_idx_base = blockIdx.z * p.loop_major; + + if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | + major_idx_base >= p.major_dim) { + return; + } + + for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; + tap_idx += blockDim.x) { + int ky = tap_idx / kernel_w; + int kx = tap_idx - ky * kernel_w; + scalar_t v = 0.0; + + if (kx < p.kernel_w & ky < p.kernel_h) { + v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; + } + + sk[ky][kx] = v; + } + + for (int loop_major = 0, major_idx = major_idx_base; + loop_major < p.loop_major & major_idx < p.major_dim; + loop_major++, major_idx++) { + for (int loop_x = 0, tile_out_x = tile_out_x_base; + loop_x < p.loop_x & tile_out_x < p.out_w; + loop_x++, tile_out_x += tile_out_w) { + int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; + int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; + int tile_in_x = floor_div(tile_mid_x, up_x); + int tile_in_y = floor_div(tile_mid_y, up_y); + + __syncthreads(); + + for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; + in_idx += blockDim.x) { + int rel_in_y = in_idx / tile_in_w; + int rel_in_x = in_idx - rel_in_y * tile_in_w; + int in_x = rel_in_x + tile_in_x; + int in_y = rel_in_y + tile_in_y; + + scalar_t v = 0.0; + + if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { + v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * + p.minor_dim + + minor_idx]; + } + + sx[rel_in_y][rel_in_x] = v; + } + + __syncthreads(); + for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; + out_idx += blockDim.x) { + int rel_out_y = out_idx / tile_out_w; + int rel_out_x = out_idx - rel_out_y * tile_out_w; + int out_x = rel_out_x + tile_out_x; + int out_y = rel_out_y + tile_out_y; + + int mid_x = tile_mid_x + rel_out_x * down_x; + int mid_y = tile_mid_y + rel_out_y * down_y; + int in_x = floor_div(mid_x, up_x); + int in_y = floor_div(mid_y, up_y); + int rel_in_x = in_x - tile_in_x; + int rel_in_y = in_y - tile_in_y; + int kernel_x = (in_x + 1) * up_x - mid_x - 1; + int kernel_y = (in_y + 1) * up_y - mid_y - 1; + + scalar_t v = 0.0; + +#pragma unroll + for (int y = 0; y < kernel_h / up_y; y++) +#pragma unroll + for (int x = 0; x < kernel_w / up_x; x++) + v += sx[rel_in_y + y][rel_in_x + x] * + sk[kernel_y + y * up_y][kernel_x + x * up_x]; + + if (out_x < p.out_w & out_y < p.out_h) { + out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + + minor_idx] = v; + } + } + } + } +} + +torch::Tensor upfirdn2d_op(const torch::Tensor &input, + const torch::Tensor &kernel, int up_x, int up_y, + int down_x, int down_y, int pad_x0, int pad_x1, + int pad_y0, int pad_y1) { + int curDevice = -1; + cudaGetDevice(&curDevice); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); + + UpFirDn2DKernelParams p; + + auto x = input.contiguous(); + auto k = kernel.contiguous(); + + p.major_dim = x.size(0); + p.in_h = x.size(1); + p.in_w = x.size(2); + p.minor_dim = x.size(3); + p.kernel_h = k.size(0); + p.kernel_w = k.size(1); + p.up_x = up_x; + p.up_y = up_y; + p.down_x = down_x; + p.down_y = down_y; + p.pad_x0 = pad_x0; + p.pad_x1 = pad_x1; + p.pad_y0 = pad_y0; + p.pad_y1 = pad_y1; + + p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / + p.down_y; + p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / + p.down_x; + + auto out = + at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); + + int mode = -1; + + int tile_out_h = -1; + int tile_out_w = -1; + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 1; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 3 && p.kernel_w <= 3) { + mode = 2; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 3; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 4; + tile_out_h = 16; + tile_out_w = 64; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 4 && p.kernel_w <= 4) { + mode = 5; + tile_out_h = 8; + tile_out_w = 32; + } + + if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && + p.kernel_h <= 2 && p.kernel_w <= 2) { + mode = 6; + tile_out_h = 8; + tile_out_w = 32; + } + + dim3 block_size; + dim3 grid_size; + + if (tile_out_h > 0 && tile_out_w > 0) { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 1; + block_size = dim3(32 * 8, 1, 1); + grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, + (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } else { + p.loop_major = (p.major_dim - 1) / 16384 + 1; + p.loop_x = 4; + block_size = dim3(4, 32, 1); + grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, + (p.out_w - 1) / (p.loop_x * block_size.y) + 1, + (p.major_dim - 1) / p.loop_major + 1); + } + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { + switch (mode) { + case 1: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 2: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 3: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 4: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 5: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + case 6: + upfirdn2d_kernel + <<>>(out.data_ptr(), + x.data_ptr(), + k.data_ptr(), p); + + break; + + default: + upfirdn2d_kernel_large<<>>( + out.data_ptr(), x.data_ptr(), + k.data_ptr(), p); + } + }); + + return out; +} diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py b/repositories/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py new file mode 100644 index 000000000..667f96e1d --- /dev/null +++ b/repositories/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py @@ -0,0 +1,186 @@ +# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 + +import torch +from torch.autograd import Function +from torch.nn import functional as F + +try: + from . import upfirdn2d_ext +except ImportError: + import os + BASICSR_JIT = os.getenv('BASICSR_JIT') + if BASICSR_JIT == 'True': + from torch.utils.cpp_extension import load + module_path = os.path.dirname(__file__) + upfirdn2d_ext = load( + 'upfirdn2d', + sources=[ + os.path.join(module_path, 'src', 'upfirdn2d.cpp'), + os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), + ], + ) + + +class UpFirDn2dBackward(Function): + + @staticmethod + def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): + + up_x, up_y = up + down_x, down_y = down + g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad + + grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) + + grad_input = upfirdn2d_ext.upfirdn2d( + grad_output, + grad_kernel, + down_x, + down_y, + up_x, + up_y, + g_pad_x0, + g_pad_x1, + g_pad_y0, + g_pad_y1, + ) + grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) + + ctx.save_for_backward(kernel) + + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + ctx.up_x = up_x + ctx.up_y = up_y + ctx.down_x = down_x + ctx.down_y = down_y + ctx.pad_x0 = pad_x0 + ctx.pad_x1 = pad_x1 + ctx.pad_y0 = pad_y0 + ctx.pad_y1 = pad_y1 + ctx.in_size = in_size + ctx.out_size = out_size + + return grad_input + + @staticmethod + def backward(ctx, gradgrad_input): + kernel, = ctx.saved_tensors + + gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) + + gradgrad_out = upfirdn2d_ext.upfirdn2d( + gradgrad_input, + kernel, + ctx.up_x, + ctx.up_y, + ctx.down_x, + ctx.down_y, + ctx.pad_x0, + ctx.pad_x1, + ctx.pad_y0, + ctx.pad_y1, + ) + # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], + # ctx.out_size[1], ctx.in_size[3]) + gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) + + return gradgrad_out, None, None, None, None, None, None, None, None + + +class UpFirDn2d(Function): + + @staticmethod + def forward(ctx, input, kernel, up, down, pad): + up_x, up_y = up + down_x, down_y = down + pad_x0, pad_x1, pad_y0, pad_y1 = pad + + kernel_h, kernel_w = kernel.shape + batch, channel, in_h, in_w = input.shape + ctx.in_size = input.shape + + input = input.reshape(-1, in_h, in_w, 1) + + ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + ctx.out_size = (out_h, out_w) + + ctx.up = (up_x, up_y) + ctx.down = (down_x, down_y) + ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) + + g_pad_x0 = kernel_w - pad_x0 - 1 + g_pad_y0 = kernel_h - pad_y0 - 1 + g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 + g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 + + ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) + + out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) + # out = out.view(major, out_h, out_w, minor) + out = out.view(-1, channel, out_h, out_w) + + return out + + @staticmethod + def backward(ctx, grad_output): + kernel, grad_kernel = ctx.saved_tensors + + grad_input = UpFirDn2dBackward.apply( + grad_output, + kernel, + grad_kernel, + ctx.up, + ctx.down, + ctx.pad, + ctx.g_pad, + ctx.in_size, + ctx.out_size, + ) + + return grad_input, None, None, None, None + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + if input.device.type == 'cpu': + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + else: + out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) + + return out + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + + _, in_h, in_w, minor = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, in_h, 1, in_w, 1, minor) + out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) + out = out.view(-1, in_h * up_y, in_w * up_x, minor) + + out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] + + out = out.permute(0, 3, 1, 2) + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape( + -1, + minor, + in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, + ) + out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + + return out.view(-1, channel, out_h, out_w) diff --git a/repositories/codeformer/basicsr/setup.py b/repositories/codeformer/basicsr/setup.py new file mode 100644 index 000000000..382a2aa10 --- /dev/null +++ b/repositories/codeformer/basicsr/setup.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python + +from setuptools import find_packages, setup + +import os +import subprocess +import sys +import time +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension + +version_file = './basicsr/version.py' + + +def readme(): + with open('README.md', encoding='utf-8') as f: + content = f.read() + return content + + +def get_git_hash(): + + def _minimal_ext_cmd(cmd): + # construct minimal environment + env = {} + for k in ['SYSTEMROOT', 'PATH', 'HOME']: + v = os.environ.get(k) + if v is not None: + env[k] = v + # LANGUAGE is used on win32 + env['LANGUAGE'] = 'C' + env['LANG'] = 'C' + env['LC_ALL'] = 'C' + out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + return out + + try: + out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) + sha = out.strip().decode('ascii') + except OSError: + sha = 'unknown' + + return sha + + +def get_hash(): + if os.path.exists('.git'): + sha = get_git_hash()[:7] + elif os.path.exists(version_file): + try: + from version import __version__ + sha = __version__.split('+')[-1] + except ImportError: + raise ImportError('Unable to get git version') + else: + sha = 'unknown' + + return sha + + +def write_version_py(): + content = """# GENERATED VERSION FILE +# TIME: {} +__version__ = '{}' +__gitsha__ = '{}' +version_info = ({}) +""" + sha = get_hash() + with open('./basicsr/VERSION', 'r') as f: + SHORT_VERSION = f.read().strip() + VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) + + version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) + with open(version_file, 'w') as f: + f.write(version_file_str) + + +def get_version(): + with open(version_file, 'r') as f: + exec(compile(f.read(), version_file, 'exec')) + return locals()['__version__'] + + +def make_cuda_ext(name, module, sources, sources_cuda=None): + if sources_cuda is None: + sources_cuda = [] + define_macros = [] + extra_compile_args = {'cxx': []} + + if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': + define_macros += [('WITH_CUDA', None)] + extension = CUDAExtension + extra_compile_args['nvcc'] = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + sources += sources_cuda + else: + print(f'Compiling {name} without CUDA') + extension = CppExtension + + return extension( + name=f'{module}.{name}', + sources=[os.path.join(*module.split('.'), p) for p in sources], + define_macros=define_macros, + extra_compile_args=extra_compile_args) + + +def get_requirements(filename='requirements.txt'): + with open(os.path.join('.', filename), 'r') as f: + requires = [line.replace('\n', '') for line in f.readlines()] + return requires + + +if __name__ == '__main__': + if '--cuda_ext' in sys.argv: + ext_modules = [ + make_cuda_ext( + name='deform_conv_ext', + module='ops.dcn', + sources=['src/deform_conv_ext.cpp'], + sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), + make_cuda_ext( + name='fused_act_ext', + module='ops.fused_act', + sources=['src/fused_bias_act.cpp'], + sources_cuda=['src/fused_bias_act_kernel.cu']), + make_cuda_ext( + name='upfirdn2d_ext', + module='ops.upfirdn2d', + sources=['src/upfirdn2d.cpp'], + sources_cuda=['src/upfirdn2d_kernel.cu']), + ] + sys.argv.remove('--cuda_ext') + else: + ext_modules = [] + + write_version_py() + setup( + name='basicsr', + version=get_version(), + description='Open Source Image and Video Super-Resolution Toolbox', + long_description=readme(), + long_description_content_type='text/markdown', + author='Xintao Wang', + author_email='xintao.wang@outlook.com', + keywords='computer vision, restoration, super resolution', + url='https://github.com/xinntao/BasicSR', + include_package_data=True, + packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), + classifiers=[ + 'Development Status :: 4 - Beta', + 'License :: OSI Approved :: Apache Software License', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + ], + license='Apache License 2.0', + setup_requires=['cython', 'numpy'], + install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass={'build_ext': BuildExtension}, + zip_safe=False) diff --git a/repositories/codeformer/basicsr/train.py b/repositories/codeformer/basicsr/train.py new file mode 100644 index 000000000..a01c0dfcc --- /dev/null +++ b/repositories/codeformer/basicsr/train.py @@ -0,0 +1,225 @@ +import argparse +import datetime +import logging +import math +import copy +import random +import time +import torch +from os import path as osp + +from basicsr.data import build_dataloader, build_dataset +from basicsr.data.data_sampler import EnlargedSampler +from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher +from basicsr.models import build_model +from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger, + init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed) +from basicsr.utils.dist_util import get_dist_info, init_dist +from basicsr.utils.options import dict2str, parse + +import warnings +# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. +warnings.filterwarnings("ignore", category=UserWarning) + +def parse_options(root_path, is_train=True): + parser = argparse.ArgumentParser() + parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') + parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + opt = parse(args.opt, root_path, is_train=is_train) + + # distributed settings + if args.launcher == 'none': + opt['dist'] = False + print('Disable distributed.', flush=True) + else: + opt['dist'] = True + if args.launcher == 'slurm' and 'dist_params' in opt: + init_dist(args.launcher, **opt['dist_params']) + else: + init_dist(args.launcher) + + opt['rank'], opt['world_size'] = get_dist_info() + + # random seed + seed = opt.get('manual_seed') + if seed is None: + seed = random.randint(1, 10000) + opt['manual_seed'] = seed + set_random_seed(seed + opt['rank']) + + return opt + + +def init_loggers(opt): + log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log") + logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) + logger.info(get_env_info()) + logger.info(dict2str(opt)) + + # initialize wandb logger before tensorboard logger to allow proper sync: + if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None): + assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb') + init_wandb_logger(opt) + tb_logger = None + if opt['logger'].get('use_tb_logger'): + tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) + return logger, tb_logger + + +def create_train_val_dataloader(opt, logger): + # create train and val dataloaders + train_loader, val_loader = None, None + for phase, dataset_opt in opt['datasets'].items(): + if phase == 'train': + dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) + train_set = build_dataset(dataset_opt) + train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio) + train_loader = build_dataloader( + train_set, + dataset_opt, + num_gpu=opt['num_gpu'], + dist=opt['dist'], + sampler=train_sampler, + seed=opt['manual_seed']) + + num_iter_per_epoch = math.ceil( + len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) + total_iters = int(opt['train']['total_iter']) + total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) + logger.info('Training statistics:' + f'\n\tNumber of train images: {len(train_set)}' + f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' + f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' + f'\n\tWorld size (gpu number): {opt["world_size"]}' + f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' + f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') + + elif phase == 'val': + val_set = build_dataset(dataset_opt) + val_loader = build_dataloader( + val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) + logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}') + else: + raise ValueError(f'Dataset phase {phase} is not recognized.') + + return train_loader, train_sampler, val_loader, total_epochs, total_iters + + +def train_pipeline(root_path): + # parse options, set distributed setting, set ramdom seed + opt = parse_options(root_path, is_train=True) + + torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.deterministic = True + + # load resume states if necessary + if opt['path'].get('resume_state'): + device_id = torch.cuda.current_device() + resume_state = torch.load( + opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id)) + else: + resume_state = None + + # mkdir for experiments and logger + if resume_state is None: + make_exp_dirs(opt) + if opt['logger'].get('use_tb_logger') and opt['rank'] == 0: + mkdir_and_rename(osp.join('tb_logger', opt['name'])) + + # initialize loggers + logger, tb_logger = init_loggers(opt) + + # create train and validation dataloaders + result = create_train_val_dataloader(opt, logger) + train_loader, train_sampler, val_loader, total_epochs, total_iters = result + + # create model + if resume_state: # resume training + check_resume(opt, resume_state['iter']) + model = build_model(opt) + model.resume_training(resume_state) # handle optimizers and schedulers + logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.") + start_epoch = resume_state['epoch'] + current_iter = resume_state['iter'] + else: + model = build_model(opt) + start_epoch = 0 + current_iter = 0 + + # create message logger (formatted outputs) + msg_logger = MessageLogger(opt, current_iter, tb_logger) + + # dataloader prefetcher + prefetch_mode = opt['datasets']['train'].get('prefetch_mode') + if prefetch_mode is None or prefetch_mode == 'cpu': + prefetcher = CPUPrefetcher(train_loader) + elif prefetch_mode == 'cuda': + prefetcher = CUDAPrefetcher(train_loader, opt) + logger.info(f'Use {prefetch_mode} prefetch dataloader') + if opt['datasets']['train'].get('pin_memory') is not True: + raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') + else: + raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.") + + # training + logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}') + data_time, iter_time = time.time(), time.time() + start_time = time.time() + + for epoch in range(start_epoch, total_epochs + 1): + train_sampler.set_epoch(epoch) + prefetcher.reset() + train_data = prefetcher.next() + + while train_data is not None: + data_time = time.time() - data_time + + current_iter += 1 + if current_iter > total_iters: + break + # update learning rate + model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) + # training + model.feed_data(train_data) + model.optimize_parameters(current_iter) + iter_time = time.time() - iter_time + # log + if current_iter % opt['logger']['print_freq'] == 0: + log_vars = {'epoch': epoch, 'iter': current_iter} + log_vars.update({'lrs': model.get_current_learning_rate()}) + log_vars.update({'time': iter_time, 'data_time': data_time}) + log_vars.update(model.get_current_log()) + msg_logger(log_vars) + + # save models and training states + if current_iter % opt['logger']['save_checkpoint_freq'] == 0: + logger.info('Saving models and training states.') + model.save(epoch, current_iter) + + # validation + if opt.get('val') is not None and opt['datasets'].get('val') is not None \ + and (current_iter % opt['val']['val_freq'] == 0): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + + data_time = time.time() + iter_time = time.time() + train_data = prefetcher.next() + # end of iter + + # end of epoch + + consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) + logger.info(f'End of training. Time consumed: {consumed_time}') + logger.info('Save the latest model.') + model.save(epoch=-1, current_iter=-1) # -1 stands for the latest + if opt.get('val') is not None and opt['datasets'].get('val'): + model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img']) + if tb_logger: + tb_logger.close() + + +if __name__ == '__main__': + root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) + train_pipeline(root_path) diff --git a/repositories/codeformer/basicsr/utils/__init__.py b/repositories/codeformer/basicsr/utils/__init__.py new file mode 100644 index 000000000..5fcc1d540 --- /dev/null +++ b/repositories/codeformer/basicsr/utils/__init__.py @@ -0,0 +1,29 @@ +from .file_client import FileClient +from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img +from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger +from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt + +__all__ = [ + # file_client.py + 'FileClient', + # img_util.py + 'img2tensor', + 'tensor2img', + 'imfrombytes', + 'imwrite', + 'crop_border', + # logger.py + 'MessageLogger', + 'init_tb_logger', + 'init_wandb_logger', + 'get_root_logger', + 'get_env_info', + # misc.py + 'set_random_seed', + 'get_time_str', + 'mkdir_and_rename', + 'make_exp_dirs', + 'scandir', + 'check_resume', + 'sizeof_fmt' +] diff --git a/repositories/codeformer/basicsr/utils/dist_util.py b/repositories/codeformer/basicsr/utils/dist_util.py new file mode 100644 index 000000000..0fab887b2 --- /dev/null +++ b/repositories/codeformer/basicsr/utils/dist_util.py @@ -0,0 +1,82 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 +import functools +import os +import subprocess +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def master_only(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank, _ = get_dist_info() + if rank == 0: + return func(*args, **kwargs) + + return wrapper diff --git a/repositories/codeformer/basicsr/utils/download_util.py b/repositories/codeformer/basicsr/utils/download_util.py new file mode 100644 index 000000000..2a2679157 --- /dev/null +++ b/repositories/codeformer/basicsr/utils/download_util.py @@ -0,0 +1,95 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +from .misc import sizeof_fmt + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/repositories/codeformer/basicsr/utils/file_client.py b/repositories/codeformer/basicsr/utils/file_client.py new file mode 100644 index 000000000..7f38d9796 --- /dev/null +++ b/repositories/codeformer/basicsr/utils/file_client.py @@ -0,0 +1,167 @@ +# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/repositories/codeformer/basicsr/utils/img_util.py b/repositories/codeformer/basicsr/utils/img_util.py new file mode 100644 index 000000000..5aba82ce0 --- /dev/null +++ b/repositories/codeformer/basicsr/utils/img_util.py @@ -0,0 +1,171 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] + \ No newline at end of file diff --git a/repositories/codeformer/basicsr/utils/lmdb_util.py b/repositories/codeformer/basicsr/utils/lmdb_util.py new file mode 100644 index 000000000..e0a10f60f --- /dev/null +++ b/repositories/codeformer/basicsr/utils/lmdb_util.py @@ -0,0 +1,196 @@ +import cv2 +import lmdb +import sys +from multiprocessing import Pool +from os import path as osp +from tqdm import tqdm + + +def make_lmdb_from_imgs(data_path, + lmdb_path, + img_path_list, + keys, + batch=5000, + compress_level=1, + multiprocessing_read=False, + n_thread=40, + map_size=None): + """Make lmdb from images. + + Contents of lmdb. The file structure is: + example.lmdb + ├── data.mdb + ├── lock.mdb + ├── meta_info.txt + + The data.mdb and lock.mdb are standard lmdb files and you can refer to + https://lmdb.readthedocs.io/en/release/ for more details. + + The meta_info.txt is a specified txt file to record the meta information + of our datasets. It will be automatically created when preparing + datasets by our provided dataset tools. + Each line in the txt file records 1)image name (with extension), + 2)image shape, and 3)compression level, separated by a white space. + + For example, the meta information could be: + `000_00000000.png (720,1280,3) 1`, which means: + 1) image name (with extension): 000_00000000.png; + 2) image shape: (720,1280,3); + 3) compression level: 1 + + We use the image name without extension as the lmdb key. + + If `multiprocessing_read` is True, it will read all the images to memory + using multiprocessing. Thus, your server needs to have enough memory. + + Args: + data_path (str): Data path for reading images. + lmdb_path (str): Lmdb save path. + img_path_list (str): Image path list. + keys (str): Used for lmdb keys. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + multiprocessing_read (bool): Whether use multiprocessing to read all + the images to memory. Default: False. + n_thread (int): For multiprocessing. + map_size (int | None): Map size for lmdb env. If None, use the + estimated size from images. Default: None + """ + + assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' + f'but got {len(img_path_list)} and {len(keys)}') + print(f'Create lmdb for {data_path}, save to {lmdb_path}...') + print(f'Totoal images: {len(img_path_list)}') + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + if multiprocessing_read: + # read all the images to memory (multiprocessing) + dataset = {} # use dict to keep the order for multiprocessing + shapes = {} + print(f'Read images with multiprocessing, #thread: {n_thread} ...') + pbar = tqdm(total=len(img_path_list), unit='image') + + def callback(arg): + """get the image data and update pbar.""" + key, dataset[key], shapes[key] = arg + pbar.update(1) + pbar.set_description(f'Read {key}') + + pool = Pool(n_thread) + for path, key in zip(img_path_list, keys): + pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) + pool.close() + pool.join() + pbar.close() + print(f'Finish reading {len(img_path_list)} images.') + + # create lmdb environment + if map_size is None: + # obtain data size for one image + img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + data_size_per_img = img_byte.nbytes + print('Data size per image is: ', data_size_per_img) + data_size = data_size_per_img * len(img_path_list) + map_size = data_size * 10 + + env = lmdb.open(lmdb_path, map_size=map_size) + + # write data to lmdb + pbar = tqdm(total=len(img_path_list), unit='chunk') + txn = env.begin(write=True) + txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + for idx, (path, key) in enumerate(zip(img_path_list, keys)): + pbar.update(1) + pbar.set_description(f'Write {key}') + key_byte = key.encode('ascii') + if multiprocessing_read: + img_byte = dataset[key] + h, w, c = shapes[key] + else: + _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) + h, w, c = img_shape + + txn.put(key_byte, img_byte) + # write meta information + txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') + if idx % batch == 0: + txn.commit() + txn = env.begin(write=True) + pbar.close() + txn.commit() + env.close() + txt_file.close() + print('\nFinish writing lmdb.') + + +def read_img_worker(path, key, compress_level): + """Read image worker. + + Args: + path (str): Image path. + key (str): Image key. + compress_level (int): Compress level when encoding images. + + Returns: + str: Image key. + byte: Image byte. + tuple[int]: Image shape. + """ + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img.ndim == 2: + h, w = img.shape + c = 1 + else: + h, w, c = img.shape + _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) + return (key, img_byte, (h, w, c)) + + +class LmdbMaker(): + """LMDB Maker. + + Args: + lmdb_path (str): Lmdb save path. + map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. + batch (int): After processing batch images, lmdb commits. + Default: 5000. + compress_level (int): Compress level when encoding images. Default: 1. + """ + + def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): + if not lmdb_path.endswith('.lmdb'): + raise ValueError("lmdb_path must end with '.lmdb'.") + if osp.exists(lmdb_path): + print(f'Folder {lmdb_path} already exists. Exit.') + sys.exit(1) + + self.lmdb_path = lmdb_path + self.batch = batch + self.compress_level = compress_level + self.env = lmdb.open(lmdb_path, map_size=map_size) + self.txn = self.env.begin(write=True) + self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') + self.counter = 0 + + def put(self, img_byte, key, img_shape): + self.counter += 1 + key_byte = key.encode('ascii') + self.txn.put(key_byte, img_byte) + # write meta information + h, w, c = img_shape + self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') + if self.counter % self.batch == 0: + self.txn.commit() + self.txn = self.env.begin(write=True) + + def close(self): + self.txn.commit() + self.env.close() + self.txt_file.close() diff --git a/repositories/codeformer/basicsr/utils/logger.py b/repositories/codeformer/basicsr/utils/logger.py new file mode 100644 index 000000000..9714bf59c --- /dev/null +++ b/repositories/codeformer/basicsr/utils/logger.py @@ -0,0 +1,169 @@ +import datetime +import logging +import time + +from .dist_util import get_dist_info, master_only + +initialized_logger = {} + + +class MessageLogger(): + """Message logger for printing. + Args: + opt (dict): Config. It contains the following keys: + name (str): Exp name. + logger (dict): Contains 'print_freq' (str) for logger interval. + train (dict): Contains 'total_iter' (int) for total iters. + use_tb_logger (bool): Use tensorboard logger. + start_iter (int): Start iter. Default: 1. + tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. + """ + + def __init__(self, opt, start_iter=1, tb_logger=None): + self.exp_name = opt['name'] + self.interval = opt['logger']['print_freq'] + self.start_iter = start_iter + self.max_iters = opt['train']['total_iter'] + self.use_tb_logger = opt['logger']['use_tb_logger'] + self.tb_logger = tb_logger + self.start_time = time.time() + self.logger = get_root_logger() + + @master_only + def __call__(self, log_vars): + """Format logging message. + Args: + log_vars (dict): It contains the following keys: + epoch (int): Epoch number. + iter (int): Current iter. + lrs (list): List for learning rates. + time (float): Iter time. + data_time (float): Data time for each iter. + """ + # epoch, iter, learning rates + epoch = log_vars.pop('epoch') + current_iter = log_vars.pop('iter') + lrs = log_vars.pop('lrs') + + message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') + for v in lrs: + message += f'{v:.3e},' + message += ')] ' + + # time and estimated time + if 'time' in log_vars.keys(): + iter_time = log_vars.pop('time') + data_time = log_vars.pop('data_time') + + total_time = time.time() - self.start_time + time_sec_avg = total_time / (current_iter - self.start_iter + 1) + eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + message += f'[eta: {eta_str}, ' + message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' + + # other items, especially losses + for k, v in log_vars.items(): + message += f'{k}: {v:.4e} ' + # tensorboard logger + if self.use_tb_logger: + if k.startswith('l_'): + self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) + else: + self.tb_logger.add_scalar(k, v, current_iter) + self.logger.info(message) + + +@master_only +def init_tb_logger(log_dir): + from torch.utils.tensorboard import SummaryWriter + tb_logger = SummaryWriter(log_dir=log_dir) + return tb_logger + + +@master_only +def init_wandb_logger(opt): + """We now only use wandb to sync tensorboard log.""" + import wandb + logger = logging.getLogger('basicsr') + + project = opt['logger']['wandb']['project'] + resume_id = opt['logger']['wandb'].get('resume_id') + if resume_id: + wandb_id = resume_id + resume = 'allow' + logger.warning(f'Resume wandb logger with id={wandb_id}.') + else: + wandb_id = wandb.util.generate_id() + resume = 'never' + + wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) + + logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') + + +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + elif log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +def get_env_info(): + """Get environment information. + Currently, only log the software version. + """ + import torch + import torchvision + + from basicsr.version import __version__ + msg = r""" + ____ _ _____ ____ + / __ ) ____ _ _____ (_)_____/ ___/ / __ \ + / __ |/ __ `// ___// // ___/\__ \ / /_/ / + / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ + /_____/ \__,_//____//_/ \___//____//_/ |_| + ______ __ __ __ __ + / ____/____ ____ ____/ / / / __ __ _____ / /__ / / + / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / + / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ + \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) + """ + msg += ('\nVersion Information: ' + f'\n\tBasicSR: {__version__}' + f'\n\tPyTorch: {torch.__version__}' + f'\n\tTorchVision: {torchvision.__version__}') + return msg \ No newline at end of file diff --git a/repositories/codeformer/basicsr/utils/matlab_functions.py b/repositories/codeformer/basicsr/utils/matlab_functions.py new file mode 100644 index 000000000..c6ce1004a --- /dev/null +++ b/repositories/codeformer/basicsr/utils/matlab_functions.py @@ -0,0 +1,347 @@ +import math +import numpy as np +import torch + + +def cubic(x): + """cubic function used for calculate_weights_indices.""" + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5 * absx3 - 2.5 * absx2 + 1) * ( + (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) * + (absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + """Calculate weights and indices, used for imresize function. + + Args: + in_length (int): Input length. + out_length (int): Output length. + scale (float): Scale factor. + kernel_width (int): Kernel width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + """ + + if (scale < 1) and antialiasing: + # Use a modified kernel (larger kernel width) to simultaneously + # interpolate and antialias + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5 + scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + p = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( + out_length, p) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices + + # apply cubic kernel + if (scale < 1) and antialiasing: + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, p) + + # If a column in weights is all zero, get rid of it. only consider the + # first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, p - 2) + weights = weights.narrow(1, 1, p - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, p - 2) + weights = weights.narrow(1, 0, p - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +@torch.no_grad() +def imresize(img, scale, antialiasing=True): + """imresize function same as MATLAB. + + It now only supports bicubic. + The same scale applies for both height and width. + + Args: + img (Tensor | Numpy array): + Tensor: Input image with shape (c, h, w), [0, 1] range. + Numpy: Input image with shape (h, w, c), [0, 1] range. + scale (float): Scale factor. The same scale applies for both height + and width. + antialisaing (bool): Whether to apply anti-aliasing when downsampling. + Default: True. + + Returns: + Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round. + """ + if type(img).__module__ == np.__name__: # numpy type + numpy_type = True + img = torch.from_numpy(img.transpose(2, 0, 1)).float() + else: + numpy_type = False + + in_c, in_h, in_w = img.size() + out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale) + kernel_width = 4 + kernel = 'cubic' + + # get weights and indices + weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width, + antialiasing) + weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width, + antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) + img_aug.narrow(1, sym_len_hs, in_h).copy_(img) + + sym_patch = img[:, :sym_len_hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_he:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_c, out_h, in_w) + kernel_width = weights_h.size(1) + for i in range(out_h): + idx = int(indices_h[i][0]) + for j in range(in_c): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) + out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_we:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_c, out_h, out_w) + kernel_width = weights_w.size(1) + for i in range(out_w): + idx = int(indices_w[i][0]) + for j in range(in_c): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) + + if numpy_type: + out_2 = out_2.numpy().transpose(1, 2, 0) + return out_2 + + +def rgb2ycbcr(img, y_only=False): + """Convert a RGB image to YCbCr image. + + This function produces the same results as Matlab's `rgb2ycbcr` function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 + else: + out_img = np.matmul( + img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def bgr2ycbcr(img, y_only=False): + """Convert a BGR image to YCbCr image. + + The bgr version of rgb2ycbcr. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + y_only (bool): Whether to only return Y channel. Default: False. + + Returns: + ndarray: The converted YCbCr image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) + if y_only: + out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 + else: + out_img = np.matmul( + img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2rgb(img): + """Convert a YCbCr image to RGB image. + + This function produces the same results as Matlab's ycbcr2rgb function. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted RGB image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def ycbcr2bgr(img): + """Convert a YCbCr image to BGR image. + + The bgr version of ycbcr2rgb. + It implements the ITU-R BT.601 conversion for standard-definition + television. See more details in + https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. + + It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. + In OpenCV, it implements a JPEG conversion. See more details in + https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + ndarray: The converted BGR image. The output image has the same type + and range as input image. + """ + img_type = img.dtype + img = _convert_input_type_range(img) * 255 + out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], + [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 + out_img = _convert_output_type_range(out_img, img_type) + return out_img + + +def _convert_input_type_range(img): + """Convert the type and range of the input image. + + It converts the input image to np.float32 type and range of [0, 1]. + It is mainly used for pre-processing the input image in colorspace + convertion functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The input image. It accepts: + 1. np.uint8 type with range [0, 255]; + 2. np.float32 type with range [0, 1]. + + Returns: + (ndarray): The converted image with type of np.float32 and range of + [0, 1]. + """ + img_type = img.dtype + img = img.astype(np.float32) + if img_type == np.float32: + pass + elif img_type == np.uint8: + img /= 255. + else: + raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') + return img + + +def _convert_output_type_range(img, dst_type): + """Convert the type and range of the image according to dst_type. + + It converts the image to desired type and range. If `dst_type` is np.uint8, + images will be converted to np.uint8 type with range [0, 255]. If + `dst_type` is np.float32, it converts the image to np.float32 type with + range [0, 1]. + It is mainly used for post-processing images in colorspace convertion + functions such as rgb2ycbcr and ycbcr2rgb. + + Args: + img (ndarray): The image to be converted with np.float32 type and + range [0, 255]. + dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it + converts the image to np.uint8 type with range [0, 255]. If + dst_type is np.float32, it converts the image to np.float32 type + with range [0, 1]. + + Returns: + (ndarray): The converted image with desired type and range. + """ + if dst_type not in (np.uint8, np.float32): + raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') + if dst_type == np.uint8: + img = img.round() + else: + img /= 255. + return img.astype(dst_type) diff --git a/repositories/codeformer/basicsr/utils/misc.py b/repositories/codeformer/basicsr/utils/misc.py new file mode 100644 index 000000000..3b444ff3b --- /dev/null +++ b/repositories/codeformer/basicsr/utils/misc.py @@ -0,0 +1,134 @@ +import numpy as np +import os +import random +import time +import torch +from os import path as osp + +from .dist_util import master_only +from .logger import get_root_logger + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def mkdir_and_rename(path): + """mkdirs. If path exists, rename it with timestamp and create a new one. + + Args: + path (str): Folder path. + """ + if osp.exists(path): + new_name = path + '_archived_' + get_time_str() + print(f'Path already exists. Rename it to {new_name}', flush=True) + os.rename(path, new_name) + os.makedirs(path, exist_ok=True) + + +@master_only +def make_exp_dirs(opt): + """Make dirs for experiments.""" + path_opt = opt['path'].copy() + if opt['is_train']: + mkdir_and_rename(path_opt.pop('experiments_root')) + else: + mkdir_and_rename(path_opt.pop('results_root')) + for key, path in path_opt.items(): + if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key): + os.makedirs(path, exist_ok=True) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def check_resume(opt, resume_iter): + """Check resume states and pretrain_network paths. + + Args: + opt (dict): Options. + resume_iter (int): Resume iteration. + """ + logger = get_root_logger() + if opt['path']['resume_state']: + # get all the networks + networks = [key for key in opt.keys() if key.startswith('network_')] + flag_pretrain = False + for network in networks: + if opt['path'].get(f'pretrain_{network}') is not None: + flag_pretrain = True + if flag_pretrain: + logger.warning('pretrain_network path will be ignored during resuming.') + # set pretrained model paths + for network in networks: + name = f'pretrain_{network}' + basename = network.replace('network_', '') + if opt['path'].get('ignore_resume_networks') is None or (basename + not in opt['path']['ignore_resume_networks']): + opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') + logger.info(f"Set {name} to {opt['path'][name]}") + + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' diff --git a/repositories/codeformer/basicsr/utils/options.py b/repositories/codeformer/basicsr/utils/options.py new file mode 100644 index 000000000..db490e4aa --- /dev/null +++ b/repositories/codeformer/basicsr/utils/options.py @@ -0,0 +1,108 @@ +import yaml +import time +from collections import OrderedDict +from os import path as osp +from basicsr.utils.misc import get_time_str + +def ordered_yaml(): + """Support OrderedDict for yaml. + + Returns: + yaml Loader and Dumper. + """ + try: + from yaml import CDumper as Dumper + from yaml import CLoader as Loader + except ImportError: + from yaml import Dumper, Loader + + _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG + + def dict_representer(dumper, data): + return dumper.represent_dict(data.items()) + + def dict_constructor(loader, node): + return OrderedDict(loader.construct_pairs(node)) + + Dumper.add_representer(OrderedDict, dict_representer) + Loader.add_constructor(_mapping_tag, dict_constructor) + return Loader, Dumper + + +def parse(opt_path, root_path, is_train=True): + """Parse option file. + + Args: + opt_path (str): Option file path. + is_train (str): Indicate whether in training or not. Default: True. + + Returns: + (dict): Options. + """ + with open(opt_path, mode='r') as f: + Loader, _ = ordered_yaml() + opt = yaml.load(f, Loader=Loader) + + opt['is_train'] = is_train + + # opt['name'] = f"{get_time_str()}_{opt['name']}" + if opt['path'].get('resume_state', None): # Shangchen added + resume_state_path = opt['path'].get('resume_state') + opt['name'] = resume_state_path.split("/")[-3] + else: + opt['name'] = f"{get_time_str()}_{opt['name']}" + + + # datasets + for phase, dataset in opt['datasets'].items(): + # for several datasets, e.g., test_1, test_2 + phase = phase.split('_')[0] + dataset['phase'] = phase + if 'scale' in opt: + dataset['scale'] = opt['scale'] + if dataset.get('dataroot_gt') is not None: + dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) + if dataset.get('dataroot_lq') is not None: + dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) + + # paths + for key, val in opt['path'].items(): + if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): + opt['path'][key] = osp.expanduser(val) + + if is_train: + experiments_root = osp.join(root_path, 'experiments', opt['name']) + opt['path']['experiments_root'] = experiments_root + opt['path']['models'] = osp.join(experiments_root, 'models') + opt['path']['training_states'] = osp.join(experiments_root, 'training_states') + opt['path']['log'] = experiments_root + opt['path']['visualization'] = osp.join(experiments_root, 'visualization') + + else: # test + results_root = osp.join(root_path, 'results', opt['name']) + opt['path']['results_root'] = results_root + opt['path']['log'] = results_root + opt['path']['visualization'] = osp.join(results_root, 'visualization') + + return opt + + +def dict2str(opt, indent_level=1): + """dict to string for printing options. + + Args: + opt (dict): Option dict. + indent_level (int): Indent level. Default: 1. + + Return: + (str): Option string for printing. + """ + msg = '\n' + for k, v in opt.items(): + if isinstance(v, dict): + msg += ' ' * (indent_level * 2) + k + ':[' + msg += dict2str(v, indent_level + 1) + msg += ' ' * (indent_level * 2) + ']\n' + else: + msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' + return msg diff --git a/repositories/codeformer/basicsr/utils/realesrgan_utils.py b/repositories/codeformer/basicsr/utils/realesrgan_utils.py new file mode 100644 index 000000000..6b7a8b460 --- /dev/null +++ b/repositories/codeformer/basicsr/utils/realesrgan_utils.py @@ -0,0 +1,299 @@ +import cv2 +import math +import numpy as np +import os +import queue +import threading +import torch +from basicsr.utils.download_util import load_file_from_url +from torch.nn import functional as F + +# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +class RealESRGANer(): + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + def __init__(self, + scale, + model_path, + model=None, + tile=0, + tile_pad=10, + pre_pad=10, + half=False, + device=None, + gpu_id=None): + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale = None + self.half = half + + # initialize model + if gpu_id: + self.device = torch.device( + f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device + # if the model_path starts with https, it will first download models to the folder: realesrgan/weights + if model_path.startswith('https://'): + model_path = load_file_from_url( + url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None) + loadnet = torch.load(model_path, map_location=torch.device('cpu')) + # prefer to use params_ema + if 'params_ema' in loadnet: + keyname = 'params_ema' + else: + keyname = 'params' + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + if self.half: + self.model = self.model.half() + + def pre_process(self, img): + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible + """ + img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if (h % self.mod_scale != 0): + self.mod_pad_h = (self.mod_scale - h % self.mod_scale) + if (w % self.mod_scale != 0): + self.mod_pad_w = (self.mod_scale - w % self.mod_scale) + self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') + + def process(self): + # model inference + self.output = self.model(self.img) + + def tile_process(self): + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + for y in range(tiles_y): + for x in range(tiles_x): + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + tile_idx = y * tiles_x + x + 1 + input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] + + # upscale tile + try: + with torch.no_grad(): + output_tile = self.model(input_tile) + except RuntimeError as error: + print('Error', error) + # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, + output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile] + + def post_process(self): + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] + return self.output + + @torch.no_grad() + def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): + h_input, w_input = img.shape[0:2] + # img: numpy + img = img.astype(np.float32) + if np.max(img) > 256: # 16-bit image + max_range = 65535 + print('\tInput is a 16-bit image') + else: + max_range = 255 + img = img / max_range + if len(img.shape) == 2: # gray image + img_mode = 'L' + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + elif img.shape[2] == 4: # RGBA image with alpha channel + img_mode = 'RGBA' + alpha = img[:, :, 3] + img = img[:, :, 0:3] + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + if alpha_upsampler == 'realesrgan': + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = 'RGB' + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + try: + with torch.no_grad(): + self.pre_process(img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_img_t = self.post_process() + output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode == 'L': + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + del output_img_t + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for RealESRGAN: {error}") + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode == 'RGBA': + if alpha_upsampler == 'realesrgan': + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha = self.post_process() + output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + h, w = alpha.shape[0:2] + output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + if outscale is not None and outscale != float(self.scale): + output = cv2.resize( + output, ( + int(w_input * outscale), + int(h_input * outscale), + ), interpolation=cv2.INTER_LANCZOS4) + + return output, img_mode + + +class PrefetchReader(threading.Thread): + """Prefetch images. + + Args: + img_list (list[str]): A image list of image paths to be read. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, img_list, num_prefetch_queue): + super().__init__() + self.que = queue.Queue(num_prefetch_queue) + self.img_list = img_list + + def run(self): + for img_path in self.img_list: + img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.que.put(img) + + self.que.put(None) + + def __next__(self): + next_item = self.que.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class IOConsumer(threading.Thread): + + def __init__(self, opt, que, qid): + super().__init__() + self._queue = que + self.qid = qid + self.opt = opt + + def run(self): + while True: + msg = self._queue.get() + if isinstance(msg, str) and msg == 'quit': + break + + output = msg['output'] + save_path = msg['save_path'] + cv2.imwrite(save_path, output) + print(f'IO worker {self.qid} is done.') \ No newline at end of file diff --git a/repositories/codeformer/basicsr/utils/registry.py b/repositories/codeformer/basicsr/utils/registry.py new file mode 100644 index 000000000..655753b3b --- /dev/null +++ b/repositories/codeformer/basicsr/utils/registry.py @@ -0,0 +1,82 @@ +# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 + + +class Registry(): + """ + The registry that provides name -> object mapping, to support third-party + users' custom modules. + + To create a registry (e.g. a backbone registry): + + .. code-block:: python + + BACKBONE_REGISTRY = Registry('BACKBONE') + + To register an object: + + .. code-block:: python + + @BACKBONE_REGISTRY.register() + class MyBackbone(): + ... + + Or: + + .. code-block:: python + + BACKBONE_REGISTRY.register(MyBackbone) + """ + + def __init__(self, name): + """ + Args: + name (str): the name of this registry + """ + self._name = name + self._obj_map = {} + + def _do_register(self, name, obj): + assert (name not in self._obj_map), (f"An object named '{name}' was already registered " + f"in '{self._name}' registry!") + self._obj_map[name] = obj + + def register(self, obj=None): + """ + Register the given object under the the name `obj.__name__`. + Can be used as either a decorator or not. + See docstring of this class for usage. + """ + if obj is None: + # used as a decorator + def deco(func_or_class): + name = func_or_class.__name__ + self._do_register(name, func_or_class) + return func_or_class + + return deco + + # used as a function call + name = obj.__name__ + self._do_register(name, obj) + + def get(self, name): + ret = self._obj_map.get(name) + if ret is None: + raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") + return ret + + def __contains__(self, name): + return name in self._obj_map + + def __iter__(self): + return iter(self._obj_map.items()) + + def keys(self): + return self._obj_map.keys() + + +DATASET_REGISTRY = Registry('dataset') +ARCH_REGISTRY = Registry('arch') +MODEL_REGISTRY = Registry('model') +LOSS_REGISTRY = Registry('loss') +METRIC_REGISTRY = Registry('metric') diff --git a/repositories/codeformer/basicsr/utils/video_util.py b/repositories/codeformer/basicsr/utils/video_util.py new file mode 100644 index 000000000..20a2ff14c --- /dev/null +++ b/repositories/codeformer/basicsr/utils/video_util.py @@ -0,0 +1,125 @@ +''' +The code is modified from the Real-ESRGAN: +https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py + +''' +import cv2 +import sys +import numpy as np + +try: + import ffmpeg +except ImportError: + import pip + pip.main(['install', '--user', 'ffmpeg-python']) + import ffmpeg + +def get_video_meta_info(video_path): + ret = {} + probe = ffmpeg.probe(video_path) + video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video'] + has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams']) + ret['width'] = video_streams[0]['width'] + ret['height'] = video_streams[0]['height'] + ret['fps'] = eval(video_streams[0]['avg_frame_rate']) + ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None + ret['nb_frames'] = int(video_streams[0]['nb_frames']) + return ret + +class VideoReader: + def __init__(self, video_path): + self.paths = [] # for image&folder type + self.audio = None + try: + self.stream_reader = ( + ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24', + loglevel='error').run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + except FileNotFoundError: + print('Please install ffmpeg (not ffmpeg-python) by running\n', + '\t$ conda install -c conda-forge ffmpeg') + sys.exit(0) + + meta = get_video_meta_info(video_path) + self.width = meta['width'] + self.height = meta['height'] + self.input_fps = meta['fps'] + self.audio = meta['audio'] + self.nb_frames = meta['nb_frames'] + + self.idx = 0 + + def get_resolution(self): + return self.height, self.width + + def get_fps(self): + if self.input_fps is not None: + return self.input_fps + return 24 + + def get_audio(self): + return self.audio + + def __len__(self): + return self.nb_frames + + def get_frame_from_stream(self): + img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel + if not img_bytes: + return None + img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3]) + return img + + def get_frame_from_list(self): + if self.idx >= self.nb_frames: + return None + img = cv2.imread(self.paths[self.idx]) + self.idx += 1 + return img + + def get_frame(self): + return self.get_frame_from_stream() + + + def close(self): + self.stream_reader.stdin.close() + self.stream_reader.wait() + + +class VideoWriter: + def __init__(self, video_save_path, height, width, fps, audio): + if height > 2160: + print('You are generating video that is larger than 4K, which will be very slow due to IO speed.', + 'We highly recommend to decrease the outscale(aka, -s).') + if audio is not None: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + audio, + video_save_path, + pix_fmt='yuv420p', + vcodec='libx264', + loglevel='error', + acodec='copy').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + else: + self.stream_writer = ( + ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}', + framerate=fps).output( + video_save_path, pix_fmt='yuv420p', vcodec='libx264', + loglevel='error').overwrite_output().run_async( + pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg')) + + def write_frame(self, frame): + try: + frame = frame.astype(np.uint8).tobytes() + self.stream_writer.stdin.write(frame) + except BrokenPipeError: + print('Please re-install ffmpeg and libx264 by running\n', + '\t$ conda install -c conda-forge ffmpeg\n', + '\t$ conda install -c conda-forge x264') + sys.exit(0) + + def close(self): + self.stream_writer.stdin.close() + self.stream_writer.wait() \ No newline at end of file diff --git a/repositories/codeformer/facelib/detection/__init__.py b/repositories/codeformer/facelib/detection/__init__.py new file mode 100644 index 000000000..5d1f8fc21 --- /dev/null +++ b/repositories/codeformer/facelib/detection/__init__.py @@ -0,0 +1,100 @@ +import os +import torch +from torch import nn +from copy import deepcopy + +from facelib.utils import load_file_from_url +from facelib.utils import download_pretrained_models +from facelib.detection.yolov5face.models.common import Conv + +from .retinaface.retinaface import RetinaFace +from .yolov5face.face_detector import YoloDetector + + +def init_detection_model(model_name, half=False, device='cuda'): + if 'retinaface' in model_name: + model = init_retinaface_model(model_name, half, device) + elif 'YOLOv5' in model_name: + model = init_yolov5face_model(model_name, device) + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + return model + + +def init_retinaface_model(model_name, half=False, device='cuda'): + if model_name == 'retinaface_resnet50': + model = RetinaFace(network_name='resnet50', half=half) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth' + elif model_name == 'retinaface_mobile0.25': + model = RetinaFace(network_name='mobile0.25', half=half) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + # remove unnecessary 'module.' + for k, v in deepcopy(load_net).items(): + if k.startswith('module.'): + load_net[k[7:]] = v + load_net.pop(k) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + + return model + + +def init_yolov5face_model(model_name, device='cuda'): + if model_name == 'YOLOv5l': + model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth' + elif model_name == 'YOLOv5n': + model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.detector.load_state_dict(load_net, strict=True) + model.detector.eval() + model.detector = model.detector.to(device).float() + + for m in model.detector.modules(): + if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: + m.inplace = True # pytorch 1.7.0 compatibility + elif isinstance(m, Conv): + m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + + return model + + +# Download from Google Drive +# def init_yolov5face_model(model_name, device='cuda'): +# if model_name == 'YOLOv5l': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device) +# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'} +# elif model_name == 'YOLOv5n': +# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device) +# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'} +# else: +# raise NotImplementedError(f'{model_name} is not implemented.') + +# model_path = os.path.join('weights/facelib', list(f_id.keys())[0]) +# if not os.path.exists(model_path): +# download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib') + +# load_net = torch.load(model_path, map_location=lambda storage, loc: storage) +# model.detector.load_state_dict(load_net, strict=True) +# model.detector.eval() +# model.detector = model.detector.to(device).float() + +# for m in model.detector.modules(): +# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]: +# m.inplace = True # pytorch 1.7.0 compatibility +# elif isinstance(m, Conv): +# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility + +# return model \ No newline at end of file diff --git a/repositories/codeformer/facelib/detection/align_trans.py b/repositories/codeformer/facelib/detection/align_trans.py new file mode 100644 index 000000000..07f1eb365 --- /dev/null +++ b/repositories/codeformer/facelib/detection/align_trans.py @@ -0,0 +1,219 @@ +import cv2 +import numpy as np + +from .matlab_cp2tform import get_similarity_transform_for_cv2 + +# reference facial points, a list of coordinates (x,y) +REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278], + [33.54930115, 92.3655014], [62.72990036, 92.20410156]] + +DEFAULT_CROP_SIZE = (96, 112) + + +class FaceWarpException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False): + """ + Function: + ---------- + get reference 5 key points according to crop settings: + 0. Set default crop_size: + if default_square: + crop_size = (112, 112) + else: + crop_size = (96, 112) + 1. Pad the crop_size by inner_padding_factor in each side; + 2. Resize crop_size into (output_size - outer_padding*2), + pad into output_size with outer_padding; + 3. Output reference_5point; + Parameters: + ---------- + @output_size: (w, h) or None + size of aligned face image + @inner_padding_factor: (w_factor, h_factor) + padding factor for inner (w, h) + @outer_padding: (w_pad, h_pad) + each row is a pair of coordinates (x, y) + @default_square: True or False + if True: + default crop_size = (112, 112) + else: + default crop_size = (96, 112); + !!! make sure, if output_size is not None: + (output_size - outer_padding) + = some_scale * (default crop_size * (1.0 + + inner_padding_factor)) + Returns: + ---------- + @reference_5point: 5x2 np.array + each row is a pair of transformed coordinates (x, y) + """ + + tmp_5pts = np.array(REFERENCE_FACIAL_POINTS) + tmp_crop_size = np.array(DEFAULT_CROP_SIZE) + + # 0) make the inner region a square + if default_square: + size_diff = max(tmp_crop_size) - tmp_crop_size + tmp_5pts += size_diff / 2 + tmp_crop_size += size_diff + + if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]): + + return tmp_5pts + + if (inner_padding_factor == 0 and outer_padding == (0, 0)): + if output_size is None: + return tmp_5pts + else: + raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size)) + + # check output size + if not (0 <= inner_padding_factor <= 1.0): + raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)') + + if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None): + output_size = tmp_crop_size * \ + (1 + inner_padding_factor * 2).astype(np.int32) + output_size += np.array(outer_padding) + if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]): + raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])') + + # 1) pad the inner region according inner_padding_factor + if inner_padding_factor > 0: + size_diff = tmp_crop_size * inner_padding_factor * 2 + tmp_5pts += size_diff / 2 + tmp_crop_size += np.round(size_diff).astype(np.int32) + + # 2) resize the padded inner region + size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2 + + if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]: + raise FaceWarpException('Must have (output_size - outer_padding)' + '= some_scale * (crop_size * (1.0 + inner_padding_factor)') + + scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0] + tmp_5pts = tmp_5pts * scale_factor + # size_diff = tmp_crop_size * (scale_factor - min(scale_factor)) + # tmp_5pts = tmp_5pts + size_diff / 2 + tmp_crop_size = size_bf_outer_pad + + # 3) add outer_padding to make output_size + reference_5point = tmp_5pts + np.array(outer_padding) + tmp_crop_size = output_size + + return reference_5point + + +def get_affine_transform_matrix(src_pts, dst_pts): + """ + Function: + ---------- + get affine transform matrix 'tfm' from src_pts to dst_pts + Parameters: + ---------- + @src_pts: Kx2 np.array + source points matrix, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points matrix, each row is a pair of coordinates (x, y) + Returns: + ---------- + @tfm: 2x3 np.array + transform matrix from src_pts to dst_pts + """ + + tfm = np.float32([[1, 0, 0], [0, 1, 0]]) + n_pts = src_pts.shape[0] + ones = np.ones((n_pts, 1), src_pts.dtype) + src_pts_ = np.hstack([src_pts, ones]) + dst_pts_ = np.hstack([dst_pts, ones]) + + A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_) + + if rank == 3: + tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]]) + elif rank == 2: + tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]]) + + return tfm + + +def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'): + """ + Function: + ---------- + apply affine transform 'trans' to uv + Parameters: + ---------- + @src_img: 3x3 np.array + input image + @facial_pts: could be + 1)a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + @reference_pts: could be + 1) a list of K coordinates (x,y) + or + 2) Kx2 or 2xK np.array + each row or col is a pair of coordinates (x, y) + or + 3) None + if None, use default reference facial points + @crop_size: (w, h) + output face image size + @align_type: transform type, could be one of + 1) 'similarity': use similarity transform + 2) 'cv2_affine': use the first 3 points to do affine transform, + by calling cv2.getAffineTransform() + 3) 'affine': use all points to do affine transform + Returns: + ---------- + @face_img: output face image with size (w, h) = @crop_size + """ + + if reference_pts is None: + if crop_size[0] == 96 and crop_size[1] == 112: + reference_pts = REFERENCE_FACIAL_POINTS + else: + default_square = False + inner_padding_factor = 0 + outer_padding = (0, 0) + output_size = crop_size + + reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding, + default_square) + + ref_pts = np.float32(reference_pts) + ref_pts_shp = ref_pts.shape + if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2: + raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2') + + if ref_pts_shp[0] == 2: + ref_pts = ref_pts.T + + src_pts = np.float32(facial_pts) + src_pts_shp = src_pts.shape + if max(src_pts_shp) < 3 or min(src_pts_shp) != 2: + raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2') + + if src_pts_shp[0] == 2: + src_pts = src_pts.T + + if src_pts.shape != ref_pts.shape: + raise FaceWarpException('facial_pts and reference_pts must have the same shape') + + if align_type == 'cv2_affine': + tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3]) + elif align_type == 'affine': + tfm = get_affine_transform_matrix(src_pts, ref_pts) + else: + tfm = get_similarity_transform_for_cv2(src_pts, ref_pts) + + face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1])) + + return face_img diff --git a/repositories/codeformer/facelib/detection/matlab_cp2tform.py b/repositories/codeformer/facelib/detection/matlab_cp2tform.py new file mode 100644 index 000000000..b2a8b54a9 --- /dev/null +++ b/repositories/codeformer/facelib/detection/matlab_cp2tform.py @@ -0,0 +1,317 @@ +import numpy as np +from numpy.linalg import inv, lstsq +from numpy.linalg import matrix_rank as rank +from numpy.linalg import norm + + +class MatlabCp2tormException(Exception): + + def __str__(self): + return 'In File {}:{}'.format(__file__, super.__str__(self)) + + +def tformfwd(trans, uv): + """ + Function: + ---------- + apply affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of transformed coordinates (x, y) + """ + uv = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy = np.dot(uv, trans) + xy = xy[:, 0:-1] + return xy + + +def tforminv(trans, uv): + """ + Function: + ---------- + apply the inverse of affine transform 'trans' to uv + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix + @uv: Kx2 np.array + each row is a pair of coordinates (x, y) + + Returns: + ---------- + @xy: Kx2 np.array + each row is a pair of inverse-transformed coordinates (x, y) + """ + Tinv = inv(trans) + xy = tformfwd(Tinv, uv) + return xy + + +def findNonreflectiveSimilarity(uv, xy, options=None): + options = {'K': 2} + + K = options['K'] + M = xy.shape[0] + x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + + tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1)))) + tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1)))) + X = np.vstack((tmp1, tmp2)) + + u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector + v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector + U = np.vstack((u, v)) + + # We know that X * r = U + if rank(X) >= 2 * K: + r, _, _, _ = lstsq(X, U, rcond=-1) + r = np.squeeze(r) + else: + raise Exception('cp2tform:twoUniquePointsReq') + sc = r[0] + ss = r[1] + tx = r[2] + ty = r[3] + + Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]]) + T = inv(Tinv) + T[:, 2] = np.array([0, 0, 1]) + + return T, Tinv + + +def findSimilarity(uv, xy, options=None): + options = {'K': 2} + + # uv = np.array(uv) + # xy = np.array(xy) + + # Solve for trans1 + trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options) + + # Solve for trans2 + + # manually reflect the xy data across the Y-axis + xyR = xy + xyR[:, 0] = -1 * xyR[:, 0] + + trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options) + + # manually reflect the tform to undo the reflection done on xyR + TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + trans2 = np.dot(trans2r, TreflectY) + + # Figure out if trans1 or trans2 is better + xy1 = tformfwd(trans1, uv) + norm1 = norm(xy1 - xy) + + xy2 = tformfwd(trans2, uv) + norm2 = norm(xy2 - xy) + + if norm1 <= norm2: + return trans1, trans1_inv + else: + trans2_inv = inv(trans2) + return trans2, trans2_inv + + +def get_similarity_transform(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'trans': + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y, 1] = [u, v, 1] * trans + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + @reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + trans_inv: 3x3 np.array + inverse of trans, transform matrix from xy to uv + """ + + if reflective: + trans, trans_inv = findSimilarity(src_pts, dst_pts) + else: + trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts) + + return trans, trans_inv + + +def cvt_tform_mat_for_cv2(trans): + """ + Function: + ---------- + Convert Transform Matrix 'trans' into 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @trans: 3x3 np.array + transform matrix from uv to xy + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + cv2_trans = trans[:, 0:2].T + + return cv2_trans + + +def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True): + """ + Function: + ---------- + Find Similarity Transform Matrix 'cv2_trans' which could be + directly used by cv2.warpAffine(): + u = src_pts[:, 0] + v = src_pts[:, 1] + x = dst_pts[:, 0] + y = dst_pts[:, 1] + [x, y].T = cv_trans * [u, v, 1].T + + Parameters: + ---------- + @src_pts: Kx2 np.array + source points, each row is a pair of coordinates (x, y) + @dst_pts: Kx2 np.array + destination points, each row is a pair of transformed + coordinates (x, y) + reflective: True or False + if True: + use reflective similarity transform + else: + use non-reflective similarity transform + + Returns: + ---------- + @cv2_trans: 2x3 np.array + transform matrix from src_pts to dst_pts, could be directly used + for cv2.warpAffine() + """ + trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective) + cv2_trans = cvt_tform_mat_for_cv2(trans) + + return cv2_trans + + +if __name__ == '__main__': + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + # In Matlab, run: + # + # uv = [u'; v']; + # xy = [x'; y']; + # tform_sim=cp2tform(uv,xy,'similarity'); + # + # trans = tform_sim.tdata.T + # ans = + # -0.0764 -1.6190 0 + # 1.6190 -0.0764 0 + # -3.2156 0.0290 1.0000 + # trans_inv = tform_sim.tdata.Tinv + # ans = + # + # -0.0291 0.6163 0 + # -0.6163 -0.0291 0 + # -0.0756 1.9826 1.0000 + # xy_m=tformfwd(tform_sim, u,v) + # + # xy_m = + # + # -3.2156 0.0290 + # 1.1833 -9.9143 + # 5.0323 2.8853 + # uv_m=tforminv(tform_sim, x,y) + # + # uv_m = + # + # 0.5698 1.3953 + # 6.0872 2.2733 + # -2.6570 4.3314 + """ + u = [0, 6, -2] + v = [0, 3, 5] + x = [-1, 0, 4] + y = [-1, -10, 4] + + uv = np.array((u, v)).T + xy = np.array((x, y)).T + + print('\n--->uv:') + print(uv) + print('\n--->xy:') + print(xy) + + trans, trans_inv = get_similarity_transform(uv, xy) + + print('\n--->trans matrix:') + print(trans) + + print('\n--->trans_inv matrix:') + print(trans_inv) + + print('\n---> apply transform to uv') + print('\nxy_m = uv_augmented * trans') + uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1)))) + xy_m = np.dot(uv_aug, trans) + print(xy_m) + + print('\nxy_m = tformfwd(trans, uv)') + xy_m = tformfwd(trans, uv) + print(xy_m) + + print('\n---> apply inverse transform to xy') + print('\nuv_m = xy_augmented * trans_inv') + xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1)))) + uv_m = np.dot(xy_aug, trans_inv) + print(uv_m) + + print('\nuv_m = tformfwd(trans_inv, xy)') + uv_m = tformfwd(trans_inv, xy) + print(uv_m) + + uv_m = tforminv(trans, xy) + print('\nuv_m = tforminv(trans, xy)') + print(uv_m) diff --git a/repositories/codeformer/facelib/detection/retinaface/retinaface.py b/repositories/codeformer/facelib/detection/retinaface/retinaface.py new file mode 100644 index 000000000..02593556d --- /dev/null +++ b/repositories/codeformer/facelib/detection/retinaface/retinaface.py @@ -0,0 +1,370 @@ +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter + +from facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face +from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head +from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm, + py_cpu_nms) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +def generate_config(network_name): + + cfg_mnet = { + 'name': 'mobilenet0.25', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 32, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + 'image_size': 640, + 'return_layers': { + 'stage1': 1, + 'stage2': 2, + 'stage3': 3 + }, + 'in_channel': 32, + 'out_channel': 64 + } + + cfg_re50 = { + 'name': 'Resnet50', + 'min_sizes': [[16, 32], [64, 128], [256, 512]], + 'steps': [8, 16, 32], + 'variance': [0.1, 0.2], + 'clip': False, + 'loc_weight': 2.0, + 'gpu_train': True, + 'batch_size': 24, + 'ngpu': 4, + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + 'image_size': 840, + 'return_layers': { + 'layer2': 1, + 'layer3': 2, + 'layer4': 3 + }, + 'in_channel': 256, + 'out_channel': 256 + } + + if network_name == 'mobile0.25': + return cfg_mnet + elif network_name == 'resnet50': + return cfg_re50 + else: + raise NotImplementedError(f'network_name={network_name}') + + +class RetinaFace(nn.Module): + + def __init__(self, network_name='resnet50', half=False, phase='test'): + super(RetinaFace, self).__init__() + self.half_inference = half + cfg = generate_config(network_name) + self.backbone = cfg['name'] + + self.model_name = f'retinaface_{network_name}' + self.cfg = cfg + self.phase = phase + self.target_size, self.max_size = 1600, 2150 + self.resize, self.scale, self.scale1 = 1., None, None + self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device) + self.reference = get_reference_facial_points(default_square=True) + # Build network. + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=False) + self.body = IntermediateLayerGetter(backbone, cfg['return_layers']) + + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list, out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + self.to(device) + self.eval() + if self.half_inference: + self.half() + + def forward(self, inputs): + out = self.body(inputs) + + if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50': + out = list(out.values()) + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1) + tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)] + ldm_regressions = (torch.cat(tmp, dim=1)) + + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output + + def __detect_faces(self, inputs): + # get scale + height, width = inputs.shape[2:] + self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device) + tmp = [width, height, width, height, width, height, width, height, width, height] + self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device) + + # forawrd + inputs = inputs.to(device) + if self.half_inference: + inputs = inputs.half() + loc, conf, landmarks = self(inputs) + + # get priorbox + priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:]) + priors = priorbox.forward().to(device) + + return loc, conf, landmarks, priors + + # single image detection + def transform(self, image, use_origin_size): + # convert to opencv format + if isinstance(image, Image.Image): + image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) + image = image.astype(np.float32) + + # testing scale + im_size_min = np.min(image.shape[0:2]) + im_size_max = np.max(image.shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + + # convert to torch.tensor format + # image -= (104, 117, 123) + image = image.transpose(2, 0, 1) + image = torch.from_numpy(image).unsqueeze(0) + + return image, resize + + def detect_faces( + self, + image, + conf_threshold=0.8, + nms_threshold=0.4, + use_origin_size=True, + ): + """ + Params: + imgs: BGR image + """ + image, self.resize = self.transform(image, use_origin_size) + image = image.to(device) + if self.half_inference: + image = image.half() + image = image - self.mean_tensor + + loc, conf, landmarks, priors = self.__detect_faces(image) + + boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance']) + boxes = boxes * self.scale / self.resize + boxes = boxes.cpu().numpy() + + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + + landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance']) + landmarks = landmarks * self.scale1 / self.resize + landmarks = landmarks.cpu().numpy() + + # ignore low scores + inds = np.where(scores > conf_threshold)[0] + boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds] + + # sort + order = scores.argsort()[::-1] + boxes, landmarks, scores = boxes[order], landmarks[order], scores[order] + + # do NMS + bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep] + # self.t['forward_pass'].toc() + # print(self.t['forward_pass'].average_time) + # import sys + # sys.stdout.flush() + return np.concatenate((bounding_boxes, landmarks), axis=1) + + def __align_multi(self, image, boxes, landmarks, limit=None): + + if len(boxes) < 1: + return [], [] + + if limit: + boxes = boxes[:limit] + landmarks = landmarks[:limit] + + faces = [] + for landmark in landmarks: + facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)] + + warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112)) + faces.append(warped_face) + + return np.concatenate((boxes, landmarks), axis=1), faces + + def align_multi(self, img, conf_threshold=0.8, limit=None): + + rlt = self.detect_faces(img, conf_threshold=conf_threshold) + boxes, landmarks = rlt[:, 0:5], rlt[:, 5:] + + return self.__align_multi(img, boxes, landmarks, limit) + + # batched detection + def batched_transform(self, frames, use_origin_size): + """ + Arguments: + frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c], + type=np.float32, BGR format). + use_origin_size: whether to use origin size. + """ + from_PIL = True if isinstance(frames[0], Image.Image) else False + + # convert to opencv format + if from_PIL: + frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames] + frames = np.asarray(frames, dtype=np.float32) + + # testing scale + im_size_min = np.min(frames[0].shape[0:2]) + im_size_max = np.max(frames[0].shape[0:2]) + resize = float(self.target_size) / float(im_size_min) + + # prevent bigger axis from being more than max_size + if np.round(resize * im_size_max) > self.max_size: + resize = float(self.max_size) / float(im_size_max) + resize = 1 if use_origin_size else resize + + # resize + if resize != 1: + if not from_PIL: + frames = F.interpolate(frames, scale_factor=resize) + else: + frames = [ + cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR) + for frame in frames + ] + + # convert to torch.tensor format + if not from_PIL: + frames = frames.transpose(1, 2).transpose(1, 3).contiguous() + else: + frames = frames.transpose((0, 3, 1, 2)) + frames = torch.from_numpy(frames) + + return frames, resize + + def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True): + """ + Arguments: + frames: a list of PIL.Image, or np.array(shape=[n, h, w, c], + type=np.uint8, BGR format). + conf_threshold: confidence threshold. + nms_threshold: nms threshold. + use_origin_size: whether to use origin size. + Returns: + final_bounding_boxes: list of np.array ([n_boxes, 5], + type=np.float32). + final_landmarks: list of np.array ([n_boxes, 10], type=np.float32). + """ + # self.t['forward_pass'].tic() + frames, self.resize = self.batched_transform(frames, use_origin_size) + frames = frames.to(device) + frames = frames - self.mean_tensor + + b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames) + + final_bounding_boxes, final_landmarks = [], [] + + # decode + priors = priors.unsqueeze(0) + b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize + b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize + b_conf = b_conf[:, :, 1] + + # index for selection + b_indice = b_conf > conf_threshold + + # concat + b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float() + + for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice): + + # ignore low scores + pred, landm = pred[inds, :], landm[inds, :] + if pred.shape[0] == 0: + final_bounding_boxes.append(np.array([], dtype=np.float32)) + final_landmarks.append(np.array([], dtype=np.float32)) + continue + + # sort + # order = score.argsort(descending=True) + # box, landm, score = box[order], landm[order], score[order] + + # to CPU + bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy() + + # NMS + keep = py_cpu_nms(bounding_boxes, nms_threshold) + bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep] + + # append + final_bounding_boxes.append(bounding_boxes) + final_landmarks.append(landmarks) + # self.t['forward_pass'].toc(average=True) + # self.batch_time += self.t['forward_pass'].diff + # self.total_frame += len(frames) + # print(self.batch_time / self.total_frame) + + return final_bounding_boxes, final_landmarks diff --git a/repositories/codeformer/facelib/detection/retinaface/retinaface_net.py b/repositories/codeformer/facelib/detection/retinaface/retinaface_net.py new file mode 100644 index 000000000..ab6aa82d3 --- /dev/null +++ b/repositories/codeformer/facelib/detection/retinaface/retinaface_net.py @@ -0,0 +1,196 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def conv_bn(inp, oup, stride=1, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True)) + + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True), + ) + + +class SSH(nn.Module): + + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + + +class FPN(nn.Module): + + def __init__(self, in_channels_list, out_channels): + super(FPN, self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky) + + def forward(self, input): + # names = list(input.keys()) + # input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest') + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest') + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + +class MobileNetV1(nn.Module): + + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky=0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + self.avg = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(256, 1000) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + x = self.avg(x) + # x = self.model(x) + x = x.view(-1, 256) + x = self.fc(x) + return x + + +class ClassHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(ClassHead, self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 2) + + +class BboxHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(BboxHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 4) + + +class LandmarkHead(nn.Module): + + def __init__(self, inchannels=512, num_anchors=3): + super(LandmarkHead, self).__init__() + self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0) + + def forward(self, x): + out = self.conv1x1(x) + out = out.permute(0, 2, 3, 1).contiguous() + + return out.view(out.shape[0], -1, 10) + + +def make_class_head(fpn_num=3, inchannels=64, anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels, anchor_num)) + return classhead + + +def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels, anchor_num)) + return bboxhead + + +def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels, anchor_num)) + return landmarkhead diff --git a/repositories/codeformer/facelib/detection/retinaface/retinaface_utils.py b/repositories/codeformer/facelib/detection/retinaface/retinaface_utils.py new file mode 100644 index 000000000..8c3577577 --- /dev/null +++ b/repositories/codeformer/facelib/detection/retinaface/retinaface_utils.py @@ -0,0 +1,421 @@ +import numpy as np +import torch +import torchvision +from itertools import product as product +from math import ceil + + +class PriorBox(object): + + def __init__(self, cfg, image_size=None, phase='train'): + super(PriorBox, self).__init__() + self.min_sizes = cfg['min_sizes'] + self.steps = cfg['steps'] + self.clip = cfg['clip'] + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps] + self.name = 's' + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + keep = torchvision.ops.nms( + boxes=torch.Tensor(dets[:, :4]), + scores=torch.Tensor(dets[:, 4]), + iou_threshold=thresh, + ) + + return list(keep) + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + ( + boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:] / 2), + 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat( + (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy + boxes[:, 2:] - boxes[:, :2], + 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when matching boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ encoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ encoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence + 3)landm preds. + """ + # jaccard index + overlaps = jaccard(truths, point_form(priors)) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = encode(matches, priors, variances) + + matches_landm = landms[best_truth_idx] + landm = encode_landm(matches_landm, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def encode_landm(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 10]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded landm (tensor), Shape: [num_priors, 10] + """ + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + tmp = ( + priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:], + priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:], + ) + landms = torch.cat(tmp, dim=1) + return landms + + +def batched_decode(b_loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + b_loc (tensor): location predictions for loc layers, + Shape: [num_batches,num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + boxes = ( + priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]), + ) + boxes = torch.cat(boxes, dim=2) + + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + +def batched_decode_landm(pre, priors, variances): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + pre (tensor): landm predictions for loc layers, + Shape: [num_batches,num_priors,10] + priors (tensor): Prior boxes in center-offset form. + Shape: [1,num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded landm predictions + """ + landms = ( + priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:], + priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:], + ) + landms = torch.cat(landms, dim=2) + return landms + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/repositories/codeformer/facelib/detection/yolov5face/__init__.py b/repositories/codeformer/facelib/detection/yolov5face/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/codeformer/facelib/detection/yolov5face/face_detector.py b/repositories/codeformer/facelib/detection/yolov5face/face_detector.py new file mode 100644 index 000000000..79fdba0c9 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/face_detector.py @@ -0,0 +1,142 @@ +import copy +import os +from pathlib import Path + +import cv2 +import numpy as np +import torch +from torch import nn + +from facelib.detection.yolov5face.models.common import Conv +from facelib.detection.yolov5face.models.yolo import Model +from facelib.detection.yolov5face.utils.datasets import letterbox +from facelib.detection.yolov5face.utils.general import ( + check_img_size, + non_max_suppression_face, + scale_coords, + scale_coords_landmarks, +) + +IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9, 0) + + +def isListempty(inList): + if isinstance(inList, list): # Is a list + return all(map(isListempty, inList)) + return False # Not a list + +class YoloDetector: + def __init__( + self, + config_name, + min_face=10, + target_size=None, + device='cuda', + ): + """ + config_name: name of .yaml config with network configuration from models/ folder. + min_face : minimal face size in pixels. + target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080. + None for original resolution. + """ + self._class_path = Path(__file__).parent.absolute() + self.target_size = target_size + self.min_face = min_face + self.detector = Model(cfg=config_name) + self.device = device + + + def _preprocess(self, imgs): + """ + Preprocessing image before passing through the network. Resize and conversion to torch tensor. + """ + pp_imgs = [] + for img in imgs: + h0, w0 = img.shape[:2] # orig hw + if self.target_size: + r = self.target_size / min(h0, w0) # resize image to img_size + if r < 1: + img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR) + + imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size + img = letterbox(img, new_shape=imgsz)[0] + pp_imgs.append(img) + pp_imgs = np.array(pp_imgs) + pp_imgs = pp_imgs.transpose(0, 3, 1, 2) + pp_imgs = torch.from_numpy(pp_imgs).to(self.device) + pp_imgs = pp_imgs.float() # uint8 to fp16/32 + return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0 + + def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres): + """ + Postprocessing of raw pytorch model output. + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + bboxes = [[] for _ in range(len(origimgs))] + landmarks = [[] for _ in range(len(origimgs))] + + pred = non_max_suppression_face(pred, conf_thres, iou_thres) + + for image_id, origimg in enumerate(origimgs): + img_shape = origimg.shape + image_height, image_width = img_shape[:2] + gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh + gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks + det = pred[image_id].cpu() + scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round() + scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round() + + for j in range(det.size()[0]): + box = (det[j, :4].view(1, 4) / gn).view(-1).tolist() + box = list( + map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height]) + ) + if box[3] - box[1] < self.min_face: + continue + lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist() + lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)])) + lm = [lm[i : i + 2] for i in range(0, len(lm), 2)] + bboxes[image_id].append(box) + landmarks[image_id].append(lm) + return bboxes, landmarks + + def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5): + """ + Get bbox coordinates and keypoints of faces on original image. + Params: + imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference) + conf_thres: confidence threshold for each prediction + iou_thres: threshold for NMS (filter of intersecting bboxes) + Returns: + bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2. + points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners). + """ + # Pass input images through face detector + images = imgs if isinstance(imgs, list) else [imgs] + images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images] + origimgs = copy.deepcopy(images) + + images = self._preprocess(images) + + if IS_HIGH_VERSION: + with torch.inference_mode(): # for pytorch>=1.9 + pred = self.detector(images)[0] + else: + with torch.no_grad(): # for pytorch<1.9 + pred = self.detector(images)[0] + + bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres) + + # return bboxes, points + if not isListempty(points): + bboxes = np.array(bboxes).reshape(-1,4) + points = np.array(points).reshape(-1,10) + padding = bboxes[:,0].reshape(-1,1) + return np.concatenate((bboxes, padding, points), axis=1) + else: + return None + + def __call__(self, *args): + return self.predict(*args) diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/__init__.py b/repositories/codeformer/facelib/detection/yolov5face/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/common.py b/repositories/codeformer/facelib/detection/yolov5face/models/common.py new file mode 100644 index 000000000..497a00444 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/models/common.py @@ -0,0 +1,299 @@ +# This file contains modules common to various models + +import math + +import numpy as np +import torch +from torch import nn + +from facelib.detection.yolov5face.utils.datasets import letterbox +from facelib.detection.yolov5face.utils.general import ( + make_divisible, + non_max_suppression, + scale_coords, + xyxy2xywh, +) + + +def autopad(k, p=None): # kernel, padding + # Pad to 'same' + if p is None: + p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad + return p + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc") + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + return x.view(batchsize, -1, height, width) + + +def DWConv(c1, c2, k=1, s=1, act=True): + # Depthwise convolution + return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act) + + +class Conv(nn.Module): + # Standard convolution + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity()) + + def forward(self, x): + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class StemBlock(nn.Module): + def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True): + super().__init__() + self.stem_1 = Conv(c1, c2, k, s, p, g, act) + self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0) + self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1) + self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True) + self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0) + + def forward(self, x): + stem_1_out = self.stem_1(x) + stem_2a_out = self.stem_2a(stem_1_out) + stem_2b_out = self.stem_2b(stem_2a_out) + stem_2p_out = self.stem_2p(stem_1_out) + return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1)) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_, c2, 3, 1, g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class BottleneckCSP(nn.Module): + # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) + self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) + self.cv4 = Conv(2 * c_, c2, 1, 1) + self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) + self.act = nn.LeakyReLU(0.1, inplace=True) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + y1 = self.cv3(self.m(self.cv1(x))) + y2 = self.cv2(x) + return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1)))) + + +class C3(nn.Module): + # CSP Bottleneck with 3 convolutions + def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c1, c_, 1, 1) + self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2) + self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) + + def forward(self, x): + return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1)) + + +class ShuffleV2Block(nn.Module): + def __init__(self, inp, oup, stride): + super().__init__() + + if not 1 <= stride <= 3: + raise ValueError("illegal stride value") + self.stride = stride + + branch_features = oup // 2 + + if self.stride > 1: + self.branch1 = nn.Sequential( + self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(inp), + nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + else: + self.branch1 = nn.Sequential() + + self.branch2 = nn.Sequential( + nn.Conv2d( + inp if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), + nn.BatchNorm2d(branch_features), + nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), + nn.BatchNorm2d(branch_features), + nn.SiLU(), + ) + + @staticmethod + def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): + return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) + + def forward(self, x): + if self.stride == 1: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + else: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + out = channel_shuffle(out, 2) + return out + + +class SPP(nn.Module): + # Spatial pyramid pooling layer used in YOLOv3-SPP + def __init__(self, c1, c2, k=(5, 9, 13)): + super().__init__() + c_ = c1 // 2 # hidden channels + self.cv1 = Conv(c1, c_, 1, 1) + self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) + + def forward(self, x): + x = self.cv1(x) + return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) + + +class Focus(nn.Module): + # Focus wh information into c-space + def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups + super().__init__() + self.conv = Conv(c1 * 4, c2, k, s, p, g, act) + + def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2) + return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1)) + + +class Concat(nn.Module): + # Concatenate a list of tensors along dimension + def __init__(self, dimension=1): + super().__init__() + self.d = dimension + + def forward(self, x): + return torch.cat(x, self.d) + + +class NMS(nn.Module): + # Non-Maximum Suppression (NMS) module + conf = 0.25 # confidence threshold + iou = 0.45 # IoU threshold + classes = None # (optional list) filter by class + + def forward(self, x): + return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + + +class AutoShape(nn.Module): + # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS + img_size = 640 # inference size (pixels) + conf = 0.25 # NMS confidence threshold + iou = 0.45 # NMS IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, model): + super().__init__() + self.model = model.eval() + + def autoshape(self): + print("autoShape already enabled, skipping... ") # model already converted to model.autoshape() + return self + + def forward(self, imgs, size=640, augment=False, profile=False): + # Inference from various sources. For height=720, width=1280, RGB images example inputs are: + # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3) + # PIL: = Image.open('image.jpg') # HWC x(720,1280,3) + # numpy: = np.zeros((720,1280,3)) # HWC + # torch: = torch.zeros(16,3,720,1280) # BCHW + # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images + + p = next(self.model.parameters()) # for device and type + if isinstance(imgs, torch.Tensor): # torch + return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference + + # Pre-process + n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images + shape0, shape1 = [], [] # image and inference shapes + for i, im in enumerate(imgs): + im = np.array(im) # to numpy + if im.shape[0] < 5: # image in CHW + im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) + im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input + s = im.shape[:2] # HWC + shape0.append(s) # image shape + g = size / max(s) # gain + shape1.append([y * g for y in s]) + imgs[i] = im # update + shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape + x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad + x = np.stack(x, 0) if n > 1 else x[0][None] # stack + x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW + x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32 + + # Inference + with torch.no_grad(): + y = self.model(x, augment, profile)[0] # forward + y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS + + # Post-process + for i in range(n): + scale_coords(shape1, y[i][:, :4], shape0[i]) + + return Detections(imgs, y, self.names) + + +class Detections: + # detections class for YOLOv5 inference results + def __init__(self, imgs, pred, names=None): + super().__init__() + d = pred[0].device # device + gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations + self.imgs = imgs # list of images as numpy arrays + self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls) + self.names = names # class names + self.xyxy = pred # xyxy pixels + self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels + self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized + self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized + self.n = len(self.pred) + + def __len__(self): + return self.n + + def tolist(self): + # return a list of Detections objects, i.e. 'for result in results.tolist():' + x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)] + for d in x: + for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]: + setattr(d, k, getattr(d, k)[0]) # pop out of list + return x diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/experimental.py b/repositories/codeformer/facelib/detection/yolov5face/models/experimental.py new file mode 100644 index 000000000..37ba4c442 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/models/experimental.py @@ -0,0 +1,45 @@ +# # This file contains experimental modules + +import numpy as np +import torch +from torch import nn + +from facelib.detection.yolov5face.models.common import Conv + + +class CrossConv(nn.Module): + # Cross Convolution Downsample + def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False): + # ch_in, ch_out, kernel, stride, groups, expansion, shortcut + super().__init__() + c_ = int(c2 * e) # hidden channels + self.cv1 = Conv(c1, c_, (1, k), (1, s)) + self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g) + self.add = shortcut and c1 == c2 + + def forward(self, x): + return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) + + +class MixConv2d(nn.Module): + # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595 + def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True): + super().__init__() + groups = len(k) + if equal_ch: # equal c_ per group + i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices + c_ = [(i == g).sum() for g in range(groups)] # intermediate channels + else: # equal weight.numel() per group + b = [c2] + [0] * groups + a = np.eye(groups + 1, groups, k=-1) + a -= np.roll(a, 1, axis=1) + a *= np.array(k) ** 2 + a[0] = 1 + c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b + + self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)]) + self.bn = nn.BatchNorm2d(c2) + self.act = nn.LeakyReLU(0.1, inplace=True) + + def forward(self, x): + return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1))) diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/yolo.py b/repositories/codeformer/facelib/detection/yolov5face/models/yolo.py new file mode 100644 index 000000000..70845d972 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/models/yolo.py @@ -0,0 +1,235 @@ +import math +from copy import deepcopy +from pathlib import Path + +import torch +import yaml # for torch hub +from torch import nn + +from facelib.detection.yolov5face.models.common import ( + C3, + NMS, + SPP, + AutoShape, + Bottleneck, + BottleneckCSP, + Concat, + Conv, + DWConv, + Focus, + ShuffleV2Block, + StemBlock, +) +from facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d +from facelib.detection.yolov5face.utils.autoanchor import check_anchor_order +from facelib.detection.yolov5face.utils.general import make_divisible +from facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn + + +class Detect(nn.Module): + stride = None # strides computed during build + export = False # onnx export + + def __init__(self, nc=80, anchors=(), ch=()): # detection layer + super().__init__() + self.nc = nc # number of classes + self.no = nc + 5 + 10 # number of outputs per anchor + + self.nl = len(anchors) # number of detection layers + self.na = len(anchors[0]) // 2 # number of anchors + self.grid = [torch.zeros(1)] * self.nl # init grid + a = torch.tensor(anchors).float().view(self.nl, -1, 2) + self.register_buffer("anchors", a) # shape(nl,na,2) + self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2) + self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv + + def forward(self, x): + z = [] # inference output + if self.export: + for i in range(self.nl): + x[i] = self.m[i](x[i]) + return x + for i in range(self.nl): + x[i] = self.m[i](x[i]) # conv + bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85) + x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() + + if not self.training: # inference + if self.grid[i].shape[2:4] != x[i].shape[2:4]: + self.grid[i] = self._make_grid(nx, ny).to(x[i].device) + + y = torch.full_like(x[i], 0) + y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid() + y[..., 5:15] = x[i][..., 5:15] + + y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy + y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh + + y[..., 5:7] = ( + y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x1 y1 + y[..., 7:9] = ( + y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x2 y2 + y[..., 9:11] = ( + y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x3 y3 + y[..., 11:13] = ( + y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x4 y4 + y[..., 13:15] = ( + y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i] + ) # landmark x5 y5 + + z.append(y.view(bs, -1, self.no)) + + return x if self.training else (torch.cat(z, 1), x) + + @staticmethod + def _make_grid(nx=20, ny=20): + # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10 + yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)]) + return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float() + + +class Model(nn.Module): + def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes + super().__init__() + self.yaml_file = Path(cfg).name + with Path(cfg).open(encoding="utf8") as f: + self.yaml = yaml.safe_load(f) # model dict + + # Define model + ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels + if nc and nc != self.yaml["nc"]: + self.yaml["nc"] = nc # override yaml value + + self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist + self.names = [str(i) for i in range(self.yaml["nc"])] # default names + + # Build strides, anchors + m = self.model[-1] # Detect() + if isinstance(m, Detect): + s = 128 # 2x min stride + m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward + m.anchors /= m.stride.view(-1, 1, 1) + check_anchor_order(m) + self.stride = m.stride + self._initialize_biases() # only run once + + def forward(self, x): + return self.forward_once(x) # single-scale inference, train + + def forward_once(self, x): + y = [] # outputs + for m in self.model: + if m.f != -1: # if not from previous layer + x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers + + x = m(x) # run + y.append(x if m.i in self.save else None) # save output + + return x + + def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency + # https://arxiv.org/abs/1708.02002 section 3.3 + m = self.model[-1] # Detect() module + for mi, s in zip(m.m, m.stride): # from + b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85) + b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image) + b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls + mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True) + + def _print_biases(self): + m = self.model[-1] # Detect() module + for mi in m.m: # from + b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85) + print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean())) + + def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers + print("Fusing layers... ") + for m in self.model.modules(): + if isinstance(m, Conv) and hasattr(m, "bn"): + m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv + delattr(m, "bn") # remove batchnorm + m.forward = m.fuseforward # update forward + elif type(m) is nn.Upsample: + m.recompute_scale_factor = None # torch 1.11.0 compatibility + return self + + def nms(self, mode=True): # add or remove NMS module + present = isinstance(self.model[-1], NMS) # last layer is NMS + if mode and not present: + print("Adding NMS... ") + m = NMS() # module + m.f = -1 # from + m.i = self.model[-1].i + 1 # index + self.model.add_module(name=str(m.i), module=m) # add + self.eval() + elif not mode and present: + print("Removing NMS... ") + self.model = self.model[:-1] # remove + return self + + def autoshape(self): # add autoShape module + print("Adding autoShape... ") + m = AutoShape(self) # wrap model + copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes + return m + + +def parse_model(d, ch): # model_dict, input_channels(3) + anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"] + na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors + no = na * (nc + 5) # number of outputs = anchors * (classes + 5) + + layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out + for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args + m = eval(m) if isinstance(m, str) else m # eval strings + for j, a in enumerate(args): + try: + args[j] = eval(a) if isinstance(a, str) else a # eval strings + except: + pass + + n = max(round(n * gd), 1) if n > 1 else n # depth gain + if m in [ + Conv, + Bottleneck, + SPP, + DWConv, + MixConv2d, + Focus, + CrossConv, + BottleneckCSP, + C3, + ShuffleV2Block, + StemBlock, + ]: + c1, c2 = ch[f], args[0] + + c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 + + args = [c1, c2, *args[1:]] + if m in [BottleneckCSP, C3]: + args.insert(2, n) + n = 1 + elif m is nn.BatchNorm2d: + args = [ch[f]] + elif m is Concat: + c2 = sum(ch[-1 if x == -1 else x + 1] for x in f) + elif m is Detect: + args.append([ch[x + 1] for x in f]) + if isinstance(args[1], int): # number of anchors + args[1] = [list(range(args[1] * 2))] * len(f) + else: + c2 = ch[f] + + m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module + t = str(m)[8:-2].replace("__main__.", "") # module type + np = sum(x.numel() for x in m_.parameters()) # number params + m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params + save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist + layers.append(m_) + ch.append(c2) + return nn.Sequential(*layers), sorted(save) diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/yolov5l.yaml b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5l.yaml new file mode 100644 index 000000000..0532b0e22 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5l.yaml @@ -0,0 +1,47 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2 + [-1, 3, C3, [128]], + [-1, 1, Conv, [256, 3, 2]], # 2-P3/8 + [-1, 9, C3, [256]], + [-1, 1, Conv, [512, 3, 2]], # 4-P4/16 + [-1, 9, C3, [512]], + [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32 + [-1, 1, SPP, [1024, [3,5,7]]], + [-1, 3, C3, [1024, False]], # 8 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [512, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 5], 1, Concat, [1]], # cat backbone P4 + [-1, 3, C3, [512, False]], # 12 + + [-1, 1, Conv, [256, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 3], 1, Concat, [1]], # cat backbone P3 + [-1, 3, C3, [256, False]], # 16 (P3/8-small) + + [-1, 1, Conv, [256, 3, 2]], + [[-1, 13], 1, Concat, [1]], # cat head P4 + [-1, 3, C3, [512, False]], # 19 (P4/16-medium) + + [-1, 1, Conv, [512, 3, 2]], + [[-1, 9], 1, Concat, [1]], # cat head P5 + [-1, 3, C3, [1024, False]], # 22 (P5/32-large) + + [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] \ No newline at end of file diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/yolov5n.yaml b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5n.yaml new file mode 100644 index 000000000..caba6bed6 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5n.yaml @@ -0,0 +1,45 @@ +# parameters +nc: 1 # number of classes +depth_multiple: 1.0 # model depth multiple +width_multiple: 1.0 # layer channel multiple + +# anchors +anchors: + - [4,5, 8,10, 13,16] # P3/8 + - [23,29, 43,55, 73,105] # P4/16 + - [146,217, 231,300, 335,433] # P5/32 + +# YOLOv5 backbone +backbone: + # [from, number, module, args] + [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4 + [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8 + [-1, 3, ShuffleV2Block, [128, 1]], # 2 + [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16 + [-1, 7, ShuffleV2Block, [256, 1]], # 4 + [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32 + [-1, 3, ShuffleV2Block, [512, 1]], # 6 + ] + +# YOLOv5 head +head: + [[-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 4], 1, Concat, [1]], # cat backbone P4 + [-1, 1, C3, [128, False]], # 10 + + [-1, 1, Conv, [128, 1, 1]], + [-1, 1, nn.Upsample, [None, 2, 'nearest']], + [[-1, 2], 1, Concat, [1]], # cat backbone P3 + [-1, 1, C3, [128, False]], # 14 (P3/8-small) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 11], 1, Concat, [1]], # cat head P4 + [-1, 1, C3, [128, False]], # 17 (P4/16-medium) + + [-1, 1, Conv, [128, 3, 2]], + [[-1, 7], 1, Concat, [1]], # cat head P5 + [-1, 1, C3, [128, False]], # 20 (P5/32-large) + + [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) + ] diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/__init__.py b/repositories/codeformer/facelib/detection/yolov5face/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/autoanchor.py b/repositories/codeformer/facelib/detection/yolov5face/utils/autoanchor.py new file mode 100644 index 000000000..a4eba3e94 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/utils/autoanchor.py @@ -0,0 +1,12 @@ +# Auto-anchor utils + + +def check_anchor_order(m): + # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary + a = m.anchor_grid.prod(-1).view(-1) # anchor area + da = a[-1] - a[0] # delta a + ds = m.stride[-1] - m.stride[0] # delta s + if da.sign() != ds.sign(): # same order + print("Reversing anchor order") + m.anchors[:] = m.anchors.flip(0) + m.anchor_grid[:] = m.anchor_grid.flip(0) diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py b/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py new file mode 100755 index 000000000..e672b136f --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py @@ -0,0 +1,35 @@ +import cv2 +import numpy as np + + +def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True): + # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232 + shape = img.shape[:2] # current shape [height, width] + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + ratio = r, r # width, height ratios + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + if auto: # minimum rectangle + dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding + elif scale_fill: # stretch + dw, dh = 0.0, 0.0 + new_unpad = (new_shape[1], new_shape[0]) + ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios + + dw /= 2 # divide padding into 2 sides + dh /= 2 + + if shape[::-1] != new_unpad: # resize + img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + return img, ratio, (dw, dh) diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py b/repositories/codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py new file mode 100644 index 000000000..4b8b63134 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py @@ -0,0 +1,5 @@ +import torch +import sys +sys.path.insert(0,'./facelib/detection/yolov5face') +model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model'] +torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth') \ No newline at end of file diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/general.py b/repositories/codeformer/facelib/detection/yolov5face/utils/general.py new file mode 100755 index 000000000..1c8e14f56 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/utils/general.py @@ -0,0 +1,271 @@ +import math +import time + +import numpy as np +import torch +import torchvision + + +def check_img_size(img_size, s=32): + # Verify img_size is a multiple of stride s + new_size = make_divisible(img_size, int(s)) # ceil gs-multiple + # if new_size != img_size: + # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}") + return new_size + + +def make_divisible(x, divisor): + # Returns x evenly divisible by divisor + return math.ceil(x / divisor) * divisor + + +def xyxy2xywh(x): + # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center + y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center + y[:, 2] = x[:, 2] - x[:, 0] # width + y[:, 3] = x[:, 3] - x[:, 1] # height + return y + + +def xywh2xyxy(x): + # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2]] -= pad[0] # x padding + coords[:, [1, 3]] -= pad[1] # y padding + coords[:, :4] /= gain + clip_coords(coords, img0_shape) + return coords + + +def clip_coords(boxes, img_shape): + # Clip bounding xyxy bounding boxes to image shape (height, width) + boxes[:, 0].clamp_(0, img_shape[1]) # x1 + boxes[:, 1].clamp_(0, img_shape[0]) # y1 + boxes[:, 2].clamp_(0, img_shape[1]) # x2 + boxes[:, 3].clamp_(0, img_shape[0]) # y2 + + +def box_iou(box1, box2): + # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py + """ + Return intersection-over-union (Jaccard index) of boxes. + Both sets of boxes are expected to be in (x1, y1, x2, y2) format. + Arguments: + box1 (Tensor[N, 4]) + box2 (Tensor[M, 4]) + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + + def box_area(box): + return (box[2] - box[0]) * (box[3] - box[1]) + + area1 = box_area(box1.T) + area2 = box_area(box2.T) + + inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2) + return inter / (area1[:, None] + area2 - inter) + + +def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 15 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + # Apply constraints + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label = labels[xi] + v = torch.zeros((len(label), nc + 15), device=x.device) + v[:, :4] = label[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, landmarks, cls) + if multi_label: + i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 15:].max(1, keepdim=True) + x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # If none remain process next image + n = x.shape[0] # number of boxes + if not n: + continue + + # Batched NMS + c = x[:, 15:16] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + break # time limit exceeded + + return output + + +def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()): + """Performs Non-Maximum Suppression (NMS) on inference results + + Returns: + detections with shape: nx6 (x1, y1, x2, y2, conf, cls) + """ + + nc = prediction.shape[2] - 5 # number of classes + xc = prediction[..., 4] > conf_thres # candidates + + # Settings + # (pixels) maximum box width and height + max_wh = 4096 + time_limit = 10.0 # seconds to quit after + redundant = True # require redundant detections + multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img) + merge = False # use merge-NMS + + t = time.time() + output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0] + for xi, x in enumerate(prediction): # image index, image inference + x = x[xc[xi]] # confidence + + # Cat apriori labels if autolabelling + if labels and len(labels[xi]): + label_id = labels[xi] + v = torch.zeros((len(label_id), nc + 5), device=x.device) + v[:, :4] = label_id[:, 1:5] # box + v[:, 4] = 1.0 # conf + v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls + x = torch.cat((x, v), 0) + + # If none remain process next image + if not x.shape[0]: + continue + + # Compute conf + x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf + + # Box (center x, center y, width, height) to (x1, y1, x2, y2) + box = xywh2xyxy(x[:, :4]) + + # Detections matrix nx6 (xyxy, conf, cls) + if multi_label: + i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1) + else: # best class only + conf, j = x[:, 5:].max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] + + # Filter by class + if classes is not None: + x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + + x = x[x[:, 4].argsort(descending=True)] # sort by confidence + + # Batched NMS + c = x[:, 5:6] * (0 if agnostic else max_wh) # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores + i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS + if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean) + # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4) + iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix + weights = iou * scores[None] # box weights + x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes + if redundant: + i = i[iou.sum(1) > 1] # require redundancy + + output[xi] = x[i] + if (time.time() - t) > time_limit: + print(f"WARNING: NMS time limit {time_limit}s exceeded") + break # time limit exceeded + + return output + + +def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None): + # Rescale coords (xyxy) from img1_shape to img0_shape + if ratio_pad is None: # calculate from img0_shape + gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new + pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding + else: + gain = ratio_pad[0][0] + pad = ratio_pad[1] + + coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding + coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding + coords[:, :10] /= gain + coords[:, 0].clamp_(0, img0_shape[1]) # x1 + coords[:, 1].clamp_(0, img0_shape[0]) # y1 + coords[:, 2].clamp_(0, img0_shape[1]) # x2 + coords[:, 3].clamp_(0, img0_shape[0]) # y2 + coords[:, 4].clamp_(0, img0_shape[1]) # x3 + coords[:, 5].clamp_(0, img0_shape[0]) # y3 + coords[:, 6].clamp_(0, img0_shape[1]) # x4 + coords[:, 7].clamp_(0, img0_shape[0]) # y4 + coords[:, 8].clamp_(0, img0_shape[1]) # x5 + coords[:, 9].clamp_(0, img0_shape[0]) # y5 + return coords diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/torch_utils.py b/repositories/codeformer/facelib/detection/yolov5face/utils/torch_utils.py new file mode 100644 index 000000000..af2d06587 --- /dev/null +++ b/repositories/codeformer/facelib/detection/yolov5face/utils/torch_utils.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + + +def fuse_conv_and_bn(conv, bn): + # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/ + fusedconv = ( + nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + # prepare filters + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var))) + fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size())) + + # prepare spatial bias + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps)) + fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn) + + return fusedconv + + +def copy_attr(a, b, include=(), exclude=()): + # Copy attributes from b to a, options to only include [...] and to exclude [...] + for k, v in b.__dict__.items(): + if (include and k not in include) or k.startswith("_") or k in exclude: + continue + + setattr(a, k, v) diff --git a/repositories/codeformer/facelib/parsing/__init__.py b/repositories/codeformer/facelib/parsing/__init__.py new file mode 100644 index 000000000..72656e4b5 --- /dev/null +++ b/repositories/codeformer/facelib/parsing/__init__.py @@ -0,0 +1,23 @@ +import torch + +from facelib.utils import load_file_from_url +from .bisenet import BiSeNet +from .parsenet import ParseNet + + +def init_parsing_model(model_name='bisenet', half=False, device='cuda'): + if model_name == 'bisenet': + model = BiSeNet(num_class=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth' + elif model_name == 'parsenet': + model = ParseNet(in_size=512, out_size=512, parsing_ch=19) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' + else: + raise NotImplementedError(f'{model_name} is not implemented.') + + model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None) + load_net = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(load_net, strict=True) + model.eval() + model = model.to(device) + return model diff --git a/repositories/codeformer/facelib/parsing/bisenet.py b/repositories/codeformer/facelib/parsing/bisenet.py new file mode 100644 index 000000000..3898cab76 --- /dev/null +++ b/repositories/codeformer/facelib/parsing/bisenet.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import ResNet18 + + +class ConvBNReLU(nn.Module): + + def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1): + super(ConvBNReLU, self).__init__() + self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm2d(out_chan) + + def forward(self, x): + x = self.conv(x) + x = F.relu(self.bn(x)) + return x + + +class BiSeNetOutput(nn.Module): + + def __init__(self, in_chan, mid_chan, num_class): + super(BiSeNetOutput, self).__init__() + self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1) + self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False) + + def forward(self, x): + feat = self.conv(x) + out = self.conv_out(feat) + return out, feat + + +class AttentionRefinementModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(AttentionRefinementModule, self).__init__() + self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1) + self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False) + self.bn_atten = nn.BatchNorm2d(out_chan) + self.sigmoid_atten = nn.Sigmoid() + + def forward(self, x): + feat = self.conv(x) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv_atten(atten) + atten = self.bn_atten(atten) + atten = self.sigmoid_atten(atten) + out = torch.mul(feat, atten) + return out + + +class ContextPath(nn.Module): + + def __init__(self): + super(ContextPath, self).__init__() + self.resnet = ResNet18() + self.arm16 = AttentionRefinementModule(256, 128) + self.arm32 = AttentionRefinementModule(512, 128) + self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1) + self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0) + + def forward(self, x): + feat8, feat16, feat32 = self.resnet(x) + h8, w8 = feat8.size()[2:] + h16, w16 = feat16.size()[2:] + h32, w32 = feat32.size()[2:] + + avg = F.avg_pool2d(feat32, feat32.size()[2:]) + avg = self.conv_avg(avg) + avg_up = F.interpolate(avg, (h32, w32), mode='nearest') + + feat32_arm = self.arm32(feat32) + feat32_sum = feat32_arm + avg_up + feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest') + feat32_up = self.conv_head32(feat32_up) + + feat16_arm = self.arm16(feat16) + feat16_sum = feat16_arm + feat32_up + feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest') + feat16_up = self.conv_head16(feat16_up) + + return feat8, feat16_up, feat32_up # x8, x8, x16 + + +class FeatureFusionModule(nn.Module): + + def __init__(self, in_chan, out_chan): + super(FeatureFusionModule, self).__init__() + self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0) + self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False) + self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False) + self.relu = nn.ReLU(inplace=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, fsp, fcp): + fcat = torch.cat([fsp, fcp], dim=1) + feat = self.convblk(fcat) + atten = F.avg_pool2d(feat, feat.size()[2:]) + atten = self.conv1(atten) + atten = self.relu(atten) + atten = self.conv2(atten) + atten = self.sigmoid(atten) + feat_atten = torch.mul(feat, atten) + feat_out = feat_atten + feat + return feat_out + + +class BiSeNet(nn.Module): + + def __init__(self, num_class): + super(BiSeNet, self).__init__() + self.cp = ContextPath() + self.ffm = FeatureFusionModule(256, 256) + self.conv_out = BiSeNetOutput(256, 256, num_class) + self.conv_out16 = BiSeNetOutput(128, 64, num_class) + self.conv_out32 = BiSeNetOutput(128, 64, num_class) + + def forward(self, x, return_feat=False): + h, w = x.size()[2:] + feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature + feat_sp = feat_res8 # replace spatial path feature with res3b1 feature + feat_fuse = self.ffm(feat_sp, feat_cp8) + + out, feat = self.conv_out(feat_fuse) + out16, feat16 = self.conv_out16(feat_cp8) + out32, feat32 = self.conv_out32(feat_cp16) + + out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True) + out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True) + out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True) + + if return_feat: + feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True) + feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True) + feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True) + return out, out16, out32, feat, feat16, feat32 + else: + return out, out16, out32 diff --git a/repositories/codeformer/facelib/parsing/parsenet.py b/repositories/codeformer/facelib/parsing/parsenet.py new file mode 100644 index 000000000..e178ebe43 --- /dev/null +++ b/repositories/codeformer/facelib/parsing/parsenet.py @@ -0,0 +1,194 @@ +"""Modified from https://github.com/chaofengc/PSFRGAN +""" +import numpy as np +import torch.nn as nn +from torch.nn import functional as F + + +class NormLayer(nn.Module): + """Normalization Layers. + + Args: + channels: input channels, for batch norm and instance norm. + input_size: input shape without batch size, for layer norm. + """ + + def __init__(self, channels, normalize_shape=None, norm_type='bn'): + super(NormLayer, self).__init__() + norm_type = norm_type.lower() + self.norm_type = norm_type + if norm_type == 'bn': + self.norm = nn.BatchNorm2d(channels, affine=True) + elif norm_type == 'in': + self.norm = nn.InstanceNorm2d(channels, affine=False) + elif norm_type == 'gn': + self.norm = nn.GroupNorm(32, channels, affine=True) + elif norm_type == 'pixel': + self.norm = lambda x: F.normalize(x, p=2, dim=1) + elif norm_type == 'layer': + self.norm = nn.LayerNorm(normalize_shape) + elif norm_type == 'none': + self.norm = lambda x: x * 1.0 + else: + assert 1 == 0, f'Norm type {norm_type} not support.' + + def forward(self, x, ref=None): + if self.norm_type == 'spade': + return self.norm(x, ref) + else: + return self.norm(x) + + +class ReluLayer(nn.Module): + """Relu Layer. + + Args: + relu type: type of relu layer, candidates are + - ReLU + - LeakyReLU: default relu slope 0.2 + - PRelu + - SELU + - none: direct pass + """ + + def __init__(self, channels, relu_type='relu'): + super(ReluLayer, self).__init__() + relu_type = relu_type.lower() + if relu_type == 'relu': + self.func = nn.ReLU(True) + elif relu_type == 'leakyrelu': + self.func = nn.LeakyReLU(0.2, inplace=True) + elif relu_type == 'prelu': + self.func = nn.PReLU(channels) + elif relu_type == 'selu': + self.func = nn.SELU(True) + elif relu_type == 'none': + self.func = lambda x: x * 1.0 + else: + assert 1 == 0, f'Relu type {relu_type} not support.' + + def forward(self, x): + return self.func(x) + + +class ConvLayer(nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + scale='none', + norm_type='none', + relu_type='none', + use_pad=True, + bias=True): + super(ConvLayer, self).__init__() + self.use_pad = use_pad + self.norm_type = norm_type + if norm_type in ['bn']: + bias = False + + stride = 2 if scale == 'down' else 1 + + self.scale_func = lambda x: x + if scale == 'up': + self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest') + + self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2))) + self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias) + + self.relu = ReluLayer(out_channels, relu_type) + self.norm = NormLayer(out_channels, norm_type=norm_type) + + def forward(self, x): + out = self.scale_func(x) + if self.use_pad: + out = self.reflection_pad(out) + out = self.conv2d(out) + out = self.norm(out) + out = self.relu(out) + return out + + +class ResidualBlock(nn.Module): + """ + Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html + """ + + def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'): + super(ResidualBlock, self).__init__() + + if scale == 'none' and c_in == c_out: + self.shortcut_func = lambda x: x + else: + self.shortcut_func = ConvLayer(c_in, c_out, 3, scale) + + scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']} + scale_conf = scale_config_dict[scale] + + self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type) + self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none') + + def forward(self, x): + identity = self.shortcut_func(x) + + res = self.conv1(x) + res = self.conv2(res) + return identity + res + + +class ParseNet(nn.Module): + + def __init__(self, + in_size=128, + out_size=128, + min_feat_size=32, + base_ch=64, + parsing_ch=19, + res_depth=10, + relu_type='LeakyReLU', + norm_type='bn', + ch_range=[32, 256]): + super().__init__() + self.res_depth = res_depth + act_args = {'norm_type': norm_type, 'relu_type': relu_type} + min_ch, max_ch = ch_range + + ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731 + min_feat_size = min(in_size, min_feat_size) + + down_steps = int(np.log2(in_size // min_feat_size)) + up_steps = int(np.log2(out_size // min_feat_size)) + + # =============== define encoder-body-decoder ==================== + self.encoder = [] + self.encoder.append(ConvLayer(3, base_ch, 3, 1)) + head_ch = base_ch + for i in range(down_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2) + self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args)) + head_ch = head_ch * 2 + + self.body = [] + for i in range(res_depth): + self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args)) + + self.decoder = [] + for i in range(up_steps): + cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2) + self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args)) + head_ch = head_ch // 2 + + self.encoder = nn.Sequential(*self.encoder) + self.body = nn.Sequential(*self.body) + self.decoder = nn.Sequential(*self.decoder) + self.out_img_conv = ConvLayer(ch_clip(head_ch), 3) + self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch) + + def forward(self, x): + feat = self.encoder(x) + x = feat + self.body(feat) + x = self.decoder(x) + out_img = self.out_img_conv(x) + out_mask = self.out_mask_conv(x) + return out_mask, out_img diff --git a/repositories/codeformer/facelib/parsing/resnet.py b/repositories/codeformer/facelib/parsing/resnet.py new file mode 100644 index 000000000..fec8e82cf --- /dev/null +++ b/repositories/codeformer/facelib/parsing/resnet.py @@ -0,0 +1,69 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + + def __init__(self, in_chan, out_chan, stride=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(in_chan, out_chan, stride) + self.bn1 = nn.BatchNorm2d(out_chan) + self.conv2 = conv3x3(out_chan, out_chan) + self.bn2 = nn.BatchNorm2d(out_chan) + self.relu = nn.ReLU(inplace=True) + self.downsample = None + if in_chan != out_chan or stride != 1: + self.downsample = nn.Sequential( + nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(out_chan), + ) + + def forward(self, x): + residual = self.conv1(x) + residual = F.relu(self.bn1(residual)) + residual = self.conv2(residual) + residual = self.bn2(residual) + + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(x) + + out = shortcut + residual + out = self.relu(out) + return out + + +def create_layer_basic(in_chan, out_chan, bnum, stride=1): + layers = [BasicBlock(in_chan, out_chan, stride=stride)] + for i in range(bnum - 1): + layers.append(BasicBlock(out_chan, out_chan, stride=1)) + return nn.Sequential(*layers) + + +class ResNet18(nn.Module): + + def __init__(self): + super(ResNet18, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1) + self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2) + self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2) + self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(self.bn1(x)) + x = self.maxpool(x) + + x = self.layer1(x) + feat8 = self.layer2(x) # 1/8 + feat16 = self.layer3(feat8) # 1/16 + feat32 = self.layer4(feat16) # 1/32 + return feat8, feat16, feat32 diff --git a/repositories/codeformer/facelib/utils/__init__.py b/repositories/codeformer/facelib/utils/__init__.py new file mode 100644 index 000000000..f03b1c2ba --- /dev/null +++ b/repositories/codeformer/facelib/utils/__init__.py @@ -0,0 +1,7 @@ +from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back +from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir + +__all__ = [ + 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url', + 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir' +] diff --git a/repositories/codeformer/facelib/utils/face_restoration_helper.py b/repositories/codeformer/facelib/utils/face_restoration_helper.py new file mode 100644 index 000000000..5d3fb8f3b --- /dev/null +++ b/repositories/codeformer/facelib/utils/face_restoration_helper.py @@ -0,0 +1,460 @@ +import cv2 +import numpy as np +import os +import torch +from torchvision.transforms.functional import normalize + +from facelib.detection import init_detection_model +from facelib.parsing import init_parsing_model +from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray + + +def get_largest_face(det_faces, h, w): + + def get_location(val, length): + if val < 0: + return 0 + elif val > length: + return length + else: + return val + + face_areas = [] + for det_face in det_faces: + left = get_location(det_face[0], w) + right = get_location(det_face[2], w) + top = get_location(det_face[1], h) + bottom = get_location(det_face[3], h) + face_area = (right - left) * (bottom - top) + face_areas.append(face_area) + largest_idx = face_areas.index(max(face_areas)) + return det_faces[largest_idx], largest_idx + + +def get_center_face(det_faces, h=0, w=0, center=None): + if center is not None: + center = np.array(center) + else: + center = np.array([w / 2, h / 2]) + center_dist = [] + for det_face in det_faces: + face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2]) + dist = np.linalg.norm(face_center - center) + center_dist.append(dist) + center_idx = center_dist.index(min(center_dist)) + return det_faces[center_idx], center_idx + + +class FaceRestoreHelper(object): + """Helper for the face restoration pipeline (base class).""" + + def __init__(self, + upscale_factor, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + template_3points=False, + pad_blur=False, + use_parse=False, + device=None): + self.template_3points = template_3points # improve robustness + self.upscale_factor = int(upscale_factor) + # the cropped face ratio based on the square face + self.crop_ratio = crop_ratio # (h, w) + assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1' + self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) + + if self.template_3points: + self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) + else: + # standard 5 landmarks for FFHQ faces with 512 x 512 + # facexlib + self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935], + [201.26117, 371.41043], [313.08905, 371.15118]]) + + # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 + # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], + # [198.22603, 372.82502], [313.91018, 372.75659]]) + + + self.face_template = self.face_template * (face_size / 512.0) + if self.crop_ratio[0] > 1: + self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 + if self.crop_ratio[1] > 1: + self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 + self.save_ext = save_ext + self.pad_blur = pad_blur + if self.pad_blur is True: + self.template_3points = False + + self.all_landmarks_5 = [] + self.det_faces = [] + self.affine_matrices = [] + self.inverse_affine_matrices = [] + self.cropped_faces = [] + self.restored_faces = [] + self.pad_input_imgs = [] + + if device is None: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + else: + self.device = device + + # init face detection model + self.face_det = init_detection_model(det_model, half=False, device=self.device) + + # init face parsing model + self.use_parse = use_parse + self.face_parse = init_parsing_model(model_name='parsenet', device=self.device) + + def set_upscale_factor(self, upscale_factor): + self.upscale_factor = upscale_factor + + def read_image(self, img): + """img can be image path or cv2 loaded image.""" + # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255] + if isinstance(img, str): + img = cv2.imread(img) + + if np.max(img) > 256: # 16-bit image + img = img / 65535 * 255 + if len(img.shape) == 2: # gray image + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + elif img.shape[2] == 4: # BGRA image with alpha channel + img = img[:, :, 0:3] + + self.input_img = img + self.is_gray = is_gray(img, threshold=5) + if self.is_gray: + print('Grayscale input: True') + + if min(self.input_img.shape[:2])<512: + f = 512.0/min(self.input_img.shape[:2]) + self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR) + + def get_face_landmarks_5(self, + only_keep_largest=False, + only_center_face=False, + resize=None, + blur_ratio=0.01, + eye_dist_threshold=None): + if resize is None: + scale = 1 + input_img = self.input_img + else: + h, w = self.input_img.shape[0:2] + scale = resize / min(h, w) + scale = max(1, scale) # always scale up + h, w = int(h * scale), int(w * scale) + interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR + input_img = cv2.resize(self.input_img, (w, h), interpolation=interp) + + with torch.no_grad(): + bboxes = self.face_det.detect_faces(input_img) + + if bboxes is None or bboxes.shape[0] == 0: + return 0 + else: + bboxes = bboxes / scale + + for bbox in bboxes: + # remove faces with too small eye distance: side faces or too small faces + eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]]) + if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold): + continue + + if self.template_3points: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)]) + else: + landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)]) + self.all_landmarks_5.append(landmark) + self.det_faces.append(bbox[0:5]) + + if len(self.det_faces) == 0: + return 0 + if only_keep_largest: + h, w, _ = self.input_img.shape + self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]] + elif only_center_face: + h, w, _ = self.input_img.shape + self.det_faces, center_idx = get_center_face(self.det_faces, h, w) + self.all_landmarks_5 = [self.all_landmarks_5[center_idx]] + + # pad blurry images + if self.pad_blur: + self.pad_input_imgs = [] + for landmarks in self.all_landmarks_5: + # get landmarks + eye_left = landmarks[0, :] + eye_right = landmarks[1, :] + eye_avg = (eye_left + eye_right) * 0.5 + mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1.5 + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + border = max(int(np.rint(qsize * 0.1)), 3) + + # get pad + # pad: (width_left, height_top, width_right, height_bottom) + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = [ + max(-pad[0] + border, 1), + max(-pad[1] + border, 1), + max(pad[2] - self.input_img.shape[0] + border, 1), + max(pad[3] - self.input_img.shape[1] + border, 1) + ] + + if max(pad) > 1: + # pad image + pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + # modify landmark coords + landmarks[:, 0] += pad[0] + landmarks[:, 1] += pad[1] + # blur pad images + h, w, _ = pad_img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * blur_ratio) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur)) + # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0) + + pad_img = pad_img.astype('float32') + pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0) + pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255] + self.pad_input_imgs.append(pad_img) + else: + self.pad_input_imgs.append(np.copy(self.input_img)) + + return len(self.all_landmarks_5) + + def align_warp_face(self, save_cropped_path=None, border_mode='constant'): + """Align and warp faces with face template. + """ + if self.pad_blur: + assert len(self.pad_input_imgs) == len( + self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}' + for idx, landmark in enumerate(self.all_landmarks_5): + # use 5 landmarks to get affine matrix + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0] + self.affine_matrices.append(affine_matrix) + # warp and crop faces + if border_mode == 'constant': + border_mode = cv2.BORDER_CONSTANT + elif border_mode == 'reflect101': + border_mode = cv2.BORDER_REFLECT101 + elif border_mode == 'reflect': + border_mode = cv2.BORDER_REFLECT + if self.pad_blur: + input_img = self.pad_input_imgs[idx] + else: + input_img = self.input_img + cropped_face = cv2.warpAffine( + input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray + self.cropped_faces.append(cropped_face) + # save the cropped face + if save_cropped_path is not None: + path = os.path.splitext(save_cropped_path)[0] + save_path = f'{path}_{idx:02d}.{self.save_ext}' + imwrite(cropped_face, save_path) + + def get_inverse_affine(self, save_inverse_affine_path=None): + """Get inverse affine matrix.""" + for idx, affine_matrix in enumerate(self.affine_matrices): + inverse_affine = cv2.invertAffineTransform(affine_matrix) + inverse_affine *= self.upscale_factor + self.inverse_affine_matrices.append(inverse_affine) + # save inverse affine matrices + if save_inverse_affine_path is not None: + path, _ = os.path.splitext(save_inverse_affine_path) + save_path = f'{path}_{idx:02d}.pth' + torch.save(inverse_affine, save_path) + + + def add_restored_face(self, face): + if self.is_gray: + face = bgr2gray(face) # convert img into grayscale + self.restored_faces.append(face) + + + def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None): + h, w, _ = self.input_img.shape + h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor) + + if upsample_img is None: + # simply resize the background + # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR) + else: + upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4) + + assert len(self.restored_faces) == len( + self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') + + inv_mask_borders = [] + for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): + if face_upsampler is not None: + restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0] + inverse_affine /= self.upscale_factor + inverse_affine[:, 2] *= self.upscale_factor + face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor) + else: + # Add an offset to inverse affine matrix, for more precise back alignment + if self.upscale_factor > 1: + extra_offset = 0.5 * self.upscale_factor + else: + extra_offset = 0 + inverse_affine[:, 2] += extra_offset + face_size = self.face_size + inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) + + # if draw_box or not self.use_parse: # use square parse maps + # mask = np.ones(face_size, dtype=np.float32) + # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # # remove the black borders + # inv_mask_erosion = cv2.erode( + # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + # pasted_face = inv_mask_erosion[:, :, None] * inv_restored + # total_face_area = np.sum(inv_mask_erosion) # // 3 + # # add border + # if draw_box: + # h, w = face_size + # mask_border = np.ones((h, w, 3), dtype=np.float32) + # border = int(1400/np.sqrt(total_face_area)) + # mask_border[border:h-border, border:w-border,:] = 0 + # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + # inv_mask_borders.append(inv_mask_border) + # if not self.use_parse: + # # compute the fusion edge based on the area of face + # w_edge = int(total_face_area**0.5) // 20 + # erosion_radius = w_edge * 2 + # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + # blur_size = w_edge * 2 + # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + # if len(upsample_img.shape) == 2: # upsample_img is gray image + # upsample_img = upsample_img[:, :, None] + # inv_soft_mask = inv_soft_mask[:, :, None] + + # always use square mask + mask = np.ones(face_size, dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) + # remove the black borders + inv_mask_erosion = cv2.erode( + inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) + pasted_face = inv_mask_erosion[:, :, None] * inv_restored + total_face_area = np.sum(inv_mask_erosion) # // 3 + # add border + if draw_box: + h, w = face_size + mask_border = np.ones((h, w, 3), dtype=np.float32) + border = int(1400/np.sqrt(total_face_area)) + mask_border[border:h-border, border:w-border,:] = 0 + inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up)) + inv_mask_borders.append(inv_mask_border) + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + if len(upsample_img.shape) == 2: # upsample_img is gray image + upsample_img = upsample_img[:, :, None] + inv_soft_mask = inv_soft_mask[:, :, None] + + # parse mask + if self.use_parse: + # inference + face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR) + face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True) + normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + face_input = torch.unsqueeze(face_input, 0).to(self.device) + with torch.no_grad(): + out = self.face_parse(face_input)[0] + out = out.argmax(dim=1).squeeze().cpu().numpy() + + parse_mask = np.zeros(out.shape) + MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0] + for idx, color in enumerate(MASK_COLORMAP): + parse_mask[out == idx] = color + # blur the mask + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11) + # remove the black borders + thres = 10 + parse_mask[:thres, :] = 0 + parse_mask[-thres:, :] = 0 + parse_mask[:, :thres] = 0 + parse_mask[:, -thres:] = 0 + parse_mask = parse_mask / 255. + + parse_mask = cv2.resize(parse_mask, face_size) + parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3) + inv_soft_parse_mask = parse_mask[:, :, None] + # pasted_face = inv_restored + fuse_mask = (inv_soft_parse_mask 256: # 16-bit image + upsample_img = upsample_img.astype(np.uint16) + else: + upsample_img = upsample_img.astype(np.uint8) + + # draw bounding box + if draw_box: + # upsample_input_img = cv2.resize(input_img, (w_up, h_up)) + img_color = np.ones([*upsample_img.shape], dtype=np.float32) + img_color[:,:,0] = 0 + img_color[:,:,1] = 255 + img_color[:,:,2] = 0 + for inv_mask_border in inv_mask_borders: + upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img + # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img + + if save_path is not None: + path = os.path.splitext(save_path)[0] + save_path = f'{path}.{self.save_ext}' + imwrite(upsample_img, save_path) + return upsample_img + + def clean_all(self): + self.all_landmarks_5 = [] + self.restored_faces = [] + self.affine_matrices = [] + self.cropped_faces = [] + self.inverse_affine_matrices = [] + self.det_faces = [] + self.pad_input_imgs = [] \ No newline at end of file diff --git a/repositories/codeformer/facelib/utils/face_utils.py b/repositories/codeformer/facelib/utils/face_utils.py new file mode 100644 index 000000000..f1474a2a4 --- /dev/null +++ b/repositories/codeformer/facelib/utils/face_utils.py @@ -0,0 +1,248 @@ +import cv2 +import numpy as np +import torch + + +def compute_increased_bbox(bbox, increase_area, preserve_aspect=True): + left, top, right, bot = bbox + width = right - left + height = bot - top + + if preserve_aspect: + width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width)) + height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height)) + else: + width_increase = height_increase = increase_area + left = int(left - width_increase * width) + top = int(top - height_increase * height) + right = int(right + width_increase * width) + bot = int(bot + height_increase * height) + return (left, top, right, bot) + + +def get_valid_bboxes(bboxes, h, w): + left = max(bboxes[0], 0) + top = max(bboxes[1], 0) + right = min(bboxes[2], w) + bottom = min(bboxes[3], h) + return (left, top, right, bottom) + + +def align_crop_face_landmarks(img, + landmarks, + output_size, + transform_size=None, + enable_padding=True, + return_inverse_affine=False, + shrink_ratio=(1, 1)): + """Align and crop face with landmarks. + + The output_size and transform_size are based on width. The height is + adjusted based on shrink_ratio_h/shring_ration_w. + + Modified from: + https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py + + Args: + img (Numpy array): Input image. + landmarks (Numpy array): 5 or 68 or 98 landmarks. + output_size (int): Output face size. + transform_size (ing): Transform size. Usually the four time of + output_size. + enable_padding (float): Default: True. + shrink_ratio (float | tuple[float] | list[float]): Shring the whole + face for height and width (crop larger area). Default: (1, 1). + + Returns: + (Numpy array): Cropped face. + """ + lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5 + + if isinstance(shrink_ratio, (float, int)): + shrink_ratio = (shrink_ratio, shrink_ratio) + if transform_size is None: + transform_size = output_size * 4 + + # Parse landmarks + lm = np.array(landmarks) + if lm.shape[0] == 5 and lm_type == 'retinaface_5': + eye_left = lm[0] + eye_right = lm[1] + mouth_avg = (lm[3] + lm[4]) * 0.5 + elif lm.shape[0] == 5 and lm_type == 'dlib_5': + lm_eye_left = lm[2:4] + lm_eye_right = lm[0:2] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = lm[4] + elif lm.shape[0] == 68: + lm_eye_left = lm[36:42] + lm_eye_right = lm[42:48] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[48] + lm[54]) * 0.5 + elif lm.shape[0] == 98: + lm_eye_left = lm[60:68] + lm_eye_right = lm[68:76] + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + mouth_avg = (lm[76] + lm[82]) * 0.5 + + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + eye_to_mouth = mouth_avg - eye_avg + + # Get the oriented crop rectangle + # x: half width of the oriented crop rectangle + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise + # norm with the hypotenuse: get the direction + x /= np.hypot(*x) # get the hypotenuse of a right triangle + rect_scale = 1 # TODO: you can edit it to get larger rect + x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale) + # y: half height of the oriented crop rectangle + y = np.flipud(x) * [-1, 1] + + x *= shrink_ratio[1] # width + y *= shrink_ratio[0] # height + + # c: center + c = eye_avg + eye_to_mouth * 0.1 + # quad: (left_top, left_bottom, right_bottom, right_top) + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + # qsize: side length of the square + qsize = np.hypot(*x) * 2 + + quad_ori = np.copy(quad) + # Shrink, for large face + # TODO: do we really need shrink + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + h, w = img.shape[0:2] + rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink))) + img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA) + quad /= shrink + qsize /= shrink + + # Crop + h, w = img.shape[0:2] + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h)) + if crop[2] - crop[0] < w or crop[3] - crop[1] < h: + img = img[crop[1]:crop[3], crop[0]:crop[2], :] + quad -= crop[0:2] + + # Pad + # pad: (width_left, height_top, width_right, height_bottom) + h, w = img.shape[0:2] + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))), + int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') + h, w = img.shape[0:2] + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), + 1.0 - np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = int(qsize * 0.02) + if blur % 2 == 0: + blur += 1 + blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur)) + + img = img.astype('float32') + img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = np.clip(img, 0, 255) # float32, [0, 255] + quad += pad[:2] + + # Transform use cv2 + h_ratio = shrink_ratio[0] / shrink_ratio[1] + dst_h, dst_w = int(transform_size * h_ratio), transform_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0] + cropped_face = cv2.warpAffine( + img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray + + if output_size < transform_size: + cropped_face = cv2.resize( + cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR) + + if return_inverse_affine: + dst_h, dst_w = int(output_size * h_ratio), output_size + template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]]) + # use cv2.LMEDS method for the equivalence to skimage transform + # ref: https://blog.csdn.net/yichxi/article/details/115827338 + affine_matrix = cv2.estimateAffinePartial2D( + quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0] + inverse_affine = cv2.invertAffineTransform(affine_matrix) + else: + inverse_affine = None + return cropped_face, inverse_affine + + +def paste_face_back(img, face, inverse_affine): + h, w = img.shape[0:2] + face_h, face_w = face.shape[0:2] + inv_restored = cv2.warpAffine(face, inverse_affine, (w, h)) + mask = np.ones((face_h, face_w, 3), dtype=np.float32) + inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h)) + # remove the black borders + inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8)) + inv_restored_remove_border = inv_mask_erosion * inv_restored + total_face_area = np.sum(inv_mask_erosion) // 3 + # compute the fusion edge based on the area of face + w_edge = int(total_face_area**0.5) // 20 + erosion_radius = w_edge * 2 + inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) + blur_size = w_edge * 2 + inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) + img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img + # float32, [0, 255] + return img + + +if __name__ == '__main__': + import os + + from facelib.detection import init_detection_model + from facelib.utils.face_restoration_helper import get_largest_face + + img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png' + img_name = os.splitext(os.path.basename(img_path))[0] + + # initialize model + det_net = init_detection_model('retinaface_resnet50', half=False) + img_ori = cv2.imread(img_path) + h, w = img_ori.shape[0:2] + # if larger than 800, scale it + scale = max(h / 800, w / 800) + if scale > 1: + img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR) + + with torch.no_grad(): + bboxes = det_net.detect_faces(img, 0.97) + if scale > 1: + bboxes *= scale # the score is incorrect + bboxes = get_largest_face(bboxes, h, w)[0] + + landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)]) + + cropped_face, inverse_affine = align_crop_face_landmarks( + img_ori, + landmarks, + output_size=512, + transform_size=None, + enable_padding=True, + return_inverse_affine=True, + shrink_ratio=(1, 1)) + + cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face) + img = paste_face_back(img_ori, cropped_face, inverse_affine) + cv2.imwrite(f'tmp/{img_name}_back.png', img) diff --git a/repositories/codeformer/facelib/utils/misc.py b/repositories/codeformer/facelib/utils/misc.py new file mode 100644 index 000000000..7f5c95506 --- /dev/null +++ b/repositories/codeformer/facelib/utils/misc.py @@ -0,0 +1,174 @@ +import cv2 +import os +import os.path as osp +import numpy as np +from PIL import Image +import torch +from torch.hub import download_url_to_file, get_dir +from urllib.parse import urlparse +# from basicsr.utils.download_util import download_file_from_google_drive + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +def download_pretrained_models(file_ids, save_path_root): + import gdown + + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + file_url = 'https://drive.google.com/uc?id='+file_id + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accepts Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + """ + if model_dir is None: + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + Returns: + A generator for all the interested files with relative paths. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) + + +def is_gray(img, threshold=10): + img = Image.fromarray(img) + if len(img.getbands()) == 1: + return True + img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16) + img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16) + img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16) + diff1 = (img1 - img2).var() + diff2 = (img2 - img3).var() + diff3 = (img3 - img1).var() + diff_sum = (diff1 + diff2 + diff3) / 3.0 + if diff_sum <= threshold: + return True + else: + return False + +def rgb2gray(img, out_channel=3): + r, g, b = img[:,:,0], img[:,:,1], img[:,:,2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + if out_channel == 3: + gray = gray[:,:,np.newaxis].repeat(3, axis=2) + return gray + +def bgr2gray(img, out_channel=3): + b, g, r = img[:,:,0], img[:,:,1], img[:,:,2] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + if out_channel == 3: + gray = gray[:,:,np.newaxis].repeat(3, axis=2) + return gray diff --git a/repositories/codeformer/inference_codeformer.py b/repositories/codeformer/inference_codeformer.py new file mode 100644 index 000000000..cde1094af --- /dev/null +++ b/repositories/codeformer/inference_codeformer.py @@ -0,0 +1,269 @@ +import os +import cv2 +import argparse +import glob +import torch +from torchvision.transforms.functional import normalize +from basicsr.utils import imwrite, img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url +from facelib.utils.face_restoration_helper import FaceRestoreHelper +from facelib.utils.misc import is_gray +import torch.nn.functional as F + +from basicsr.utils.registry import ARCH_REGISTRY + +pretrain_model_url = { + 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', +} + +def set_realesrgan(): + from basicsr.archs.rrdbnet_arch import RRDBNet + from basicsr.utils.realesrgan_utils import RealESRGANer + + cuda_is_available = torch.cuda.is_available() + half = True if cuda_is_available else False + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + upsampler = RealESRGANer( + scale=2, + model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth", + model=model, + tile=args.bg_tile, + tile_pad=40, + pre_pad=0, + half=half, # need to set False in CPU mode + ) + + if not cuda_is_available: # CPU + import warnings + warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.' + 'The unoptimized RealESRGAN is slow on CPU. ' + 'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.', + category=RuntimeWarning) + return upsampler + +if __name__ == '__main__': + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + parser = argparse.ArgumentParser() + + parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs', + help='Input image, video or folder. Default: inputs/whole_imgs') + parser.add_argument('-o', '--output_path', type=str, default=None, + help='Output folder. Default: results/_') + parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5, + help='Balance the quality and fidelity. Default: 0.5') + parser.add_argument('-s', '--upscale', type=int, default=2, + help='The final upsampling scale of the image. Default: 2') + parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False') + parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False') + parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False') + # large det_model: 'YOLOv5l', 'retinaface_resnet50' + # small det_model: 'YOLOv5n', 'retinaface_mobile0.25' + parser.add_argument('--detection_model', type=str, default='retinaface_resnet50', + help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. \ + Default: retinaface_resnet50') + parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan') + parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False') + parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400') + parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None') + parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None') + + args = parser.parse_args() + + # ------------------------ input & output ------------------------ + w = args.fidelity_weight + input_video = False + if args.input_path.endswith(('jpg', 'png')): # input single img path + input_img_list = [args.input_path] + result_root = f'results/test_img_{w}' + elif args.input_path.endswith(('mp4', 'mov', 'avi')): # input video path + from basicsr.utils.video_util import VideoReader, VideoWriter + input_img_list = [] + vidreader = VideoReader(args.input_path) + image = vidreader.get_frame() + while image is not None: + input_img_list.append(image) + image = vidreader.get_frame() + audio = vidreader.get_audio() + fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps + video_name = os.path.basename(args.input_path)[:-4] + result_root = f'results/{video_name}_{w}' + input_video = True + vidreader.close() + else: # input img folder + if args.input_path.endswith('/'): # solve when path ends with / + args.input_path = args.input_path[:-1] + # scan all the jpg and png images + input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jp][pn]g'))) + result_root = f'results/{os.path.basename(args.input_path)}_{w}' + + if not args.output_path is None: # set output path + result_root = args.output_path + + test_img_num = len(input_img_list) + if test_img_num == 0: + raise FileNotFoundError("\nInput file is not found.") + + # ------------------ set up background upsampler ------------------ + if args.bg_upsampler == 'realesrgan': + bg_upsampler = set_realesrgan() + else: + bg_upsampler = None + + # ------------------ set up face upsampler ------------------ + if args.face_upsample: + face_upsampler = None + # if bg_upsampler is not None: + # face_upsampler = bg_upsampler + # else: + # face_upsampler = set_realesrgan() + else: + face_upsampler = None + + # ------------------ set up CodeFormer restorer ------------------- + net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, + connect_list=['32', '64', '128', '256']).to(device) + + # ckpt_path = 'weights/CodeFormer/codeformer.pth' + ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], + model_dir='weights/CodeFormer', progress=True, file_name=None) + checkpoint = torch.load(ckpt_path)['params_ema'] + net.load_state_dict(checkpoint) + net.eval() + + # ------------------ set up FaceRestoreHelper ------------------- + # large det_model: 'YOLOv5l', 'retinaface_resnet50' + # small det_model: 'YOLOv5n', 'retinaface_mobile0.25' + if not args.has_aligned: + print(f'Face detection model: {args.detection_model}') + if bg_upsampler is not None: + print(f'Background upsampling: True, Face upsampling: {args.face_upsample}') + else: + print(f'Background upsampling: False, Face upsampling: {args.face_upsample}') + + face_helper = FaceRestoreHelper( + args.upscale, + face_size=512, + crop_ratio=(1, 1), + det_model = args.detection_model, + save_ext='png', + use_parse=True, + device=device) + + # -------------------- start to processing --------------------- + for i, img_path in enumerate(input_img_list): + # clean all the intermediate results to process the next image + face_helper.clean_all() + + if isinstance(img_path, str): + img_name = os.path.basename(img_path) + basename, ext = os.path.splitext(img_name) + print(f'[{i+1}/{test_img_num}] Processing: {img_name}') + img = cv2.imread(img_path, cv2.IMREAD_COLOR) + else: # for video processing + basename = str(i).zfill(6) + img_name = f'{video_name}_{basename}' if input_video else basename + print(f'[{i+1}/{test_img_num}] Processing: {img_name}') + img = img_path + + if args.has_aligned: + # the input faces are already cropped and aligned + img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) + face_helper.is_gray = is_gray(img, threshold=5) + if face_helper.is_gray: + print('Grayscale input: True') + face_helper.cropped_faces = [img] + else: + face_helper.read_image(img) + # get face landmarks for each face + num_det_faces = face_helper.get_face_landmarks_5( + only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5) + print(f'\tdetect {num_det_faces} faces') + # align and warp each face + face_helper.align_warp_face() + + # face restoration for each cropped face + for idx, cropped_face in enumerate(face_helper.cropped_faces): + # prepare data + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + with torch.no_grad(): + output = net(cropped_face_t, w=w, adain=True)[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except Exception as error: + print(f'\tFailed inference for CodeFormer: {error}') + restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + # paste_back + if not args.has_aligned: + # upsample the background + if bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0] + else: + bg_img = None + face_helper.get_inverse_affine(None) + # paste each restored face to the input image + if args.face_upsample and face_upsampler is not None: + restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler) + else: + restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box) + + # save faces + for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)): + # save cropped face + if not args.has_aligned: + save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png') + imwrite(cropped_face, save_crop_path) + # save restored face + if args.has_aligned: + save_face_name = f'{basename}.png' + else: + save_face_name = f'{basename}_{idx:02d}.png' + if args.suffix is not None: + save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png' + save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name) + imwrite(restored_face, save_restore_path) + + # save restored img + if not args.has_aligned and restored_img is not None: + if args.suffix is not None: + basename = f'{basename}_{args.suffix}' + save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png') + imwrite(restored_img, save_restore_path) + + # save enhanced video + if input_video: + print('Video Saving...') + # load images + video_frames = [] + img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g'))) + for img_path in img_list: + img = cv2.imread(img_path) + video_frames.append(img) + # write images to video + height, width = video_frames[0].shape[:2] + if args.suffix is not None: + video_name = f'{video_name}_{args.suffix}.png' + save_restore_path = os.path.join(result_root, f'{video_name}.mp4') + vidwriter = VideoWriter(save_restore_path, height, width, fps, audio) + + for f in video_frames: + vidwriter.write_frame(f) + vidwriter.close() + + print(f'\nAll results are saved in {result_root}') diff --git a/repositories/codeformer/requirements.txt b/repositories/codeformer/requirements.txt new file mode 100644 index 000000000..7e1950a06 --- /dev/null +++ b/repositories/codeformer/requirements.txt @@ -0,0 +1,17 @@ +addict +future +lmdb +numpy +opencv-python +Pillow +pyyaml +requests +scikit-image +scipy +tb-nightly +torch>=1.7.1 +torchvision +tqdm +yapf +lpips +gdown # supports downloading the large file from Google Drive \ No newline at end of file diff --git a/repositories/codeformer/scripts/crop_align_face.py b/repositories/codeformer/scripts/crop_align_face.py new file mode 100755 index 000000000..31e66266a --- /dev/null +++ b/repositories/codeformer/scripts/crop_align_face.py @@ -0,0 +1,192 @@ +""" +brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset) +author: lzhbrian (https://lzhbrian.me) +link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5 +date: 2020.1.5 +note: code is heavily borrowed from + https://github.com/NVlabs/ffhq-dataset + http://dlib.net/face_landmark_detection.py.html +requirements: + conda install Pillow numpy scipy + conda install -c conda-forge dlib + # download face landmark model from: + # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 +""" + +import cv2 +import dlib +import glob +import numpy as np +import os +import PIL +import PIL.Image +import scipy +import scipy.ndimage +import sys +import argparse + +# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 +predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat') + + +def get_landmark(filepath, only_keep_largest=True): + """get landmark with dlib + :return: np.array shape=(68, 2) + """ + detector = dlib.get_frontal_face_detector() + + img = dlib.load_rgb_image(filepath) + dets = detector(img, 1) + + # Shangchen modified + print("Number of faces detected: {}".format(len(dets))) + if only_keep_largest: + print('Detect several faces and only keep the largest.') + face_areas = [] + for k, d in enumerate(dets): + face_area = (d.right() - d.left()) * (d.bottom() - d.top()) + face_areas.append(face_area) + + largest_idx = face_areas.index(max(face_areas)) + d = dets[largest_idx] + shape = predictor(img, d) + print("Part 0: {}, Part 1: {} ...".format( + shape.part(0), shape.part(1))) + else: + for k, d in enumerate(dets): + print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format( + k, d.left(), d.top(), d.right(), d.bottom())) + # Get the landmarks/parts for the face in box d. + shape = predictor(img, d) + print("Part 0: {}, Part 1: {} ...".format( + shape.part(0), shape.part(1))) + + t = list(shape.parts()) + a = [] + for tt in t: + a.append([tt.x, tt.y]) + lm = np.array(a) + # lm is a shape=(68,2) np.array + return lm + +def align_face(filepath, out_path): + """ + :param filepath: str + :return: PIL Image + """ + try: + lm = get_landmark(filepath) + except: + print('No landmark ...') + return + + lm_chin = lm[0:17] # left-right + lm_eyebrow_left = lm[17:22] # left-right + lm_eyebrow_right = lm[22:27] # left-right + lm_nose = lm[27:31] # top-down + lm_nostrils = lm[31:36] # top-down + lm_eye_left = lm[36:42] # left-clockwise + lm_eye_right = lm[42:48] # left-clockwise + lm_mouth_outer = lm[48:60] # left-clockwise + lm_mouth_inner = lm[60:68] # left-clockwise + + # Calculate auxiliary vectors. + eye_left = np.mean(lm_eye_left, axis=0) + eye_right = np.mean(lm_eye_right, axis=0) + eye_avg = (eye_left + eye_right) * 0.5 + eye_to_eye = eye_right - eye_left + mouth_left = lm_mouth_outer[0] + mouth_right = lm_mouth_outer[6] + mouth_avg = (mouth_left + mouth_right) * 0.5 + eye_to_mouth = mouth_avg - eye_avg + + # Choose oriented crop rectangle. + x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] + x /= np.hypot(*x) + x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) + y = np.flipud(x) * [-1, 1] + c = eye_avg + eye_to_mouth * 0.1 + quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) + qsize = np.hypot(*x) * 2 + + # read image + img = PIL.Image.open(filepath) + + output_size = 512 + transform_size = 4096 + enable_padding = False + + # Shrink. + shrink = int(np.floor(qsize / output_size * 0.5)) + if shrink > 1: + rsize = (int(np.rint(float(img.size[0]) / shrink)), + int(np.rint(float(img.size[1]) / shrink))) + img = img.resize(rsize, PIL.Image.ANTIALIAS) + quad /= shrink + qsize /= shrink + + # Crop. + border = max(int(np.rint(qsize * 0.1)), 3) + crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) + crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), + min(crop[2] + border, + img.size[0]), min(crop[3] + border, img.size[1])) + if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: + img = img.crop(crop) + quad -= crop[0:2] + + # Pad. + pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), + int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) + pad = (max(-pad[0] + border, + 0), max(-pad[1] + border, + 0), max(pad[2] - img.size[0] + border, + 0), max(pad[3] - img.size[1] + border, 0)) + if enable_padding and max(pad) > border - 4: + pad = np.maximum(pad, int(np.rint(qsize * 0.3))) + img = np.pad( + np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), + 'reflect') + h, w, _ = img.shape + y, x, _ = np.ogrid[:h, :w, :1] + mask = np.maximum( + 1.0 - + np.minimum(np.float32(x) / pad[0], + np.float32(w - 1 - x) / pad[2]), 1.0 - + np.minimum(np.float32(y) / pad[1], + np.float32(h - 1 - y) / pad[3])) + blur = qsize * 0.02 + img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - + img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) + img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) + img = PIL.Image.fromarray( + np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB') + quad += pad[:2] + + img = img.transform((transform_size, transform_size), PIL.Image.QUAD, + (quad + 0.5).flatten(), PIL.Image.BILINEAR) + + if output_size < transform_size: + img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) + + # Save aligned image. + print('saveing: ', out_path) + img.save(out_path) + + return img, np.max(quad[:, 0]) - np.min(quad[:, 0]) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--in_dir', type=str, default='./inputs/whole_imgs') + parser.add_argument('--out_dir', type=str, default='./inputs/cropped_faces') + args = parser.parse_args() + + img_list = sorted(glob.glob(f'{args.in_dir}/*.png')) + img_list = sorted(img_list) + + for in_path in img_list: + out_path = os.path.join(args.out_dir, in_path.split("/")[-1]) + out_path = out_path.replace('.jpg', '.png') + size_ = align_face(in_path, out_path) \ No newline at end of file diff --git a/repositories/codeformer/scripts/download_pretrained_models.py b/repositories/codeformer/scripts/download_pretrained_models.py new file mode 100644 index 000000000..daa6e8ca1 --- /dev/null +++ b/repositories/codeformer/scripts/download_pretrained_models.py @@ -0,0 +1,40 @@ +import argparse +import os +from os import path as osp + +from basicsr.utils.download_util import load_file_from_url + + +def download_pretrained_models(method, file_urls): + save_path_root = f'./weights/{method}' + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_url in file_urls.items(): + save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument( + 'method', + type=str, + help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models.")) + args = parser.parse_args() + + file_urls = { + 'CodeFormer': { + 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + }, + 'facelib': { + # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth', + 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', + 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth' + } + } + + if args.method == 'all': + for method in file_urls.keys(): + download_pretrained_models(method, file_urls[method]) + else: + download_pretrained_models(args.method, file_urls[args.method]) \ No newline at end of file diff --git a/repositories/codeformer/scripts/download_pretrained_models_from_gdrive.py b/repositories/codeformer/scripts/download_pretrained_models_from_gdrive.py new file mode 100644 index 000000000..7df5be6fc --- /dev/null +++ b/repositories/codeformer/scripts/download_pretrained_models_from_gdrive.py @@ -0,0 +1,60 @@ +import argparse +import os +from os import path as osp + +# from basicsr.utils.download_util import download_file_from_google_drive +import gdown + + +def download_pretrained_models(method, file_ids): + save_path_root = f'./weights/{method}' + os.makedirs(save_path_root, exist_ok=True) + + for file_name, file_id in file_ids.items(): + file_url = 'https://drive.google.com/uc?id='+file_id + save_path = osp.abspath(osp.join(save_path_root, file_name)) + if osp.exists(save_path): + user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') + if user_response.lower() == 'y': + print(f'Covering {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + elif user_response.lower() == 'n': + print(f'Skipping {file_name}') + else: + raise ValueError('Wrong input. Only accepts Y/N.') + else: + print(f'Downloading {file_name} to {save_path}') + gdown.download(file_url, save_path, quiet=False) + # download_file_from_google_drive(file_id, save_path) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument( + 'method', + type=str, + help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models.")) + args = parser.parse_args() + + # file name: file id + # 'dlib': { + # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX', + # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg', + # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq' + # } + file_ids = { + 'CodeFormer': { + 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB' + }, + 'facelib': { + 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV', + 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK' + } + } + + if args.method == 'all': + for method in file_ids.keys(): + download_pretrained_models(method, file_ids[method]) + else: + download_pretrained_models(args.method, file_ids[args.method]) \ No newline at end of file diff --git a/repositories/codeformer/web-demos/hugging_face/app.py b/repositories/codeformer/web-demos/hugging_face/app.py new file mode 100644 index 000000000..7da0fc947 --- /dev/null +++ b/repositories/codeformer/web-demos/hugging_face/app.py @@ -0,0 +1,280 @@ +""" +This file is used for deploying hugging face demo: +https://huggingface.co/spaces/sczhou/CodeFormer +""" + +import sys +sys.path.append('CodeFormer') +import os +import cv2 +import torch +import torch.nn.functional as F +import gradio as gr + +from torchvision.transforms.functional import normalize + +from basicsr.utils import imwrite, img2tensor, tensor2img +from basicsr.utils.download_util import load_file_from_url +from facelib.utils.face_restoration_helper import FaceRestoreHelper +from facelib.utils.misc import is_gray +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.utils.realesrgan_utils import RealESRGANer + +from basicsr.utils.registry import ARCH_REGISTRY + + +os.system("pip freeze") + +pretrain_model_url = { + 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth', + 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth', + 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth', + 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth' +} +# download weights +if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'): + load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None) +if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'): + load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None) +if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'): + load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None) +if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'): + load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None) + +# download images +torch.hub.download_url_to_file( + 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png', + '01.png') +torch.hub.download_url_to_file( + 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg', + '02.jpg') +torch.hub.download_url_to_file( + 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg', + '03.jpg') +torch.hub.download_url_to_file( + 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg', + '04.jpg') +torch.hub.download_url_to_file( + 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg', + '05.jpg') + +def imread(img_path): + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + +# set enhancer with RealESRGAN +def set_realesrgan(): + half = True if torch.cuda.is_available() else False + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + upsampler = RealESRGANer( + scale=2, + model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth", + model=model, + tile=400, + tile_pad=40, + pre_pad=0, + half=half, + ) + return upsampler + +upsampler = set_realesrgan() +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +codeformer_net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], +).to(device) +ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth" +checkpoint = torch.load(ckpt_path)["params_ema"] +codeformer_net.load_state_dict(checkpoint) +codeformer_net.eval() + +os.makedirs('output', exist_ok=True) + +def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity): + """Run a single prediction on the model""" + try: # global try + # take the default setting for the demo + has_aligned = False + only_center_face = False + draw_box = False + detection_model = "retinaface_resnet50" + print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity) + + img = cv2.imread(str(image), cv2.IMREAD_COLOR) + print('\timage size:', img.shape) + + upscale = int(upscale) # convert type to int + if upscale > 4: # avoid memory exceeded due to too large upscale + upscale = 4 + if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution + upscale = 2 + if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution + upscale = 1 + background_enhance = False + face_upsample = False + + face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model=detection_model, + save_ext="png", + use_parse=True, + device=device, + ) + bg_upsampler = upsampler if background_enhance else None + face_upsampler = upsampler if face_upsample else None + + if has_aligned: + # the input faces are already cropped and aligned + img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) + face_helper.is_gray = is_gray(img, threshold=5) + if face_helper.is_gray: + print('\tgrayscale input: True') + face_helper.cropped_faces = [img] + else: + face_helper.read_image(img) + # get face landmarks for each face + num_det_faces = face_helper.get_face_landmarks_5( + only_center_face=only_center_face, resize=640, eye_dist_threshold=5 + ) + print(f'\tdetect {num_det_faces} faces') + # align and warp each face + face_helper.align_warp_face() + + # face restoration for each cropped face + for idx, cropped_face in enumerate(face_helper.cropped_faces): + # prepare data + cropped_face_t = img2tensor( + cropped_face / 255.0, bgr2rgb=True, float32=True + ) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + with torch.no_grad(): + output = codeformer_net( + cropped_face_t, w=codeformer_fidelity, adain=True + )[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except RuntimeError as error: + print(f"Failed inference for CodeFormer: {error}") + restored_face = tensor2img( + cropped_face_t, rgb2bgr=True, min_max=(-1, 1) + ) + + restored_face = restored_face.astype("uint8") + face_helper.add_restored_face(restored_face) + + # paste_back + if not has_aligned: + # upsample the background + if bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] + else: + bg_img = None + face_helper.get_inverse_affine(None) + # paste each restored face to the input image + if face_upsample and face_upsampler is not None: + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, + draw_box=draw_box, + face_upsampler=face_upsampler, + ) + else: + restored_img = face_helper.paste_faces_to_input_image( + upsample_img=bg_img, draw_box=draw_box + ) + + # save restored img + save_path = f'output/out.png' + imwrite(restored_img, str(save_path)) + + restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB) + return restored_img, save_path + except Exception as error: + print('Global exception', error) + return None, None + + +title = "CodeFormer: Robust Face Restoration and Enhancement Network" +description = r"""
CodeFormer logo
+Official Gradio demo for Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022).
+🔥 CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.
+🤗 Try CodeFormer for improved stable-diffusion generation!
+""" +article = r""" +If CodeFormer is helpful, please help to ⭐ the Github Repo. Thanks! +[![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer) + +--- + +📝 **Citation** + +If our work is useful for your research, please consider citing: +```bibtex +@inproceedings{zhou2022codeformer, + author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change}, + title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer}, + booktitle = {NeurIPS}, + year = {2022} +} +``` + +📋 **License** + +This project is licensed under S-Lab License 1.0. +Redistribution and use for non-commercial purposes should follow this license. + +📧 **Contact** + +If you have any questions, please feel free to reach me out at shangchenzhou@gmail.com. + +
+ 🤗 Find Me: + Twitter Follow + Github Follow +
+ +
visitors
+""" + +demo = gr.Interface( + inference, [ + gr.inputs.Image(type="filepath", label="Input"), + gr.inputs.Checkbox(default=True, label="Background_Enhance"), + gr.inputs.Checkbox(default=True, label="Face_Upsample"), + gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"), + gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)') + ], [ + gr.outputs.Image(type="numpy", label="Output"), + gr.outputs.File(label="Download the output") + ], + title=title, + description=description, + article=article, + examples=[ + ['01.png', True, True, 2, 0.7], + ['02.jpg', True, True, 2, 0.7], + ['03.jpg', True, True, 2, 0.7], + ['04.jpg', True, True, 2, 0.1], + ['05.jpg', True, True, 2, 0.1] + ] + ) + +demo.queue(concurrency_count=2) +demo.launch() \ No newline at end of file diff --git a/repositories/codeformer/web-demos/replicate/cog.yaml b/repositories/codeformer/web-demos/replicate/cog.yaml new file mode 100644 index 000000000..3f4589690 --- /dev/null +++ b/repositories/codeformer/web-demos/replicate/cog.yaml @@ -0,0 +1,30 @@ +""" +This file is used for deploying replicate demo: +https://replicate.com/sczhou/codeformer +""" + +build: + gpu: true + cuda: "11.3" + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "ipython==8.4.0" + - "future==0.18.2" + - "lmdb==1.3.0" + - "scikit-image==0.19.3" + - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" + - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" + - "scipy==1.9.0" + - "gdown==4.5.1" + - "pyyaml==6.0" + - "tb-nightly==2.11.0a20220906" + - "tqdm==4.64.1" + - "yapf==0.32.0" + - "lpips==0.1.4" + - "Pillow==9.2.0" + - "opencv-python==4.6.0.66" + +predict: "predict.py:Predictor" diff --git a/repositories/codeformer/web-demos/replicate/predict.py b/repositories/codeformer/web-demos/replicate/predict.py new file mode 100644 index 000000000..61935e9e7 --- /dev/null +++ b/repositories/codeformer/web-demos/replicate/predict.py @@ -0,0 +1,189 @@ +""" +This file is used for deploying replicate demo: +https://replicate.com/sczhou/codeformer +running: cog predict -i image=@inputs/whole_imgs/04.jpg -i codeformer_fidelity=0.5 -i upscale=2 +push: cog push r8.im/sczhou/codeformer +""" + +import tempfile +import cv2 +import torch +from torchvision.transforms.functional import normalize +try: + from cog import BasePredictor, Input, Path +except Exception: + print('please install cog package') + +from basicsr.utils import imwrite, img2tensor, tensor2img +from basicsr.archs.rrdbnet_arch import RRDBNet +from basicsr.utils.realesrgan_utils import RealESRGANer +from basicsr.utils.registry import ARCH_REGISTRY +from facelib.utils.face_restoration_helper import FaceRestoreHelper + + +class Predictor(BasePredictor): + def setup(self): + """Load the model into memory to make running multiple predictions efficient""" + self.device = "cuda:0" + self.upsampler = set_realesrgan() + self.net = ARCH_REGISTRY.get("CodeFormer")( + dim_embd=512, + codebook_size=1024, + n_head=8, + n_layers=9, + connect_list=["32", "64", "128", "256"], + ).to(self.device) + ckpt_path = "weights/CodeFormer/codeformer.pth" + checkpoint = torch.load(ckpt_path)[ + "params_ema" + ] # update file permission if cannot load + self.net.load_state_dict(checkpoint) + self.net.eval() + + def predict( + self, + image: Path = Input(description="Input image"), + codeformer_fidelity: float = Input( + default=0.5, + ge=0, + le=1, + description="Balance the quality (lower number) and fidelity (higher number).", + ), + background_enhance: bool = Input( + description="Enhance background image with Real-ESRGAN", default=True + ), + face_upsample: bool = Input( + description="Upsample restored faces for high-resolution AI-created images", + default=True, + ), + upscale: int = Input( + description="The final upsampling scale of the image", + default=2, + ), + ) -> Path: + """Run a single prediction on the model""" + + # take the default setting for the demo + has_aligned = False + only_center_face = False + draw_box = False + detection_model = "retinaface_resnet50" + + self.face_helper = FaceRestoreHelper( + upscale, + face_size=512, + crop_ratio=(1, 1), + det_model=detection_model, + save_ext="png", + use_parse=True, + device=self.device, + ) + + bg_upsampler = self.upsampler if background_enhance else None + face_upsampler = self.upsampler if face_upsample else None + + img = cv2.imread(str(image), cv2.IMREAD_COLOR) + + if has_aligned: + # the input faces are already cropped and aligned + img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) + self.face_helper.cropped_faces = [img] + else: + self.face_helper.read_image(img) + # get face landmarks for each face + num_det_faces = self.face_helper.get_face_landmarks_5( + only_center_face=only_center_face, resize=640, eye_dist_threshold=5 + ) + print(f"\tdetect {num_det_faces} faces") + # align and warp each face + self.face_helper.align_warp_face() + + # face restoration for each cropped face + for idx, cropped_face in enumerate(self.face_helper.cropped_faces): + # prepare data + cropped_face_t = img2tensor( + cropped_face / 255.0, bgr2rgb=True, float32=True + ) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) + + try: + with torch.no_grad(): + output = self.net( + cropped_face_t, w=codeformer_fidelity, adain=True + )[0] + restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except Exception as error: + print(f"\tFailed inference for CodeFormer: {error}") + restored_face = tensor2img( + cropped_face_t, rgb2bgr=True, min_max=(-1, 1) + ) + + restored_face = restored_face.astype("uint8") + self.face_helper.add_restored_face(restored_face) + + # paste_back + if not has_aligned: + # upsample the background + if bg_upsampler is not None: + # Now only support RealESRGAN for upsampling background + bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] + else: + bg_img = None + self.face_helper.get_inverse_affine(None) + # paste each restored face to the input image + if face_upsample and face_upsampler is not None: + restored_img = self.face_helper.paste_faces_to_input_image( + upsample_img=bg_img, + draw_box=draw_box, + face_upsampler=face_upsampler, + ) + else: + restored_img = self.face_helper.paste_faces_to_input_image( + upsample_img=bg_img, draw_box=draw_box + ) + + # save restored img + out_path = Path(tempfile.mkdtemp()) / 'output.png' + imwrite(restored_img, str(out_path)) + + return out_path + + +def imread(img_path): + img = cv2.imread(img_path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img + + +def set_realesrgan(): + if not torch.cuda.is_available(): # CPU + import warnings + + warnings.warn( + "The unoptimized RealESRGAN is slow on CPU. We do not use it. " + "If you really want to use it, please modify the corresponding codes.", + category=RuntimeWarning, + ) + upsampler = None + else: + model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=2, + ) + upsampler = RealESRGANer( + scale=2, + model_path="./weights/realesrgan/RealESRGAN_x2plus.pth", + model=model, + tile=400, + tile_pad=40, + pre_pad=0, + half=True, + ) + return upsampler diff --git a/repositories/codeformer/weights/CodeFormer/.gitkeep b/repositories/codeformer/weights/CodeFormer/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/codeformer/weights/README.md b/repositories/codeformer/weights/README.md new file mode 100644 index 000000000..67ad334bd --- /dev/null +++ b/repositories/codeformer/weights/README.md @@ -0,0 +1,3 @@ +# Weights + +Put the downloaded pre-trained models to this folder. \ No newline at end of file diff --git a/repositories/codeformer/weights/facelib/.gitkeep b/repositories/codeformer/weights/facelib/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/README.md b/repositories/ldm/README.md new file mode 100644 index 000000000..2dfddca33 --- /dev/null +++ b/repositories/ldm/README.md @@ -0,0 +1,302 @@ +# Stable Diffusion Version 2 +![t2i](assets/stable-samples/txt2img/768/merged-0006.png) +![t2i](assets/stable-samples/txt2img/768/merged-0002.png) +![t2i](assets/stable-samples/txt2img/768/merged-0005.png) + +This repository contains [Stable Diffusion](https://github.com/CompVis/stable-diffusion) models trained from scratch and will be continuously updated with +new checkpoints. The following list provides an overview of all currently available models. More coming soon. + +## News + + +**March 24, 2023** + +*Stable UnCLIP 2.1* + +- New stable diffusion finetune (_Stable unCLIP 2.1_, [Hugging Face](https://huggingface.co/stabilityai/)) at 768x768 resolution, based on SD2.1-768. This model allows for image variations and mixing operations as described in [*Hierarchical Text-Conditional Image Generation with CLIP Latents*](https://arxiv.org/abs/2204.06125), and, thanks to its modularity, can be combined with other models such as [KARLO](https://github.com/kakaobrain/karlo). Comes in two variants: [*Stable unCLIP-L*](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/blob/main/sd21-unclip-l.ckpt) and [*Stable unCLIP-H*](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/blob/main/sd21-unclip-h.ckpt), which are conditioned on CLIP ViT-L and ViT-H image embeddings, respectively. Instructions are available [here](doc/UNCLIP.MD). + +- A public demo of SD-unCLIP is already available at [clipdrop.co/stable-diffusion-reimagine](https://clipdrop.co/stable-diffusion-reimagine) + + +**December 7, 2022** + +*Version 2.1* + +- New stable diffusion model (_Stable Diffusion 2.1-v_, [Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-2-1)) at 768x768 resolution and (_Stable Diffusion 2.1-base_, [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base)) at 512x512 resolution, both based on the same number of parameters and architecture as 2.0 and fine-tuned on 2.0, on a less restrictive NSFW filtering of the [LAION-5B](https://laion.ai/blog/laion-5b/) dataset. +Per default, the attention operation of the model is evaluated at full precision when `xformers` is not installed. To enable fp16 (which can cause numerical instabilities with the vanilla attention module on the v2.1 model) , run your script with `ATTN_PRECISION=fp16 python ` + +**November 24, 2022** + +*Version 2.0* + +- New stable diffusion model (_Stable Diffusion 2.0-v_) at 768x768 resolution. Same number of parameters in the U-Net as 1.5, but uses [OpenCLIP-ViT/H](https://github.com/mlfoundations/open_clip) as the text encoder and is trained from scratch. _SD 2.0-v_ is a so-called [v-prediction](https://arxiv.org/abs/2202.00512) model. +- The above model is finetuned from _SD 2.0-base_, which was trained as a standard noise-prediction model on 512x512 images and is also made available. +- Added a [x4 upscaling latent text-guided diffusion model](#image-upscaling-with-stable-diffusion). +- New [depth-guided stable diffusion model](#depth-conditional-stable-diffusion), finetuned from _SD 2.0-base_. The model is conditioned on monocular depth estimates inferred via [MiDaS](https://github.com/isl-org/MiDaS) and can be used for structure-preserving img2img and shape-conditional synthesis. + + ![d2i](assets/stable-samples/depth2img/depth2img01.png) +- A [text-guided inpainting model](#image-inpainting-with-stable-diffusion), finetuned from SD _2.0-base_. + +We follow the [original repository](https://github.com/CompVis/stable-diffusion) and provide basic inference scripts to sample from the models. + +________________ +*The original Stable Diffusion model was created in a collaboration with [CompVis](https://arxiv.org/abs/2202.00512) and [RunwayML](https://runwayml.com/) and builds upon the work:* + +[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://ommer-lab.com/research/latent-diffusion-models/)
+[Robin Rombach](https://github.com/rromb)\*, +[Andreas Blattmann](https://github.com/ablattmann)\*, +[Dominik Lorenz](https://github.com/qp-qp)\, +[Patrick Esser](https://github.com/pesser), +[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
+_[CVPR '22 Oral](https://openaccess.thecvf.com/content/CVPR2022/html/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.html) | +[GitHub](https://github.com/CompVis/latent-diffusion) | [arXiv](https://arxiv.org/abs/2112.10752) | [Project page](https://ommer-lab.com/research/latent-diffusion-models/)_ + +and [many others](#shout-outs). + +Stable Diffusion is a latent text-to-image diffusion model. +________________________________ + +## Requirements + +You can update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running + +``` +conda install pytorch==1.12.1 torchvision==0.13.1 -c pytorch +pip install transformers==4.19.2 diffusers invisible-watermark +pip install -e . +``` +#### xformers efficient attention +For more efficiency and speed on GPUs, +we highly recommended installing the [xformers](https://github.com/facebookresearch/xformers) +library. + +Tested on A100 with CUDA 11.4. +Installation needs a somewhat recent version of nvcc and gcc/g++, obtain those, e.g., via +```commandline +export CUDA_HOME=/usr/local/cuda-11.4 +conda install -c nvidia/label/cuda-11.4.0 cuda-nvcc +conda install -c conda-forge gcc +conda install -c conda-forge gxx_linux-64==9.5.0 +``` + +Then, run the following (compiling takes up to 30 min). + +```commandline +cd .. +git clone https://github.com/facebookresearch/xformers.git +cd xformers +git submodule update --init --recursive +pip install -r requirements.txt +pip install -e . +cd ../stablediffusion +``` +Upon successful installation, the code will automatically default to [memory efficient attention](https://github.com/facebookresearch/xformers) +for the self- and cross-attention layers in the U-Net and autoencoder. + +## General Disclaimer +Stable Diffusion models are general text-to-image diffusion models and therefore mirror biases and (mis-)conceptions that are present +in their training data. Although efforts were made to reduce the inclusion of explicit pornographic material, **we do not recommend using the provided weights for services or products without additional safety mechanisms and considerations. +The weights are research artifacts and should be treated as such.** +Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](https://huggingface.co/stabilityai/stable-diffusion-2). +The weights are available via [the StabilityAI organization at Hugging Face](https://huggingface.co/StabilityAI) under the [CreativeML Open RAIL++-M License](LICENSE-MODEL). + + + +## Stable Diffusion v2 + +Stable Diffusion v2 refers to a specific configuration of the model +architecture that uses a downsampling-factor 8 autoencoder with an 865M UNet +and OpenCLIP ViT-H/14 text encoder for the diffusion model. The _SD 2-v_ model produces 768x768 px outputs. + +Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0, +5.0, 6.0, 7.0, 8.0) and 50 DDIM sampling steps show the relative improvements of the checkpoints: + +![sd evaluation results](assets/model-variants.jpg) + + + +### Text-to-Image +![txt2img-stable2](assets/stable-samples/txt2img/merged-0003.png) +![txt2img-stable2](assets/stable-samples/txt2img/merged-0001.png) + +Stable Diffusion 2 is a latent diffusion model conditioned on the penultimate text embeddings of a CLIP ViT-H/14 text encoder. +We provide a [reference script for sampling](#reference-sampling-script). +#### Reference Sampling Script + +This script incorporates an [invisible watermarking](https://github.com/ShieldMnt/invisible-watermark) of the outputs, to help viewers [identify the images as machine-generated](scripts/tests/test_watermark.py). +We provide the configs for the _SD2-v_ (768px) and _SD2-base_ (512px) model. + +First, download the weights for [_SD2.1-v_](https://huggingface.co/stabilityai/stable-diffusion-2-1) and [_SD2.1-base_](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). + +To sample from the _SD2.1-v_ model, run the following: + +``` +python scripts/txt2img.py --prompt "a professional photograph of an astronaut riding a horse" --ckpt --config configs/stable-diffusion/v2-inference-v.yaml --H 768 --W 768 +``` +or try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/stabilityai/stable-diffusion). + +To sample from the base model, use +``` +python scripts/txt2img.py --prompt "a professional photograph of an astronaut riding a horse" --ckpt --config +``` + +By default, this uses the [DDIM sampler](https://arxiv.org/abs/2010.02502), and renders images of size 768x768 (which it was trained on) in 50 steps. +Empirically, the v-models can be sampled with higher guidance scales. + +Note: The inference config for all model versions is designed to be used with EMA-only checkpoints. +For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from +non-EMA to EMA weights. + +#### Enable Intel® Extension for PyTorch* optimizations in Text-to-Image script + +If you're planning on running Text-to-Image on Intel® CPU, try to sample an image with TorchScript and Intel® Extension for PyTorch* optimizations. Intel® Extension for PyTorch* extends PyTorch by enabling up-to-date features optimizations for an extra performance boost on Intel® hardware. It can optimize memory layout of the operators to Channel Last memory format, which is generally beneficial for Intel CPUs, take advantage of the most advanced instruction set available on a machine, optimize operators and many more. + +**Prerequisites** + +Before running the script, make sure you have all needed libraries installed. (the optimization was checked on `Ubuntu 20.04`). Install [jemalloc](https://github.com/jemalloc/jemalloc), [numactl](https://linux.die.net/man/8/numactl), Intel® OpenMP and Intel® Extension for PyTorch*. + +```bash +apt-get install numactl libjemalloc-dev +pip install intel-openmp +pip install intel_extension_for_pytorch -f https://software.intel.com/ipex-whl-stable +``` + +To sample from the _SD2.1-v_ model with TorchScript+IPEX optimizations, run the following. Remember to specify desired number of instances you want to run the program on ([more](https://github.com/intel/intel-extension-for-pytorch/blob/master/intel_extension_for_pytorch/cpu/launch.py#L48)). + +``` +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-v-fp32.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex +``` + +To sample from the base model with IPEX optimizations, use + +``` +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-fp32.yaml --n_samples 1 --n_iter 4 --precision full --device cpu --torchscript --ipex +``` + +If you're using a CPU that supports `bfloat16`, consider sample from the model with bfloat16 enabled for a performance boost, like so + +```bash +# SD2.1-v +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-v-bf16.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex --bf16 +# SD2.1-base +MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-bf16.yaml --precision full --device cpu --torchscript --ipex --bf16 +``` + +### Image Modification with Stable Diffusion + +![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png) +#### Depth-Conditional Stable Diffusion + +To augment the well-established [img2img](https://github.com/CompVis/stable-diffusion#image-modification-with-stable-diffusion) functionality of Stable Diffusion, we provide a _shape-preserving_ stable diffusion model. + + +Note that the original method for image modification introduces significant semantic changes w.r.t. the initial image. +If that is not desired, download our [depth-conditional stable diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-depth) model and the `dpt_hybrid` MiDaS [model weights](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), place the latter in a folder `midas_models` and sample via +``` +python scripts/gradio/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml +``` + +or + +``` +streamlit run scripts/streamlit/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml +``` + +This method can be used on the samples of the base model itself. +For example, take [this sample](assets/stable-samples/depth2img/old_man.png) generated by an anonymous discord user. +Using the [gradio](https://gradio.app) or [streamlit](https://streamlit.io/) script `depth2img.py`, the MiDaS model first infers a monocular depth estimate given this input, +and the diffusion model is then conditioned on the (relative) depth output. + +

+ depth2image
+ +

+ +This model is particularly useful for a photorealistic style; see the [examples](assets/stable-samples/depth2img). +For a maximum strength of 1.0, the model removes all pixel-based information and only relies on the text prompt and the inferred monocular depth estimate. + +![depth2img-stable3](assets/stable-samples/depth2img/merged-0005.png) + +#### Classic Img2Img + +For running the "classic" img2img, use +``` +python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img --strength 0.8 --ckpt +``` +and adapt the checkpoint and config paths accordingly. + +### Image Upscaling with Stable Diffusion +![upscaling-x4](assets/stable-samples/upscaling/merged-dog.png) +After [downloading the weights](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler), run +``` +python scripts/gradio/superresolution.py configs/stable-diffusion/x4-upscaling.yaml +``` + +or + +``` +streamlit run scripts/streamlit/superresolution.py -- configs/stable-diffusion/x4-upscaling.yaml +``` + +for a Gradio or Streamlit demo of the text-guided x4 superresolution model. +This model can be used both on real inputs and on synthesized examples. For the latter, we recommend setting a higher +`noise_level`, e.g. `noise_level=100`. + +### Image Inpainting with Stable Diffusion + +![inpainting-stable2](assets/stable-inpainting/merged-leopards.png) + +[Download the SD 2.0-inpainting checkpoint](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) and run + +``` +python scripts/gradio/inpainting.py configs/stable-diffusion/v2-inpainting-inference.yaml +``` + +or + +``` +streamlit run scripts/streamlit/inpainting.py -- configs/stable-diffusion/v2-inpainting-inference.yaml +``` + +for a Gradio or Streamlit demo of the inpainting model. +This scripts adds invisible watermarking to the demo in the [RunwayML](https://github.com/runwayml/stable-diffusion/blob/main/scripts/inpaint_st.py) repository, but both should work interchangeably with the checkpoints/configs. + + + +## Shout-Outs +- Thanks to [Hugging Face](https://huggingface.co/) and in particular [Apolinário](https://github.com/apolinario) for support with our model releases! +- Stable Diffusion would not be possible without [LAION](https://laion.ai/) and their efforts to create open, large-scale datasets. +- The [DeepFloyd team](https://twitter.com/deepfloydai) at Stability AI, for creating the subset of [LAION-5B](https://laion.ai/blog/laion-5b/) dataset used to train the model. +- Stable Diffusion 2.0 uses [OpenCLIP](https://laion.ai/blog/large-openclip/), trained by [Romain Beaumont](https://github.com/rom1504). +- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion) +and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch). +Thanks for open-sourcing! +- [CompVis](https://github.com/CompVis/stable-diffusion) initial stable diffusion release +- [Patrick](https://github.com/pesser)'s [implementation](https://github.com/runwayml/stable-diffusion/blob/main/scripts/inpaint_st.py) of the streamlit demo for inpainting. +- `img2img` is an application of [SDEdit](https://arxiv.org/abs/2108.01073) by [Chenlin Meng](https://cs.stanford.edu/~chenlin/) from the [Stanford AI Lab](https://cs.stanford.edu/~ermon/website/). +- [Kat's implementation]((https://github.com/CompVis/latent-diffusion/pull/51)) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler, and [more](https://github.com/crowsonkb/k-diffusion). +- [DPMSolver](https://arxiv.org/abs/2206.00927) [integration](https://github.com/CompVis/stable-diffusion/pull/440) by [Cheng Lu](https://github.com/LuChengTHU). +- Facebook's [xformers](https://github.com/facebookresearch/xformers) for efficient attention computation. +- [MiDaS](https://github.com/isl-org/MiDaS) for monocular depth estimation. + + +## License + +The code in this repository is released under the MIT License. + +The weights are available via [the StabilityAI organization at Hugging Face](https://huggingface.co/StabilityAI), and released under the [CreativeML Open RAIL++-M License](LICENSE-MODEL) License. + +## BibTeX + +``` +@misc{rombach2021highresolution, + title={High-Resolution Image Synthesis with Latent Diffusion Models}, + author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer}, + year={2021}, + eprint={2112.10752}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + + diff --git a/repositories/ldm/data/__init__.py b/repositories/ldm/data/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/data/util.py b/repositories/ldm/data/util.py new file mode 100644 index 000000000..5b60ceb23 --- /dev/null +++ b/repositories/ldm/data/util.py @@ -0,0 +1,24 @@ +import torch + +from ldm.modules.midas.api import load_midas_transform + + +class AddMiDaS(object): + def __init__(self, model_type): + super().__init__() + self.transform = load_midas_transform(model_type) + + def pt2np(self, x): + x = ((x + 1.0) * .5).detach().cpu().numpy() + return x + + def np2pt(self, x): + x = torch.from_numpy(x) * 2 - 1. + return x + + def __call__(self, sample): + # sample['jpg'] is tensor hwc in [-1, 1] at this point + x = self.pt2np(sample['jpg']) + x = self.transform({"image": x})["image"] + sample['midas_in'] = x + return sample \ No newline at end of file diff --git a/repositories/ldm/models/autoencoder.py b/repositories/ldm/models/autoencoder.py new file mode 100644 index 000000000..d12254999 --- /dev/null +++ b/repositories/ldm/models/autoencoder.py @@ -0,0 +1,219 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from ldm.modules.diffusionmodules.model import Encoder, Decoder +from ldm.modules.distributions.distributions import DiagonalGaussianDistribution + +from ldm.util import instantiate_from_config +from ldm.modules.ema import LitEma + + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ema_decay=None, + learn_logvar=False + ): + super().__init__() + self.learn_logvar = learn_logvar + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + self.use_ema = ema_decay is not None + if self.use_ema: + self.ema_decay = ema_decay + assert 0. < ema_decay < 1. + self.model_ema = LitEma(self, decay=ema_decay) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, postfix=""): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val"+postfix) + + self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list( + self.quant_conv.parameters()) + list(self.post_quant_conv.parameters()) + if self.learn_logvar: + print(f"{self.__class__.__name__}: Learning logvar") + ae_params_list.append(self.loss.logvar) + opt_ae = torch.optim.Adam(ae_params_list, + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + if log_ema or self.use_ema: + with self.ema_scope(): + xrec_ema, posterior_ema = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec_ema.shape[1] > 3 + xrec_ema = self.to_rgb(xrec_ema) + log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample())) + log["reconstructions_ema"] = xrec_ema + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + diff --git a/repositories/ldm/models/diffusion/__init__.py b/repositories/ldm/models/diffusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/models/diffusion/ddim.py b/repositories/ldm/models/diffusion/ddim.py new file mode 100644 index 000000000..c6cfd5712 --- /dev/null +++ b/repositories/ldm/models/diffusion/ddim.py @@ -0,0 +1,337 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = device + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != self.device: + attr = attr.to(self.device) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + ucg_schedule=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ucg_schedule=ucg_schedule + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None, + ucg_schedule=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + if ucg_schedule is not None: + assert len(ucg_schedule) == len(time_range) + unconditional_guidance_scale = ucg_schedule[i] + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + model_output = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [torch.cat([ + unconditional_conditioning[k][i], + c[k][i]]) for i in range(len(c[k]))] + else: + c_in[k] = torch.cat([ + unconditional_conditioning[k], + c[k]]) + elif isinstance(c, list): + c_in = list() + assert isinstance(unconditional_conditioning, list) + for i in range(len(c)): + c_in.append(torch.cat([unconditional_conditioning[i], c[i]])) + else: + c_in = torch.cat([unconditional_conditioning, c]) + model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond) + + if self.model.parameterization == "v": + e_t = self.model.predict_eps_from_z_and_v(x, t, model_output) + else: + e_t = model_output + + if score_corrector is not None: + assert self.model.parameterization == "eps", 'not implemented' + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + if self.model.parameterization != "v": + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + else: + pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) + + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + + if dynamic_threshold is not None: + raise NotImplementedError() + + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + @torch.no_grad() + def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None, + unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None): + num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0] + + assert t_enc <= num_reference_steps + num_steps = t_enc + + if use_original_steps: + alphas_next = self.alphas_cumprod[:num_steps] + alphas = self.alphas_cumprod_prev[:num_steps] + else: + alphas_next = self.ddim_alphas[:num_steps] + alphas = torch.tensor(self.ddim_alphas_prev[:num_steps]) + + x_next = x0 + intermediates = [] + inter_steps = [] + for i in tqdm(range(num_steps), desc='Encoding Image'): + t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long) + if unconditional_guidance_scale == 1.: + noise_pred = self.model.apply_model(x_next, t, c) + else: + assert unconditional_conditioning is not None + e_t_uncond, noise_pred = torch.chunk( + self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)), + torch.cat((unconditional_conditioning, c))), 2) + noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond) + + xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next + weighted_noise_pred = alphas_next[i].sqrt() * ( + (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred + x_next = xt_weighted + weighted_noise_pred + if return_intermediates and i % ( + num_steps // return_intermediates) == 0 and i < num_steps - 1: + intermediates.append(x_next) + inter_steps.append(i) + elif return_intermediates and i >= num_steps - 2: + intermediates.append(x_next) + inter_steps.append(i) + if callback: callback(i) + + out = {'x_encoded': x_next, 'intermediate_steps': inter_steps} + if return_intermediates: + out.update({'intermediates': intermediates}) + return x_next, out + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) + + @torch.no_grad() + def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, + use_original_steps=False, callback=None): + + timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + print(f"Running DDIM Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='Decoding image', total=total_steps) + x_dec = x_latent + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) + x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + if callback: callback(i) + return x_dec \ No newline at end of file diff --git a/repositories/ldm/models/diffusion/ddpm.py b/repositories/ldm/models/diffusion/ddpm.py new file mode 100644 index 000000000..3350c032f --- /dev/null +++ b/repositories/ldm/models/diffusion/ddpm.py @@ -0,0 +1,1873 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" + +import torch +import torch.nn as nn +import numpy as np +import pytorch_lightning as pl +from torch.optim.lr_scheduler import LambdaLR +from einops import rearrange, repeat +from contextlib import contextmanager, nullcontext +from functools import partial +import itertools +from tqdm import tqdm +from torchvision.utils import make_grid +from pytorch_lightning.utilities.distributed import rank_zero_only +from omegaconf import ListConfig + +from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config +from ldm.modules.ema import LitEma +from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution +from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL +from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like +from ldm.models.diffusion.ddim import DDIMSampler + + +__conditioning_keys__ = {'concat': 'c_concat', + 'crossattn': 'c_crossattn', + 'adm': 'y'} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DDPM(pl.LightningModule): + # classic DDPM with Gaussian diffusion, in image space + def __init__(self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + image_size=256, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0., + v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1., + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0., + make_it_fit=False, + ucg_training=None, + reset_ema=False, + reset_num_ema_updates=False, + ): + super().__init__() + assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"' + self.parameterization = parameterization + print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + self.image_size = image_size # try conv? + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + self.make_it_fit = make_it_fit + if reset_ema: assert exists(ckpt_path) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet) + if reset_ema: + assert self.use_ema + print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps, + linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + + self.ucg_training = ucg_training or dict() + if self.ucg_training: + self.ucg_prng = np.random.RandomState() + + def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( + 1. - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + if self.parameterization == "eps": + lvlb_weights = self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + elif self.parameterization == "x0": + lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod)) + elif self.parameterization == "v": + lvlb_weights = torch.ones_like(self.betas ** 2 / ( + 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))) + else: + raise NotImplementedError("mu not supported") + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + @torch.no_grad() + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + if self.make_it_fit: + n_params = len([name for name, _ in + itertools.chain(self.named_parameters(), + self.named_buffers())]) + for name, param in tqdm( + itertools.chain(self.named_parameters(), + self.named_buffers()), + desc="Fitting old weights to new weights", + total=n_params + ): + if not name in sd: + continue + old_shape = sd[name].shape + new_shape = param.shape + assert len(old_shape) == len(new_shape) + if len(new_shape) > 2: + # we only modify first two axes + assert new_shape[2:] == old_shape[2:] + # assumes first axis corresponds to output dim + if not new_shape == old_shape: + new_param = param.clone() + old_param = sd[name] + if len(new_shape) == 1: + for i in range(new_param.shape[0]): + new_param[i] = old_param[i % old_shape[0]] + elif len(new_shape) >= 2: + for i in range(new_param.shape[0]): + for j in range(new_param.shape[1]): + new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]] + + n_used_old = torch.ones(old_shape[1]) + for j in range(new_param.shape[1]): + n_used_old[j % old_shape[1]] += 1 + n_used_new = torch.zeros(new_shape[1]) + for j in range(new_param.shape[1]): + n_used_new[j] = n_used_old[j % old_shape[1]] + + n_used_new = n_used_new[None, :] + while len(n_used_new.shape) < len(new_shape): + n_used_new = n_used_new.unsqueeze(-1) + new_param /= n_used_new + + sd[name] = new_param + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys:\n {missing}") + if len(unexpected) > 0: + print(f"\nUnexpected Keys:\n {unexpected}") + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start) + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def predict_start_from_z_and_v(self, x_t, t, v): + # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + def predict_eps_from_z_and_v(self, x_t, t, v): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size), + return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def get_v(self, x, noise, t): + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x + ) + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def p_losses(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.model(x_noisy, t) + + loss_dict = {} + if self.parameterization == "eps": + target = noise + elif self.parameterization == "x0": + target = x_start + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported") + + loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3]) + + log_prefix = 'train' if self.training else 'val' + + loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()}) + loss_simple = loss.mean() * self.l_simple_weight + + loss_vlb = (self.lvlb_weights[t] * loss).mean() + loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + self.original_elbo_weight * loss_vlb + + loss_dict.update({f'{log_prefix}/loss': loss}) + + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size + # assert h == img_size and w == img_size, f'height and width of image must be {img_size}' + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = rearrange(x, 'b h w c -> b c h w') + x = x.to(memory_format=torch.contiguous_format).float() + return x + + def shared_step(self, batch): + x = self.get_input(batch, self.first_stage_key) + loss, loss_dict = self(x) + return loss, loss_dict + + def training_step(self, batch, batch_idx): + for k in self.ucg_training: + p = self.ucg_training[k]["p"] + val = self.ucg_training[k]["val"] + if val is None: + val = "" + for i in range(len(batch[k])): + if self.ucg_prng.choice(2, p=[1 - p, p]): + batch[k][i] = val + + loss, loss_dict = self.shared_step(batch) + + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) + + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + + return loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + _, loss_dict_no_ema = self.shared_step(batch) + with self.ema_scope(): + _, loss_dict_ema = self.shared_step(batch) + loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema} + self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True) + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self.model) + + def _get_rows_from_list(self, samples): + n_imgs_per_row = len(samples) + denoise_grid = rearrange(samples, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs): + log = dict() + x = self.get_input(batch, self.first_stage_key) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + x = x.to(self.device)[:N] + log["inputs"] = x + + # get diffusion row + diffusion_row = list() + x_start = x[:n_row] + + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(x_start) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + diffusion_row.append(x_noisy) + + log["diffusion_row"] = self._get_rows_from_list(diffusion_row) + + if sample: + # get denoise row + with self.ema_scope("Plotting"): + samples, denoise_row = self.sample(batch_size=N, return_intermediates=True) + + log["samples"] = samples + log["denoise_row"] = self._get_rows_from_list(denoise_row) + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.learn_logvar: + params = params + [self.logvar] + opt = torch.optim.AdamW(params, lr=lr) + return opt + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__(self, + first_stage_config, + cond_stage_config, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + force_null_conditioning=False, + *args, **kwargs): + self.force_null_conditioning = force_null_conditioning + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs['timesteps'] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = 'concat' if concat_mode else 'crossattn' + if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning: + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + reset_ema = kwargs.pop("reset_ema", False) + reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + self.bbox_tokenizer = None + + self.restarted_from_ckpt = False + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys) + self.restarted_from_ckpt = True + if reset_ema: + assert self.use_ema + print( + f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.") + self.model_ema = LitEma(self.model) + if reset_num_ema_updates: + print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ") + assert self.use_ema + self.model_ema.reset_num_updates() + + def make_cond_schedule(self, ): + self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) + ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() + self.cond_ids[:self.num_timesteps_cond] = ids + + @rank_zero_only + @torch.no_grad() + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # only for very first batch + if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x = super().get_input(batch, self.first_stage_key) + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + print("### USING STD-RESCALING ###") + + def register_schedule(self, + given_betas=None, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != '__is_first_stage__' + assert config != '__is_unconditional__' + model = instantiate_from_config(config) + self.cond_stage_model = model + + def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False): + denoise_row = [] + for zd in tqdm(samples, desc=desc): + denoise_row.append(self.decode_first_stage(zd.to(self.device), + force_not_quantize=force_no_decoder_quantization)) + n_imgs_per_row = len(denoise_row) + denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W + denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w') + denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w') + denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row) + return denoise_grid + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + def meshgrid(self, h, w): + y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1) + x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1) + + arr = torch.cat([y, x], dim=-1) + return arr + + def delta_border(self, h, w): + """ + :param h: height + :param w: width + :return: normalized distance to image border, + wtith min distance = 0 at border and max dist = 0.5 at image center + """ + lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2) + arr = self.meshgrid(h, w) / lower_right_corner + dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0] + dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0] + edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0] + return edge_dist + + def get_weighting(self, h, w, Ly, Lx, device): + weighting = self.delta_border(h, w) + weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) + weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) + + if self.split_input_params["tie_braker"]: + L_weighting = self.delta_border(Ly, Lx) + L_weighting = torch.clip(L_weighting, + self.split_input_params["clip_min_tie_weight"], + self.split_input_params["clip_max_tie_weight"]) + + L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device) + weighting = weighting * L_weighting + return weighting + + def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code + """ + :param x: img of size (bs, c, h, w) + :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1]) + """ + bs, nc, h, w = x.shape + + # number of crops in image + Ly = (h - kernel_size[0]) // stride[0] + 1 + Lx = (w - kernel_size[1]) // stride[1] + 1 + + if uf == 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params) + + weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx)) + + elif uf > 1 and df == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf), + dilation=1, padding=0, + stride=(stride[0] * uf, stride[1] * uf)) + fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx)) + + elif df > 1 and uf == 1: + fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride) + unfold = torch.nn.Unfold(**fold_params) + + fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df), + dilation=1, padding=0, + stride=(stride[0] // df, stride[1] // df)) + fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2) + + weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype) + normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap + weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx)) + + else: + raise NotImplementedError + + return fold, unfold, normalization, weighting + + @torch.no_grad() + def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False, + cond_key=None, return_original_cond=False, bs=None, return_x=False): + x = super().get_input(batch, k) + if bs is not None: + x = x[:bs] + x = x.to(self.device) + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + + if self.model.conditioning_key is not None and not self.force_null_conditioning: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ['caption', 'coordinates_bbox', "txt"]: + xc = batch[cond_key] + elif cond_key in ['class_label', 'cls']: + xc = batch + else: + xc = super().get_input(batch, cond_key).to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + if bs is not None: + c = c[:bs] + + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + ckey = __conditioning_keys__[self.model.conditioning_key] + c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y} + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {'pos_x': pos_x, 'pos_y': pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_x: + out.extend([x]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, 'b h w c -> b c h w').contiguous() + + z = 1. / self.scale_factor * z + return self.first_stage_model.decode(z) + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def shared_step(self, batch, **kwargs): + x, c = self.get_input(batch, self.first_stage_key) + loss = self(x, c) + return loss + + def forward(self, x, c, *args, **kwargs): + t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long() + if self.model.conditioning_key is not None: + assert c is not None + if self.cond_stage_trainable: + c = self.get_learned_conditioning(c) + if self.shorten_cond_schedule: # TODO: drop this option + tc = self.cond_ids[t].to(self.device) + c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float())) + return self.p_losses(x, c, t, *args, **kwargs) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + if isinstance(cond, dict): + # hybrid case, cond is expected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn' + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \ + extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0) + return mean_flat(kl_prior) / np.log(2.0) + + def p_losses(self, x_start, cond, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + model_output = self.apply_model(x_noisy, t, cond) + + loss_dict = {} + prefix = 'train' if self.training else 'val' + + if self.parameterization == "x0": + target = x_start + elif self.parameterization == "eps": + target = noise + elif self.parameterization == "v": + target = self.get_v(x_start, noise, t) + else: + raise NotImplementedError() + + loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3]) + loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()}) + + logvar_t = self.logvar[t].to(self.device) + loss = loss_simple / torch.exp(logvar_t) + logvar_t + # loss = loss_simple / torch.exp(self.logvar) + self.logvar + if self.learn_logvar: + loss_dict.update({f'{prefix}/loss_gamma': loss.mean()}) + loss_dict.update({'logvar': self.logvar.data.mean()}) + + loss = self.l_simple_weight * loss.mean() + + loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3)) + loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + loss += (self.original_elbo_weight * loss_vlb) + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False, + return_x0=False, score_corrector=None, corrector_kwargs=None): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1., 1.) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False, + return_codebook_ids=False, quantize_denoised=False, return_x0=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + + if return_codebook_ids: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1) + if return_x0: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0 + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False, + img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0., + score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None, + log_every_t=None): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation', + total=timesteps) if verbose else reversed( + range(0, timesteps)) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, return_x0=True, + temperature=temperature[i], noise_dropout=noise_dropout, + score_corrector=score_corrector, corrector_kwargs=corrector_kwargs) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: callback(i) + if img_callback: img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop(self, cond, shape, return_intermediates=False, + x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, start_T=None, + log_every_t=None): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed( + range(0, timesteps)) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != 'hybrid' + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample(img, cond, ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1. - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: callback(i) + if img_callback: img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None, + verbose=True, timesteps=None, quantize_denoised=False, + mask=None, x0=None, shape=None, **kwargs): + if shape is None: + shape = (batch_size, self.channels, self.image_size, self.image_size) + if cond is not None: + if isinstance(cond, dict): + cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else + list(map(lambda x: x[:batch_size], cond[key])) for key in cond} + else: + cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size] + return self.p_sample_loop(cond, + shape, + return_intermediates=return_intermediates, x_T=x_T, + verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised, + mask=mask, x0=x0) + + @torch.no_grad() + def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs): + if ddim: + ddim_sampler = DDIMSampler(self) + shape = (self.channels, self.image_size, self.image_size) + samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, + shape, cond, verbose=False, **kwargs) + + else: + samples, intermediates = self.sample(cond=cond, batch_size=batch_size, + return_intermediates=True, **kwargs) + + return samples, intermediates + + @torch.no_grad() + def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + if self.cond_stage_key in ["class_label", "cls"]: + xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device) + return self.get_learned_conditioning(xc) + else: + raise NotImplementedError("todo") + if isinstance(c, list): # in case the encoder gives us a list + for i in range(len(c)): + c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device) + else: + c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device) + return c + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, + return_first_stage_outputs=True, + force_c_encode=True, + return_original_cond=True, + bs=N) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', "cls"]: + try: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + except KeyError: + # probably no "human_label" in batch + pass + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance( + self.first_stage_model, IdentityFirstStage): + # also display when quantizing x0 while sampling + with ema_scope("Plotting Quantized Denoised"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + quantize_denoised=True) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True, + # quantize_denoised=True) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_x0_quantized"] = x_samples + + if unconditional_guidance_scale > 1.0: + uc = self.get_unconditional_conditioning(N, unconditional_guidance_label) + if self.model.conditioning_key == "crossattn-adm": + uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if inpaint: + # make a simple center square + b, h, w = z.shape[0], z.shape[2], z.shape[3] + mask = torch.ones(N, h, w).to(self.device) + # zeros will be filled in + mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0. + mask = mask[:, None, ...] + with ema_scope("Plotting Inpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_inpainting"] = x_samples + log["mask"] = mask + + # outpaint + mask = 1. - mask + with ema_scope("Plotting Outpaint"): + samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta, + ddim_steps=ddim_steps, x0=z[:N], mask=mask) + x_samples = self.decode_first_stage(samples.to(self.device)) + log["samples_outpainting"] = x_samples + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + if return_keys: + if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0: + return log + else: + return {key: log[key] for key in return_keys} + return log + + def configure_optimizers(self): + lr = self.learning_rate + params = list(self.model.parameters()) + if self.cond_stage_trainable: + print(f"{self.__class__.__name__}: Also optimizing conditioner params!") + params = params + list(self.cond_stage_model.parameters()) + if self.learn_logvar: + print('Diffusion model optimizing logvar') + params.append(self.logvar) + opt = torch.optim.AdamW(params, lr=lr) + if self.use_scheduler: + assert 'target' in self.scheduler_config + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }] + return [opt], scheduler + return opt + + @torch.no_grad() + def to_rgb(self, x): + x = x.float() + if not hasattr(self, "colorize"): + self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2. * (x - x.min()) / (x.max() - x.min()) - 1. + return x + + +class DiffusionWrapper(pl.LightningModule): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False) + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm'] + + def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None): + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == 'concat': + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == 'crossattn': + if not self.sequential_cross_attn: + cc = torch.cat(c_crossattn, 1) + else: + cc = c_crossattn + if hasattr(self, "scripted_diffusion_model"): + # TorchScript changes names of the arguments + # with argument cc defined as context=cc scripted model will produce + # an error: RuntimeError: forward() is missing value for argument 'argument_3'. + out = self.scripted_diffusion_model(x, t, cc) + else: + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == 'hybrid': + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif self.conditioning_key == 'hybrid-adm': + assert c_adm is not None + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc, y=c_adm) + elif self.conditioning_key == 'crossattn-adm': + assert c_adm is not None + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc, y=c_adm) + elif self.conditioning_key == 'adm': + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class LatentUpscaleDiffusion(LatentDiffusion): + def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs): + super().__init__(*args, **kwargs) + # assumes that neither the cond_stage nor the low_scale_model contain trainable params + assert not self.cond_stage_trainable + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + self.noise_level_key = noise_level_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False): + if not log_mode: + z, c = super().get_input(batch, k, force_c_encode=True, bs=bs) + else: + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + x_low = batch[self.low_scale_key][:bs] + x_low = rearrange(x_low, 'b h w c -> b c h w') + x_low = x_low.to(memory_format=torch.contiguous_format).float() + zx, noise_level = self.low_scale_model(x_low) + if self.noise_level_key is not None: + # get noise level from batch instead, e.g. when extracting a custom noise level for bsr + raise NotImplementedError('TODO') + + all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level} + if log_mode: + # TODO: maybe disable if too expensive + x_low_rec = self.low_scale_model.decode(zx) + return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level + return z, all_conds + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True, + unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N, + log_mode=True) + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + log["x_lr"] = x_low + log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label) + # TODO explore better "unconditional" choices for the other keys + # maybe guide away from empty text label and highest noise level and maximally degraded zx? + uc = dict() + for k in c: + if k == "c_crossattn": + assert isinstance(c[k], list) and len(c[k]) == 1 + uc[k] = [uc_tmp] + elif k == "c_adm": # todo: only run with text-based guidance? + assert isinstance(c[k], torch.Tensor) + #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level + uc[k] = c[k] + elif isinstance(c[k], list): + uc[k] = [c[k][i] for i in range(len(c[k]))] + else: + uc[k] = c[k] + + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + if plot_progressive_rows: + with ema_scope("Plotting Progressives"): + img, progressives = self.progressive_denoising(c, + shape=(self.channels, self.image_size, self.image_size), + batch_size=N) + prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation") + log["progressive_row"] = prog_row + + return log + + +class LatentFinetuneDiffusion(LatentDiffusion): + """ + Basis for different finetunas, such as inpainting or depth2image + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys: tuple, + finetune_keys=("model.diffusion_model.input_blocks.0.0.weight", + "model_ema.diffusion_modelinput_blocks00weight" + ), + keep_finetune_dims=4, + # if model was trained without concat mode before and we would like to keep these channels + c_concat_log_start=None, # to log reconstruction of c_concat codes + c_concat_log_end=None, + *args, **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", list()) + super().__init__(*args, **kwargs) + self.finetune_keys = finetune_keys + self.concat_keys = concat_keys + self.keep_dims = keep_finetune_dims + self.c_concat_log_start = c_concat_log_start + self.c_concat_log_end = c_concat_log_end + if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint' + if exists(ckpt_path): + self.init_from_ckpt(ckpt_path, ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + + # make it explicit, finetune by including extra input channels + if exists(self.finetune_keys) and k in self.finetune_keys: + new_entry = None + for name, param in self.named_parameters(): + if name in self.finetune_keys: + print( + f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only") + new_entry = torch.zeros_like(param) # zero init + assert exists(new_entry), 'did not find matching parameter to modify' + new_entry[:, :self.keep_dims, ...] = sd[k] + sd[k] = new_entry + + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None, + quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True, + plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None, + use_ema_scope=True, + **kwargs): + ema_scope = self.ema_scope if use_ema_scope else nullcontext + use_ddim = ddim_steps is not None + + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True) + c_cat, c = c["c_concat"][0], c["c_crossattn"][0] + N = min(x.shape[0], N) + n_row = min(x.shape[0], n_row) + log["inputs"] = x + log["reconstruction"] = xrec + if self.model.conditioning_key is not None: + if hasattr(self.cond_stage_model, "decode"): + xc = self.cond_stage_model.decode(c) + log["conditioning"] = xc + elif self.cond_stage_key in ["caption", "txt"]: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + elif self.cond_stage_key in ['class_label', 'cls']: + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25) + log['conditioning'] = xc + elif isimage(xc): + log["conditioning"] = xc + if ismap(xc): + log["original_conditioning"] = self.to_rgb(xc) + + if not (self.c_concat_log_start is None and self.c_concat_log_end is None): + log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end]) + + if plot_diffusion_rows: + # get diffusion row + diffusion_row = list() + z_start = z[:n_row] + for t in range(self.num_timesteps): + if t % self.log_every_t == 0 or t == self.num_timesteps - 1: + t = repeat(torch.tensor([t]), '1 -> b', b=n_row) + t = t.to(self.device).long() + noise = torch.randn_like(z_start) + z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise) + diffusion_row.append(self.decode_first_stage(z_noisy)) + + diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W + diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w') + diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w') + diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0]) + log["diffusion_row"] = diffusion_grid + + if sample: + # get denoise row + with ema_scope("Sampling"): + samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta) + # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True) + x_samples = self.decode_first_stage(samples) + log["samples"] = x_samples + if plot_denoise_rows: + denoise_grid = self._get_denoise_row_from_list(z_denoise_row) + log["denoise_row"] = denoise_grid + + if unconditional_guidance_scale > 1.0: + uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label) + uc_cat = c_cat + uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]} + with ema_scope("Sampling with classifier-free guidance"): + samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]}, + batch_size=N, ddim=use_ddim, + ddim_steps=ddim_steps, eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_full, + ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + + return log + + +class LatentInpaintDiffusion(LatentFinetuneDiffusion): + """ + can either run as pure inpainting model (only concat mode) or with mixed conditionings, + e.g. mask as concat and text via cross-attn. + To disable finetuning mode, set finetune_keys to None + """ + + def __init__(self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, **kwargs + ): + super().__init__(concat_keys, *args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + c_cat = list() + for ck in self.concat_keys: + cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + bchw = z.shape + if ck != self.masked_image_key: + cc = torch.nn.functional.interpolate(cc, size=bchw[-2:]) + else: + cc = self.get_first_stage_encoding(self.encode_first_stage(cc)) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs) + log["masked_image"] = rearrange(args[0]["masked_image"], + 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float() + return log + + +class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion): + """ + condition on monocular depth estimation + """ + + def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.depth_model = instantiate_from_config(depth_stage_config) + self.depth_stage_key = concat_keys[0] + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + c_cat = list() + for ck in self.concat_keys: + cc = batch[ck] + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + cc = self.depth_model(cc) + cc = torch.nn.functional.interpolate( + cc, + size=z.shape[2:], + mode="bicubic", + align_corners=False, + ) + + depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3], + keepdim=True) + cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1. + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + depth = self.depth_model(args[0][self.depth_stage_key]) + depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \ + torch.amax(depth, dim=[1, 2, 3], keepdim=True) + log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1. + return log + + +class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion): + """ + condition on low-res image (and optionally on some spatial noise augmentation) + """ + def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None, + low_scale_config=None, low_scale_key=None, *args, **kwargs): + super().__init__(concat_keys=concat_keys, *args, **kwargs) + self.reshuffle_patch_size = reshuffle_patch_size + self.low_scale_model = None + if low_scale_config is not None: + print("Initializing a low-scale model") + assert exists(low_scale_key) + self.instantiate_low_stage(low_scale_config) + self.low_scale_key = low_scale_key + + def instantiate_low_stage(self, config): + model = instantiate_from_config(config) + self.low_scale_model = model.eval() + self.low_scale_model.train = disabled_train + for param in self.low_scale_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False): + # note: restricted to non-trainable encoders currently + assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft' + z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True, + force_c_encode=True, return_original_cond=True, bs=bs) + + assert exists(self.concat_keys) + assert len(self.concat_keys) == 1 + # optionally make spatial noise_level here + c_cat = list() + noise_level = None + for ck in self.concat_keys: + cc = batch[ck] + cc = rearrange(cc, 'b h w c -> b c h w') + if exists(self.reshuffle_patch_size): + assert isinstance(self.reshuffle_patch_size, int) + cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w', + p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size) + if bs is not None: + cc = cc[:bs] + cc = cc.to(self.device) + if exists(self.low_scale_model) and ck == self.low_scale_key: + cc, noise_level = self.low_scale_model(cc) + c_cat.append(cc) + c_cat = torch.cat(c_cat, dim=1) + if exists(noise_level): + all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level} + else: + all_conds = {"c_concat": [c_cat], "c_crossattn": [c]} + if return_first_stage_outputs: + return z, all_conds, x, xrec, xc + return z, all_conds + + @torch.no_grad() + def log_images(self, *args, **kwargs): + log = super().log_images(*args, **kwargs) + log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w') + return log + + +class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion): + def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5, + freeze_embedder=True, noise_aug_config=None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.embed_key = embedding_key + self.embedding_dropout = embedding_dropout + self._init_embedder(embedder_config, freeze_embedder) + self._init_noise_aug(noise_aug_config) + + def _init_embedder(self, config, freeze=True): + embedder = instantiate_from_config(config) + if freeze: + self.embedder = embedder.eval() + self.embedder.train = disabled_train + for param in self.embedder.parameters(): + param.requires_grad = False + + def _init_noise_aug(self, config): + if config is not None: + # use the KARLO schedule for noise augmentation on CLIP image embeddings + noise_augmentor = instantiate_from_config(config) + assert isinstance(noise_augmentor, nn.Module) + noise_augmentor = noise_augmentor.eval() + noise_augmentor.train = disabled_train + self.noise_augmentor = noise_augmentor + else: + self.noise_augmentor = None + + def get_input(self, batch, k, cond_key=None, bs=None, **kwargs): + outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs) + z, c = outputs[0], outputs[1] + img = batch[self.embed_key][:bs] + img = rearrange(img, 'b h w c -> b c h w') + c_adm = self.embedder(img) + if self.noise_augmentor is not None: + c_adm, noise_level_emb = self.noise_augmentor(c_adm) + # assume this gives embeddings of noise levels + c_adm = torch.cat((c_adm, noise_level_emb), 1) + if self.training: + c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0], + device=c_adm.device)[:, None]) * c_adm + all_conds = {"c_crossattn": [c], "c_adm": c_adm} + noutputs = [z, all_conds] + noutputs.extend(outputs[2:]) + return noutputs + + @torch.no_grad() + def log_images(self, batch, N=8, n_row=4, **kwargs): + log = dict() + z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True, + return_original_cond=True) + log["inputs"] = x + log["reconstruction"] = xrec + assert self.model.conditioning_key is not None + assert self.cond_stage_key in ["caption", "txt"] + xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25) + log["conditioning"] = xc + uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', '')) + unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.) + + uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]} + ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext + with ema_scope(f"Sampling"): + samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True, + ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.), + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc_, ) + x_samples_cfg = self.decode_first_stage(samples_cfg) + log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg + return log diff --git a/repositories/ldm/models/diffusion/dpm_solver/__init__.py b/repositories/ldm/models/diffusion/dpm_solver/__init__.py new file mode 100644 index 000000000..7427f38c0 --- /dev/null +++ b/repositories/ldm/models/diffusion/dpm_solver/__init__.py @@ -0,0 +1 @@ +from .sampler import DPMSolverSampler \ No newline at end of file diff --git a/repositories/ldm/models/diffusion/dpm_solver/dpm_solver.py b/repositories/ldm/models/diffusion/dpm_solver/dpm_solver.py new file mode 100644 index 000000000..da8d41f9c --- /dev/null +++ b/repositories/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -0,0 +1,1163 @@ +import torch +import torch.nn.functional as F +import math +from tqdm import tqdm + + +class NoiseScheduleVP: + def __init__( + self, + schedule='discrete', + betas=None, + alphas_cumprod=None, + continuous_beta_0=0.1, + continuous_beta_1=20., + ): + """Create a wrapper class for the forward SDE (VP type). + *** + Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. + We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. + *** + The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). + We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). + Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: + log_alpha_t = self.marginal_log_mean_coeff(t) + sigma_t = self.marginal_std(t) + lambda_t = self.marginal_lambda(t) + Moreover, as lambda(t) is an invertible function, we also support its inverse function: + t = self.inverse_lambda(lambda_t) + =============================================================== + We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). + 1. For discrete-time DPMs: + For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: + t_i = (i + 1) / N + e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. + We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. + Args: + betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) + alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) + Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. + **Important**: Please pay special attention for the args for `alphas_cumprod`: + The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that + q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). + Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have + alpha_{t_n} = \sqrt{\hat{alpha_n}}, + and + log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). + 2. For continuous-time DPMs: + We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise + schedule are the default settings in DDPM and improved-DDPM: + Args: + beta_min: A `float` number. The smallest beta for the linear schedule. + beta_max: A `float` number. The largest beta for the linear schedule. + cosine_s: A `float` number. The hyperparameter in the cosine schedule. + cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. + T: A `float` number. The ending time of the forward process. + =============================================================== + Args: + schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, + 'linear' or 'cosine' for continuous-time DPMs. + Returns: + A wrapper object of the forward SDE (VP type). + + =============================================================== + Example: + # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', betas=betas) + # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): + >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + # For continuous-time DPMs (VPSDE), linear schedule: + >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) + """ + + if schedule not in ['discrete', 'linear', 'cosine']: + raise ValueError( + "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format( + schedule)) + + self.schedule = schedule + if schedule == 'discrete': + if betas is not None: + log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) + else: + assert alphas_cumprod is not None + log_alphas = 0.5 * torch.log(alphas_cumprod) + self.total_N = len(log_alphas) + self.T = 1. + self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)) + self.log_alpha_array = log_alphas.reshape((1, -1,)) + else: + self.total_N = 1000 + self.beta_0 = continuous_beta_0 + self.beta_1 = continuous_beta_1 + self.cosine_s = 0.008 + self.cosine_beta_max = 999. + self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) + self.schedule = schedule + if schedule == 'cosine': + # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. + # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. + self.T = 0.9946 + else: + self.T = 1. + + def marginal_log_mean_coeff(self, t): + """ + Compute log(alpha_t) of a given continuous-time label t in [0, T]. + """ + if self.schedule == 'discrete': + return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), + self.log_alpha_array.to(t.device)).reshape((-1)) + elif self.schedule == 'linear': + return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 + elif self.schedule == 'cosine': + log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) + log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 + return log_alpha_t + + def marginal_alpha(self, t): + """ + Compute alpha_t of a given continuous-time label t in [0, T]. + """ + return torch.exp(self.marginal_log_mean_coeff(t)) + + def marginal_std(self, t): + """ + Compute sigma_t of a given continuous-time label t in [0, T]. + """ + return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) + + def marginal_lambda(self, t): + """ + Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. + """ + log_mean_coeff = self.marginal_log_mean_coeff(t) + log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) + return log_mean_coeff - log_std + + def inverse_lambda(self, lamb): + """ + Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. + """ + if self.schedule == 'linear': + tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + Delta = self.beta_0 ** 2 + tmp + return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) + elif self.schedule == 'discrete': + log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) + t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), + torch.flip(self.t_array.to(lamb.device), [1])) + return t.reshape((-1,)) + else: + log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) + t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * ( + 1. + self.cosine_s) / math.pi - self.cosine_s + t = t_fn(log_alpha) + return t + + +def model_wrapper( + model, + noise_schedule, + model_type="noise", + model_kwargs={}, + guidance_type="uncond", + condition=None, + unconditional_condition=None, + guidance_scale=1., + classifier_fn=None, + classifier_kwargs={}, +): + """Create a wrapper function for the noise prediction model. + DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to + firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. + We support four types of the diffusion model by setting `model_type`: + 1. "noise": noise prediction model. (Trained by predicting noise). + 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). + 3. "v": velocity prediction model. (Trained by predicting the velocity). + The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. + [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." + arXiv preprint arXiv:2202.00512 (2022). + [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." + arXiv preprint arXiv:2210.02303 (2022). + + 4. "score": marginal score function. (Trained by denoising score matching). + Note that the score function and the noise prediction model follows a simple relationship: + ``` + noise(x_t, t) = -sigma_t * score(x_t, t) + ``` + We support three types of guided sampling by DPMs by setting `guidance_type`: + 1. "uncond": unconditional sampling by DPMs. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. + The input `model` has the following format: + `` + model(x, t_input, **model_kwargs) -> noise | x_start | v | score + `` + The input `classifier_fn` has the following format: + `` + classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) + `` + [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," + in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. + 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. + The input `model` has the following format: + `` + model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score + `` + And if cond == `unconditional_condition`, the model output is the unconditional DPM output. + [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." + arXiv preprint arXiv:2207.12598 (2022). + + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) + or continuous-time labels (i.e. epsilon to T). + We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: + `` + def model_fn(x, t_continuous) -> noise: + t_input = get_model_input_time(t_continuous) + return noise_pred(model, x, t_input, **model_kwargs) + `` + where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. + =============================================================== + Args: + model: A diffusion model with the corresponding format described above. + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + model_type: A `str`. The parameterization type of the diffusion model. + "noise" or "x_start" or "v" or "score". + model_kwargs: A `dict`. A dict for the other inputs of the model function. + guidance_type: A `str`. The type of the guidance for sampling. + "uncond" or "classifier" or "classifier-free". + condition: A pytorch tensor. The condition for the guided sampling. + Only used for "classifier" or "classifier-free" guidance type. + unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. + Only used for "classifier-free" guidance type. + guidance_scale: A `float`. The scale for the guided sampling. + classifier_fn: A classifier function. Only used for the classifier guidance. + classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. + Returns: + A noise prediction model that accepts the noised data and the continuous time as the inputs. + """ + + def get_model_input_time(t_continuous): + """ + Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. + For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. + For continuous-time DPMs, we just use `t_continuous`. + """ + if noise_schedule.schedule == 'discrete': + return (t_continuous - 1. / noise_schedule.total_N) * 1000. + else: + return t_continuous + + def noise_pred_fn(x, t_continuous, cond=None): + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + t_input = get_model_input_time(t_continuous) + if cond is None: + output = model(x, t_input, **model_kwargs) + else: + output = model(x, t_input, cond, **model_kwargs) + if model_type == "noise": + return output + elif model_type == "x_start": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims) + elif model_type == "v": + alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x + elif model_type == "score": + sigma_t = noise_schedule.marginal_std(t_continuous) + dims = x.dim() + return -expand_dims(sigma_t, dims) * output + + def cond_grad_fn(x, t_input): + """ + Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). + """ + with torch.enable_grad(): + x_in = x.detach().requires_grad_(True) + log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) + return torch.autograd.grad(log_prob.sum(), x_in)[0] + + def model_fn(x, t_continuous): + """ + The noise predicition model function that is used for DPM-Solver. + """ + if t_continuous.reshape((-1,)).shape[0] == 1: + t_continuous = t_continuous.expand((x.shape[0])) + if guidance_type == "uncond": + return noise_pred_fn(x, t_continuous) + elif guidance_type == "classifier": + assert classifier_fn is not None + t_input = get_model_input_time(t_continuous) + cond_grad = cond_grad_fn(x, t_input) + sigma_t = noise_schedule.marginal_std(t_continuous) + noise = noise_pred_fn(x, t_continuous) + return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad + elif guidance_type == "classifier-free": + if guidance_scale == 1. or unconditional_condition is None: + return noise_pred_fn(x, t_continuous, cond=condition) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t_continuous] * 2) + if isinstance(condition, dict): + assert isinstance(unconditional_condition, dict) + c_in = dict() + for k in condition: + if isinstance(condition[k], list): + c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))] + else: + c_in[k] = torch.cat([unconditional_condition[k], condition[k]]) + else: + c_in = torch.cat([unconditional_condition, condition]) + noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) + return noise_uncond + guidance_scale * (noise - noise_uncond) + + assert model_type in ["noise", "x_start", "v"] + assert guidance_type in ["uncond", "classifier", "classifier-free"] + return model_fn + + +class DPM_Solver: + def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.): + """Construct a DPM-Solver. + We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0"). + If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver). + If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++). + In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True. + The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales. + Args: + model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): + `` + def model_fn(x, t_continuous): + return noise + `` + noise_schedule: A noise schedule object, such as NoiseScheduleVP. + predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model. + thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1]. + max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding. + + [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. + """ + self.model = model_fn + self.noise_schedule = noise_schedule + self.predict_x0 = predict_x0 + self.thresholding = thresholding + self.max_val = max_val + + def noise_prediction_fn(self, x, t): + """ + Return the noise prediction model. + """ + return self.model(x, t) + + def data_prediction_fn(self, x, t): + """ + Return the data prediction model (with thresholding). + """ + noise = self.noise_prediction_fn(x, t) + dims = x.dim() + alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) + x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims) + if self.thresholding: + p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. + s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s + return x0 + + def model_fn(self, x, t): + """ + Convert the model to the noise prediction model or the data prediction model. + """ + if self.predict_x0: + return self.data_prediction_fn(x, t) + else: + return self.noise_prediction_fn(x, t) + + def get_time_steps(self, skip_type, t_T, t_0, N, device): + """Compute the intermediate time steps for sampling. + Args: + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + N: A `int`. The total number of the spacing of the time steps. + device: A torch device. + Returns: + A pytorch tensor of the time steps, with the shape (N + 1,). + """ + if skip_type == 'logSNR': + lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) + lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) + logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) + return self.noise_schedule.inverse_lambda(logSNR_steps) + elif skip_type == 'time_uniform': + return torch.linspace(t_T, t_0, N + 1).to(device) + elif skip_type == 'time_quadratic': + t_order = 2 + t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device) + return t + else: + raise ValueError( + "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) + + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): + """ + Get the order of each step for sampling by the singlestep DPM-Solver. + We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". + Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: + - If order == 1: + We take `steps` of DPM-Solver-1 (i.e. DDIM). + - If order == 2: + - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of DPM-Solver-2. + - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If order == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. + ============================================ + Args: + order: A `int`. The max order for the solver (2 or 3). + steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. + Returns: + orders: A list of the solver order of each step. + """ + if order == 3: + K = steps // 3 + 1 + if steps % 3 == 0: + orders = [3, ] * (K - 2) + [2, 1] + elif steps % 3 == 1: + orders = [3, ] * (K - 1) + [1] + else: + orders = [3, ] * (K - 1) + [2] + elif order == 2: + if steps % 2 == 0: + K = steps // 2 + orders = [2, ] * K + else: + K = steps // 2 + 1 + orders = [2, ] * (K - 1) + [1] + elif order == 1: + K = 1 + orders = [1, ] * steps + else: + raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[ + torch.cumsum(torch.tensor([0, ] + orders)).to(device)] + return timesteps_outer, orders + + def denoise_to_zero_fn(self, x, s): + """ + Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. + """ + return self.data_prediction_fn(x, s) + + def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): + """ + DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + if self.predict_x0: + phi_1 = torch.expm1(-h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + else: + phi_1 = torch.expm1(h) + if model_s is None: + model_s = self.model_fn(x, s) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + ) + if return_intermediate: + return x_t, {'model_s': model_s} + else: + return x_t + + def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, + solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-2 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the second-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 0.5 + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + s1 = ns.inverse_lambda(lambda_s1) + log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff( + s1), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) + alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_1 = torch.expm1(-h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * ( + model_s1 - model_s) + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_1 = torch.expm1(h) + + if model_s is None: + model_s = self.model_fn(x, s) + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s) + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s) + ) + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1} + else: + return x_t + + def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None, + return_intermediate=False, solver_type='dpm_solver'): + """ + Singlestep solver DPM-Solver-3 from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + r1: A `float`. The hyperparameter of the third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + model_s: A pytorch tensor. The model function evaluated at time `s`. + If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. + model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). + If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + if r1 is None: + r1 = 1. / 3. + if r2 is None: + r2 = 2. / 3. + ns = self.noise_schedule + dims = x.dim() + lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) + h = lambda_t - lambda_s + lambda_s1 = lambda_s + r1 * h + lambda_s2 = lambda_s + r2 * h + s1 = ns.inverse_lambda(lambda_s1) + s2 = ns.inverse_lambda(lambda_s2) + log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff( + s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) + sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std( + s2), ns.marginal_std(t) + alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) + + if self.predict_x0: + phi_11 = torch.expm1(-r1 * h) + phi_12 = torch.expm1(-r2 * h) + phi_1 = torch.expm1(-h) + phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. + phi_2 = phi_1 / h + 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(sigma_s1 / sigma_s, dims) * x + - expand_dims(alpha_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(sigma_s2 / sigma_s, dims) * x + - expand_dims(alpha_s2 * phi_12, dims) * model_s + + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(sigma_t / sigma_s, dims) * x + - expand_dims(alpha_t * phi_1, dims) * model_s + + expand_dims(alpha_t * phi_2, dims) * D1 + - expand_dims(alpha_t * phi_3, dims) * D2 + ) + else: + phi_11 = torch.expm1(r1 * h) + phi_12 = torch.expm1(r2 * h) + phi_1 = torch.expm1(h) + phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. + phi_2 = phi_1 / h - 1. + phi_3 = phi_2 / h - 0.5 + + if model_s is None: + model_s = self.model_fn(x, s) + if model_s1 is None: + x_s1 = ( + expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x + - expand_dims(sigma_s1 * phi_11, dims) * model_s + ) + model_s1 = self.model_fn(x_s1, s1) + x_s2 = ( + expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x + - expand_dims(sigma_s2 * phi_12, dims) * model_s + - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s) + ) + model_s2 = self.model_fn(x_s2, s2) + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s) + ) + elif solver_type == 'taylor': + D1_0 = (1. / r1) * (model_s1 - model_s) + D1_1 = (1. / r2) * (model_s2 - model_s) + D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) + D2 = 2. * (D1_1 - D1_0) / (r2 - r1) + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x + - expand_dims(sigma_t * phi_1, dims) * model_s + - expand_dims(sigma_t * phi_2, dims) * D1 + - expand_dims(sigma_t * phi_3, dims) * D2 + ) + + if return_intermediate: + return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} + else: + return x_t + + def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"): + """ + Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if solver_type not in ['dpm_solver', 'taylor']: + raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type)) + ns = self.noise_schedule + dims = x.dim() + model_prev_1, model_prev_0 = model_prev_list + t_prev_1, t_prev_0 = t_prev_list + lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda( + t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0 = h_0 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + if self.predict_x0: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0 + ) + else: + if solver_type == 'dpm_solver': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0 + ) + elif solver_type == 'taylor': + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0 + ) + return x_t + + def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'): + """ + Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + ns = self.noise_schedule + dims = x.dim() + model_prev_2, model_prev_1, model_prev_0 = model_prev_list + t_prev_2, t_prev_1, t_prev_0 = t_prev_list + lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda( + t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) + log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) + sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) + alpha_t = torch.exp(log_alpha_t) + + h_1 = lambda_prev_1 - lambda_prev_2 + h_0 = lambda_prev_0 - lambda_prev_1 + h = lambda_t - lambda_prev_0 + r0, r1 = h_0 / h, h_1 / h + D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1) + D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2) + D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1) + D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1) + if self.predict_x0: + x_t = ( + expand_dims(sigma_t / sigma_prev_0, dims) * x + - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0 + + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1 + - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2 + ) + else: + x_t = ( + expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x + - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0 + - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1 + - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2 + ) + return x_t + + def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None, + r2=None): + """ + Singlestep DPM-Solver with the order `order` from time `s` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + s: A pytorch tensor. The starting time, with the shape (x.shape[0],). + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + r1: A `float`. The hyperparameter of the second-order or third-order solver. + r2: A `float`. The hyperparameter of the third-order solver. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) + elif order == 2: + return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1) + elif order == 3: + return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, + solver_type=solver_type, r1=r1, r2=r2) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'): + """ + Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. + Args: + x: A pytorch tensor. The initial value at time `s`. + model_prev_list: A list of pytorch tensor. The previous computed model values. + t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],) + t: A pytorch tensor. The ending time, with the shape (x.shape[0],). + order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_t: A pytorch tensor. The approximated solution at time `t`. + """ + if order == 1: + return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) + elif order == 2: + return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + elif order == 3: + return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) + else: + raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) + + def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, + solver_type='dpm_solver'): + """ + The adaptive step size solver based on singlestep DPM-Solver. + Args: + x: A pytorch tensor. The initial value at time `t_T`. + order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + h_init: A `float`. The initial step size (for logSNR). + atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. + rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. + theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. + t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the + current time and `t_0` is less than `t_err`. The default setting is 1e-5. + solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers. + The type slightly impacts the performance. We recommend to use 'dpm_solver' type. + Returns: + x_0: A pytorch tensor. The approximated solution at time `t_0`. + [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. + """ + ns = self.noise_schedule + s = t_T * torch.ones((x.shape[0],)).to(x) + lambda_s = ns.marginal_lambda(s) + lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) + h = h_init * torch.ones_like(s).to(x) + x_prev = x + nfe = 0 + if order == 2: + r1 = 0.5 + lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + solver_type=solver_type, + **kwargs) + elif order == 3: + r1, r2 = 1. / 3., 2. / 3. + lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, + return_intermediate=True, + solver_type=solver_type) + higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, + solver_type=solver_type, + **kwargs) + else: + raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) + while torch.abs((s - t_0)).mean() > t_err: + t = ns.inverse_lambda(lambda_s + h) + x_lower, lower_noise_kwargs = lower_update(x, s, t) + x_higher = higher_update(x, s, t, **lower_noise_kwargs) + delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) + norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) + E = norm_fn((x_higher - x_lower) / delta).max() + if torch.all(E <= 1.): + x = x_higher + s = t + x_prev = x_lower + lambda_s = ns.marginal_lambda(s) + h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) + nfe += order + print('adaptive solver nfe', nfe) + return x + + def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, + ): + """ + Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. + ===================================================== + We support the following algorithms for both noise prediction model and data prediction model: + - 'singlestep': + Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. + We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). + The total number of function evaluations (NFE) == `steps`. + Given a fixed NFE == `steps`, the sampling procedure is: + - If `order` == 1: + - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. + - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. + - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If `order` == 3: + - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. + - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. + - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. + - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. + - 'multistep': + Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. + We initialize the first `order` values by lower order multistep solvers. + Given a fixed NFE == `steps`, the sampling procedure is: + Denote K = steps. + - If `order` == 1: + - We use K steps of DPM-Solver-1 (i.e. DDIM). + - If `order` == 2: + - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. + - If `order` == 3: + - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. + - 'singlestep_fixed': + Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). + We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. + - 'adaptive': + Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). + We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. + You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs + (NFE) and the sample quality. + - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. + - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. + ===================================================== + Some advices for choosing the algorithm: + - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: + Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, + skip_type='time_uniform', method='singlestep') + - For **guided sampling with large guidance scale** by DPMs: + Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`. + e.g. + >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True) + >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, + skip_type='time_uniform', method='multistep') + We support three types of `skip_type`: + - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** + - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. + - 'time_quadratic': quadratic time for the time steps. + ===================================================== + Args: + x: A pytorch tensor. The initial value at time `t_start` + e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. + steps: A `int`. The total number of function evaluations (NFE). + t_start: A `float`. The starting time of the sampling. + If `T` is None, we use self.noise_schedule.T (default is 1.0). + t_end: A `float`. The ending time of the sampling. + If `t_end` is None, we use 1. / self.noise_schedule.total_N. + e.g. if total_N == 1000, we have `t_end` == 1e-3. + For discrete-time DPMs: + - We recommend `t_end` == 1. / self.noise_schedule.total_N. + For continuous-time DPMs: + - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. + order: A `int`. The order of DPM-Solver. + skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. + method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. + solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. + atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. + Returns: + x_end: A pytorch tensor. The approximated solution at time `t_end`. + """ + t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end + t_T = self.noise_schedule.T if t_start is None else t_start + device = x.device + if method == 'adaptive': + with torch.no_grad(): + x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, + solver_type=solver_type) + elif method == 'multistep': + assert steps >= order + timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + assert timesteps.shape[0] - 1 == steps + with torch.no_grad(): + vec_t = timesteps[0].expand((x.shape[0])) + model_prev_list = [self.model_fn(x, vec_t)] + t_prev_list = [vec_t] + # Init the first `order` values by lower order multistep DPM-Solver. + for init_order in tqdm(range(1, order), desc="DPM init order"): + vec_t = timesteps[init_order].expand(x.shape[0]) + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order, + solver_type=solver_type) + model_prev_list.append(self.model_fn(x, vec_t)) + t_prev_list.append(vec_t) + # Compute the remaining values by `order`-th order multistep DPM-Solver. + for step in tqdm(range(order, steps + 1), desc="DPM multistep"): + vec_t = timesteps[step].expand(x.shape[0]) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, + solver_type=solver_type) + for i in range(order - 1): + t_prev_list[i] = t_prev_list[i + 1] + model_prev_list[i] = model_prev_list[i + 1] + t_prev_list[-1] = vec_t + # We do not need to evaluate the final model value. + if step < steps: + model_prev_list[-1] = self.model_fn(x, vec_t) + elif method in ['singlestep', 'singlestep_fixed']: + if method == 'singlestep': + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, + skip_type=skip_type, + t_T=t_T, t_0=t_0, + device=device) + elif method == 'singlestep_fixed': + K = steps // order + orders = [order, ] * K + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), + N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + return x + + +############################################################# +# other utility functions +############################################################# + +def interpolate_fn(x, xp, yp): + """ + A piecewise linear function y = f(x), using xp and yp as keypoints. + We implement f(x) in a differentiable way (i.e. applicable for autograd). + The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) + Args: + x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). + xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. + yp: PyTorch tensor with shape [C, K]. + Returns: + The function values f(x), with shape [N, C]. + """ + N, K = x.shape[0], xp.shape[1] + all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) + sorted_all_x, x_indices = torch.sort(all_x, dim=2) + x_idx = torch.argmin(x_indices, dim=2) + cand_start_idx = x_idx - 1 + start_idx = torch.where( + torch.eq(x_idx, 0), + torch.tensor(1, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) + start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) + end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) + start_idx2 = torch.where( + torch.eq(x_idx, 0), + torch.tensor(0, device=x.device), + torch.where( + torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, + ), + ) + y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) + start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) + end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) + cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) + return cand + + +def expand_dims(v, dims): + """ + Expand the tensor `v` to the dim `dims`. + Args: + `v`: a PyTorch tensor with shape [N]. + `dim`: a `int`. + Returns: + a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. + """ + return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file diff --git a/repositories/ldm/models/diffusion/dpm_solver/sampler.py b/repositories/ldm/models/diffusion/dpm_solver/sampler.py new file mode 100644 index 000000000..e4d0d0a38 --- /dev/null +++ b/repositories/ldm/models/diffusion/dpm_solver/sampler.py @@ -0,0 +1,96 @@ +"""SAMPLING ONLY.""" +import torch + +from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver + +MODEL_TYPES = { + "eps": "noise", + "v": "v" +} + + +class DPMSolverSampler(object): + def __init__(self, model, device=torch.device("cuda"), **kwargs): + super().__init__() + self.model = model + self.device = device + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != self.device: + attr = attr.to(self.device) + setattr(self, name, attr) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + if isinstance(ctmp, torch.Tensor): + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}") + else: + if isinstance(conditioning, torch.Tensor): + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=MODEL_TYPES[self.model.parameterization], + guidance_type="classifier-free", + condition=conditioning, + unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, + lower_order_final=True) + + return x.to(device), None diff --git a/repositories/ldm/models/diffusion/plms.py b/repositories/ldm/models/diffusion/plms.py new file mode 100644 index 000000000..9d31b3994 --- /dev/null +++ b/repositories/ldm/models/diffusion/plms.py @@ -0,0 +1,245 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like +from ldm.models.diffusion.sampling_util import norm_thresholding + + +class PLMSSampler(object): + def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + self.device = device + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != self.device: + attr = attr.to(self.device) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): + if ddim_eta != 0: + raise ValueError('ddim_eta must be 0 for PLMS') + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + dynamic_threshold=None, + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + dynamic_threshold=dynamic_threshold, + ) + return samples, intermediates + + @torch.no_grad() + def plms_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, + dynamic_threshold=None): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + print(f"Running PLMS Sampling with {total_steps} timesteps") + + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) + old_eps = [] + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + old_eps=old_eps, t_next=ts_next, + dynamic_threshold=dynamic_threshold) + img, pred_x0, e_t = outs + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, + dynamic_threshold=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + if dynamic_threshold is not None: + pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t diff --git a/repositories/ldm/models/diffusion/sampling_util.py b/repositories/ldm/models/diffusion/sampling_util.py new file mode 100644 index 000000000..7eff02be6 --- /dev/null +++ b/repositories/ldm/models/diffusion/sampling_util.py @@ -0,0 +1,22 @@ +import torch +import numpy as np + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions. + From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + +def norm_thresholding(x0, value): + s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) + return x0 * (value / s) + + +def spatial_norm_thresholding(x0, value): + # b c h w + s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) + return x0 * (value / s) \ No newline at end of file diff --git a/repositories/ldm/modules/attention.py b/repositories/ldm/modules/attention.py new file mode 100644 index 000000000..509cd8737 --- /dev/null +++ b/repositories/ldm/modules/attention.py @@ -0,0 +1,341 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat +from typing import Optional, Any + +from ldm.modules.diffusionmodules.util import checkpoint + + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + +# CrossAttn precision handling +import os +_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") + +def exists(val): + return val is not None + + +def uniq(arr): + return{el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU() + ) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential( + project_in, + nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = rearrange(q, 'b c h w -> b (h w) c') + k = rearrange(k, 'b c h w -> b c (h w)') + w_ = torch.einsum('bij,bjk->bik', q, k) + + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, 'b c h w -> b c (h w)') + w_ = rearrange(w_, 'b i j -> b j i') + h_ = torch.einsum('bij,bjk->bik', v, w_) + h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) + h_ = self.proj_out(h_) + + return x+h_ + + +class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head ** -0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), + nn.Dropout(dropout) + ) + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + + del q, k + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class MemoryEfficientCrossAttention(nn.Module): + # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): + super().__init__() + print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " + f"{heads} heads.") + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.heads = heads + self.dim_head = dim_head + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + self.attention_op: Optional[Any] = None + + def forward(self, x, context=None, mask=None): + q = self.to_q(x) + context = default(context, x) + k = self.to_k(context) + v = self.to_v(context) + + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], self.heads, self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b * self.heads, t.shape[1], self.dim_head) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, self.heads, out.shape[1], self.dim_head) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], self.heads * self.dim_head) + ) + return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + ATTENTION_MODES = { + "softmax": CrossAttention, # vanilla attention + "softmax-xformers": MemoryEfficientCrossAttention + } + def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax" + assert attn_mode in self.ATTENTION_MODES + attn_cls = self.ATTENTION_MODES[attn_mode] + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, + context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, + heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + def __init__(self, in_channels, n_heads, d_head, + depth=1, dropout=0., context_dim=None, + disable_self_attn=False, use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + if not use_linear: + self.proj_in = nn.Conv2d(in_channels, + inner_dim, + kernel_size=1, + stride=1, + padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], + disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) + for d in range(depth)] + ) + if not use_linear: + self.proj_out = zero_module(nn.Conv2d(inner_dim, + in_channels, + kernel_size=1, + stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + diff --git a/repositories/ldm/modules/diffusionmodules/__init__.py b/repositories/ldm/modules/diffusionmodules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/diffusionmodules/model.py b/repositories/ldm/modules/diffusionmodules/model.py new file mode 100644 index 000000000..b089eebbe --- /dev/null +++ b/repositories/ldm/modules/diffusionmodules/model.py @@ -0,0 +1,852 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from typing import Optional, Any + +from ldm.modules.attention import MemoryEfficientCrossAttention + +try: + import xformers + import xformers.ops + XFORMERS_IS_AVAILBLE = True +except: + XFORMERS_IS_AVAILBLE = False + print("No module 'xformers'. Proceeding without it.") + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + +class MemoryEfficientAttnBlock(nn.Module): + """ + Uses xformers efficient implementation, + see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 + Note: this is a single-head self-attention operation + """ + # + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.attention_op: Optional[Any] = None + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + B, C, H, W = q.shape + q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) + + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(B, t.shape[1], 1, C) + .permute(0, 2, 1, 3) + .reshape(B * 1, t.shape[1], C) + .contiguous(), + (q, k, v), + ) + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) + + out = ( + out.unsqueeze(0) + .reshape(B, 1, out.shape[1], C) + .permute(0, 2, 1, 3) + .reshape(B, out.shape[1], C) + ) + out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) + out = self.proj_out(out) + return x+out + + +class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): + def forward(self, x, context=None, mask=None): + b, c, h, w = x.shape + x = rearrange(x, 'b c h w -> b (h w) c') + out = super().forward(x, context=context, mask=mask) + out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) + return x + out + + +def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): + assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' + if XFORMERS_IS_AVAILBLE and attn_type == "vanilla": + attn_type = "vanilla-xformers" + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + assert attn_kwargs is None + return AttnBlock(in_channels) + elif attn_type == "vanilla-xformers": + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + return MemoryEfficientAttnBlock(in_channels) + elif type == "memory-efficient-cross-attn": + attn_kwargs["query_dim"] = in_channels + return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise NotImplementedError() + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x diff --git a/repositories/ldm/modules/diffusionmodules/openaimodel.py b/repositories/ldm/modules/diffusionmodules/openaimodel.py new file mode 100644 index 000000000..cc3875c63 --- /dev/null +++ b/repositories/ldm/modules/diffusionmodules/openaimodel.py @@ -0,0 +1,807 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from ldm.modules.diffusionmodules.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from ldm.modules.attention import SpatialTransformer +from ldm.util import exists + + +# dummy replace +def convert_module_to_f16(x): + pass + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1) # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + +class TransposedUpsample(nn.Module): + 'Learned 2x upsampling without padding' + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2) + + def forward(self,x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=padding + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + #return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class Timestep(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, t): + return timestep_embedding(t, self.dim) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + use_checkpoint=False, + use_fp16=False, + use_bf16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + disable_self_attentions=None, + num_attention_blocks=None, + disable_middle_self_attn=False, + use_linear_in_transformer=False, + adm_in_channels=None, + ): + super().__init__() + if use_spatial_transformer: + assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' + + if context_dim is not None: + assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' + from omegaconf.listconfig import ListConfig + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' + + if num_head_channels == -1: + assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + if isinstance(num_res_blocks, int): + self.num_res_blocks = len(channel_mult) * [num_res_blocks] + else: + if len(num_res_blocks) != len(channel_mult): + raise ValueError("provide num_res_blocks either as an int (globally constant) or " + "as a list/tuple (per-level) with the same length as channel_mult") + self.num_res_blocks = num_res_blocks + if disable_self_attentions is not None: + # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not + assert len(disable_self_attentions) == len(channel_mult) + if num_attention_blocks is not None: + assert len(num_attention_blocks) == len(self.num_res_blocks) + assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) + print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " + f"This option has LESS priority than attention_resolutions {attention_resolutions}, " + f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " + f"attention will still not be set.") + + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.dtype = th.bfloat16 if use_bf16 else self.dtype + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.num_classes is not None: + if isinstance(self.num_classes, int): + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + elif self.num_classes == "continuous": + print("setting up linear c_adm embedding layer") + self.label_emb = nn.Linear(1, time_embed_dim) + elif self.num_classes == "sequential": + assert adm_in_channels is not None + self.label_emb = nn.Sequential( + nn.Sequential( + linear(adm_in_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + ) + else: + raise ValueError() + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for nr in range(self.num_res_blocks[level]): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(self.num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + #num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + if exists(disable_self_attentions): + disabled_sa = disable_self_attentions[level] + else: + disabled_sa = False + + if not exists(num_attention_blocks) or i < num_attention_blocks[level]: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) if not use_spatial_transformer else SpatialTransformer( + ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, + disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint + ) + ) + if level and i == self.num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None,**kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/repositories/ldm/modules/diffusionmodules/upscaling.py b/repositories/ldm/modules/diffusionmodules/upscaling.py new file mode 100644 index 000000000..038166620 --- /dev/null +++ b/repositories/ldm/modules/diffusionmodules/upscaling.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial + +from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule +from ldm.util import default + + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, beta_schedule="linear", timesteps=1000, + linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, + cosine_s=cosine_s) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/repositories/ldm/modules/diffusionmodules/util.py b/repositories/ldm/modules/diffusionmodules/util.py new file mode 100644 index 000000000..daf35da7b --- /dev/null +++ b/repositories/ldm/modules/diffusionmodules/util.py @@ -0,0 +1,278 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "squaredcos_cap_v2": # used for karlo prior + # return early + return betas_for_alpha_bar( + n_timestep, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), + "dtype": torch.get_autocast_gpu_dtype(), + "cache_enabled": torch.is_autocast_cache_enabled()} + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(), \ + torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/repositories/ldm/modules/distributions/__init__.py b/repositories/ldm/modules/distributions/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/distributions/distributions.py b/repositories/ldm/modules/distributions/distributions.py new file mode 100644 index 000000000..f2b8ef901 --- /dev/null +++ b/repositories/ldm/modules/distributions/distributions.py @@ -0,0 +1,92 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/repositories/ldm/modules/ema.py b/repositories/ldm/modules/ema.py new file mode 100644 index 000000000..bded25019 --- /dev/null +++ b/repositories/ldm/modules/ema.py @@ -0,0 +1,80 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError('Decay must be between 0 and 1') + + self.m_name2s_name = {} + self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates + else torch.tensor(-1, dtype=torch.int)) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace('.', '') + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def reset_num_updates(self): + del self.num_updates + self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int)) + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/repositories/ldm/modules/encoders/__init__.py b/repositories/ldm/modules/encoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/encoders/modules.py b/repositories/ldm/modules/encoders/modules.py new file mode 100644 index 000000000..523a7d853 --- /dev/null +++ b/repositories/ldm/modules/encoders/modules.py @@ -0,0 +1,350 @@ +import torch +import torch.nn as nn +import kornia +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip +from ldm.util import default, count_params, autocast + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, + freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + # self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + +class ClipImageEmbedder(nn.Module): + def __init__( + self, + model, + jit=False, + device='cuda' if torch.cuda.is_available() else 'cpu', + antialias=True, + ucg_rate=0. + ): + super().__init__() + from clip import load as load_clip + self.model, _ = load_clip(name=model, device=device, jit=jit) + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # re-normalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def forward(self, x, no_dropout=False): + # x is assumed to be in range [-1,1] + out = self.model.encode_image(self.preprocess(x)) + out = out.to(x.dtype) + if self.ucg_rate > 0. and not no_dropout: + out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out + return out + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + # "pooled", + "last", + "penultimate" + ] + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenOpenCLIPImageEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP vision transformer encoder for images + """ + + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="pooled", antialias=True, ucg_rate=0.): + super().__init__() + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), + pretrained=version, ) + del model.transformer + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "penultimate": + raise NotImplementedError() + self.layer_idx = 1 + + self.antialias = antialias + + self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) + self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) + self.ucg_rate = ucg_rate + + def preprocess(self, x): + # normalize to [0,1] + x = kornia.geometry.resize(x, (224, 224), + interpolation='bicubic', align_corners=True, + antialias=self.antialias) + x = (x + 1.) / 2. + # renormalize according to clip + x = kornia.enhance.normalize(x, self.mean, self.std) + return x + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + @autocast + def forward(self, image, no_dropout=False): + z = self.encode_with_vision_transformer(image) + if self.ucg_rate > 0. and not no_dropout: + z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z + return z + + def encode_with_vision_transformer(self, img): + img = self.preprocess(img) + x = self.model.visual(img) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " + f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + +from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation +from ldm.modules.diffusionmodules.openaimodel import Timestep + + +class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): + def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): + super().__init__(*args, **kwargs) + if clip_stats_path is None: + clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim) + else: + clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu") + self.register_buffer("data_mean", clip_mean[None, :], persistent=False) + self.register_buffer("data_std", clip_std[None, :], persistent=False) + self.time_embed = Timestep(timestep_dim) + + def scale(self, x): + # re-normalize to centered mean and unit variance + x = (x - self.data_mean) * 1. / self.data_std + return x + + def unscale(self, x): + # back to original data stats + x = (x * self.data_std) + self.data_mean + return x + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + x = self.scale(x) + z = self.q_sample(x, noise_level) + z = self.unscale(z) + noise_level = self.time_embed(noise_level) + return z, noise_level diff --git a/repositories/ldm/modules/image_degradation/__init__.py b/repositories/ldm/modules/image_degradation/__init__.py new file mode 100644 index 000000000..7836cada8 --- /dev/null +++ b/repositories/ldm/modules/image_degradation/__init__.py @@ -0,0 +1,2 @@ +from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr +from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light diff --git a/repositories/ldm/modules/image_degradation/bsrgan.py b/repositories/ldm/modules/image_degradation/bsrgan.py new file mode 100644 index 000000000..32ef56169 --- /dev/null +++ b/repositories/ldm/modules/image_degradation/bsrgan.py @@ -0,0 +1,730 @@ +# -*- coding: utf-8 -*- +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random()) + img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(30, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + elif i == 1: + image = add_blur(image, sf=sf) + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + example = {"image":image} + return example + + +# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc... +def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): + """ + This is an extended degradation model by combining + the degradation models of BSRGAN and Real-ESRGAN + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + use_shuffle: the degradation shuffle + use_sharp: sharpening the img + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + if use_sharp: + img = add_sharpening(img) + hq = img.copy() + + if random.random() < shuffle_prob: + shuffle_order = random.sample(range(13), 13) + else: + shuffle_order = list(range(13)) + # local shuffle for noise, JPEG is always the last one + shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) + shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) + + poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 + + for i in shuffle_order: + if i == 0: + img = add_blur(img, sf=sf) + elif i == 1: + img = add_resize(img, sf=sf) + elif i == 2: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 3: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 4: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 5: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + elif i == 6: + img = add_JPEG_noise(img) + elif i == 7: + img = add_blur(img, sf=sf) + elif i == 8: + img = add_resize(img, sf=sf) + elif i == 9: + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) + elif i == 10: + if random.random() < poisson_prob: + img = add_Poisson_noise(img) + elif i == 11: + if random.random() < speckle_prob: + img = add_speckle_noise(img) + elif i == 12: + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + else: + print('check the shuffle!') + + # resize to desired size + img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])), + interpolation=random.choice([1, 2, 3])) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf, lq_patchsize) + + return img, hq + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + print(img) + img = util.uint2single(img) + print(img) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_lq = deg_fn(img) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') + + diff --git a/repositories/ldm/modules/image_degradation/bsrgan_light.py b/repositories/ldm/modules/image_degradation/bsrgan_light.py new file mode 100644 index 000000000..808c7f882 --- /dev/null +++ b/repositories/ldm/modules/image_degradation/bsrgan_light.py @@ -0,0 +1,651 @@ +# -*- coding: utf-8 -*- +import numpy as np +import cv2 +import torch + +from functools import partial +import random +from scipy import ndimage +import scipy +import scipy.stats as ss +from scipy.interpolate import interp2d +from scipy.linalg import orth +import albumentations + +import ldm.modules.image_degradation.utils_image as util + +""" +# -------------------------------------------- +# Super-Resolution +# -------------------------------------------- +# +# Kai Zhang (cskaizhang@gmail.com) +# https://github.com/cszn +# From 2019/03--2021/08 +# -------------------------------------------- +""" + +def modcrop_np(img, sf): + ''' + Args: + img: numpy image, WxH or WxHxC + sf: scale factor + Return: + cropped image + ''' + w, h = img.shape[:2] + im = np.copy(img) + return im[:w - w % sf, :h - h % sf, ...] + + +""" +# -------------------------------------------- +# anisotropic Gaussian kernels +# -------------------------------------------- +""" + + +def analytic_kernel(k): + """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" + k_size = k.shape[0] + # Calculate the big kernels size + big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) + # Loop over the small kernel to fill the big one + for r in range(k_size): + for c in range(k_size): + big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k + # Crop the edges of the big kernel to ignore very small values and increase run time of SR + crop = k_size // 2 + cropped_big_k = big_k[crop:-crop, crop:-crop] + # Normalize to 1 + return cropped_big_k / cropped_big_k.sum() + + +def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): + """ generate an anisotropic Gaussian kernel + Args: + ksize : e.g., 15, kernel size + theta : [0, pi], rotation angle range + l1 : [0.1,50], scaling of eigenvalues + l2 : [0.1,l1], scaling of eigenvalues + If l1 = l2, will get an isotropic Gaussian kernel. + Returns: + k : kernel + """ + + v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) + V = np.array([[v[0], v[1]], [v[1], -v[0]]]) + D = np.array([[l1, 0], [0, l2]]) + Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) + k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) + + return k + + +def gm_blur_kernel(mean, cov, size=15): + center = size / 2.0 + 0.5 + k = np.zeros([size, size]) + for y in range(size): + for x in range(size): + cy = y - center + 1 + cx = x - center + 1 + k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) + + k = k / np.sum(k) + return k + + +def shift_pixel(x, sf, upper_left=True): + """shift pixel for super-resolution with different scale factors + Args: + x: WxHxC or WxH + sf: scale factor + upper_left: shift direction + """ + h, w = x.shape[:2] + shift = (sf - 1) * 0.5 + xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) + if upper_left: + x1 = xv + shift + y1 = yv + shift + else: + x1 = xv - shift + y1 = yv - shift + + x1 = np.clip(x1, 0, w - 1) + y1 = np.clip(y1, 0, h - 1) + + if x.ndim == 2: + x = interp2d(xv, yv, x)(x1, y1) + if x.ndim == 3: + for i in range(x.shape[-1]): + x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) + + return x + + +def blur(x, k): + ''' + x: image, NxcxHxW + k: kernel, Nx1xhxw + ''' + n, c = x.shape[:2] + p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2 + x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') + k = k.repeat(1, c, 1, 1) + k = k.view(-1, 1, k.shape[2], k.shape[3]) + x = x.view(1, -1, x.shape[2], x.shape[3]) + x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c) + x = x.view(n, c, x.shape[2], x.shape[3]) + + return x + + +def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): + """" + # modified version of https://github.com/assafshocher/BlindSR_dataset_generator + # Kai Zhang + # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var + # max_var = 2.5 * sf + """ + # Set random eigen-vals (lambdas) and angle (theta) for COV matrix + lambda_1 = min_var + np.random.rand() * (max_var - min_var) + lambda_2 = min_var + np.random.rand() * (max_var - min_var) + theta = np.random.rand() * np.pi # random theta + noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 + + # Set COV matrix using Lambdas and Theta + LAMBDA = np.diag([lambda_1, lambda_2]) + Q = np.array([[np.cos(theta), -np.sin(theta)], + [np.sin(theta), np.cos(theta)]]) + SIGMA = Q @ LAMBDA @ Q.T + INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] + + # Set expectation position (shifting kernel for aligned image) + MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) + MU = MU[None, None, :, None] + + # Create meshgrid for Gaussian + [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) + Z = np.stack([X, Y], 2)[:, :, :, None] + + # Calcualte Gaussian for every pixel of the kernel + ZZ = Z - MU + ZZ_t = ZZ.transpose(0, 1, 3, 2) + raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) + + # shift the kernel so it will be centered + # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) + + # Normalize the kernel and return + # kernel = raw_kernel_centered / np.sum(raw_kernel_centered) + kernel = raw_kernel / np.sum(raw_kernel) + return kernel + + +def fspecial_gaussian(hsize, sigma): + hsize = [hsize, hsize] + siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0] + std = sigma + [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1)) + arg = -(x * x + y * y) / (2 * std * std) + h = np.exp(arg) + h[h < scipy.finfo(float).eps * h.max()] = 0 + sumh = h.sum() + if sumh != 0: + h = h / sumh + return h + + +def fspecial_laplacian(alpha): + alpha = max([0, min([alpha, 1])]) + h1 = alpha / (alpha + 1) + h2 = (1 - alpha) / (alpha + 1) + h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]] + h = np.array(h) + return h + + +def fspecial(filter_type, *args, **kwargs): + ''' + python code from: + https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py + ''' + if filter_type == 'gaussian': + return fspecial_gaussian(*args, **kwargs) + if filter_type == 'laplacian': + return fspecial_laplacian(*args, **kwargs) + + +""" +# -------------------------------------------- +# degradation models +# -------------------------------------------- +""" + + +def bicubic_degradation(x, sf=3): + ''' + Args: + x: HxWxC image, [0, 1] + sf: down-scale factor + Return: + bicubicly downsampled LR image + ''' + x = util.imresize_np(x, scale=1 / sf) + return x + + +def srmd_degradation(x, k, sf=3): + ''' blur + bicubic downsampling + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2018learning, + title={Learning a single convolutional super-resolution network for multiple degradations}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={3262--3271}, + year={2018} + } + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' + x = bicubic_degradation(x, sf=sf) + return x + + +def dpsr_degradation(x, k, sf=3): + ''' bicubic downsampling + blur + Args: + x: HxWxC image, [0, 1] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + Reference: + @inproceedings{zhang2019deep, + title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, + author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, + booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, + pages={1671--1681}, + year={2019} + } + ''' + x = bicubic_degradation(x, sf=sf) + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + return x + + +def classical_degradation(x, k, sf=3): + ''' blur + downsampling + Args: + x: HxWxC image, [0, 1]/[0, 255] + k: hxw, double + sf: down-scale factor + Return: + downsampled LR image + ''' + x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') + # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) + st = 0 + return x[st::sf, st::sf, ...] + + +def add_sharpening(img, weight=0.5, radius=50, threshold=10): + """USM sharpening. borrowed from real-ESRGAN + Input image: I; Blurry image: B. + 1. K = I + weight * (I - B) + 2. Mask = 1 if abs(I - B) > threshold, else: 0 + 3. Blur mask: + 4. Out = Mask * K + (1 - Mask) * I + Args: + img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. + weight (float): Sharp weight. Default: 1. + radius (float): Kernel size of Gaussian blur. Default: 50. + threshold (int): + """ + if radius % 2 == 0: + radius += 1 + blur = cv2.GaussianBlur(img, (radius, radius), 0) + residual = img - blur + mask = np.abs(residual) * 255 > threshold + mask = mask.astype('float32') + soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) + + K = img + weight * residual + K = np.clip(K, 0, 1) + return soft_mask * K + (1 - soft_mask) * img + + +def add_blur(img, sf=4): + wd2 = 4.0 + sf + wd = 2.0 + 0.2 * sf + + wd2 = wd2/4 + wd = wd/4 + + if random.random() < 0.5: + l1 = wd2 * random.random() + l2 = wd2 * random.random() + k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2) + else: + k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random()) + img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror') + + return img + + +def add_resize(img, sf=4): + rnum = np.random.rand() + if rnum > 0.8: # up + sf1 = random.uniform(1, 2) + elif rnum < 0.7: # down + sf1 = random.uniform(0.5 / sf, 1) + else: + sf1 = 1.0 + img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + return img + + +# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): +# noise_level = random.randint(noise_level1, noise_level2) +# rnum = np.random.rand() +# if rnum > 0.6: # add color Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) +# elif rnum < 0.4: # add grayscale Gaussian noise +# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) +# else: # add noise +# L = noise_level2 / 255. +# D = np.diag(np.random.rand(3)) +# U = orth(np.random.rand(3, 3)) +# conv = np.dot(np.dot(np.transpose(U), D), U) +# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) +# img = np.clip(img, 0.0, 1.0) +# return img + +def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + rnum = np.random.rand() + if rnum > 0.6: # add color Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: # add grayscale Gaussian noise + img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: # add noise + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_speckle_noise(img, noise_level1=2, noise_level2=25): + noise_level = random.randint(noise_level1, noise_level2) + img = np.clip(img, 0.0, 1.0) + rnum = random.random() + if rnum > 0.6: + img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32) + elif rnum < 0.4: + img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32) + else: + L = noise_level2 / 255. + D = np.diag(np.random.rand(3)) + U = orth(np.random.rand(3, 3)) + conv = np.dot(np.dot(np.transpose(U), D), U) + img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32) + img = np.clip(img, 0.0, 1.0) + return img + + +def add_Poisson_noise(img): + img = np.clip((img * 255.0).round(), 0, 255) / 255. + vals = 10 ** (2 * random.random() + 2.0) # [2, 4] + if random.random() < 0.5: + img = np.random.poisson(img * vals).astype(np.float32) / vals + else: + img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114]) + img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. + noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray + img += noise_gray[:, :, np.newaxis] + img = np.clip(img, 0.0, 1.0) + return img + + +def add_JPEG_noise(img): + quality_factor = random.randint(80, 95) + img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) + result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) + img = cv2.imdecode(encimg, 1) + img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) + return img + + +def random_crop(lq, hq, sf=4, lq_patchsize=64): + h, w = lq.shape[:2] + rnd_h = random.randint(0, h - lq_patchsize) + rnd_w = random.randint(0, w - lq_patchsize) + lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] + + rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) + hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :] + return lq, hq + + +def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = img.shape[:2] + img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = img.shape[:2] + + if h < lq_patchsize * sf or w < lq_patchsize * sf: + raise ValueError(f'img size ({h1}X{w1}) is too small!') + + hq = img.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + img = util.imresize_np(img, 1 / 2, True) + img = np.clip(img, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + img = add_blur(img, sf=sf) + + elif i == 1: + img = add_blur(img, sf=sf) + + elif i == 2: + a, b = img.shape[1], img.shape[0] + # downsample2 + if random.random() < 0.75: + sf1 = random.uniform(1, 2 * sf) + img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') + img = img[0::sf, 0::sf, ...] # nearest downsampling + img = np.clip(img, 0.0, 1.0) + + elif i == 3: + # downsample3 + img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + img = np.clip(img, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + img = add_JPEG_noise(img) + + elif i == 6: + # add processed camera sensor noise + if random.random() < isp_prob and isp_model is not None: + with torch.no_grad(): + img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + img = add_JPEG_noise(img) + + # random crop + img, hq = random_crop(img, hq, sf_ori, lq_patchsize) + + return img, hq + + +# todo no isp_model? +def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False): + """ + This is the degradation model of BSRGAN from the paper + "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" + ---------- + sf: scale factor + isp_model: camera ISP model + Returns + ------- + img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] + hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] + """ + image = util.uint2single(image) + isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 + sf_ori = sf + + h1, w1 = image.shape[:2] + image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop + h, w = image.shape[:2] + + hq = image.copy() + + if sf == 4 and random.random() < scale2_prob: # downsample1 + if np.random.rand() < 0.5: + image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + image = util.imresize_np(image, 1 / 2, True) + image = np.clip(image, 0.0, 1.0) + sf = 2 + + shuffle_order = random.sample(range(7), 7) + idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) + if idx1 > idx2: # keep downsample3 last + shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] + + for i in shuffle_order: + + if i == 0: + image = add_blur(image, sf=sf) + + # elif i == 1: + # image = add_blur(image, sf=sf) + + if i == 0: + pass + + elif i == 2: + a, b = image.shape[1], image.shape[0] + # downsample2 + if random.random() < 0.8: + sf1 = random.uniform(1, 2 * sf) + image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])), + interpolation=random.choice([1, 2, 3])) + else: + k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf)) + k_shifted = shift_pixel(k, sf) + k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel + image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror') + image = image[0::sf, 0::sf, ...] # nearest downsampling + + image = np.clip(image, 0.0, 1.0) + + elif i == 3: + # downsample3 + image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3])) + image = np.clip(image, 0.0, 1.0) + + elif i == 4: + # add Gaussian noise + image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2) + + elif i == 5: + # add JPEG noise + if random.random() < jpeg_prob: + image = add_JPEG_noise(image) + # + # elif i == 6: + # # add processed camera sensor noise + # if random.random() < isp_prob and isp_model is not None: + # with torch.no_grad(): + # img, hq = isp_model.forward(img.copy(), hq) + + # add final JPEG compression noise + image = add_JPEG_noise(image) + image = util.single2uint(image) + if up: + image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then + example = {"image": image} + return example + + + + +if __name__ == '__main__': + print("hey") + img = util.imread_uint('utils/test.png', 3) + img = img[:448, :448] + h = img.shape[0] // 4 + print("resizing to", h) + sf = 4 + deg_fn = partial(degradation_bsrgan_variant, sf=sf) + for i in range(20): + print(i) + img_hq = img + img_lq = deg_fn(img)["image"] + img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq) + print(img_lq) + img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"] + print(img_lq.shape) + print("bicubic", img_lq_bicubic.shape) + print(img_hq.shape) + lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), + (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])), + interpolation=0) + img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1) + util.imsave(img_concat, str(i) + '.png') diff --git a/repositories/ldm/modules/image_degradation/utils/test.png b/repositories/ldm/modules/image_degradation/utils/test.png new file mode 100644 index 000000000..4249b43de Binary files /dev/null and b/repositories/ldm/modules/image_degradation/utils/test.png differ diff --git a/repositories/ldm/modules/image_degradation/utils_image.py b/repositories/ldm/modules/image_degradation/utils_image.py new file mode 100644 index 000000000..0175f155a --- /dev/null +++ b/repositories/ldm/modules/image_degradation/utils_image.py @@ -0,0 +1,916 @@ +import os +import math +import random +import numpy as np +import torch +import cv2 +from torchvision.utils import make_grid +from datetime import datetime +#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py + + +os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" + + +''' +# -------------------------------------------- +# Kai Zhang (github: https://github.com/cszn) +# 03/Mar/2019 +# -------------------------------------------- +# https://github.com/twhui/SRGAN-pyTorch +# https://github.com/xinntao/BasicSR +# -------------------------------------------- +''' + + +IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] + + +def is_image_file(filename): + return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) + + +def get_timestamp(): + return datetime.now().strftime('%y%m%d-%H%M%S') + + +def imshow(x, title=None, cbar=False, figsize=None): + plt.figure(figsize=figsize) + plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') + if title: + plt.title(title) + if cbar: + plt.colorbar() + plt.show() + + +def surf(Z, cmap='rainbow', figsize=None): + plt.figure(figsize=figsize) + ax3 = plt.axes(projection='3d') + + w, h = Z.shape[:2] + xx = np.arange(0,w,1) + yy = np.arange(0,h,1) + X, Y = np.meshgrid(xx, yy) + ax3.plot_surface(X,Y,Z,cmap=cmap) + #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) + plt.show() + + +''' +# -------------------------------------------- +# get image pathes +# -------------------------------------------- +''' + + +def get_image_paths(dataroot): + paths = None # return None if dataroot is None + if dataroot is not None: + paths = sorted(_get_paths_from_images(dataroot)) + return paths + + +def _get_paths_from_images(path): + assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) + images = [] + for dirpath, _, fnames in sorted(os.walk(path)): + for fname in sorted(fnames): + if is_image_file(fname): + img_path = os.path.join(dirpath, fname) + images.append(img_path) + assert images, '{:s} has no valid image file'.format(path) + return images + + +''' +# -------------------------------------------- +# split large images into small images +# -------------------------------------------- +''' + + +def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): + w, h = img.shape[:2] + patches = [] + if w > p_max and h > p_max: + w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) + h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) + w1.append(w-p_size) + h1.append(h-p_size) +# print(w1) +# print(h1) + for i in w1: + for j in h1: + patches.append(img[i:i+p_size, j:j+p_size,:]) + else: + patches.append(img) + + return patches + + +def imssave(imgs, img_path): + """ + imgs: list, N images of size WxHxC + """ + img_name, ext = os.path.splitext(os.path.basename(img_path)) + + for i, img in enumerate(imgs): + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') + cv2.imwrite(new_path, img) + + +def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): + """ + split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), + and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) + will be splitted. + Args: + original_dataroot: + taget_dataroot: + p_size: size of small images + p_overlap: patch size in training is a good choice + p_max: images with smaller size than (p_max)x(p_max) keep unchanged. + """ + paths = get_image_paths(original_dataroot) + for img_path in paths: + # img_name, ext = os.path.splitext(os.path.basename(img_path)) + img = imread_uint(img_path, n_channels=n_channels) + patches = patches_from_image(img, p_size, p_overlap, p_max) + imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) + #if original_dataroot == taget_dataroot: + #del img_path + +''' +# -------------------------------------------- +# makedir +# -------------------------------------------- +''' + + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + + +def mkdirs(paths): + if isinstance(paths, str): + mkdir(paths) + else: + for path in paths: + mkdir(path) + + +def mkdir_and_rename(path): + if os.path.exists(path): + new_name = path + '_archived_' + get_timestamp() + print('Path already exists. Rename it to [{:s}]'.format(new_name)) + os.rename(path, new_name) + os.makedirs(path) + + +''' +# -------------------------------------------- +# read image from path +# opencv is fast, but read BGR numpy image +# -------------------------------------------- +''' + + +# -------------------------------------------- +# get uint8 image of size HxWxn_channles (RGB) +# -------------------------------------------- +def imread_uint(path, n_channels=3): + # input: path + # output: HxWx3(RGB or GGG), or HxWx1 (G) + if n_channels == 1: + img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE + img = np.expand_dims(img, axis=2) # HxWx1 + elif n_channels == 3: + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG + else: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB + return img + + +# -------------------------------------------- +# matlab's imwrite +# -------------------------------------------- +def imsave(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + +def imwrite(img, img_path): + img = np.squeeze(img) + if img.ndim == 3: + img = img[:, :, [2, 1, 0]] + cv2.imwrite(img_path, img) + + + +# -------------------------------------------- +# get single image of size HxWxn_channles (BGR) +# -------------------------------------------- +def read_img(path): + # read image by cv2 + # return: Numpy float32, HWC, BGR, [0,1] + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE + img = img.astype(np.float32) / 255. + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + # some images have 4 channels + if img.shape[2] > 3: + img = img[:, :, :3] + return img + + +''' +# -------------------------------------------- +# image format conversion +# -------------------------------------------- +# numpy(single) <---> numpy(unit) +# numpy(single) <---> tensor +# numpy(unit) <---> tensor +# -------------------------------------------- +''' + + +# -------------------------------------------- +# numpy(single) [0, 1] <---> numpy(unit) +# -------------------------------------------- + + +def uint2single(img): + + return np.float32(img/255.) + + +def single2uint(img): + + return np.uint8((img.clip(0, 1)*255.).round()) + + +def uint162single(img): + + return np.float32(img/65535.) + + +def single2uint16(img): + + return np.uint16((img.clip(0, 1)*65535.).round()) + + +# -------------------------------------------- +# numpy(unit) (HxWxC or HxW) <---> tensor +# -------------------------------------------- + + +# convert uint to 4-dimensional torch tensor +def uint2tensor4(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) + + +# convert uint to 3-dimensional torch tensor +def uint2tensor3(img): + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) + + +# convert 2/3/4-dimensional torch tensor to uint +def tensor2uint(img): + img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + return np.uint8((img*255.0).round()) + + +# -------------------------------------------- +# numpy(single) (HxWxC) <---> tensor +# -------------------------------------------- + + +# convert single (HxWxC) to 3-dimensional torch tensor +def single2tensor3(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() + + +# convert single (HxWxC) to 4-dimensional torch tensor +def single2tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) + + +# convert torch tensor to single +def tensor2single(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + + return img + +# convert torch tensor to single +def tensor2single3(img): + img = img.data.squeeze().float().cpu().numpy() + if img.ndim == 3: + img = np.transpose(img, (1, 2, 0)) + elif img.ndim == 2: + img = np.expand_dims(img, axis=2) + return img + + +def single2tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) + + +def single32tensor5(img): + return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) + + +def single42tensor4(img): + return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() + + +# from skimage.io import imread, imsave +def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): + ''' + Converts a torch Tensor into an image Numpy array of BGR channel order + Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order + Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) + ''' + tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp + tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] + n_dim = tensor.dim() + if n_dim == 4: + n_img = len(tensor) + img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 3: + img_np = tensor.numpy() + img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR + elif n_dim == 2: + img_np = tensor.numpy() + else: + raise TypeError( + 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) + if out_type == np.uint8: + img_np = (img_np * 255.0).round() + # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. + return img_np.astype(out_type) + + +''' +# -------------------------------------------- +# Augmentation, flipe and/or rotate +# -------------------------------------------- +# The following two are enough. +# (1) augmet_img: numpy image of WxHxC or WxH +# (2) augment_img_tensor4: tensor image 1xCxWxH +# -------------------------------------------- +''' + + +def augment_img(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return np.flipud(np.rot90(img)) + elif mode == 2: + return np.flipud(img) + elif mode == 3: + return np.rot90(img, k=3) + elif mode == 4: + return np.flipud(np.rot90(img, k=2)) + elif mode == 5: + return np.rot90(img) + elif mode == 6: + return np.rot90(img, k=2) + elif mode == 7: + return np.flipud(np.rot90(img, k=3)) + + +def augment_img_tensor4(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + if mode == 0: + return img + elif mode == 1: + return img.rot90(1, [2, 3]).flip([2]) + elif mode == 2: + return img.flip([2]) + elif mode == 3: + return img.rot90(3, [2, 3]) + elif mode == 4: + return img.rot90(2, [2, 3]).flip([2]) + elif mode == 5: + return img.rot90(1, [2, 3]) + elif mode == 6: + return img.rot90(2, [2, 3]) + elif mode == 7: + return img.rot90(3, [2, 3]).flip([2]) + + +def augment_img_tensor(img, mode=0): + '''Kai Zhang (github: https://github.com/cszn) + ''' + img_size = img.size() + img_np = img.data.cpu().numpy() + if len(img_size) == 3: + img_np = np.transpose(img_np, (1, 2, 0)) + elif len(img_size) == 4: + img_np = np.transpose(img_np, (2, 3, 1, 0)) + img_np = augment_img(img_np, mode=mode) + img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) + if len(img_size) == 3: + img_tensor = img_tensor.permute(2, 0, 1) + elif len(img_size) == 4: + img_tensor = img_tensor.permute(3, 2, 0, 1) + + return img_tensor.type_as(img) + + +def augment_img_np3(img, mode=0): + if mode == 0: + return img + elif mode == 1: + return img.transpose(1, 0, 2) + elif mode == 2: + return img[::-1, :, :] + elif mode == 3: + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 4: + return img[:, ::-1, :] + elif mode == 5: + img = img[:, ::-1, :] + img = img.transpose(1, 0, 2) + return img + elif mode == 6: + img = img[:, ::-1, :] + img = img[::-1, :, :] + return img + elif mode == 7: + img = img[:, ::-1, :] + img = img[::-1, :, :] + img = img.transpose(1, 0, 2) + return img + + +def augment_imgs(img_list, hflip=True, rot=True): + # horizontal flip OR rotate + hflip = hflip and random.random() < 0.5 + vflip = rot and random.random() < 0.5 + rot90 = rot and random.random() < 0.5 + + def _augment(img): + if hflip: + img = img[:, ::-1, :] + if vflip: + img = img[::-1, :, :] + if rot90: + img = img.transpose(1, 0, 2) + return img + + return [_augment(img) for img in img_list] + + +''' +# -------------------------------------------- +# modcrop and shave +# -------------------------------------------- +''' + + +def modcrop(img_in, scale): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + if img.ndim == 2: + H, W = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r] + elif img.ndim == 3: + H, W, C = img.shape + H_r, W_r = H % scale, W % scale + img = img[:H - H_r, :W - W_r, :] + else: + raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) + return img + + +def shave(img_in, border=0): + # img_in: Numpy, HWC or HW + img = np.copy(img_in) + h, w = img.shape[:2] + img = img[border:h-border, border:w-border] + return img + + +''' +# -------------------------------------------- +# image processing process on numpy image +# channel_convert(in_c, tar_type, img_list): +# rgb2ycbcr(img, only_y=True): +# bgr2ycbcr(img, only_y=True): +# ycbcr2rgb(img): +# -------------------------------------------- +''' + + +def rgb2ycbcr(img, only_y=True): + '''same as matlab rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], + [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def ycbcr2rgb(img): + '''same as matlab ycbcr2rgb + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], + [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def bgr2ycbcr(img, only_y=True): + '''bgr version of rgb2ycbcr + only_y: only return Y channel + Input: + uint8, [0, 255] + float, [0, 1] + ''' + in_img_type = img.dtype + img.astype(np.float32) + if in_img_type != np.uint8: + img *= 255. + # convert + if only_y: + rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 + else: + rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], + [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] + if in_img_type == np.uint8: + rlt = rlt.round() + else: + rlt /= 255. + return rlt.astype(in_img_type) + + +def channel_convert(in_c, tar_type, img_list): + # conversion among BGR, gray and y + if in_c == 3 and tar_type == 'gray': # BGR to gray + gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] + return [np.expand_dims(img, axis=2) for img in gray_list] + elif in_c == 3 and tar_type == 'y': # BGR to y + y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] + return [np.expand_dims(img, axis=2) for img in y_list] + elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR + return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] + else: + return img_list + + +''' +# -------------------------------------------- +# metric, PSNR and SSIM +# -------------------------------------------- +''' + + +# -------------------------------------------- +# PSNR +# -------------------------------------------- +def calculate_psnr(img1, img2, border=0): + # img1 and img2 have range [0, 255] + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20 * math.log10(255.0 / math.sqrt(mse)) + + +# -------------------------------------------- +# SSIM +# -------------------------------------------- +def calculate_ssim(img1, img2, border=0): + '''calculate SSIM + the same outputs as MATLAB's + img1, img2: [0, 255] + ''' + #img1 = img1.squeeze() + #img2 = img2.squeeze() + if not img1.shape == img2.shape: + raise ValueError('Input images must have the same dimensions.') + h, w = img1.shape[:2] + img1 = img1[border:h-border, border:w-border] + img2 = img2[border:h-border, border:w-border] + + if img1.ndim == 2: + return ssim(img1, img2) + elif img1.ndim == 3: + if img1.shape[2] == 3: + ssims = [] + for i in range(3): + ssims.append(ssim(img1[:,:,i], img2[:,:,i])) + return np.array(ssims).mean() + elif img1.shape[2] == 1: + return ssim(np.squeeze(img1), np.squeeze(img2)) + else: + raise ValueError('Wrong input image dimensions.') + + +def ssim(img1, img2): + C1 = (0.01 * 255)**2 + C2 = (0.03 * 255)**2 + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + kernel = cv2.getGaussianKernel(11, 1.5) + window = np.outer(kernel, kernel.transpose()) + + mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid + mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] + mu1_sq = mu1**2 + mu2_sq = mu2**2 + mu1_mu2 = mu1 * mu2 + sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq + sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq + sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * + (sigma1_sq + sigma2_sq + C2)) + return ssim_map.mean() + + +''' +# -------------------------------------------- +# matlab's bicubic imresize (numpy and torch) [0, 1] +# -------------------------------------------- +''' + + +# matlab 'imresize' function, now only support 'bicubic' +def cubic(x): + absx = torch.abs(x) + absx2 = absx**2 + absx3 = absx**3 + return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ + (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) + + +def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): + if (scale < 1) and (antialiasing): + # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width + kernel_width = kernel_width / scale + + # Output-space coordinates + x = torch.linspace(1, out_length, out_length) + + # Input-space coordinates. Calculate the inverse mapping such that 0.5 + # in output space maps to 0.5 in input space, and 0.5+scale in output + # space maps to 1.5 in input space. + u = x / scale + 0.5 * (1 - 1 / scale) + + # What is the left-most pixel that can be involved in the computation? + left = torch.floor(u - kernel_width / 2) + + # What is the maximum number of pixels that can be involved in the + # computation? Note: it's OK to use an extra pixel here; if the + # corresponding weights are all zero, it will be eliminated at the end + # of this function. + P = math.ceil(kernel_width) + 2 + + # The indices of the input pixels involved in computing the k-th output + # pixel are in row k of the indices matrix. + indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( + 1, P).expand(out_length, P) + + # The weights used to compute the k-th output pixel are in row k of the + # weights matrix. + distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices + # apply cubic kernel + if (scale < 1) and (antialiasing): + weights = scale * cubic(distance_to_center * scale) + else: + weights = cubic(distance_to_center) + # Normalize the weights matrix so that each row sums to 1. + weights_sum = torch.sum(weights, 1).view(out_length, 1) + weights = weights / weights_sum.expand(out_length, P) + + # If a column in weights is all zero, get rid of it. only consider the first and last column. + weights_zero_tmp = torch.sum((weights == 0), 0) + if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): + indices = indices.narrow(1, 1, P - 2) + weights = weights.narrow(1, 1, P - 2) + if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): + indices = indices.narrow(1, 0, P - 2) + weights = weights.narrow(1, 0, P - 2) + weights = weights.contiguous() + indices = indices.contiguous() + sym_len_s = -indices.min() + 1 + sym_len_e = indices.max() - in_length + indices = indices + sym_len_s - 1 + return weights, indices, int(sym_len_s), int(sym_len_e) + + +# -------------------------------------------- +# imresize for tensor image [0, 1] +# -------------------------------------------- +def imresize(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: pytorch tensor, CHW or HW [0,1] + # output: CHW or HW [0,1] w/o round + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(0) + in_C, in_H, in_W = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) + img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:, :sym_len_Hs, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[:, -sym_len_He:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(in_C, out_H, in_W) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) + out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :, :sym_len_Ws] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, :, -sym_len_We:] + inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(2, inv_idx) + out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(in_C, out_H, out_W) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + return out_2 + + +# -------------------------------------------- +# imresize for numpy image [0, 1] +# -------------------------------------------- +def imresize_np(img, scale, antialiasing=True): + # Now the scale should be the same for H and W + # input: img: Numpy, HWC or HW [0,1] + # output: HWC or HW [0,1] w/o round + img = torch.from_numpy(img) + need_squeeze = True if img.dim() == 2 else False + if need_squeeze: + img.unsqueeze_(2) + + in_H, in_W, in_C = img.size() + out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) + kernel_width = 4 + kernel = 'cubic' + + # Return the desired dimension order for performing the resize. The + # strategy is to perform the resize first along the dimension with the + # smallest scale factor. + # Now we do not support this. + + # get weights and indices + weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( + in_H, out_H, scale, kernel, kernel_width, antialiasing) + weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( + in_W, out_W, scale, kernel, kernel_width, antialiasing) + # process H dimension + # symmetric copying + img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) + img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) + + sym_patch = img[:sym_len_Hs, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) + + sym_patch = img[-sym_len_He:, :, :] + inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(0, inv_idx) + img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) + + out_1 = torch.FloatTensor(out_H, in_W, in_C) + kernel_width = weights_H.size(1) + for i in range(out_H): + idx = int(indices_H[i][0]) + for j in range(out_C): + out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) + + # process W dimension + # symmetric copying + out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) + out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) + + sym_patch = out_1[:, :sym_len_Ws, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) + + sym_patch = out_1[:, -sym_len_We:, :] + inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() + sym_patch_inv = sym_patch.index_select(1, inv_idx) + out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) + + out_2 = torch.FloatTensor(out_H, out_W, in_C) + kernel_width = weights_W.size(1) + for i in range(out_W): + idx = int(indices_W[i][0]) + for j in range(out_C): + out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) + if need_squeeze: + out_2.squeeze_() + + return out_2.numpy() + + +if __name__ == '__main__': + print('---') +# img = imread_uint('test.bmp', 3) +# img = uint2single(img) +# img_bicubic = imresize_np(img, 1/4) \ No newline at end of file diff --git a/repositories/ldm/modules/karlo/__init__.py b/repositories/ldm/modules/karlo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/karlo/diffusers_pipeline.py b/repositories/ldm/modules/karlo/diffusers_pipeline.py new file mode 100644 index 000000000..07f72b35a --- /dev/null +++ b/repositories/ldm/modules/karlo/diffusers_pipeline.py @@ -0,0 +1,512 @@ +# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import functional as F + +from transformers import CLIPTextModelWithProjection, CLIPTokenizer +from transformers.models.clip.modeling_clip import CLIPTextModelOutput + +from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel +from ...pipelines import DiffusionPipeline, ImagePipelineOutput +from ...schedulers import UnCLIPScheduler +from ...utils import is_accelerate_available, logging, randn_tensor +from .text_proj import UnCLIPTextProjModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class UnCLIPPipeline(DiffusionPipeline): + """ + Pipeline for text-to-image generation using unCLIP + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + Args: + text_encoder ([`CLIPTextModelWithProjection`]): + Frozen text-encoder. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + prior ([`PriorTransformer`]): + The canonincal unCLIP prior to approximate the image embedding from the text embedding. + text_proj ([`UnCLIPTextProjModel`]): + Utility class to prepare and combine the embeddings before they are passed to the decoder. + decoder ([`UNet2DConditionModel`]): + The decoder to invert the image embedding into an image. + super_res_first ([`UNet2DModel`]): + Super resolution unet. Used in all but the last step of the super resolution diffusion process. + super_res_last ([`UNet2DModel`]): + Super resolution unet. Used in the last step of the super resolution diffusion process. + prior_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the prior denoising process. Just a modified DDPMScheduler. + decoder_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the decoder denoising process. Just a modified DDPMScheduler. + super_res_scheduler ([`UnCLIPScheduler`]): + Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler. + """ + + prior: PriorTransformer + decoder: UNet2DConditionModel + text_proj: UnCLIPTextProjModel + text_encoder: CLIPTextModelWithProjection + tokenizer: CLIPTokenizer + super_res_first: UNet2DModel + super_res_last: UNet2DModel + + prior_scheduler: UnCLIPScheduler + decoder_scheduler: UnCLIPScheduler + super_res_scheduler: UnCLIPScheduler + + def __init__( + self, + prior: PriorTransformer, + decoder: UNet2DConditionModel, + text_encoder: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + text_proj: UnCLIPTextProjModel, + super_res_first: UNet2DModel, + super_res_last: UNet2DModel, + prior_scheduler: UnCLIPScheduler, + decoder_scheduler: UnCLIPScheduler, + super_res_scheduler: UnCLIPScheduler, + ): + super().__init__() + + self.register_modules( + prior=prior, + decoder=decoder, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_proj=text_proj, + super_res_first=super_res_first, + super_res_last=super_res_last, + prior_scheduler=prior_scheduler, + decoder_scheduler=decoder_scheduler, + super_res_scheduler=super_res_scheduler, + ) + + def prepare_latents(self, shape, dtype, device, generator, latents, scheduler): + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + latents = latents * scheduler.init_noise_sigma + return latents + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + ): + if text_model_output is None: + batch_size = len(prompt) if isinstance(prompt, list) else 1 + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_mask = text_inputs.attention_mask.bool().to(device) + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + + text_encoder_output = self.text_encoder(text_input_ids.to(device)) + + text_embeddings = text_encoder_output.text_embeds + text_encoder_hidden_states = text_encoder_output.last_hidden_state + + else: + batch_size = text_model_output[0].shape[0] + text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1] + text_mask = text_attention_mask + + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + uncond_tokens = [""] * batch_size + + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + uncond_text_mask = uncond_input.attention_mask.bool().to(device) + uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device)) + + uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds + uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len) + + seq_len = uncond_text_encoder_hidden_states.shape[1] + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1) + uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0) + + # done duplicates + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states]) + + text_mask = torch.cat([uncond_text_mask, text_mask]) + + return text_embeddings, text_encoder_hidden_states, text_mask + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's + models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only + when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list + models = [ + self.decoder, + self.text_proj, + self.text_encoder, + self.super_res_first, + self.super_res_last, + ] + for cpu_offloaded_model in models: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"): + return self.device + for module in self.decoder.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + @torch.no_grad() + def __call__( + self, + prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: int = 1, + prior_num_inference_steps: int = 25, + decoder_num_inference_steps: int = 25, + super_res_num_inference_steps: int = 7, + generator: Optional[torch.Generator] = None, + prior_latents: Optional[torch.FloatTensor] = None, + decoder_latents: Optional[torch.FloatTensor] = None, + super_res_latents: Optional[torch.FloatTensor] = None, + text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None, + text_attention_mask: Optional[torch.Tensor] = None, + prior_guidance_scale: float = 4.0, + decoder_guidance_scale: float = 8.0, + output_type: Optional[str] = "pil", + return_dict: bool = True, + ): + """ + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. This can only be left undefined if + `text_model_output` and `text_attention_mask` is passed. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + prior_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the prior. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + decoder_num_inference_steps (`int`, *optional*, defaults to 25): + The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality + image at the expense of slower inference. + super_res_num_inference_steps (`int`, *optional*, defaults to 7): + The number of denoising steps for super resolution. More denoising steps usually lead to a higher + quality image at the expense of slower inference. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*): + Pre-generated noisy latents to be used as inputs for the prior. + decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*): + Pre-generated noisy latents to be used as inputs for the decoder. + prior_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + decoder_guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + text_model_output (`CLIPTextModelOutput`, *optional*): + Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs + can be passed for tasks like text embedding interpolations. Make sure to also pass + `text_attention_mask` in this case. `prompt` can the be left to `None`. + text_attention_mask (`torch.Tensor`, *optional*): + Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention + masks are necessary when passing `text_model_output`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. + """ + if prompt is not None: + if isinstance(prompt, str): + batch_size = 1 + elif isinstance(prompt, list): + batch_size = len(prompt) + else: + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + else: + batch_size = text_model_output[0].shape[0] + + device = self._execution_device + + batch_size = batch_size * num_images_per_prompt + + do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0 + + text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt( + prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask + ) + + # prior + + self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device) + prior_timesteps_tensor = self.prior_scheduler.timesteps + + embedding_dim = self.prior.config.embedding_dim + + prior_latents = self.prepare_latents( + (batch_size, embedding_dim), + text_embeddings.dtype, + device, + generator, + prior_latents, + self.prior_scheduler, + ) + + for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents + + predicted_image_embedding = self.prior( + latent_model_input, + timestep=t, + proj_embedding=text_embeddings, + encoder_hidden_states=text_encoder_hidden_states, + attention_mask=text_mask, + ).predicted_image_embedding + + if do_classifier_free_guidance: + predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2) + predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * ( + predicted_image_embedding_text - predicted_image_embedding_uncond + ) + + if i + 1 == prior_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = prior_timesteps_tensor[i + 1] + + prior_latents = self.prior_scheduler.step( + predicted_image_embedding, + timestep=t, + sample=prior_latents, + generator=generator, + prev_timestep=prev_timestep, + ).prev_sample + + prior_latents = self.prior.post_process_latents(prior_latents) + + image_embeddings = prior_latents + + # done prior + + # decoder + + text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj( + image_embeddings=image_embeddings, + text_embeddings=text_embeddings, + text_encoder_hidden_states=text_encoder_hidden_states, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + + decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1) + + self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device) + decoder_timesteps_tensor = self.decoder_scheduler.timesteps + + num_channels_latents = self.decoder.in_channels + height = self.decoder.sample_size + width = self.decoder.sample_size + + decoder_latents = self.prepare_latents( + (batch_size, num_channels_latents, height, width), + text_encoder_hidden_states.dtype, + device, + generator, + decoder_latents, + self.decoder_scheduler, + ) + + for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents + + noise_pred = self.decoder( + sample=latent_model_input, + timestep=t, + encoder_hidden_states=text_encoder_hidden_states, + class_labels=additive_clip_time_embeddings, + attention_mask=decoder_text_mask, + ).sample + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1) + noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1) + noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) + + if i + 1 == decoder_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = decoder_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + decoder_latents = self.decoder_scheduler.step( + noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + decoder_latents = decoder_latents.clamp(-1, 1) + + image_small = decoder_latents + + # done decoder + + # super res + + self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device) + super_res_timesteps_tensor = self.super_res_scheduler.timesteps + + channels = self.super_res_first.in_channels // 2 + height = self.super_res_first.sample_size + width = self.super_res_first.sample_size + + super_res_latents = self.prepare_latents( + (batch_size, channels, height, width), + image_small.dtype, + device, + generator, + super_res_latents, + self.super_res_scheduler, + ) + + interpolate_antialias = {} + if "antialias" in inspect.signature(F.interpolate).parameters: + interpolate_antialias["antialias"] = True + + image_upscaled = F.interpolate( + image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias + ) + + for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)): + # no classifier free guidance + + if i == super_res_timesteps_tensor.shape[0] - 1: + unet = self.super_res_last + else: + unet = self.super_res_first + + latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1) + + noise_pred = unet( + sample=latent_model_input, + timestep=t, + ).sample + + if i + 1 == super_res_timesteps_tensor.shape[0]: + prev_timestep = None + else: + prev_timestep = super_res_timesteps_tensor[i + 1] + + # compute the previous noisy sample x_t -> x_t-1 + super_res_latents = self.super_res_scheduler.step( + noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator + ).prev_sample + + image = super_res_latents + # done super res + + # post processing + + image = image * 0.5 + 0.5 + image = image.clamp(0, 1) + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + + if output_type == "pil": + image = self.numpy_to_pil(image) + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file diff --git a/repositories/ldm/modules/karlo/kakao/__init__.py b/repositories/ldm/modules/karlo/kakao/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/karlo/kakao/models/__init__.py b/repositories/ldm/modules/karlo/kakao/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/karlo/kakao/models/clip.py b/repositories/ldm/modules/karlo/kakao/models/clip.py new file mode 100644 index 000000000..961d81502 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/models/clip.py @@ -0,0 +1,182 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ +# ------------------------------------------------------------------------------------ +# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/) +# ------------------------------------------------------------------------------------ + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import clip + +from clip.model import CLIP, convert_weights +from clip.simple_tokenizer import SimpleTokenizer, default_bpe + + +"""===== Monkey-Patching original CLIP for JIT compile =====""" + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = F.layer_norm( + x.type(torch.float32), + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ) + return ret.type(orig_type) + + +clip.model.LayerNorm = LayerNorm +delattr(clip.model.CLIP, "forward") + +"""===== End of Monkey-Patching =====""" + + +class CustomizedCLIP(CLIP): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.jit.export + def encode_image(self, image): + return self.visual(image) + + @torch.jit.export + def encode_text(self, text): + # re-define this function to return unpooled text features + + x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + + x_seq = x + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x_out, x_seq + + @torch.jit.ignore + def forward(self, image, text): + super().forward(image, text) + + @classmethod + def load_from_checkpoint(cls, ckpt_path: str): + state_dict = torch.load(ckpt_path, map_location="cpu").state_dict() + + vit = "visual.proj" in state_dict + if vit: + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [ + k + for k in state_dict.keys() + if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") + ] + ) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round( + (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 + ) + image_resolution = vision_patch_size * grid_size + else: + counts: list = [ + len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"visual.layer{b}") + ) + ) + for b in [1, 2, 3, 4] + ] + vision_layers = tuple(counts) + vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] + output_width = round( + (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 + ) + vision_patch_size = None + assert ( + output_width**2 + 1 + == state_dict["visual.attnpool.positional_embedding"].shape[0] + ) + image_resolution = output_width * 32 + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith("transformer.resblocks") + ) + ) + + model = cls( + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + ) + + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in state_dict: + del state_dict[key] + + convert_weights(model) + model.load_state_dict(state_dict) + model.eval() + model.float() + return model + + +class CustomizedTokenizer(SimpleTokenizer): + def __init__(self): + super().__init__(bpe_path=default_bpe()) + + self.sot_token = self.encoder["<|startoftext|>"] + self.eot_token = self.encoder["<|endoftext|>"] + + def padded_tokens_and_mask(self, texts, text_ctx): + assert isinstance(texts, list) and all( + isinstance(elem, str) for elem in texts + ), "texts should be a list of strings" + + all_tokens = [ + [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts + ] + + mask = [ + [True] * min(text_ctx, len(tokens)) + + [False] * max(text_ctx - len(tokens), 0) + for tokens in all_tokens + ] + mask = torch.tensor(mask, dtype=torch.bool) + result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int) + for i, tokens in enumerate(all_tokens): + if len(tokens) > text_ctx: + tokens = tokens[:text_ctx] + tokens[-1] = self.eot_token + result[i, : len(tokens)] = torch.tensor(tokens) + + return result, mask diff --git a/repositories/ldm/modules/karlo/kakao/models/decoder_model.py b/repositories/ldm/modules/karlo/kakao/models/decoder_model.py new file mode 100644 index 000000000..84e96c9b2 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/models/decoder_model.py @@ -0,0 +1,193 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import copy +import torch + +from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion +from ldm.modules.karlo.kakao.modules.unet import PLMImUNet + + +class Text2ImProgressiveModel(torch.nn.Module): + """ + A decoder that generates 64x64px images based on the text prompt. + + :param config: yaml config to define the decoder. + :param tokenizer: tokenizer used in clip. + """ + + def __init__( + self, + config, + tokenizer, + ): + super().__init__() + + self._conf = config + self._model_conf = config.model.hparams + self._diffusion_kwargs = dict( + steps=config.diffusion.steps, + learn_sigma=config.diffusion.learn_sigma, + sigma_small=config.diffusion.sigma_small, + noise_schedule=config.diffusion.noise_schedule, + use_kl=config.diffusion.use_kl, + predict_xstart=config.diffusion.predict_xstart, + rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, + timestep_respacing=config.diffusion.timestep_respacing, + ) + self._tokenizer = tokenizer + + self.model = self.create_plm_dec_model() + + cf_token, cf_mask = self.set_cf_text_tensor() + self.register_buffer("cf_token", cf_token, persistent=False) + self.register_buffer("cf_mask", cf_mask, persistent=False) + + @classmethod + def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True): + ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + model = cls(config, tokenizer) + model.load_state_dict(ckpt, strict=strict) + return model + + def create_plm_dec_model(self): + image_size = self._model_conf.image_size + if self._model_conf.channel_mult == "": + if image_size == 256: + channel_mult = (1, 1, 2, 2, 4, 4) + elif image_size == 128: + channel_mult = (1, 1, 2, 3, 4) + elif image_size == 64: + channel_mult = (1, 2, 3, 4) + else: + raise ValueError(f"unsupported image size: {image_size}") + else: + channel_mult = tuple( + int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",") + ) + assert 2 ** (len(channel_mult) + 2) == image_size + + attention_ds = [] + for res in self._model_conf.attention_resolutions.split(","): + attention_ds.append(image_size // int(res)) + + return PLMImUNet( + text_ctx=self._model_conf.text_ctx, + xf_width=self._model_conf.xf_width, + in_channels=3, + model_channels=self._model_conf.num_channels, + out_channels=6 if self._model_conf.learn_sigma else 3, + num_res_blocks=self._model_conf.num_res_blocks, + attention_resolutions=tuple(attention_ds), + dropout=self._model_conf.dropout, + channel_mult=channel_mult, + num_heads=self._model_conf.num_heads, + num_head_channels=self._model_conf.num_head_channels, + num_heads_upsample=self._model_conf.num_heads_upsample, + use_scale_shift_norm=self._model_conf.use_scale_shift_norm, + resblock_updown=self._model_conf.resblock_updown, + clip_dim=self._model_conf.clip_dim, + clip_emb_mult=self._model_conf.clip_emb_mult, + clip_emb_type=self._model_conf.clip_emb_type, + clip_emb_drop=self._model_conf.clip_emb_drop, + ) + + def set_cf_text_tensor(self): + return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) + + def get_sample_fn(self, timestep_respacing): + use_ddim = timestep_respacing.startswith(("ddim", "fast")) + + diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) + diffusion_kwargs.update(timestep_respacing=timestep_respacing) + diffusion = create_gaussian_diffusion(**diffusion_kwargs) + sample_fn = ( + diffusion.ddim_sample_loop_progressive + if use_ddim + else diffusion.p_sample_loop_progressive + ) + + return sample_fn + + def forward( + self, + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat=None, + cf_guidance_scales=None, + timestep_respacing=None, + ): + # cfg should be enabled in inference + assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) + assert img_feat is not None + + bsz = txt_feat.shape[0] + img_sz = self._model_conf.image_size + + def guided_model_fn(x_t, ts, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.model(combined, ts, **kwargs) + eps, rest = model_out[:, :3], model_out[:, 3:] + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * ( + cond_eps - uncond_eps + ) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + cf_feat = self.model.cf_param.unsqueeze(0) + cf_feat = cf_feat.expand(bsz // 2, -1) + feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0) + + cond = { + "y": feat, + "txt_feat": txt_feat, + "txt_feat_seq": txt_feat_seq, + "mask": mask, + } + sample_fn = self.get_sample_fn(timestep_respacing) + sample_outputs = sample_fn( + guided_model_fn, + (bsz, 3, img_sz, img_sz), + noise=None, + device=txt_feat.device, + clip_denoised=True, + model_kwargs=cond, + ) + + for out in sample_outputs: + sample = out["sample"] + yield sample if cf_guidance_scales is None else sample[ + : sample.shape[0] // 2 + ] + + +class Text2ImModel(Text2ImProgressiveModel): + def forward( + self, + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat=None, + cf_guidance_scales=None, + timestep_respacing=None, + ): + last_out = None + for out in super().forward( + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat, + cf_guidance_scales, + timestep_respacing, + ): + last_out = out + return last_out diff --git a/repositories/ldm/modules/karlo/kakao/models/prior_model.py b/repositories/ldm/modules/karlo/kakao/models/prior_model.py new file mode 100644 index 000000000..03ef230d2 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/models/prior_model.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import copy +import torch + +from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion +from ldm.modules.karlo.kakao.modules.xf import PriorTransformer + + +class PriorDiffusionModel(torch.nn.Module): + """ + A prior that generates clip image feature based on the text prompt. + + :param config: yaml config to define the decoder. + :param tokenizer: tokenizer used in clip. + :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance). + :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance). + """ + + def __init__(self, config, tokenizer, clip_mean, clip_std): + super().__init__() + + self._conf = config + self._model_conf = config.model.hparams + self._diffusion_kwargs = dict( + steps=config.diffusion.steps, + learn_sigma=config.diffusion.learn_sigma, + sigma_small=config.diffusion.sigma_small, + noise_schedule=config.diffusion.noise_schedule, + use_kl=config.diffusion.use_kl, + predict_xstart=config.diffusion.predict_xstart, + rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, + timestep_respacing=config.diffusion.timestep_respacing, + ) + self._tokenizer = tokenizer + + self.register_buffer("clip_mean", clip_mean[None, :], persistent=False) + self.register_buffer("clip_std", clip_std[None, :], persistent=False) + + causal_mask = self.get_causal_mask() + self.register_buffer("causal_mask", causal_mask, persistent=False) + + self.model = PriorTransformer( + text_ctx=self._model_conf.text_ctx, + xf_width=self._model_conf.xf_width, + xf_layers=self._model_conf.xf_layers, + xf_heads=self._model_conf.xf_heads, + xf_final_ln=self._model_conf.xf_final_ln, + clip_dim=self._model_conf.clip_dim, + ) + + cf_token, cf_mask = self.set_cf_text_tensor() + self.register_buffer("cf_token", cf_token, persistent=False) + self.register_buffer("cf_mask", cf_mask, persistent=False) + + @classmethod + def load_from_checkpoint( + cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True + ): + ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + model = cls(config, tokenizer, clip_mean, clip_std) + model.load_state_dict(ckpt, strict=strict) + return model + + def set_cf_text_tensor(self): + return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx) + + def get_sample_fn(self, timestep_respacing): + use_ddim = timestep_respacing.startswith(("ddim", "fast")) + + diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) + diffusion_kwargs.update(timestep_respacing=timestep_respacing) + diffusion = create_gaussian_diffusion(**diffusion_kwargs) + sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop + + return sample_fn + + def get_causal_mask(self): + seq_len = self._model_conf.text_ctx + 4 + mask = torch.empty(seq_len, seq_len) + mask.fill_(float("-inf")) + mask.triu_(1) + mask = mask[None, ...] + return mask + + def forward( + self, + txt_feat, + txt_feat_seq, + mask, + cf_guidance_scales=None, + timestep_respacing=None, + denoised_fn=True, + ): + # cfg should be enabled in inference + assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0) + + bsz_ = txt_feat.shape[0] + bsz = bsz_ // 2 + + def guided_model_fn(x_t, ts, **kwargs): + half = x_t[: len(x_t) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.model(combined, ts, **kwargs) + eps, rest = ( + model_out[:, : int(x_t.shape[1])], + model_out[:, int(x_t.shape[1]) :], + ) + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * ( + cond_eps - uncond_eps + ) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + cond = { + "text_emb": txt_feat, + "text_enc": txt_feat_seq, + "mask": mask, + "causal_mask": self.causal_mask, + } + sample_fn = self.get_sample_fn(timestep_respacing) + sample = sample_fn( + guided_model_fn, + (bsz_, self.model.clip_dim), + noise=None, + device=txt_feat.device, + clip_denoised=False, + denoised_fn=lambda x: torch.clamp(x, -10, 10), + model_kwargs=cond, + ) + sample = (sample * self.clip_std) + self.clip_mean + + return sample[:bsz] diff --git a/repositories/ldm/modules/karlo/kakao/models/sr_256_1k.py b/repositories/ldm/modules/karlo/kakao/models/sr_256_1k.py new file mode 100644 index 000000000..1e874f6f1 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/models/sr_256_1k.py @@ -0,0 +1,10 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive + + +class SupRes256to1kProgressive(SupRes64to256Progressive): + pass # no difference currently diff --git a/repositories/ldm/modules/karlo/kakao/models/sr_64_256.py b/repositories/ldm/modules/karlo/kakao/models/sr_64_256.py new file mode 100644 index 000000000..32687afe3 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/models/sr_64_256.py @@ -0,0 +1,88 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import copy +import torch + +from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel +from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion + + +class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module): + """ + ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses. + In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model. + In the following additional one step, a seperate fine-tuned model recovers high-frequency details. + This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps. + """ + + def __init__(self, config): + super().__init__() + + self._config = config + self._diffusion_kwargs = dict( + steps=config.diffusion.steps, + learn_sigma=config.diffusion.learn_sigma, + sigma_small=config.diffusion.sigma_small, + noise_schedule=config.diffusion.noise_schedule, + use_kl=config.diffusion.use_kl, + predict_xstart=config.diffusion.predict_xstart, + rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas, + ) + + self.model_first_steps = SuperResUNetModel( + in_channels=3, # auto-changed to 6 inside the model + model_channels=config.model.hparams.channels, + out_channels=3, + num_res_blocks=config.model.hparams.depth, + attention_resolutions=(), # no attention + dropout=config.model.hparams.dropout, + channel_mult=config.model.hparams.channels_multiple, + resblock_updown=True, + use_middle_attention=False, + ) + self.model_last_step = SuperResUNetModel( + in_channels=3, # auto-changed to 6 inside the model + model_channels=config.model.hparams.channels, + out_channels=3, + num_res_blocks=config.model.hparams.depth, + attention_resolutions=(), # no attention + dropout=config.model.hparams.dropout, + channel_mult=config.model.hparams.channels_multiple, + resblock_updown=True, + use_middle_attention=False, + ) + + @classmethod + def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True): + ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] + + model = cls(config) + model.load_state_dict(ckpt, strict=strict) + return model + + def get_sample_fn(self, timestep_respacing): + diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs) + diffusion_kwargs.update(timestep_respacing=timestep_respacing) + diffusion = create_gaussian_diffusion(**diffusion_kwargs) + return diffusion.p_sample_loop_progressive_for_improved_sr + + def forward(self, low_res, timestep_respacing="7", **kwargs): + assert ( + timestep_respacing == "7" + ), "different respacing method may work, but no guaranteed" + + sample_fn = self.get_sample_fn(timestep_respacing) + sample_outputs = sample_fn( + self.model_first_steps, + self.model_last_step, + shape=low_res.shape, + clip_denoised=True, + model_kwargs=dict(low_res=low_res), + **kwargs, + ) + for x in sample_outputs: + sample = x["sample"] + yield sample diff --git a/repositories/ldm/modules/karlo/kakao/modules/__init__.py b/repositories/ldm/modules/karlo/kakao/modules/__init__.py new file mode 100644 index 000000000..11d4358a6 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/modules/__init__.py @@ -0,0 +1,49 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + + +from .diffusion import gaussian_diffusion as gd +from .diffusion.respace import ( + SpacedDiffusion, + space_timesteps, +) + + +def create_gaussian_diffusion( + steps, + learn_sigma, + sigma_small, + noise_schedule, + use_kl, + predict_xstart, + rescale_learned_sigmas, + timestep_respacing, +): + betas = gd.get_named_beta_schedule(noise_schedule, steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if not timestep_respacing: + timestep_respacing = [steps] + + return SpacedDiffusion( + use_timesteps=space_timesteps(steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type, + ) diff --git a/repositories/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py b/repositories/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py new file mode 100644 index 000000000..6a111aa09 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py @@ -0,0 +1,828 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +import enum +import math + +import numpy as np +import torch as th + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace( + beta_start, beta_end, warmup_time, dtype=np.float64 + ) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start**0.5, + beta_end**0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +class GaussianDiffusion(th.nn.Module): + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + ): + super(GaussianDiffusion, self).__init__() + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0) + assert alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod) + log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod) + sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) + sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + posterior_log_variance_clipped = np.log( + np.append(posterior_variance[1], posterior_variance[1:]) + ) + posterior_mean_coef1 = ( + betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + posterior_mean_coef2 = ( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ) + + self.register_buffer("betas", th.from_numpy(betas), persistent=False) + self.register_buffer( + "alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False + ) + self.register_buffer( + "alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False + ) + self.register_buffer( + "alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False + ) + + self.register_buffer( + "sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + th.from_numpy(sqrt_one_minus_alphas_cumprod), + persistent=False, + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", + th.from_numpy(log_one_minus_alphas_cumprod), + persistent=False, + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", + th.from_numpy(sqrt_recip_alphas_cumprod), + persistent=False, + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + th.from_numpy(sqrt_recipm1_alphas_cumprod), + persistent=False, + ) + + self.register_buffer( + "posterior_variance", th.from_numpy(posterior_variance), persistent=False + ) + self.register_buffer( + "posterior_log_variance_clipped", + th.from_numpy(posterior_log_variance_clipped), + persistent=False, + ) + self.register_buffer( + "posterior_mean_coef1", + th.from_numpy(posterior_mean_coef1), + persistent=False, + ) + self.register_buffer( + "posterior_mean_coef2", + th.from_numpy(posterior_mean_coef2), + persistent=False, + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + ) + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + model_kwargs=None, + **ignore_kwargs, + ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, C = x.shape[:2] + assert t.shape == (B,) + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, C * 2, *x.shape[2:]) + model_output, model_var_values = th.split(model_output, C, dim=1) + if self.model_var_type == ModelVarType.LEARNED: + model_log_variance = model_var_values + model_variance = th.exp(model_log_variance) + else: + min_log = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x.shape + ) + max_log = _extract_into_tensor(th.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + th.cat([self.posterior_variance[1][None], self.betas[1:]]), + th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + pred_xstart = process_xstart( + self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) + ) + model_mean = model_output + elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + model_mean, _, _ = self.q_posterior_mean_variance( + x_start=pred_xstart, x_t=x, t=t + ) + else: + raise NotImplementedError(self.model_mean_type) + + assert ( + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + ) + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = ( + p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + ) + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance( + x_start=out["pred_xstart"], x_t=x, t=t + ) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean( + cond_fn, out, x, t, model_kwargs=model_kwargs + ) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for idx, i in enumerate(indices): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def p_sample_loop_progressive_for_improved_sr( + self, + model, + model_aux, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Modified version of p_sample_loop_progressive for sampling from the improved sr model + """ + + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for idx, i in enumerate(indices): + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model_aux if len(indices) - 1 == idx else model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_next) + + th.sqrt(1 - alpha_bar_next) * eps + ) + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = arr.to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/repositories/ldm/modules/karlo/kakao/modules/diffusion/respace.py b/repositories/ldm/modules/karlo/kakao/modules/diffusion/respace.py new file mode 100644 index 000000000..70c808f8b --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/modules/diffusion/respace.py @@ -0,0 +1,112 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + + +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + elif section_counts == "fast27": + steps = space_timesteps(num_timesteps, "10,10,3,2,2") + # Help reduce DDIM artifacts from noisiest timesteps. + steps.remove(num_timesteps - 1) + steps.add(num_timesteps - 3) + return steps + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + timestep_map = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + timestep_map.append(i) + kwargs["betas"] = th.tensor(new_betas).numpy() + super().__init__(**kwargs) + self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False) + + def p_mean_variance(self, model, *args, **kwargs): + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + def wrapped(x, ts, **kwargs): + ts_cpu = ts.detach().to("cpu") + return model( + x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs + ) + + return wrapped diff --git a/repositories/ldm/modules/karlo/kakao/modules/nn.py b/repositories/ldm/modules/karlo/kakao/modules/nn.py new file mode 100644 index 000000000..2eef3f5a0 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/modules/nn.py @@ -0,0 +1,114 @@ +# ------------------------------------------------------------------------------------ +# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +import math + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, swish, eps=1e-5): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps) + self.swish = swish + + def forward(self, x): + y = super().forward(x.float()).to(x.dtype) + if self.swish == 1.0: + y = F.silu(y) + elif self.swish: + y = y * F.sigmoid(y * float(self.swish)) + return y + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def normalization(channels, swish=0.0): + """ + Make a standard normalization layer, with an optional swish activation. + + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(num_channels=channels, num_groups=32, swish=swish) + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = th.exp( + -math.log(max_period) + * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device) + / half + ) + args = timesteps[:, None].float() * freqs[None] + embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) + if dim % 2: + embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) diff --git a/repositories/ldm/modules/karlo/kakao/modules/resample.py b/repositories/ldm/modules/karlo/kakao/modules/resample.py new file mode 100644 index 000000000..485421aa4 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/modules/resample.py @@ -0,0 +1,68 @@ +# ------------------------------------------------------------------------------------ +# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +from abc import abstractmethod + +import torch as th + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(th.nn.Module): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / th.sum(w) + indices = p.multinomial(batch_size, replacement=True) + weights = 1 / (len(p) * p[indices]) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + super(UniformSampler, self).__init__() + self.diffusion = diffusion + self.register_buffer( + "_weights", th.ones([diffusion.num_timesteps]), persistent=False + ) + + def weights(self): + return self._weights diff --git a/repositories/ldm/modules/karlo/kakao/modules/unet.py b/repositories/ldm/modules/karlo/kakao/modules/unet.py new file mode 100644 index 000000000..c99d0b791 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/modules/unet.py @@ -0,0 +1,792 @@ +# ------------------------------------------------------------------------------------ +# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion) +# ------------------------------------------------------------------------------------ + +import math +from abc import abstractmethod + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .nn import ( + avg_pool_nd, + conv_nd, + linear, + normalization, + timestep_embedding, + zero_module, +) +from .xf import LayerNorm + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, encoder_out=None, mask=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, AttentionBlock): + x = layer(x, encoder_out, mask=mask) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, self.channels, self.out_channels, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels, swish=1.0), + nn.Identity(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization( + self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0 + ), + nn.SiLU() if use_scale_shift_norm else nn.Identity(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class ResBlockNoTimeEmbedding(nn.Module): + """ + A residual block without time embedding + + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + **kwargs, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + + self.in_layers = nn.Sequential( + normalization(channels, swish=1.0), + nn.Identity(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.out_layers = nn.Sequential( + normalization(self.out_channels, swish=1.0), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb=None): + """ + Apply the block to a Tensor, NOT conditioned on a timestep embedding. + + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + assert emb is None + + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + encoder_channels=None, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels, swish=0.0) + self.qkv = conv_nd(1, channels, channels * 3, 1) + self.attention = QKVAttention(self.num_heads) + + if encoder_channels is not None: + self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x, encoder_out=None, mask=None): + b, c, *spatial = x.shape + qkv = self.qkv(self.norm(x).view(b, c, -1)) + if encoder_out is not None: + encoder_out = self.encoder_kv(encoder_out) + h = self.attention(qkv, encoder_out, mask=mask) + else: + h = self.attention(qkv) + h = self.proj_out(h) + return x + h.reshape(b, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv, encoder_kv=None, mask=None): + """ + Apply QKV attention. + + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + if encoder_kv is not None: + assert encoder_kv.shape[1] == self.n_heads * ch * 2 + ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + k = th.cat([ek, k], dim=-1) + v = th.cat([ev, v], dim=-1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum("bct,bcs->bts", q * scale, k * scale) + if mask is not None: + mask = F.pad(mask, (0, length), value=0.0) + mask = ( + mask.unsqueeze(1) + .expand(-1, self.n_heads, -1) + .reshape(bs * self.n_heads, 1, -1) + ) + weight = weight + mask + weight = th.softmax(weight, dim=-1) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param clip_dim: dimension of clip feature. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param encoder_channels: use to make the dimension of query and kv same in AttentionBlock. + :param use_time_embedding: use time embedding for condition. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + clip_dim=None, + use_checkpoint=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + use_middle_attention=True, + resblock_updown=False, + encoder_channels=None, + use_time_embedding=True, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.clip_dim = clip_dim + self.use_checkpoint = use_checkpoint + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.use_middle_attention = use_middle_attention + self.use_time_embedding = use_time_embedding + + if self.use_time_embedding: + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + if self.clip_dim is not None: + self.clip_emb = nn.Linear(clip_dim, time_embed_dim) + else: + time_embed_dim = None + + CustomResidualBlock = ( + ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding + ) + ch = input_ch = int(channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(mult * model_channels) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + *( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + ), + ) + if self.use_middle_attention + else tuple(), # add AttentionBlock or not + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + CustomResidualBlock( + ch + ich, + time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = int(model_channels * mult) + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + CustomResidualBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch, swish=1.0), + nn.Identity(), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), + ) + + def forward(self, x, timesteps, y=None): + """ + Apply the model to an input batch. + + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + assert (y is not None) == ( + self.clip_dim is not None + ), "must specify y if and only if the model is clip-rep-conditional" + + hs = [] + if self.use_time_embedding: + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + if self.clip_dim is not None: + emb = emb + self.clip_emb(y) + else: + emb = None + + h = x + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb) + + return self.out(h) + + +class SuperResUNetModel(UNetModel): + """ + A UNetModel that performs super-resolution. + + Expects an extra kwarg `low_res` to condition on a low-resolution image. + Assumes that the shape of low-resolution and the input should be the same. + """ + + def __init__(self, *args, **kwargs): + if "in_channels" in kwargs: + kwargs = dict(kwargs) + kwargs["in_channels"] = kwargs["in_channels"] * 2 + else: + # Curse you, Python. Or really, just curse positional arguments :|. + args = list(args) + args[1] = args[1] * 2 + super().__init__(*args, **kwargs) + + def forward(self, x, timesteps, low_res=None, **kwargs): + _, _, new_height, new_width = x.shape + assert new_height == low_res.shape[2] and new_width == low_res.shape[3] + + x = th.cat([x, low_res], dim=1) + return super().forward(x, timesteps, **kwargs) + + +class PLMImUNet(UNetModel): + """ + A UNetModel that conditions on text with a pretrained text encoder in CLIP. + + :param text_ctx: number of text tokens to expect. + :param xf_width: width of the transformer. + :param clip_emb_mult: #extra tokens by projecting clip text feature. + :param clip_emb_type: type of condition (here, we fix clip image feature). + :param clip_emb_drop: dropout rato of clip image feature for cfg. + """ + + def __init__( + self, + text_ctx, + xf_width, + *args, + clip_emb_mult=None, + clip_emb_type="image", + clip_emb_drop=0.0, + **kwargs, + ): + self.text_ctx = text_ctx + self.xf_width = xf_width + self.clip_emb_mult = clip_emb_mult + self.clip_emb_type = clip_emb_type + self.clip_emb_drop = clip_emb_drop + + if not xf_width: + super().__init__(*args, **kwargs, encoder_channels=None) + else: + super().__init__(*args, **kwargs, encoder_channels=xf_width) + + # Project text encoded feat seq from pre-trained text encoder in CLIP + self.text_seq_proj = nn.Sequential( + nn.Linear(self.clip_dim, xf_width), + LayerNorm(xf_width), + ) + # Project CLIP text feat + self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4) + + assert clip_emb_mult is not None + assert clip_emb_type == "image" + assert self.clip_dim is not None, "CLIP representation dim should be specified" + + self.clip_tok_proj = nn.Linear( + self.clip_dim, self.xf_width * self.clip_emb_mult + ) + if self.clip_emb_drop > 0: + self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32)) + + def proc_clip_emb_drop(self, feat): + if self.clip_emb_drop > 0: + bsz, feat_dim = feat.shape + assert ( + feat_dim == self.clip_dim + ), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}" + drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop + feat = th.where( + drop_idx[..., None], self.cf_param[None].type_as(feat), feat + ) + return feat + + def forward( + self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None + ): + bsz = x.shape[0] + hs = [] + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + emb = emb + self.clip_emb(y) + + xf_out = self.text_seq_proj(txt_feat_seq) + xf_out = xf_out.permute(0, 2, 1) + emb = emb + self.text_feat_proj(txt_feat) + xf_out = th.cat( + [ + self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult), + xf_out, + ], + dim=2, + ) + mask = F.pad(mask, (self.clip_emb_mult, 0), value=True) + mask = th.where(mask, 0.0, float("-inf")) + + h = x + for module in self.input_blocks: + h = module(h, emb, xf_out, mask=mask) + hs.append(h) + h = self.middle_block(h, emb, xf_out, mask=mask) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, xf_out, mask=mask) + h = self.out(h) + + return h diff --git a/repositories/ldm/modules/karlo/kakao/modules/xf.py b/repositories/ldm/modules/karlo/kakao/modules/xf.py new file mode 100644 index 000000000..66d7d4a2f --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/modules/xf.py @@ -0,0 +1,231 @@ +# ------------------------------------------------------------------------------------ +# Adapted from the repos below: +# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion) +# (b) CLIP ViT (https://github.com/openai/CLIP/) +# ------------------------------------------------------------------------------------ + +import math + +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from .nn import timestep_embedding + + +def convert_module_to_f16(param): + """ + Convert primitive modules to float16. + """ + if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): + param.weight.data = param.weight.data.half() + if param.bias is not None: + param.bias.data = param.bias.data.half() + + +class LayerNorm(nn.LayerNorm): + """ + Implementation that supports fp16 inputs but fp32 gains/biases. + """ + + def forward(self, x: th.Tensor): + return super().forward(x.float()).to(x.dtype) + + +class MultiheadAttention(nn.Module): + def __init__(self, n_ctx, width, heads): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.heads = heads + self.c_qkv = nn.Linear(width, width * 3) + self.c_proj = nn.Linear(width, width) + self.attention = QKVMultiheadAttention(heads, n_ctx) + + def forward(self, x, mask=None): + x = self.c_qkv(x) + x = self.attention(x, mask=mask) + x = self.c_proj(x) + return x + + +class MLP(nn.Module): + def __init__(self, width): + super().__init__() + self.width = width + self.c_fc = nn.Linear(width, width * 4) + self.c_proj = nn.Linear(width * 4, width) + self.gelu = nn.GELU() + + def forward(self, x): + return self.c_proj(self.gelu(self.c_fc(x))) + + +class QKVMultiheadAttention(nn.Module): + def __init__(self, n_heads: int, n_ctx: int): + super().__init__() + self.n_heads = n_heads + self.n_ctx = n_ctx + + def forward(self, qkv, mask=None): + bs, n_ctx, width = qkv.shape + attn_ch = width // self.n_heads // 3 + scale = 1 / math.sqrt(math.sqrt(attn_ch)) + qkv = qkv.view(bs, n_ctx, self.n_heads, -1) + q, k, v = th.split(qkv, attn_ch, dim=-1) + weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale) + wdtype = weight.dtype + if mask is not None: + weight = weight + mask[:, None, ...] + weight = th.softmax(weight, dim=-1).type(wdtype) + return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + heads: int, + ): + super().__init__() + + self.attn = MultiheadAttention( + n_ctx, + width, + heads, + ) + self.ln_1 = LayerNorm(width) + self.mlp = MLP(width) + self.ln_2 = LayerNorm(width) + + def forward(self, x, mask=None): + x = x + self.attn(self.ln_1(x), mask=mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, + n_ctx: int, + width: int, + layers: int, + heads: int, + ): + super().__init__() + self.n_ctx = n_ctx + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock( + n_ctx, + width, + heads, + ) + for _ in range(layers) + ] + ) + + def forward(self, x, mask=None): + for block in self.resblocks: + x = block(x, mask=mask) + return x + + +class PriorTransformer(nn.Module): + """ + A Causal Transformer that conditions on CLIP text embedding, text. + + :param text_ctx: number of text tokens to expect. + :param xf_width: width of the transformer. + :param xf_layers: depth of the transformer. + :param xf_heads: heads in the transformer. + :param xf_final_ln: use a LayerNorm after the output layer. + :param clip_dim: dimension of clip feature. + """ + + def __init__( + self, + text_ctx, + xf_width, + xf_layers, + xf_heads, + xf_final_ln, + clip_dim, + ): + super().__init__() + + self.text_ctx = text_ctx + self.xf_width = xf_width + self.xf_layers = xf_layers + self.xf_heads = xf_heads + self.clip_dim = clip_dim + self.ext_len = 4 + + self.time_embed = nn.Sequential( + nn.Linear(xf_width, xf_width), + nn.SiLU(), + nn.Linear(xf_width, xf_width), + ) + self.text_enc_proj = nn.Linear(clip_dim, xf_width) + self.text_emb_proj = nn.Linear(clip_dim, xf_width) + self.clip_img_proj = nn.Linear(clip_dim, xf_width) + self.out_proj = nn.Linear(xf_width, clip_dim) + self.transformer = Transformer( + text_ctx + self.ext_len, + xf_width, + xf_layers, + xf_heads, + ) + if xf_final_ln: + self.final_ln = LayerNorm(xf_width) + else: + self.final_ln = None + + self.positional_embedding = nn.Parameter( + th.empty(1, text_ctx + self.ext_len, xf_width) + ) + self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width))) + + nn.init.normal_(self.prd_emb, std=0.01) + nn.init.normal_(self.positional_embedding, std=0.01) + + def forward( + self, + x, + timesteps, + text_emb=None, + text_enc=None, + mask=None, + causal_mask=None, + ): + bsz = x.shape[0] + mask = F.pad(mask, (0, self.ext_len), value=True) + + t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width)) + text_enc = self.text_enc_proj(text_enc) + text_emb = self.text_emb_proj(text_emb) + x = self.clip_img_proj(x) + + input_seq = [ + text_enc, + text_emb[:, None, :], + t_emb[:, None, :], + x[:, None, :], + self.prd_emb.to(x.dtype).expand(bsz, -1, -1), + ] + input = th.cat(input_seq, dim=1) + input = input + self.positional_embedding.to(input.dtype) + + mask = th.where(mask, 0.0, float("-inf")) + mask = (mask[:, None, :] + causal_mask).to(input.dtype) + + out = self.transformer(input, mask=mask) + if self.final_ln is not None: + out = self.final_ln(out) + + out = self.out_proj(out[:, -1]) + + return out diff --git a/repositories/ldm/modules/karlo/kakao/sampler.py b/repositories/ldm/modules/karlo/kakao/sampler.py new file mode 100644 index 000000000..b56bf2f20 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/sampler.py @@ -0,0 +1,272 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. + +# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15 +# ------------------------------------------------------------------------------------ + +from typing import Iterator + +import torch +import torchvision.transforms.functional as TVF +from torchvision.transforms import InterpolationMode + +from .template import BaseSampler, CKPT_PATH + + +class T2ISampler(BaseSampler): + """ + A sampler for text-to-image generation. + :param root_dir: directory for model checkpoints. + :param sampling_type: ["default", "fast"] + """ + + def __init__( + self, + root_dir: str, + sampling_type: str = "default", + ): + super().__init__(root_dir, sampling_type) + + @classmethod + def from_pretrained( + cls, + root_dir: str, + clip_model_path: str, + clip_stat_path: str, + sampling_type: str = "default", + ): + + model = cls( + root_dir=root_dir, + sampling_type=sampling_type, + ) + model.load_clip(clip_model_path) + model.load_prior( + f"{CKPT_PATH['prior']}", + clip_stat_path=clip_stat_path, + prior_config="configs/karlo/prior_1B_vit_l.yaml" + ) + model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml") + model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml") + return model + + def preprocess( + self, + prompt: str, + bsz: int, + ): + """Setup prompts & cfg scales""" + prompts_batch = [prompt for _ in range(bsz)] + + prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) + prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") + + decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) + decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") + + """ Get CLIP text feature """ + clip_model = self._clip + tokenizer = self._tokenizer + max_txt_length = self._prior.model.text_ctx + + tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) + cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) + if not (cf_token.shape == tok.shape): + cf_token = cf_token.expand(tok.shape[0], -1) + cf_mask = cf_mask.expand(tok.shape[0], -1) + + tok = torch.cat([tok, cf_token], dim=0) + mask = torch.cat([mask, cf_mask], dim=0) + + tok, mask = tok.to(device="cuda"), mask.to(device="cuda") + txt_feat, txt_feat_seq = clip_model.encode_text(tok) + + return ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) + + def __call__( + self, + prompt: str, + bsz: int, + progressive_mode=None, + ) -> Iterator[torch.Tensor]: + assert progressive_mode in ("loop", "stage", "final") + with torch.no_grad(), torch.cuda.amp.autocast(): + ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) = self.preprocess( + prompt, + bsz, + ) + + """ Transform CLIP text feature into image feature """ + img_feat = self._prior( + txt_feat, + txt_feat_seq, + mask, + prior_cf_scales_batch, + timestep_respacing=self._prior_sm, + ) + + """ Generate 64x64px images """ + images_64_outputs = self._decoder( + txt_feat, + txt_feat_seq, + tok, + mask, + img_feat, + cf_guidance_scales=decoder_cf_scales_batch, + timestep_respacing=self._decoder_sm, + ) + + images_64 = None + for k, out in enumerate(images_64_outputs): + images_64 = out + if progressive_mode == "loop": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + if progressive_mode == "stage": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + + images_64 = torch.clamp(images_64, -1, 1) + + """ Upsample 64x64 to 256x256 """ + images_256 = TVF.resize( + images_64, + [256, 256], + interpolation=InterpolationMode.BICUBIC, + antialias=True, + ) + images_256_outputs = self._sr_64_256( + images_256, timestep_respacing=self._sr_sm + ) + + for k, out in enumerate(images_256_outputs): + images_256 = out + if progressive_mode == "loop": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + if progressive_mode == "stage": + yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0) + + yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0) + + +class PriorSampler(BaseSampler): + """ + A sampler for text-to-image generation, but only the prior. + :param root_dir: directory for model checkpoints. + :param sampling_type: ["default", "fast"] + """ + + def __init__( + self, + root_dir: str, + sampling_type: str = "default", + ): + super().__init__(root_dir, sampling_type) + + @classmethod + def from_pretrained( + cls, + root_dir: str, + clip_model_path: str, + clip_stat_path: str, + sampling_type: str = "default", + ): + model = cls( + root_dir=root_dir, + sampling_type=sampling_type, + ) + model.load_clip(clip_model_path) + model.load_prior( + f"{CKPT_PATH['prior']}", + clip_stat_path=clip_stat_path, + prior_config="configs/karlo/prior_1B_vit_l.yaml" + ) + return model + + def preprocess( + self, + prompt: str, + bsz: int, + ): + """Setup prompts & cfg scales""" + prompts_batch = [prompt for _ in range(bsz)] + + prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch) + prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda") + + decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch) + decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda") + + """ Get CLIP text feature """ + clip_model = self._clip + tokenizer = self._tokenizer + max_txt_length = self._prior.model.text_ctx + + tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length) + cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length) + if not (cf_token.shape == tok.shape): + cf_token = cf_token.expand(tok.shape[0], -1) + cf_mask = cf_mask.expand(tok.shape[0], -1) + + tok = torch.cat([tok, cf_token], dim=0) + mask = torch.cat([mask, cf_mask], dim=0) + + tok, mask = tok.to(device="cuda"), mask.to(device="cuda") + txt_feat, txt_feat_seq = clip_model.encode_text(tok) + + return ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) + + def __call__( + self, + prompt: str, + bsz: int, + progressive_mode=None, + ) -> Iterator[torch.Tensor]: + assert progressive_mode in ("loop", "stage", "final") + with torch.no_grad(), torch.cuda.amp.autocast(): + ( + prompts_batch, + prior_cf_scales_batch, + decoder_cf_scales_batch, + txt_feat, + txt_feat_seq, + tok, + mask, + ) = self.preprocess( + prompt, + bsz, + ) + + """ Transform CLIP text feature into image feature """ + img_feat = self._prior( + txt_feat, + txt_feat_seq, + mask, + prior_cf_scales_batch, + timestep_respacing=self._prior_sm, + ) + + yield img_feat diff --git a/repositories/ldm/modules/karlo/kakao/template.py b/repositories/ldm/modules/karlo/kakao/template.py new file mode 100644 index 000000000..949e80e67 --- /dev/null +++ b/repositories/ldm/modules/karlo/kakao/template.py @@ -0,0 +1,141 @@ +# ------------------------------------------------------------------------------------ +# Karlo-v1.0.alpha +# Copyright (c) 2022 KakaoBrain. All Rights Reserved. +# ------------------------------------------------------------------------------------ + +import os +import logging +import torch + +from omegaconf import OmegaConf + +from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer +from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel +from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel +from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel + + +SAMPLING_CONF = { + "default": { + "prior_sm": "25", + "prior_n_samples": 1, + "prior_cf_scale": 4.0, + "decoder_sm": "50", + "decoder_cf_scale": 8.0, + "sr_sm": "7", + }, + "fast": { + "prior_sm": "25", + "prior_n_samples": 1, + "prior_cf_scale": 4.0, + "decoder_sm": "25", + "decoder_cf_scale": 8.0, + "sr_sm": "7", + }, +} + +CKPT_PATH = { + "prior": "prior-ckpt-step=01000000-of-01000000.ckpt", + "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt", + "sr_256": "improved-sr-ckpt-step=1.2M.ckpt", +} + + +class BaseSampler: + _PRIOR_CLASS = PriorDiffusionModel + _DECODER_CLASS = Text2ImProgressiveModel + _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel + + def __init__( + self, + root_dir: str, + sampling_type: str = "fast", + ): + self._root_dir = root_dir + + sampling_type = SAMPLING_CONF[sampling_type] + self._prior_sm = sampling_type["prior_sm"] + self._prior_n_samples = sampling_type["prior_n_samples"] + self._prior_cf_scale = sampling_type["prior_cf_scale"] + + assert self._prior_n_samples == 1 + + self._decoder_sm = sampling_type["decoder_sm"] + self._decoder_cf_scale = sampling_type["decoder_cf_scale"] + + self._sr_sm = sampling_type["sr_sm"] + + def __repr__(self): + line = "" + line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n" + line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n" + line += f"SR(64->256), sampling method: {self._sr_sm}" + + return line + + def load_clip(self, clip_path: str): + clip = CustomizedCLIP.load_from_checkpoint( + os.path.join(self._root_dir, clip_path) + ) + clip = torch.jit.script(clip) + clip.cuda() + clip.eval() + + self._clip = clip + self._tokenizer = CustomizedTokenizer() + + def load_prior( + self, + ckpt_path: str, + clip_stat_path: str, + prior_config: str = "configs/prior_1B_vit_l.yaml" + ): + logging.info(f"Loading prior: {ckpt_path}") + + config = OmegaConf.load(prior_config) + clip_mean, clip_std = torch.load( + os.path.join(self._root_dir, clip_stat_path), map_location="cpu" + ) + + prior = self._PRIOR_CLASS.load_from_checkpoint( + config, + self._tokenizer, + clip_mean, + clip_std, + os.path.join(self._root_dir, ckpt_path), + strict=True, + ) + prior.cuda() + prior.eval() + logging.info("done.") + + self._prior = prior + + def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"): + logging.info(f"Loading decoder: {ckpt_path}") + + config = OmegaConf.load(decoder_config) + decoder = self._DECODER_CLASS.load_from_checkpoint( + config, + self._tokenizer, + os.path.join(self._root_dir, ckpt_path), + strict=True, + ) + decoder.cuda() + decoder.eval() + logging.info("done.") + + self._decoder = decoder + + def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"): + logging.info(f"Loading SR(64->256): {ckpt_path}") + + config = OmegaConf.load(sr_config) + sr = self._SR256_CLASS.load_from_checkpoint( + config, os.path.join(self._root_dir, ckpt_path), strict=True + ) + sr.cuda() + sr.eval() + logging.info("done.") + + self._sr_64_256 = sr \ No newline at end of file diff --git a/repositories/ldm/modules/midas/__init__.py b/repositories/ldm/modules/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/midas/api.py b/repositories/ldm/modules/midas/api.py new file mode 100644 index 000000000..b58ebbffd --- /dev/null +++ b/repositories/ldm/modules/midas/api.py @@ -0,0 +1,170 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from ldm.modules.midas.midas.dpt_depth import DPTDepthModel +from ldm.modules.midas.midas.midas_net import MidasNet +from ldm.modules.midas.midas.midas_net_custom import MidasNet_small +from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet + + +ISL_PATHS = { + "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt", + "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt", + "midas_v21": "", + "midas_v21_small": "", +} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array + # NOTE: we expect that the correct transform has been called during dataloading. + with torch.no_grad(): + prediction = self.model(x) + prediction = torch.nn.functional.interpolate( + prediction.unsqueeze(1), + size=x.shape[2:], + mode="bicubic", + align_corners=False, + ) + assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3]) + return prediction + diff --git a/repositories/ldm/modules/midas/midas/__init__.py b/repositories/ldm/modules/midas/midas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/repositories/ldm/modules/midas/midas/base_model.py b/repositories/ldm/modules/midas/midas/base_model.py new file mode 100644 index 000000000..5cf430239 --- /dev/null +++ b/repositories/ldm/modules/midas/midas/base_model.py @@ -0,0 +1,16 @@ +import torch + + +class BaseModel(torch.nn.Module): + def load(self, path): + """Load model from file. + + Args: + path (str): file path + """ + parameters = torch.load(path, map_location=torch.device('cpu')) + + if "optimizer" in parameters: + parameters = parameters["model"] + + self.load_state_dict(parameters) diff --git a/repositories/ldm/modules/midas/midas/blocks.py b/repositories/ldm/modules/midas/midas/blocks.py new file mode 100644 index 000000000..2145d18fa --- /dev/null +++ b/repositories/ldm/modules/midas/midas/blocks.py @@ -0,0 +1,342 @@ +import torch +import torch.nn as nn + +from .vit import ( + _make_pretrained_vitb_rn50_384, + _make_pretrained_vitl16_384, + _make_pretrained_vitb16_384, + forward_vit, +) + +def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): + if backbone == "vitl16_384": + pretrained = _make_pretrained_vitl16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [256, 512, 1024, 1024], features, groups=groups, expand=expand + ) # ViT-L/16 - 85.0% Top1 (backbone) + elif backbone == "vitb_rn50_384": + pretrained = _make_pretrained_vitb_rn50_384( + use_pretrained, + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) + scratch = _make_scratch( + [256, 512, 768, 768], features, groups=groups, expand=expand + ) # ViT-H/16 - 85.0% Top1 (backbone) + elif backbone == "vitb16_384": + pretrained = _make_pretrained_vitb16_384( + use_pretrained, hooks=hooks, use_readout=use_readout + ) + scratch = _make_scratch( + [96, 192, 384, 768], features, groups=groups, expand=expand + ) # ViT-B/16 - 84.6% Top1 (backbone) + elif backbone == "resnext101_wsl": + pretrained = _make_pretrained_resnext101_wsl(use_pretrained) + scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 + elif backbone == "efficientnet_lite3": + pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) + scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 + else: + print(f"Backbone '{backbone}' not implemented") + assert False + + return pretrained, scratch + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + out_shape4 = out_shape + if expand==True: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): + efficientnet = torch.hub.load( + "rwightman/gen-efficientnet-pytorch", + "tf_efficientnet_lite3", + pretrained=use_pretrained, + exportable=exportable + ) + return _make_efficientnet_backbone(efficientnet) + + +def _make_efficientnet_backbone(effnet): + pretrained = nn.Module() + + pretrained.layer1 = nn.Sequential( + effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] + ) + pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) + pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) + pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) + + return pretrained + + +def _make_resnet_backbone(resnet): + pretrained = nn.Module() + pretrained.layer1 = nn.Sequential( + resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 + ) + + pretrained.layer2 = resnet.layer2 + pretrained.layer3 = resnet.layer3 + pretrained.layer4 = resnet.layer4 + + return pretrained + + +def _make_pretrained_resnext101_wsl(use_pretrained): + resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") + return _make_resnet_backbone(resnet) + + + +class Interpolate(nn.Module): + """Interpolation module. + """ + + def __init__(self, scale_factor, mode, align_corners=False): + """Init. + + Args: + scale_factor (float): scaling + mode (str): interpolation mode + """ + super(Interpolate, self).__init__() + + self.interp = nn.functional.interpolate + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: interpolated data + """ + + x = self.interp( + x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners + ) + + return x + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True + ) + + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + output += self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=True + ) + + return output + + + + +class ResidualConvUnit_custom(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + # return out + x + + +class FeatureFusionBlock_custom(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock_custom, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, *xs): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + # output += res + + output = self.resConfUnit2(output) + + output = nn.functional.interpolate( + output, scale_factor=2, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output + diff --git a/repositories/ldm/modules/midas/midas/dpt_depth.py b/repositories/ldm/modules/midas/midas/dpt_depth.py new file mode 100644 index 000000000..4e9aab5d2 --- /dev/null +++ b/repositories/ldm/modules/midas/midas/dpt_depth.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .base_model import BaseModel +from .blocks import ( + FeatureFusionBlock, + FeatureFusionBlock_custom, + Interpolate, + _make_encoder, + forward_vit, +) + + +def _make_fusion_block(features, use_bn): + return FeatureFusionBlock_custom( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + ) + + +class DPT(BaseModel): + def __init__( + self, + head, + features=256, + backbone="vitb_rn50_384", + readout="project", + channels_last=False, + use_bn=False, + ): + + super(DPT, self).__init__() + + self.channels_last = channels_last + + hooks = { + "vitb_rn50_384": [0, 1, 8, 11], + "vitb16_384": [2, 5, 8, 11], + "vitl16_384": [5, 11, 17, 23], + } + + # Instantiate backbone and reassemble blocks + self.pretrained, self.scratch = _make_encoder( + backbone, + features, + False, # Set to true of you want to train from scratch, uses ImageNet weights + groups=1, + expand=False, + exportable=False, + hooks=hooks[backbone], + use_readout=readout, + ) + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + self.scratch.output_conv = head + + + def forward(self, x): + if self.channels_last == True: + x.contiguous(memory_format=torch.channels_last) + + layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return out + + +class DPTDepthModel(DPT): + def __init__(self, path=None, non_negative=True, **kwargs): + features = kwargs["features"] if "features" in kwargs else 256 + + head = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + super().__init__(head, **kwargs) + + if path is not None: + self.load(path) + + def forward(self, x): + return super().forward(x).squeeze(dim=1) + diff --git a/repositories/ldm/modules/midas/midas/midas_net.py b/repositories/ldm/modules/midas/midas/midas_net.py new file mode 100644 index 000000000..8a9549778 --- /dev/null +++ b/repositories/ldm/modules/midas/midas/midas_net.py @@ -0,0 +1,76 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, Interpolate, _make_encoder + + +class MidasNet(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=256, non_negative=True): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet, self).__init__() + + use_pretrained = False if path is None else True + + self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) + + self.scratch.refinenet4 = FeatureFusionBlock(features) + self.scratch.refinenet3 = FeatureFusionBlock(features) + self.scratch.refinenet2 = FeatureFusionBlock(features) + self.scratch.refinenet1 = FeatureFusionBlock(features) + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + ) + + if path: + self.load(path) + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) diff --git a/repositories/ldm/modules/midas/midas/midas_net_custom.py b/repositories/ldm/modules/midas/midas/midas_net_custom.py new file mode 100644 index 000000000..50e4acb5e --- /dev/null +++ b/repositories/ldm/modules/midas/midas/midas_net_custom.py @@ -0,0 +1,128 @@ +"""MidashNet: Network for monocular depth estimation trained by mixing several datasets. +This file contains code that is adapted from +https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py +""" +import torch +import torch.nn as nn + +from .base_model import BaseModel +from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder + + +class MidasNet_small(BaseModel): + """Network for monocular depth estimation. + """ + + def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, + blocks={'expand': True}): + """Init. + + Args: + path (str, optional): Path to saved model. Defaults to None. + features (int, optional): Number of features. Defaults to 256. + backbone (str, optional): Backbone network for encoder. Defaults to resnet50 + """ + print("Loading weights: ", path) + + super(MidasNet_small, self).__init__() + + use_pretrained = False if path else True + + self.channels_last = channels_last + self.blocks = blocks + self.backbone = backbone + + self.groups = 1 + + features1=features + features2=features + features3=features + features4=features + self.expand = False + if "expand" in self.blocks and self.blocks['expand'] == True: + self.expand = True + features1=features + features2=features*2 + features3=features*4 + features4=features*8 + + self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) + + self.scratch.activation = nn.ReLU(False) + + self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) + self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) + + + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), + Interpolate(scale_factor=2, mode="bilinear"), + nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), + self.scratch.activation, + nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True) if non_negative else nn.Identity(), + nn.Identity(), + ) + + if path: + self.load(path) + + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input data (image) + + Returns: + tensor: depth + """ + if self.channels_last==True: + print("self.channels_last = ", self.channels_last) + x.contiguous(memory_format=torch.channels_last) + + + layer_1 = self.pretrained.layer1(x) + layer_2 = self.pretrained.layer2(layer_1) + layer_3 = self.pretrained.layer3(layer_2) + layer_4 = self.pretrained.layer4(layer_3) + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + + path_4 = self.scratch.refinenet4(layer_4_rn) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv(path_1) + + return torch.squeeze(out, dim=1) + + + +def fuse_model(m): + prev_previous_type = nn.Identity() + prev_previous_name = '' + previous_type = nn.Identity() + previous_name = '' + for name, module in m.named_modules(): + if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: + # print("FUSED ", prev_previous_name, previous_name, name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) + elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: + # print("FUSED ", prev_previous_name, previous_name) + torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) + # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: + # print("FUSED ", previous_name, name) + # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) + + prev_previous_type = previous_type + prev_previous_name = previous_name + previous_type = type(module) + previous_name = name \ No newline at end of file diff --git a/repositories/ldm/modules/midas/midas/transforms.py b/repositories/ldm/modules/midas/midas/transforms.py new file mode 100644 index 000000000..350cbc116 --- /dev/null +++ b/repositories/ldm/modules/midas/midas/transforms.py @@ -0,0 +1,234 @@ +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "disparity" in sample: + disparity = sample["disparity"].astype(np.float32) + sample["disparity"] = np.ascontiguousarray(disparity) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + return sample diff --git a/repositories/ldm/modules/midas/midas/vit.py b/repositories/ldm/modules/midas/midas/vit.py new file mode 100644 index 000000000..ea46b1be8 --- /dev/null +++ b/repositories/ldm/modules/midas/midas/vit.py @@ -0,0 +1,491 @@ +import torch +import torch.nn as nn +import timm +import types +import math +import torch.nn.functional as F + + +class Slice(nn.Module): + def __init__(self, start_index=1): + super(Slice, self).__init__() + self.start_index = start_index + + def forward(self, x): + return x[:, self.start_index :] + + +class AddReadout(nn.Module): + def __init__(self, start_index=1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x): + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +class ProjectReadout(nn.Module): + def __init__(self, in_features, start_index=1): + super(ProjectReadout, self).__init__() + self.start_index = start_index + + self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU()) + + def forward(self, x): + readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :]) + features = torch.cat((x[:, self.start_index :], readout), -1) + + return self.project(features) + + +class Transpose(nn.Module): + def __init__(self, dim0, dim1): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + x = x.transpose(self.dim0, self.dim1) + return x + + +def forward_vit(pretrained, x): + b, c, h, w = x.shape + + glob = pretrained.model.forward_flex(x) + + layer_1 = pretrained.activations["1"] + layer_2 = pretrained.activations["2"] + layer_3 = pretrained.activations["3"] + layer_4 = pretrained.activations["4"] + + layer_1 = pretrained.act_postprocess1[0:2](layer_1) + layer_2 = pretrained.act_postprocess2[0:2](layer_2) + layer_3 = pretrained.act_postprocess3[0:2](layer_3) + layer_4 = pretrained.act_postprocess4[0:2](layer_4) + + unflatten = nn.Sequential( + nn.Unflatten( + 2, + torch.Size( + [ + h // pretrained.model.patch_size[1], + w // pretrained.model.patch_size[0], + ] + ), + ) + ) + + if layer_1.ndim == 3: + layer_1 = unflatten(layer_1) + if layer_2.ndim == 3: + layer_2 = unflatten(layer_2) + if layer_3.ndim == 3: + layer_3 = unflatten(layer_3) + if layer_4.ndim == 3: + layer_4 = unflatten(layer_4) + + layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1) + layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2) + layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3) + layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4) + + return layer_1, layer_2, layer_3, layer_4 + + +def _resize_pos_embed(self, posemb, gs_h, gs_w): + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear") + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +def forward_flex(self, x): + b, c, h, w = x.shape + + pos_embed = self._resize_pos_embed( + self.pos_embed, h // self.patch_size[1], w // self.patch_size[0] + ) + + B = x.shape[0] + + if hasattr(self.patch_embed, "backbone"): + x = self.patch_embed.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + + if getattr(self, "dist_token", None) is not None: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + else: + cls_tokens = self.cls_token.expand( + B, -1, -1 + ) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + + return x + + +activations = {} + + +def get_activation(name): + def hook(model, input, output): + activations[name] = output + + return hook + + +def get_readout_oper(vit_features, features, use_readout, start_index=1): + if use_readout == "ignore": + readout_oper = [Slice(start_index)] * len(features) + elif use_readout == "add": + readout_oper = [AddReadout(start_index)] * len(features) + elif use_readout == "project": + readout_oper = [ + ProjectReadout(vit_features, start_index) for out_feat in features + ] + else: + assert ( + False + ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'" + + return readout_oper + + +def _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + size=[384, 384], + hooks=[2, 5, 8, 11], + vit_features=768, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + # 32, 48, 136, 384 + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_large_patch16_384", pretrained=pretrained) + + hooks = [5, 11, 17, 23] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[256, 512, 1024, 1024], + hooks=hooks, + vit_features=1024, + use_readout=use_readout, + ) + + +def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout + ) + + +def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None): + model = timm.create_model( + "vit_deit_base_distilled_patch16_384", pretrained=pretrained + ) + + hooks = [2, 5, 8, 11] if hooks == None else hooks + return _make_vit_b16_backbone( + model, + features=[96, 192, 384, 768], + hooks=hooks, + use_readout=use_readout, + start_index=2, + ) + + +def _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=[0, 1, 8, 11], + vit_features=768, + use_vit_only=False, + use_readout="ignore", + start_index=1, +): + pretrained = nn.Module() + + pretrained.model = model + + if use_vit_only == True: + pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) + pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) + else: + pretrained.model.patch_embed.backbone.stages[0].register_forward_hook( + get_activation("1") + ) + pretrained.model.patch_embed.backbone.stages[1].register_forward_hook( + get_activation("2") + ) + + pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) + pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4")) + + pretrained.activations = activations + + readout_oper = get_readout_oper(vit_features, features, use_readout, start_index) + + if use_vit_only == True: + pretrained.act_postprocess1 = nn.Sequential( + readout_oper[0], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[0], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[0], + out_channels=features[0], + kernel_size=4, + stride=4, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + + pretrained.act_postprocess2 = nn.Sequential( + readout_oper[1], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[1], + kernel_size=1, + stride=1, + padding=0, + ), + nn.ConvTranspose2d( + in_channels=features[1], + out_channels=features[1], + kernel_size=2, + stride=2, + padding=0, + bias=True, + dilation=1, + groups=1, + ), + ) + else: + pretrained.act_postprocess1 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + pretrained.act_postprocess2 = nn.Sequential( + nn.Identity(), nn.Identity(), nn.Identity() + ) + + pretrained.act_postprocess3 = nn.Sequential( + readout_oper[2], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[2], + kernel_size=1, + stride=1, + padding=0, + ), + ) + + pretrained.act_postprocess4 = nn.Sequential( + readout_oper[3], + Transpose(1, 2), + nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])), + nn.Conv2d( + in_channels=vit_features, + out_channels=features[3], + kernel_size=1, + stride=1, + padding=0, + ), + nn.Conv2d( + in_channels=features[3], + out_channels=features[3], + kernel_size=3, + stride=2, + padding=1, + ), + ) + + pretrained.model.start_index = start_index + pretrained.model.patch_size = [16, 16] + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model) + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + pretrained.model._resize_pos_embed = types.MethodType( + _resize_pos_embed, pretrained.model + ) + + return pretrained + + +def _make_pretrained_vitb_rn50_384( + pretrained, use_readout="ignore", hooks=None, use_vit_only=False +): + model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained) + + hooks = [0, 1, 8, 11] if hooks == None else hooks + return _make_vit_b_rn50_backbone( + model, + features=[256, 512, 768, 768], + size=[384, 384], + hooks=hooks, + use_vit_only=use_vit_only, + use_readout=use_readout, + ) diff --git a/repositories/ldm/modules/midas/utils.py b/repositories/ldm/modules/midas/utils.py new file mode 100644 index 000000000..9a9d3b5b6 --- /dev/null +++ b/repositories/ldm/modules/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/repositories/ldm/util.py b/repositories/ldm/util.py new file mode 100644 index 000000000..9ede259d5 --- /dev/null +++ b/repositories/ldm/util.py @@ -0,0 +1,207 @@ +import importlib + +import torch +from torch import optim +import numpy as np + +from inspect import isfunction +from PIL import Image, ImageDraw, ImageFont + + +def autocast(f): + def do_autocast(*args, **kwargs): + with torch.cuda.amp.autocast(enabled=True, + dtype=torch.get_autocast_gpu_dtype(), + cache_enabled=torch.is_autocast_cache_enabled()): + return f(*args, **kwargs) + + return do_autocast + + +def log_txt_as_img(wh, xc, size=10): + # wh a tuple of (width, height) + # xc a list of captions to plot + b = len(xc) + txts = list() + for bi in range(b): + txt = Image.new("RGB", wh, color="white") + draw = ImageDraw.Draw(txt) + font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) + nc = int(40 * (wh[0] / 256)) + lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) + + try: + draw.text((0, 0), lines, fill="black", font=font) + except UnicodeEncodeError: + print("Cant encode string for logging. Skipping.") + + txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 + txts.append(txt) + txts = np.stack(txts) + txts = torch.tensor(txts) + return txts + + +def ismap(x): + if not isinstance(x, torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] > 3) + + +def isimage(x): + if not isinstance(x,torch.Tensor): + return False + return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def mean_flat(tensor): + """ + https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + return total_params + + +def instantiate_from_config(config): + if not "target" in config: + if config == '__is_first_stage__': + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +class AdamWwithEMAandWings(optim.Optimizer): + # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 + def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using + weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code + ema_power=1., param_names=()): + """AdamW that saves EMA versions of the parameters.""" + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= ema_decay <= 1.0: + raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, + ema_power=ema_power, param_names=param_names) + super().__init__(params, defaults) + + def __setstate__(self, state): + super().__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + ema_params_with_grad = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + ema_decay = group['ema_decay'] + ema_power = group['ema_power'] + + for p in group['params']: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError('AdamW does not support sparse gradients') + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of parameter values + state['param_exp_avg'] = p.detach().float().clone() + + exp_avgs.append(state['exp_avg']) + exp_avg_sqs.append(state['exp_avg_sq']) + ema_params_with_grad.append(state['param_exp_avg']) + + if amsgrad: + max_exp_avg_sqs.append(state['max_exp_avg_sq']) + + # update the steps for each param group update + state['step'] += 1 + # record the step after step update + state_steps.append(state['step']) + + optim._functional.adamw(params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group['lr'], + weight_decay=group['weight_decay'], + eps=group['eps'], + maximize=False) + + cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) + for param, ema_param in zip(params_with_grad, ema_params_with_grad): + ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) + + return loss \ No newline at end of file diff --git a/repositories/taming/README.md b/repositories/taming/README.md new file mode 100644 index 000000000..d295fbf75 --- /dev/null +++ b/repositories/taming/README.md @@ -0,0 +1,410 @@ +# Taming Transformers for High-Resolution Image Synthesis +##### CVPR 2021 (Oral) +![teaser](assets/mountain.jpeg) + +[**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)
+[Patrick Esser](https://github.com/pesser)\*, +[Robin Rombach](https://github.com/rromb)\*, +[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
+\* equal contribution + +**tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer. + +![teaser](assets/teaser.png) +[arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/) + + +### News +#### 2022 +- More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion). +- Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis). +#### 2021 +- Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data). +- Included a bugfix for the quantizer. For backward compatibility it is + disabled by default (which corresponds to always training with `beta=1.0`). + Use `legacy=False` in the quantizer config to enable it. + Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)! +- Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog. +- Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN. +- Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf). +- Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`. +- Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization. + See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb). +- We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources). +- We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k). +- The streamlit demo now supports image completions. +- We now include a couple of examples from the D-RIN dataset so you can run the + [D-RIN demo](#d-rin) without preparing the dataset first. +- You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb). + +## Requirements +A suitable [conda](https://conda.io/) environment named `taming` can be created +and activated with: + +``` +conda env create -f environment.yaml +conda activate taming +``` +## Overview of pretrained models +The following table provides an overview of all models that are currently available. +FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity). +For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model. +See the corresponding [colab +notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) +for a comparison and discussion of reconstruction capabilities. + +| Dataset | FID vs train | FID vs val | Link | Samples (256x256) | Comments +| ------------- | ------------- | ------------- |------------- | ------------- |------------- | +| FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) | [ffhq_samples](https://k00.fr/j626x093) | +| CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) | +| ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images) +| COCO-Stuff (f=16) | -- | 20.4 | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images) +| ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters | +| | | | || | +| FacesHQ (f=16) | -- | -- | [faceshq_transformer](https://k00.fr/qqfl2do8) +| S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/) +| D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5) +| | | | | || | +| VQGAN ImageNet (f=16), 1024 | 10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs. +| VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs. +| VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion). +| VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion) +| VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) | --- | Reconstruction-FIDs. +| | | | | || | +| DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs. + + +## Running pretrained models + +The commands below will start a streamlit demo which supports sampling at +different resolutions and image completions. To run a non-interactive version +of the sampling process, replace `streamlit run scripts/sample_conditional.py --` +by `python scripts/make_samples.py --outdir ` and +keep the remaining command line arguments. + +To sample from unconditional or class-conditional models, +run `python scripts/sample_fast.py -r `. +We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models, +respectively. + +### S-FLCKR +![teaser](assets/sunset_and_ocean.jpg) + +You can also [run this model in a Colab +notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb), +which includes all necessary steps to start sampling. + +Download the +[2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/) +folder and place it into `logs`. Then, run +``` +streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/ +``` + +### ImageNet +![teaser](assets/imagenet.png) + +Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv) +folder and place it into logs. Sampling from the class-conditional ImageNet +model does not require any data preparation. To produce 50 samples for each of +the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus +sampling and temperature t=1.0, run + +``` +python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 +``` + +To restrict the model to certain classes, provide them via the `--classes` argument, separated by +commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run + +``` +python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901 +``` +We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results. + +### FFHQ/CelebA-HQ + +Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and +[2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf) +folders and place them into logs. +Again, sampling from these unconditional models does not require any data preparation. +To produce 50000 samples, with k=250 for top-k sampling, +p=1.0 for nucleus sampling and temperature t=1.0, run + +``` +python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/ +``` +for FFHQ and + +``` +python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/ +``` +to sample from the CelebA-HQ model. +For both models it can be advantageous to vary the top-k/top-p parameters for sampling. + +### FacesHQ +![teaser](assets/faceshq.jpg) + +Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and +place it into `logs`. Follow the data preparation steps for +[CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run +``` +streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/ +``` + +### D-RIN +![teaser](assets/drin.jpg) + +Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and +place it into `logs`. To run the demo on a couple of example depth maps +included in the repository, run + +``` +streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}" +``` + +To run the demo on the complete validation set, first follow the data preparation steps for +[ImageNet](#imagenet) and then run +``` +streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ +``` + +### COCO +Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and +place it into `logs`. To run the demo on a couple of example segmentation maps +included in the repository, run + +``` +streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}" +``` + +### ADE20k +Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and +place it into `logs`. To run the demo on a couple of example segmentation maps +included in the repository, run + +``` +streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}" +``` + +## Scene Image Synthesis +![teaser](assets/scene_images_samples.svg) +Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images. + +### Training +Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images. +Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models. +Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already. + +Code can be run with +`python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,` +or +`python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,` + +### Sampling +Train a model as described above or download a pre-trained model: + - [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512. + - [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11. + - [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/) + - [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`) + +When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training). + +Scene image generation can be run with +`python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512` + + +## Training on custom data + +Training on your own dataset can be beneficial to get better tokens and hence better images for your domain. +Those are the steps to follow to make this work: +1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .` +1. put your .jpg files in a folder `your_folder` +2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`) +3. adapt `configs/custom_vqgan.yaml` to point to these 2 files +4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to + train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU. + +## Data Preparation + +### ImageNet +The code will try to download (through [Academic +Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it +is used. However, since ImageNet is quite large, this requires a lot of disk +space and time. If you already have ImageNet on your disk, you can speed things +up by putting the data into +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to +`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one +of `train`/`validation`. It should have the following structure: + +``` +${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/ +├── n01440764 +│ ├── n01440764_10026.JPEG +│ ├── n01440764_10027.JPEG +│ ├── ... +├── n01443537 +│ ├── n01443537_10007.JPEG +│ ├── n01443537_10014.JPEG +│ ├── ... +├── ... +``` + +If you haven't extracted the data, you can also place +`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` / +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be +extracted into above structure without downloading it again. Note that this +will only happen if neither a folder +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file +`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them +if you want to force running the dataset preparation again. + +You will then need to prepare the depth data using +[MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink +`data/imagenet_depth` pointing to a folder with two subfolders `train` and +`val`, each mirroring the structure of the corresponding ImageNet folder +described above and containing a `png` file for each of ImageNet's `JPEG` +files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA +images. We provide the script `scripts/extract_depth.py` to generate this data. +**Please note** that this script uses [MiDaS via PyTorch +Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data, +the hub provided the [MiDaS +v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it +provides a v2.1 version. We haven't tested our models with depth maps obtained +via v2.1 and if you want to make sure that things work as expected, you must +adjust the script to make sure it explicitly uses +[v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)! + +### CelebA-HQ +Create a symlink `data/celebahq` pointing to a folder containing the `.npy` +files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN +repository](https://github.com/tkarras/progressive_growing_of_gans)). + +### FFHQ +Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained +from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset). + +### S-FLCKR +Unfortunately, we are not allowed to distribute the images we collected for the +S-FLCKR dataset and can therefore only give a description how it was produced. +There are many resources on [collecting images from the +web](https://github.com/adrianmrit/flickrdatasets) to get started. +We collected sufficiently large images from [flickr](https://www.flickr.com) +(see `data/flickr_tags.txt` for a full list of tags used to find images) +and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network) +(see `data/subreddits.txt` for all subreddits that were used). +Overall, we collected 107625 images, and split them randomly into 96861 +training images and 10764 validation images. We then obtained segmentation +masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915) +trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch +reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an +example script for this process in `scripts/extract_segmentation.py`. + +### COCO +Create a symlink `data/coco` containing the images from the 2017 split in +`train2017` and `val2017`, and their annotations in `annotations`. Files can be +obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use +the [Stuff+thing PNG-style annotations on COCO 2017 +trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip) +annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which +should be placed under `data/cocostuffthings`. + +### ADE20k +Create a symlink `data/ade20k_root` containing the contents of +[ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip) +from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/). + +## Training models + +### FacesHQ + +Train a VQGAN with +``` +python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0, +``` + +Then, adjust the checkpoint path of the config key +`model.params.first_stage_config.params.ckpt_path` in +`configs/faceshq_transformer.yaml` (or download +[2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which +corresponds to the preconfigured checkpoint path), then run +``` +python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0, +``` + +### D-RIN + +Train a VQGAN on ImageNet with +``` +python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0, +``` + +or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac) +and place under `logs`. If you trained your own, adjust the path in the config +key `model.params.first_stage_config.params.ckpt_path` of +`configs/drin_transformer.yaml`. + +Train a VQGAN on Depth Maps of ImageNet with +``` +python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0, +``` + +or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i) +and place under `logs`. If you trained your own, adjust the path in the config +key `model.params.cond_stage_config.params.ckpt_path` of +`configs/drin_transformer.yaml`. + +To train the transformer, run +``` +python main.py --base configs/drin_transformer.yaml -t True --gpus 0, +``` + +## More Resources +### Comparing Different First Stage Models +The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb). +In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384, +a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192 +codebook entries). +![firststages1](assets/first_stage_squirrels.png) +![firststages2](assets/first_stage_mushrooms.png) + +### Other +- A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg). +- A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about). +- A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY) +by [ayulockin](https://github.com/ayulockin). +- A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg). +- Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab. + +### Text-to-Image Optimization via CLIP +VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation +from scratch and image-to-image translation. We recommend the following notebooks/videos/resources: + + - [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme + - The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings). + - A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available) + +![txt2img](assets/birddrawnbyachild.png) + +Text prompt: *'A bird drawn by a child'* + +## Shout-outs +Thanks to everyone who makes their code and models available. In particular, + +- The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion) +- The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT) +- The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity) + +## BibTeX + +``` +@misc{esser2020taming, + title={Taming Transformers for High-Resolution Image Synthesis}, + author={Patrick Esser and Robin Rombach and Björn Ommer}, + year={2020}, + eprint={2012.09841}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` diff --git a/repositories/taming/data/ade20k.py b/repositories/taming/data/ade20k.py new file mode 100644 index 000000000..366dae972 --- /dev/null +++ b/repositories/taming/data/ade20k.py @@ -0,0 +1,124 @@ +import os +import numpy as np +import cv2 +import albumentations +from PIL import Image +from torch.utils.data import Dataset + +from taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/ade20k_examples.txt", + data_root="data/ade20k_images", + segmentation_root="data/ade20k_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=151, shift_segmentation=False) + + +# With semantic map and scene label +class ADE20kBase(Dataset): + def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): + self.split = self.get_split() + self.n_labels = 151 # unknown + 150 + self.data_csv = {"train": "data/ade20k_train.txt", + "validation": "data/ade20k_test.txt"}[self.split] + self.data_root = "data/ade20k_root" + with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: + self.scene_categories = f.read().splitlines() + self.scene_categories = dict(line.split() for line in self.scene_categories) + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, "images", l) + for l in self.image_paths], + "relative_segmentation_path_": [l.replace(".jpg", ".png") + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.data_root, "annotations", + l.replace(".jpg", ".png")) + for l in self.image_paths], + "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] + for l in self.image_paths], + } + + size = None if size is not None and size<=0 else size + self.size = size + if crop_size is None: + self.crop_size = size if size is not None else None + else: + self.crop_size = crop_size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + + if crop_size is not None: + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + segmentation = np.array(segmentation).astype(np.uint8) + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, mask=segmentation) + else: + processed = {"image": image, "mask": segmentation} + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class ADE20kTrain(ADE20kBase): + # default to random_crop=True + def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): + super().__init__(config=config, size=size, random_crop=random_crop, + interpolation=interpolation, crop_size=crop_size) + + def get_split(self): + return "train" + + +class ADE20kValidation(ADE20kBase): + def get_split(self): + return "validation" + + +if __name__ == "__main__": + dset = ADE20kValidation() + ex = dset[0] + for k in ["image", "scene_category", "segmentation"]: + print(type(ex[k])) + try: + print(ex[k].shape) + except: + print(ex[k]) diff --git a/repositories/taming/data/annotated_objects_coco.py b/repositories/taming/data/annotated_objects_coco.py new file mode 100644 index 000000000..af000ecd9 --- /dev/null +++ b/repositories/taming/data/annotated_objects_coco.py @@ -0,0 +1,139 @@ +import json +from itertools import chain +from pathlib import Path +from typing import Iterable, Dict, List, Callable, Any +from collections import defaultdict + +from tqdm import tqdm + +from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset +from taming.data.helper_types import Annotation, ImageDescription, Category + +COCO_PATH_STRUCTURE = { + 'train': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_train2017.json', + 'stuff_annotations': 'annotations/stuff_train2017.json', + 'files': 'train2017' + }, + 'validation': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_val2017.json', + 'stuff_annotations': 'annotations/stuff_val2017.json', + 'files': 'val2017' + } +} + + +def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: + return { + str(img['id']): ImageDescription( + id=img['id'], + license=img.get('license'), + file_name=img['file_name'], + coco_url=img['coco_url'], + original_size=(img['width'], img['height']), + date_captured=img.get('date_captured'), + flickr_url=img.get('flickr_url') + ) + for img in description_json + } + + +def load_categories(category_json: Iterable) -> Dict[str, Category]: + return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) + for cat in category_json if cat['name'] != 'other'} + + +def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], + category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: + annotations = defaultdict(list) + total = sum(len(a) for a in annotations_json) + for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): + image_id = str(ann['image_id']) + if image_id not in image_descriptions: + raise ValueError(f'image_id [{image_id}] has no image description.') + category_id = ann['category_id'] + try: + category_no = category_no_for_id(str(category_id)) + except KeyError: + continue + + width, height = image_descriptions[image_id].original_size + bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) + + annotations[image_id].append( + Annotation( + id=ann['id'], + area=bbox[2]*bbox[3], # use bbox area + is_group_of=ann['iscrowd'], + image_id=ann['image_id'], + bbox=bbox, + category_id=str(category_id), + category_no=category_no + ) + ) + return dict(annotations) + + +class AnnotatedObjectsCoco(AnnotatedObjectsDataset): + def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): + """ + @param data_path: is the path to the following folder structure: + coco/ + ├── annotations + │ ├── instances_train2017.json + │ ├── instances_val2017.json + │ ├── stuff_train2017.json + │ └── stuff_val2017.json + ├── train2017 + │ ├── 000000000009.jpg + │ ├── 000000000025.jpg + │ └── ... + ├── val2017 + │ ├── 000000000139.jpg + │ ├── 000000000285.jpg + │ └── ... + @param: split: one of 'train' or 'validation' + @param: desired image size (give square images) + """ + super().__init__(**kwargs) + self.use_things = use_things + self.use_stuff = use_stuff + + with open(self.paths['instances_annotations']) as f: + inst_data_json = json.load(f) + with open(self.paths['stuff_annotations']) as f: + stuff_data_json = json.load(f) + + category_jsons = [] + annotation_jsons = [] + if self.use_things: + category_jsons.append(inst_data_json['categories']) + annotation_jsons.append(inst_data_json['annotations']) + if self.use_stuff: + category_jsons.append(stuff_data_json['categories']) + annotation_jsons.append(stuff_data_json['annotations']) + + self.categories = load_categories(chain(*category_jsons)) + self.filter_categories() + self.setup_category_id_and_number() + + self.image_descriptions = load_image_descriptions(inst_data_json['images']) + annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) + self.annotations = self.filter_object_number(annotations, self.min_object_area, + self.min_objects_per_image, self.max_objects_per_image) + self.image_ids = list(self.annotations.keys()) + self.clean_up_annotations_and_image_descriptions() + + def get_path_structure(self) -> Dict[str, str]: + if self.split not in COCO_PATH_STRUCTURE: + raise ValueError(f'Split [{self.split} does not exist for COCO data.]') + return COCO_PATH_STRUCTURE[self.split] + + def get_image_path(self, image_id: str) -> Path: + return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + # noinspection PyProtectedMember + return self.image_descriptions[image_id]._asdict() diff --git a/repositories/taming/data/annotated_objects_dataset.py b/repositories/taming/data/annotated_objects_dataset.py new file mode 100644 index 000000000..53cc346a1 --- /dev/null +++ b/repositories/taming/data/annotated_objects_dataset.py @@ -0,0 +1,218 @@ +from pathlib import Path +from typing import Optional, List, Callable, Dict, Any, Union +import warnings + +import PIL.Image as pil_image +from torch import Tensor +from torch.utils.data import Dataset +from torchvision import transforms + +from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder +from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from taming.data.conditional_builder.utils import load_object_from_string +from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType +from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ + Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor + + +class AnnotatedObjectsDataset(Dataset): + def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int, + min_object_area: float, min_objects_per_image: int, max_objects_per_image: int, + crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool, + encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "", + no_object_classes: Optional[int] = None): + self.data_path = data_path + self.split = split + self.keys = keys + self.target_image_size = target_image_size + self.min_object_area = min_object_area + self.min_objects_per_image = min_objects_per_image + self.max_objects_per_image = max_objects_per_image + self.crop_method = crop_method + self.random_flip = random_flip + self.no_tokens = no_tokens + self.use_group_parameter = use_group_parameter + self.encode_crop = encode_crop + + self.annotations = None + self.image_descriptions = None + self.categories = None + self.category_ids = None + self.category_number = None + self.image_ids = None + self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip) + self.paths = self.build_paths(self.data_path) + self._conditional_builders = None + self.category_allow_list = None + if category_allow_list_target: + allow_list = load_object_from_string(category_allow_list_target) + self.category_allow_list = {name for name, _ in allow_list} + self.category_mapping = {} + if category_mapping_target: + self.category_mapping = load_object_from_string(category_mapping_target) + self.no_object_classes = no_object_classes + + def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]: + top_level = Path(top_level) + sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()} + for path in sub_paths.values(): + if not path.exists(): + raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.') + return sub_paths + + @staticmethod + def load_image_from_disk(path: Path) -> Image: + return pil_image.open(path).convert('RGB') + + @staticmethod + def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool): + transform_functions = [] + if crop_method == 'none': + transform_functions.append(transforms.Resize((target_image_size, target_image_size))) + elif crop_method == 'center': + transform_functions.extend([ + transforms.Resize(target_image_size), + CenterCropReturnCoordinates(target_image_size) + ]) + elif crop_method == 'random-1d': + transform_functions.extend([ + transforms.Resize(target_image_size), + RandomCrop1dReturnCoordinates(target_image_size) + ]) + elif crop_method == 'random-2d': + transform_functions.extend([ + Random2dCropReturnCoordinates(target_image_size), + transforms.Resize(target_image_size) + ]) + elif crop_method is None: + return None + else: + raise ValueError(f'Received invalid crop method [{crop_method}].') + if random_flip: + transform_functions.append(RandomHorizontalFlipReturn()) + transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.)) + return transform_functions + + def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor): + crop_bbox = None + flipped = None + for t in self.transform_functions: + if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)): + crop_bbox, x = t(x) + elif isinstance(t, RandomHorizontalFlipReturn): + flipped, x = t(x) + else: + x = t(x) + return crop_bbox, flipped, x + + @property + def no_classes(self) -> int: + return self.no_object_classes if self.no_object_classes else len(self.categories) + + @property + def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: + # cannot set this up in init because no_classes is only known after loading data in init of superclass + if self._conditional_builders is None: + self._conditional_builders = { + 'objects_center_points': ObjectsCenterPointsConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.encode_crop, + self.use_group_parameter, + getattr(self, 'use_additional_parameters', False) + ), + 'objects_bbox': ObjectsBoundingBoxConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.encode_crop, + self.use_group_parameter, + getattr(self, 'use_additional_parameters', False) + ) + } + return self._conditional_builders + + def filter_categories(self) -> None: + if self.category_allow_list: + self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list} + if self.category_mapping: + self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping} + + def setup_category_id_and_number(self) -> None: + self.category_ids = list(self.categories.keys()) + self.category_ids.sort() + if '/m/01s55n' in self.category_ids: + self.category_ids.remove('/m/01s55n') + self.category_ids.append('/m/01s55n') + self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)} + if self.category_allow_list is not None and self.category_mapping is None \ + and len(self.category_ids) != len(self.category_allow_list): + warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. ' + 'Make sure all names in category_allow_list exist.') + + def clean_up_annotations_and_image_descriptions(self) -> None: + image_id_set = set(self.image_ids) + self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set} + self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set} + + @staticmethod + def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float, + min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]: + filtered = {} + for image_id, annotations in all_annotations.items(): + annotations_with_min_area = [a for a in annotations if a.area > min_object_area] + if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image: + filtered[image_id] = annotations_with_min_area + return filtered + + def __len__(self): + return len(self.image_ids) + + def __getitem__(self, n: int) -> Dict[str, Any]: + image_id = self.get_image_id(n) + sample = self.get_image_description(image_id) + sample['annotations'] = self.get_annotation(image_id) + + if 'image' in self.keys: + sample['image_path'] = str(self.get_image_path(image_id)) + sample['image'] = self.load_image_from_disk(sample['image_path']) + sample['image'] = convert_pil_to_tensor(sample['image']) + sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image']) + sample['image'] = sample['image'].permute(1, 2, 0) + + for conditional, builder in self.conditional_builders.items(): + if conditional in self.keys: + sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped']) + + if self.keys: + # only return specified keys + sample = {key: sample[key] for key in self.keys} + return sample + + def get_image_id(self, no: int) -> str: + return self.image_ids[no] + + def get_annotation(self, image_id: str) -> str: + return self.annotations[image_id] + + def get_textual_label_for_category_id(self, category_id: str) -> str: + return self.categories[category_id].name + + def get_textual_label_for_category_no(self, category_no: int) -> str: + return self.categories[self.get_category_id(category_no)].name + + def get_category_number(self, category_id: str) -> int: + return self.category_number[category_id] + + def get_category_id(self, category_no: int) -> str: + return self.category_ids[category_no] + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + raise NotImplementedError() + + def get_path_structure(self): + raise NotImplementedError + + def get_image_path(self, image_id: str) -> Path: + raise NotImplementedError diff --git a/repositories/taming/data/annotated_objects_open_images.py b/repositories/taming/data/annotated_objects_open_images.py new file mode 100644 index 000000000..aede6803d --- /dev/null +++ b/repositories/taming/data/annotated_objects_open_images.py @@ -0,0 +1,137 @@ +from collections import defaultdict +from csv import DictReader, reader as TupleReader +from pathlib import Path +from typing import Dict, List, Any +import warnings + +from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset +from taming.data.helper_types import Annotation, Category +from tqdm import tqdm + +OPEN_IMAGES_STRUCTURE = { + 'train': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'oidv6-train-annotations-bbox.csv', + 'file_list': 'train-images-boxable.csv', + 'files': 'train' + }, + 'validation': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'validation-annotations-bbox.csv', + 'file_list': 'validation-images.csv', + 'files': 'validation' + }, + 'test': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'test-annotations-bbox.csv', + 'file_list': 'test-images.csv', + 'files': 'test' + } +} + + +def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], + category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: + annotations: Dict[str, List[Annotation]] = defaultdict(list) + with open(descriptor_path) as file: + reader = DictReader(file) + for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): + width = float(row['XMax']) - float(row['XMin']) + height = float(row['YMax']) - float(row['YMin']) + area = width * height + category_id = row['LabelName'] + if category_id in category_mapping: + category_id = category_mapping[category_id] + if area >= min_object_area and category_id in category_no_for_id: + annotations[row['ImageID']].append( + Annotation( + id=i, + image_id=row['ImageID'], + source=row['Source'], + category_id=category_id, + category_no=category_no_for_id[category_id], + confidence=float(row['Confidence']), + bbox=(float(row['XMin']), float(row['YMin']), width, height), + area=area, + is_occluded=bool(int(row['IsOccluded'])), + is_truncated=bool(int(row['IsTruncated'])), + is_group_of=bool(int(row['IsGroupOf'])), + is_depiction=bool(int(row['IsDepiction'])), + is_inside=bool(int(row['IsInside'])) + ) + ) + if 'train' in str(descriptor_path) and i < 14000000: + warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') + return dict(annotations) + + +def load_image_ids(csv_path: Path) -> List[str]: + with open(csv_path) as file: + reader = DictReader(file) + return [row['image_name'] for row in reader] + + +def load_categories(csv_path: Path) -> Dict[str, Category]: + with open(csv_path) as file: + reader = TupleReader(file) + return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} + + +class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): + def __init__(self, use_additional_parameters: bool, **kwargs): + """ + @param data_path: is the path to the following folder structure: + open_images/ + │ oidv6-train-annotations-bbox.csv + ├── class-descriptions-boxable.csv + ├── oidv6-train-annotations-bbox.csv + ├── test + │ ├── 000026e7ee790996.jpg + │ ├── 000062a39995e348.jpg + │ └── ... + ├── test-annotations-bbox.csv + ├── test-images.csv + ├── train + │ ├── 000002b66c9c498e.jpg + │ ├── 000002b97e5471a0.jpg + │ └── ... + ├── train-images-boxable.csv + ├── validation + │ ├── 0001eeaf4aed83f9.jpg + │ ├── 0004886b7d043cfd.jpg + │ └── ... + ├── validation-annotations-bbox.csv + └── validation-images.csv + @param: split: one of 'train', 'validation' or 'test' + @param: desired image size (returns square images) + """ + + super().__init__(**kwargs) + self.use_additional_parameters = use_additional_parameters + + self.categories = load_categories(self.paths['class_descriptions']) + self.filter_categories() + self.setup_category_id_and_number() + + self.image_descriptions = {} + annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, + self.category_number) + self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, + self.max_objects_per_image) + self.image_ids = list(self.annotations.keys()) + self.clean_up_annotations_and_image_descriptions() + + def get_path_structure(self) -> Dict[str, str]: + if self.split not in OPEN_IMAGES_STRUCTURE: + raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') + return OPEN_IMAGES_STRUCTURE[self.split] + + def get_image_path(self, image_id: str) -> Path: + return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + image_path = self.get_image_path(image_id) + return {'file_path': str(image_path), 'file_name': image_path.name} diff --git a/repositories/taming/data/base.py b/repositories/taming/data/base.py new file mode 100644 index 000000000..e21667df4 --- /dev/null +++ b/repositories/taming/data/base.py @@ -0,0 +1,70 @@ +import bisect +import numpy as np +import albumentations +from PIL import Image +from torch.utils.data import Dataset, ConcatDataset + + +class ConcatDatasetWithIndex(ConcatDataset): + """Modified from original pytorch code to return dataset idx""" + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx], dataset_idx + + +class ImagePaths(Dataset): + def __init__(self, paths, size=None, random_crop=False, labels=None): + self.size = size + self.random_crop = random_crop + + self.labels = dict() if labels is None else labels + self.labels["file_path_"] = paths + self._length = len(paths) + + if self.size is not None and self.size > 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) + if not self.random_crop: + self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) + self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return self._length + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def __getitem__(self, i): + example = dict() + example["image"] = self.preprocess_image(self.labels["file_path_"][i]) + for k in self.labels: + example[k] = self.labels[k][i] + return example + + +class NumpyPaths(ImagePaths): + def preprocess_image(self, image_path): + image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 + image = np.transpose(image, (1,2,0)) + image = Image.fromarray(image, mode="RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image diff --git a/repositories/taming/data/coco.py b/repositories/taming/data/coco.py new file mode 100644 index 000000000..2b2f78384 --- /dev/null +++ b/repositories/taming/data/coco.py @@ -0,0 +1,176 @@ +import os +import json +import albumentations +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset + +from taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/coco_examples.txt", + data_root="data/coco_images", + segmentation_root="data/coco_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=183, shift_segmentation=True) + + +class CocoBase(Dataset): + """needed for (image, caption, segmentation) pairs""" + def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, + crop_size=None, force_no_crop=False, given_files=None): + self.split = self.get_split() + self.size = size + if crop_size is None: + self.crop_size = size + else: + self.crop_size = crop_size + + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation + if self.onehot and not self.stuffthing: + raise NotImplemented("One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different.") + + data_json = datajson + with open(data_json) as json_file: + self.json_data = json.load(json_file) + self.img_id_to_captions = dict() + self.img_id_to_filepath = dict() + self.img_id_to_segmentation_filepath = dict() + + assert data_json.split("/")[-1] in ["captions_train2017.json", + "captions_val2017.json"] + if self.stuffthing: + self.segmentation_prefix = ( + "data/cocostuffthings/val2017" if + data_json.endswith("captions_val2017.json") else + "data/cocostuffthings/train2017") + else: + self.segmentation_prefix = ( + "data/coco/annotations/stuff_val2017_pixelmaps" if + data_json.endswith("captions_val2017.json") else + "data/coco/annotations/stuff_train2017_pixelmaps") + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in tqdm(imagedirs, desc="ImgToPath"): + self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) + self.img_id_to_captions[imgdir["id"]] = list() + pngfilename = imgdir["file_name"].replace("jpg", "png") + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( + self.segmentation_prefix, pngfilename) + if given_files is not None: + if pngfilename in given_files: + self.labels["image_ids"].append(imgdir["id"]) + else: + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in tqdm(capdirs, desc="ImgToCaptions"): + # there are in average 5 captions per image + self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + if self.split=="validation": + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = albumentations.Compose( + [self.rescaler, self.cropper], + additional_targets={"segmentation": "image"}) + if force_no_crop: + self.rescaler = albumentations.Resize(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose( + [self.rescaler], + additional_targets={"segmentation": "image"}) + + def __len__(self): + return len(self.labels["image_ids"]) + + def preprocess_image(self, image_path, segmentation_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + + segmentation = Image.open(segmentation_path) + if not self.onehot and not segmentation.mode == "RGB": + segmentation = segmentation.convert("RGB") + segmentation = np.array(segmentation).astype(np.uint8) + if self.onehot: + assert self.stuffthing + # stored in caffe format: unlabeled==255. stuff and thing from + # 0-181. to be compatible with the labels in + # https://github.com/nightrome/cocostuff/blob/master/labels.txt + # we shift stuffthing one to the right and put unlabeled in zero + # as long as segmentation is uint8 shifting to right handles the + # latter too + assert segmentation.dtype == np.uint8 + segmentation = segmentation + 1 + + processed = self.preprocessor(image=image, segmentation=segmentation) + image, segmentation = processed["image"], processed["segmentation"] + image = (image / 127.5 - 1.0).astype(np.float32) + + if self.onehot: + assert segmentation.dtype == np.uint8 + # make it one hot + n_labels = 183 + flatseg = np.ravel(segmentation) + onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) + onehot[np.arange(flatseg.size), flatseg] = True + onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) + segmentation = onehot + else: + segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) + return image, segmentation + + def __getitem__(self, i): + img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] + seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] + image, segmentation = self.preprocess_image(img_path, seg_path) + captions = self.img_id_to_captions[self.labels["image_ids"][i]] + # randomly draw one of all available captions per image + caption = captions[np.random.randint(0, len(captions))] + example = {"image": image, + "caption": [str(caption[0])], + "segmentation": segmentation, + "img_path": img_path, + "seg_path": seg_path, + "filename_": img_path.split(os.sep)[-1] + } + return example + + +class CocoImagesAndCaptionsTrain(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): + super().__init__(size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + + def get_split(self): + return "train" + + +class CocoImagesAndCaptionsValidation(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None): + super().__init__(size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files) + + def get_split(self): + return "validation" diff --git a/repositories/taming/data/conditional_builder/objects_bbox.py b/repositories/taming/data/conditional_builder/objects_bbox.py new file mode 100644 index 000000000..15881e76b --- /dev/null +++ b/repositories/taming/data/conditional_builder/objects_bbox.py @@ -0,0 +1,60 @@ +from itertools import cycle +from typing import List, Tuple, Callable, Optional + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + +from taming.data.helper_types import BoundingBox, Annotation +from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ + pad_list, get_plot_font_size, absolute_bbox + + +class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): + @property + def object_descriptor_length(self) -> int: + return 3 + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_triples = [ + (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) + for ann in annotations + ] + empty_triple = (self.none, self.none, self.none) + object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) + return object_triples + + def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + object_triples = grouper(conditional_list, 3) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) + for object_triple in object_triples if object_triple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + font = ImageFont.truetype( + "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", + size=get_plot_font_size(font_size, figure_size) + ) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): + annotation = self.representation_to_annotation(representation) + class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) + bbox = absolute_bbox(bbox, width, height) + draw.rectangle(bbox, outline=color, width=line_width) + draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. diff --git a/repositories/taming/data/conditional_builder/objects_center_points.py b/repositories/taming/data/conditional_builder/objects_center_points.py new file mode 100644 index 000000000..9a480329c --- /dev/null +++ b/repositories/taming/data/conditional_builder/objects_center_points.py @@ -0,0 +1,168 @@ +import math +import random +import warnings +from itertools import cycle +from typing import List, Optional, Tuple, Callable + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ + additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ + absolute_bbox, rescale_annotations +from taming.data.helper_types import BoundingBox, Annotation +from taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + + +class ObjectsCenterPointsConditionalBuilder: + def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, + use_group_parameter: bool, use_additional_parameters: bool): + self.no_object_classes = no_object_classes + self.no_max_objects = no_max_objects + self.no_tokens = no_tokens + self.encode_crop = encode_crop + self.no_sections = int(math.sqrt(self.no_tokens)) + self.use_group_parameter = use_group_parameter + self.use_additional_parameters = use_additional_parameters + + @property + def none(self) -> int: + return self.no_tokens - 1 + + @property + def object_descriptor_length(self) -> int: + return 2 + + @property + def embedding_dim(self) -> int: + extra_length = 2 if self.encode_crop else 0 + return self.no_max_objects * self.object_descriptor_length + extra_length + + def tokenize_coordinates(self, x: float, y: float) -> int: + """ + Express 2d coordinates with one number. + Example: assume self.no_tokens = 16, then no_sections = 4: + 0 0 0 0 + 0 0 # 0 + 0 0 0 0 + 0 0 0 x + Then the # position corresponds to token 6, the x position to token 15. + @param x: float in [0, 1] + @param y: float in [0, 1] + @return: discrete tokenized coordinate + """ + x_discrete = int(round(x * (self.no_sections - 1))) + y_discrete = int(round(y * (self.no_sections - 1))) + return y_discrete * self.no_sections + x_discrete + + def coordinates_from_token(self, token: int) -> (float, float): + x = token % self.no_sections + y = token // self.no_sections + return x / (self.no_sections - 1), y / (self.no_sections - 1) + + def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: + x0, y0 = self.coordinates_from_token(token1) + x1, y1 = self.coordinates_from_token(token2) + return x0, y0, x1 - x0, y1 - y0 + + def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: + return self.tokenize_coordinates(bbox[0], bbox[1]), \ + self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) + + def inverse_build(self, conditional: LongTensor) \ + -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + table_of_content = grouper(conditional_list, self.object_descriptor_length) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_tuple[0], self.coordinates_from_token(object_tuple[1])) + for object_tuple in table_of_content if object_tuple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + circle_size = get_circle_size(figure_size) + font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', + size=get_plot_font_size(font_size, figure_size)) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): + x_abs, y_abs = x * width, y * height + ann = self.representation_to_annotation(representation) + label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) + ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] + draw.ellipse(ellipse_bbox, fill=color, width=0) + draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. + + def object_representation(self, annotation: Annotation) -> int: + modifier = 0 + if self.use_group_parameter: + modifier |= 1 * (annotation.is_group_of is True) + if self.use_additional_parameters: + modifier |= 2 * (annotation.is_occluded is True) + modifier |= 4 * (annotation.is_depiction is True) + modifier |= 8 * (annotation.is_inside is True) + return annotation.category_no + self.no_object_classes * modifier + + def representation_to_annotation(self, representation: int) -> Annotation: + category_no = representation % self.no_object_classes + modifier = representation // self.no_object_classes + # noinspection PyTypeChecker + return Annotation( + area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, + category_no=category_no, + is_group_of=bool((modifier & 1) * self.use_group_parameter), + is_occluded=bool((modifier & 2) * self.use_additional_parameters), + is_depiction=bool((modifier & 4) * self.use_additional_parameters), + is_inside=bool((modifier & 8) * self.use_additional_parameters) + ) + + def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: + return list(self.token_pair_from_bbox(crop_coordinates)) + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_tuples = [ + (self.object_representation(a), + self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) + for a in annotations + ] + empty_tuple = (self.none, self.none) + object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) + return object_tuples + + def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ + -> LongTensor: + if len(annotations) == 0: + warnings.warn('Did not receive any annotations.') + if len(annotations) > self.no_max_objects: + warnings.warn('Received more annotations than allowed.') + annotations = annotations[:self.no_max_objects] + + if not crop_coordinates: + crop_coordinates = FULL_CROP + + random.shuffle(annotations) + annotations = filter_annotations(annotations, crop_coordinates) + if self.encode_crop: + annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) + if horizontal_flip: + crop_coordinates = horizontally_flip_bbox(crop_coordinates) + extra = self._crop_encoder(crop_coordinates) + else: + annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) + extra = [] + + object_tuples = self._make_object_descriptors(annotations) + flattened = [token for tuple_ in object_tuples for token in tuple_] + extra + assert len(flattened) == self.embedding_dim + assert all(0 <= value < self.no_tokens for value in flattened) + return LongTensor(flattened) diff --git a/repositories/taming/data/conditional_builder/utils.py b/repositories/taming/data/conditional_builder/utils.py new file mode 100644 index 000000000..d0ee175f2 --- /dev/null +++ b/repositories/taming/data/conditional_builder/utils.py @@ -0,0 +1,105 @@ +import importlib +from typing import List, Any, Tuple, Optional + +from taming.data.helper_types import BoundingBox, Annotation + +# source: seaborn, color palette tab10 +COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), + (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] +BLACK = (0, 0, 0) +GRAY_75 = (63, 63, 63) +GRAY_50 = (127, 127, 127) +GRAY_25 = (191, 191, 191) +WHITE = (255, 255, 255) +FULL_CROP = (0., 0., 1., 1.) + + +def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: + """ + Give intersection area of two rectangles. + @param rectangle1: (x0, y0, w, h) of first rectangle + @param rectangle2: (x0, y0, w, h) of second rectangle + """ + rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] + rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] + x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) + y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) + return x_overlap * y_overlap + + +def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: + return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] + + +def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: + bbox = relative_bbox + bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height + return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + + +def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: + return list_ + [pad_element for _ in range(pad_to_length - len(list_))] + + +def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ + List[Annotation]: + def clamp(x: float): + return max(min(x, 1.), 0.) + + def rescale_bbox(bbox: BoundingBox) -> BoundingBox: + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + if flip: + x0 = 1 - (x0 + w) + return x0, y0, w, h + + return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] + + +def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: + return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] + + +def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: + sl = slice(1) if short else slice(None) + string = '' + if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): + return string + if annotation.is_group_of: + string += 'group'[sl] + ',' + if annotation.is_occluded: + string += 'occluded'[sl] + ',' + if annotation.is_depiction: + string += 'depiction'[sl] + ',' + if annotation.is_inside: + string += 'inside'[sl] + return '(' + string.strip(",") + ')' + + +def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: + if font_size is None: + font_size = 10 + if max(figure_size) >= 256: + font_size = 12 + if max(figure_size) >= 512: + font_size = 15 + return font_size + + +def get_circle_size(figure_size: Tuple[int, int]) -> int: + circle_size = 2 + if max(figure_size) >= 256: + circle_size = 3 + if max(figure_size) >= 512: + circle_size = 4 + return circle_size + + +def load_object_from_string(object_string: str) -> Any: + """ + Source: https://stackoverflow.com/a/10773699 + """ + module_name, class_name = object_string.rsplit(".", 1) + return getattr(importlib.import_module(module_name), class_name) diff --git a/repositories/taming/data/custom.py b/repositories/taming/data/custom.py new file mode 100644 index 000000000..33f302a4b --- /dev/null +++ b/repositories/taming/data/custom.py @@ -0,0 +1,38 @@ +import os +import numpy as np +import albumentations +from torch.utils.data import Dataset + +from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex + + +class CustomBase(Dataset): + def __init__(self, *args, **kwargs): + super().__init__() + self.data = None + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + example = self.data[i] + return example + + + +class CustomTrain(CustomBase): + def __init__(self, size, training_images_list_file): + super().__init__() + with open(training_images_list_file, "r") as f: + paths = f.read().splitlines() + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + + +class CustomTest(CustomBase): + def __init__(self, size, test_images_list_file): + super().__init__() + with open(test_images_list_file, "r") as f: + paths = f.read().splitlines() + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + + diff --git a/repositories/taming/data/faceshq.py b/repositories/taming/data/faceshq.py new file mode 100644 index 000000000..6912d04b6 --- /dev/null +++ b/repositories/taming/data/faceshq.py @@ -0,0 +1,134 @@ +import os +import numpy as np +import albumentations +from torch.utils.data import Dataset + +from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex + + +class FacesBase(Dataset): + def __init__(self, *args, **kwargs): + super().__init__() + self.data = None + self.keys = None + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + example = self.data[i] + ex = {} + if self.keys is not None: + for k in self.keys: + ex[k] = example[k] + else: + ex = example + return ex + + +class CelebAHQTrain(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/celebahq" + with open("data/celebahqtrain.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = NumpyPaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class CelebAHQValidation(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/celebahq" + with open("data/celebahqvalidation.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = NumpyPaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FFHQTrain(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/ffhq" + with open("data/ffhqtrain.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FFHQValidation(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/ffhq" + with open("data/ffhqvalidation.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FacesHQTrain(Dataset): + # CelebAHQ [0] + FFHQ [1] + def __init__(self, size, keys=None, crop_size=None, coord=False): + d1 = CelebAHQTrain(size=size, keys=keys) + d2 = FFHQTrain(size=size, keys=keys) + self.data = ConcatDatasetWithIndex([d1, d2]) + self.coord = coord + if crop_size is not None: + self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + if self.coord: + self.cropper = albumentations.Compose([self.cropper], + additional_targets={"coord": "image"}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + ex, y = self.data[i] + if hasattr(self, "cropper"): + if not self.coord: + out = self.cropper(image=ex["image"]) + ex["image"] = out["image"] + else: + h,w,_ = ex["image"].shape + coord = np.arange(h*w).reshape(h,w,1)/(h*w) + out = self.cropper(image=ex["image"], coord=coord) + ex["image"] = out["image"] + ex["coord"] = out["coord"] + ex["class"] = y + return ex + + +class FacesHQValidation(Dataset): + # CelebAHQ [0] + FFHQ [1] + def __init__(self, size, keys=None, crop_size=None, coord=False): + d1 = CelebAHQValidation(size=size, keys=keys) + d2 = FFHQValidation(size=size, keys=keys) + self.data = ConcatDatasetWithIndex([d1, d2]) + self.coord = coord + if crop_size is not None: + self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + if self.coord: + self.cropper = albumentations.Compose([self.cropper], + additional_targets={"coord": "image"}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + ex, y = self.data[i] + if hasattr(self, "cropper"): + if not self.coord: + out = self.cropper(image=ex["image"]) + ex["image"] = out["image"] + else: + h,w,_ = ex["image"].shape + coord = np.arange(h*w).reshape(h,w,1)/(h*w) + out = self.cropper(image=ex["image"], coord=coord) + ex["image"] = out["image"] + ex["coord"] = out["coord"] + ex["class"] = y + return ex diff --git a/repositories/taming/data/helper_types.py b/repositories/taming/data/helper_types.py new file mode 100644 index 000000000..fb51e301d --- /dev/null +++ b/repositories/taming/data/helper_types.py @@ -0,0 +1,49 @@ +from typing import Dict, Tuple, Optional, NamedTuple, Union +from PIL.Image import Image as pil_image +from torch import Tensor + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +Image = Union[Tensor, pil_image] +BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h +CropMethodType = Literal['none', 'random', 'center', 'random-2d'] +SplitType = Literal['train', 'validation', 'test'] + + +class ImageDescription(NamedTuple): + id: int + file_name: str + original_size: Tuple[int, int] # w, h + url: Optional[str] = None + license: Optional[int] = None + coco_url: Optional[str] = None + date_captured: Optional[str] = None + flickr_url: Optional[str] = None + flickr_id: Optional[str] = None + coco_id: Optional[str] = None + + +class Category(NamedTuple): + id: str + super_category: Optional[str] + name: str + + +class Annotation(NamedTuple): + area: float + image_id: str + bbox: BoundingBox + category_no: int + category_id: str + id: Optional[int] = None + source: Optional[str] = None + confidence: Optional[float] = None + is_group_of: Optional[bool] = None + is_truncated: Optional[bool] = None + is_occluded: Optional[bool] = None + is_depiction: Optional[bool] = None + is_inside: Optional[bool] = None + segmentation: Optional[Dict] = None diff --git a/repositories/taming/data/image_transforms.py b/repositories/taming/data/image_transforms.py new file mode 100644 index 000000000..657ac3321 --- /dev/null +++ b/repositories/taming/data/image_transforms.py @@ -0,0 +1,132 @@ +import random +import warnings +from typing import Union + +import torch +from torch import Tensor +from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor +from torchvision.transforms.functional import _get_image_size as get_image_size + +from taming.data.helper_types import BoundingBox, Image + +pil_to_tensor = PILToTensor() + + +def convert_pil_to_tensor(image: Image) -> Tensor: + with warnings.catch_warnings(): + # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 + warnings.simplefilter("ignore") + return pil_to_tensor(image) + + +class RandomCrop1dReturnCoordinates(RandomCrop): + def forward(self, img: Image) -> (BoundingBox, Image): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + width, height = get_image_size(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h + return bbox, F.crop(img, i, j, h, w) + + +class Random2dCropReturnCoordinates(torch.nn.Module): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + + def __init__(self, min_size: int): + super().__init__() + self.min_size = min_size + + def forward(self, img: Image) -> (BoundingBox, Image): + width, height = get_image_size(img) + max_size = min(width, height) + if max_size <= self.min_size: + size = max_size + else: + size = random.randint(self.min_size, max_size) + top = random.randint(0, height - size) + left = random.randint(0, width - size) + bbox = left / width, top / height, size / width, size / height + return bbox, F.crop(img, top, left, size, size) + + +class CenterCropReturnCoordinates(CenterCrop): + @staticmethod + def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: + if width > height: + w = height / width + h = 1.0 + x0 = 0.5 - w / 2 + y0 = 0. + else: + w = 1.0 + h = width / height + x0 = 0. + y0 = 0.5 - h / 2 + return x0, y0, w, h + + def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + width, height = get_image_size(img) + return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) + + +class RandomHorizontalFlipReturn(RandomHorizontalFlip): + def forward(self, img: Image) -> (bool, Image): + """ + Additionally to flipping, returns a boolean whether it was flipped or not. + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + flipped: whether the image was flipped or not + PIL Image or Tensor: Randomly flipped image. + + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + if torch.rand(1) < self.p: + return True, F.hflip(img) + return False, img diff --git a/repositories/taming/data/imagenet.py b/repositories/taming/data/imagenet.py new file mode 100644 index 000000000..9a02ec44b --- /dev/null +++ b/repositories/taming/data/imagenet.py @@ -0,0 +1,558 @@ +import os, tarfile, glob, shutil +import yaml +import numpy as np +from tqdm import tqdm +from PIL import Image +import albumentations +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from taming.data.base import ImagePaths +from taming.util import download, retrieve +import taming.data.utils as bdu + + +def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): + synsets = [] + with open(path_to_yaml) as f: + di2s = yaml.load(f) + for idx in indices: + synsets.append(str(di2s[idx])) + print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) + return synsets + + +def str_to_indices(string): + """Expects a string in the format '32-123, 256, 280-321'""" + assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) + subs = string.split(",") + indices = [] + for sub in subs: + subsubs = sub.split("-") + assert len(subsubs) > 0 + if len(subsubs) == 1: + indices.append(int(subsubs[0])) + else: + rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] + indices.extend(rang) + return sorted(indices) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + self.class_labels = [class_dict[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + self.data = ImagePaths(self.abspaths, + labels=labels, + size=retrieve(self.config, "size", default=0), + random_crop=self.random_crop) + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +def get_preprocessor(size=None, random_crop=False, additional_targets=None, + crop_size=None): + if size is not None and size > 0: + transforms = list() + rescaler = albumentations.SmallestMaxSize(max_size = size) + transforms.append(rescaler) + if not random_crop: + cropper = albumentations.CenterCrop(height=size,width=size) + transforms.append(cropper) + else: + cropper = albumentations.RandomCrop(height=size,width=size) + transforms.append(cropper) + flipper = albumentations.HorizontalFlip() + transforms.append(flipper) + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + elif crop_size is not None and crop_size > 0: + if not random_crop: + cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + else: + cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + transforms = [cropper] + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + preprocessor = lambda **kwargs: kwargs + return preprocessor + + +def rgba_to_depth(x): + assert x.dtype == np.uint8 + assert len(x.shape) == 3 and x.shape[2] == 4 + y = x.copy() + y.dtype = np.float32 + y = y.reshape(x.shape[:2]) + return np.ascontiguousarray(y) + + +class BaseWithDepth(Dataset): + DEFAULT_DEPTH_ROOT="data/imagenet_depth" + + def __init__(self, config=None, size=None, random_crop=False, + crop_size=None, root=None): + self.config = config + self.base_dset = self.get_base_dset() + self.preprocessor = get_preprocessor( + size=size, + crop_size=crop_size, + random_crop=random_crop, + additional_targets={"depth": "image"}) + self.crop_size = crop_size + if self.crop_size is not None: + self.rescaler = albumentations.Compose( + [albumentations.SmallestMaxSize(max_size = self.crop_size)], + additional_targets={"depth": "image"}) + if root is not None: + self.DEFAULT_DEPTH_ROOT = root + + def __len__(self): + return len(self.base_dset) + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = self.base_dset[i] + e["depth"] = self.preprocess_depth(self.get_depth_path(e)) + # up if necessary + h,w,c = e["image"].shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + out = self.rescaler(image=e["image"], depth=e["depth"]) + e["image"] = out["image"] + e["depth"] = out["depth"] + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +class ImageNetTrainWithDepth(BaseWithDepth): + # default to random_crop=True + def __init__(self, random_crop=True, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(random_crop=random_crop, **kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetTrain() + else: + return ImageNetTrain({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) + return fid + + +class ImageNetValidationWithDepth(BaseWithDepth): + def __init__(self, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(**kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetValidation() + else: + return ImageNetValidation({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) + return fid + + +class RINTrainWithDepth(ImageNetTrainWithDepth): + def __init__(self, config=None, size=None, random_crop=True, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class RINValidationWithDepth(ImageNetValidationWithDepth): + def __init__(self, config=None, size=None, random_crop=False, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class DRINExamples(Dataset): + def __init__(self): + self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) + with open("data/drin_examples.txt", "r") as f: + relpaths = f.read().splitlines() + self.image_paths = [os.path.join("data/drin_images", + relpath) for relpath in relpaths] + self.depth_paths = [os.path.join("data/drin_depth", + relpath.replace(".JPEG", ".png")) for relpath in relpaths] + + def __len__(self): + return len(self.image_paths) + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = dict() + e["image"] = self.preprocess_image(self.image_paths[i]) + e["depth"] = self.preprocess_depth(self.depth_paths[i]) + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +def imscale(x, factor, keepshapes=False, keepmode="bicubic"): + if factor is None or factor==1: + return x + + dtype = x.dtype + assert dtype in [np.float32, np.float64] + assert x.min() >= -1 + assert x.max() <= 1 + + keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC}[keepmode] + + lr = (x+1.0)*127.5 + lr = lr.clip(0,255).astype(np.uint8) + lr = Image.fromarray(lr) + + h, w, _ = x.shape + nh = h//factor + nw = w//factor + assert nh > 0 and nw > 0, (nh, nw) + + lr = lr.resize((nw,nh), Image.BICUBIC) + if keepshapes: + lr = lr.resize((w,h), keepmode) + lr = np.array(lr)/127.5-1.0 + lr = lr.astype(dtype) + + return lr + + +class ImageNetScale(Dataset): + def __init__(self, size=None, crop_size=None, random_crop=False, + up_factor=None, hr_factor=None, keep_mode="bicubic"): + self.base = self.get_base() + + self.size = size + self.crop_size = crop_size if crop_size is not None else self.size + self.random_crop = random_crop + self.up_factor = up_factor + self.hr_factor = hr_factor + self.keep_mode = keep_mode + + transforms = list() + + if self.size is not None and self.size > 0: + rescaler = albumentations.SmallestMaxSize(max_size = self.size) + self.rescaler = rescaler + transforms.append(rescaler) + + if self.crop_size is not None and self.crop_size > 0: + if len(transforms) == 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) + + if not self.random_crop: + cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) + else: + cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) + transforms.append(cropper) + + if len(transforms) > 0: + if self.up_factor is not None: + additional_targets = {"lr": "image"} + else: + additional_targets = None + self.preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + # adjust resolution + image = imscale(image, self.hr_factor, keepshapes=False) + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + if self.up_factor is None: + image = self.preprocessor(image=image)["image"] + example["image"] = image + else: + lr = imscale(image, self.up_factor, keepshapes=True, + keepmode=self.keep_mode) + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + +class ImageNetScaleTrain(ImageNetScale): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetScaleValidation(ImageNetScale): + def get_base(self): + return ImageNetValidation() + + +from skimage.feature import canny +from skimage.color import rgb2gray + + +class ImageNetEdges(ImageNetScale): + def __init__(self, up_factor=1, **kwargs): + super().__init__(up_factor=1, **kwargs) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + + lr = canny(rgb2gray(image), sigma=2) + lr = lr.astype(np.float32) + lr = lr[:,:,None][:,:,[0,0,0]] + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + + +class ImageNetEdgesTrain(ImageNetEdges): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetEdgesValidation(ImageNetEdges): + def get_base(self): + return ImageNetValidation() diff --git a/repositories/taming/data/open_images_helper.py b/repositories/taming/data/open_images_helper.py new file mode 100644 index 000000000..8feb7c6e7 --- /dev/null +++ b/repositories/taming/data/open_images_helper.py @@ -0,0 +1,379 @@ +open_images_unify_categories_for_coco = { + '/m/03bt1vf': '/m/01g317', + '/m/04yx4': '/m/01g317', + '/m/05r655': '/m/01g317', + '/m/01bl7v': '/m/01g317', + '/m/0cnyhnx': '/m/01xq0k1', + '/m/01226z': '/m/018xm', + '/m/05ctyq': '/m/018xm', + '/m/058qzx': '/m/04ctx', + '/m/06pcq': '/m/0l515', + '/m/03m3pdh': '/m/02crq1', + '/m/046dlr': '/m/01x3z', + '/m/0h8mzrc': '/m/01x3z', +} + + +top_300_classes_plus_coco_compatibility = [ + ('Man', 1060962), + ('Clothing', 986610), + ('Tree', 748162), + ('Woman', 611896), + ('Person', 610294), + ('Human face', 442948), + ('Girl', 175399), + ('Building', 162147), + ('Car', 159135), + ('Plant', 155704), + ('Human body', 137073), + ('Flower', 133128), + ('Window', 127485), + ('Human arm', 118380), + ('House', 114365), + ('Wheel', 111684), + ('Suit', 99054), + ('Human hair', 98089), + ('Human head', 92763), + ('Chair', 88624), + ('Boy', 79849), + ('Table', 73699), + ('Jeans', 57200), + ('Tire', 55725), + ('Skyscraper', 53321), + ('Food', 52400), + ('Footwear', 50335), + ('Dress', 50236), + ('Human leg', 47124), + ('Toy', 46636), + ('Tower', 45605), + ('Boat', 43486), + ('Land vehicle', 40541), + ('Bicycle wheel', 34646), + ('Palm tree', 33729), + ('Fashion accessory', 32914), + ('Glasses', 31940), + ('Bicycle', 31409), + ('Furniture', 30656), + ('Sculpture', 29643), + ('Bottle', 27558), + ('Dog', 26980), + ('Snack', 26796), + ('Human hand', 26664), + ('Bird', 25791), + ('Book', 25415), + ('Guitar', 24386), + ('Jacket', 23998), + ('Poster', 22192), + ('Dessert', 21284), + ('Baked goods', 20657), + ('Drink', 19754), + ('Flag', 18588), + ('Houseplant', 18205), + ('Tableware', 17613), + ('Airplane', 17218), + ('Door', 17195), + ('Sports uniform', 17068), + ('Shelf', 16865), + ('Drum', 16612), + ('Vehicle', 16542), + ('Microphone', 15269), + ('Street light', 14957), + ('Cat', 14879), + ('Fruit', 13684), + ('Fast food', 13536), + ('Animal', 12932), + ('Vegetable', 12534), + ('Train', 12358), + ('Horse', 11948), + ('Flowerpot', 11728), + ('Motorcycle', 11621), + ('Fish', 11517), + ('Desk', 11405), + ('Helmet', 10996), + ('Truck', 10915), + ('Bus', 10695), + ('Hat', 10532), + ('Auto part', 10488), + ('Musical instrument', 10303), + ('Sunglasses', 10207), + ('Picture frame', 10096), + ('Sports equipment', 10015), + ('Shorts', 9999), + ('Wine glass', 9632), + ('Duck', 9242), + ('Wine', 9032), + ('Rose', 8781), + ('Tie', 8693), + ('Butterfly', 8436), + ('Beer', 7978), + ('Cabinetry', 7956), + ('Laptop', 7907), + ('Insect', 7497), + ('Goggles', 7363), + ('Shirt', 7098), + ('Dairy Product', 7021), + ('Marine invertebrates', 7014), + ('Cattle', 7006), + ('Trousers', 6903), + ('Van', 6843), + ('Billboard', 6777), + ('Balloon', 6367), + ('Human nose', 6103), + ('Tent', 6073), + ('Camera', 6014), + ('Doll', 6002), + ('Coat', 5951), + ('Mobile phone', 5758), + ('Swimwear', 5729), + ('Strawberry', 5691), + ('Stairs', 5643), + ('Goose', 5599), + ('Umbrella', 5536), + ('Cake', 5508), + ('Sun hat', 5475), + ('Bench', 5310), + ('Bookcase', 5163), + ('Bee', 5140), + ('Computer monitor', 5078), + ('Hiking equipment', 4983), + ('Office building', 4981), + ('Coffee cup', 4748), + ('Curtain', 4685), + ('Plate', 4651), + ('Box', 4621), + ('Tomato', 4595), + ('Coffee table', 4529), + ('Office supplies', 4473), + ('Maple', 4416), + ('Muffin', 4365), + ('Cocktail', 4234), + ('Castle', 4197), + ('Couch', 4134), + ('Pumpkin', 3983), + ('Computer keyboard', 3960), + ('Human mouth', 3926), + ('Christmas tree', 3893), + ('Mushroom', 3883), + ('Swimming pool', 3809), + ('Pastry', 3799), + ('Lavender (Plant)', 3769), + ('Football helmet', 3732), + ('Bread', 3648), + ('Traffic sign', 3628), + ('Common sunflower', 3597), + ('Television', 3550), + ('Bed', 3525), + ('Cookie', 3485), + ('Fountain', 3484), + ('Paddle', 3447), + ('Bicycle helmet', 3429), + ('Porch', 3420), + ('Deer', 3387), + ('Fedora', 3339), + ('Canoe', 3338), + ('Carnivore', 3266), + ('Bowl', 3202), + ('Human eye', 3166), + ('Ball', 3118), + ('Pillow', 3077), + ('Salad', 3061), + ('Beetle', 3060), + ('Orange', 3050), + ('Drawer', 2958), + ('Platter', 2937), + ('Elephant', 2921), + ('Seafood', 2921), + ('Monkey', 2915), + ('Countertop', 2879), + ('Watercraft', 2831), + ('Helicopter', 2805), + ('Kitchen appliance', 2797), + ('Personal flotation device', 2781), + ('Swan', 2739), + ('Lamp', 2711), + ('Boot', 2695), + ('Bronze sculpture', 2693), + ('Chicken', 2677), + ('Taxi', 2643), + ('Juice', 2615), + ('Cowboy hat', 2604), + ('Apple', 2600), + ('Tin can', 2590), + ('Necklace', 2564), + ('Ice cream', 2560), + ('Human beard', 2539), + ('Coin', 2536), + ('Candle', 2515), + ('Cart', 2512), + ('High heels', 2441), + ('Weapon', 2433), + ('Handbag', 2406), + ('Penguin', 2396), + ('Rifle', 2352), + ('Violin', 2336), + ('Skull', 2304), + ('Lantern', 2285), + ('Scarf', 2269), + ('Saucer', 2225), + ('Sheep', 2215), + ('Vase', 2189), + ('Lily', 2180), + ('Mug', 2154), + ('Parrot', 2140), + ('Human ear', 2137), + ('Sandal', 2115), + ('Lizard', 2100), + ('Kitchen & dining room table', 2063), + ('Spider', 1977), + ('Coffee', 1974), + ('Goat', 1926), + ('Squirrel', 1922), + ('Cello', 1913), + ('Sushi', 1881), + ('Tortoise', 1876), + ('Pizza', 1870), + ('Studio couch', 1864), + ('Barrel', 1862), + ('Cosmetics', 1841), + ('Moths and butterflies', 1841), + ('Convenience store', 1817), + ('Watch', 1792), + ('Home appliance', 1786), + ('Harbor seal', 1780), + ('Luggage and bags', 1756), + ('Vehicle registration plate', 1754), + ('Shrimp', 1751), + ('Jellyfish', 1730), + ('French fries', 1723), + ('Egg (Food)', 1698), + ('Football', 1697), + ('Musical keyboard', 1683), + ('Falcon', 1674), + ('Candy', 1660), + ('Medical equipment', 1654), + ('Eagle', 1651), + ('Dinosaur', 1634), + ('Surfboard', 1630), + ('Tank', 1628), + ('Grape', 1624), + ('Lion', 1624), + ('Owl', 1622), + ('Ski', 1613), + ('Waste container', 1606), + ('Frog', 1591), + ('Sparrow', 1585), + ('Rabbit', 1581), + ('Pen', 1546), + ('Sea lion', 1537), + ('Spoon', 1521), + ('Sink', 1512), + ('Teddy bear', 1507), + ('Bull', 1495), + ('Sofa bed', 1490), + ('Dragonfly', 1479), + ('Brassiere', 1478), + ('Chest of drawers', 1472), + ('Aircraft', 1466), + ('Human foot', 1463), + ('Pig', 1455), + ('Fork', 1454), + ('Antelope', 1438), + ('Tripod', 1427), + ('Tool', 1424), + ('Cheese', 1422), + ('Lemon', 1397), + ('Hamburger', 1393), + ('Dolphin', 1390), + ('Mirror', 1390), + ('Marine mammal', 1387), + ('Giraffe', 1385), + ('Snake', 1368), + ('Gondola', 1364), + ('Wheelchair', 1360), + ('Piano', 1358), + ('Cupboard', 1348), + ('Banana', 1345), + ('Trumpet', 1335), + ('Lighthouse', 1333), + ('Invertebrate', 1317), + ('Carrot', 1268), + ('Sock', 1260), + ('Tiger', 1241), + ('Camel', 1224), + ('Parachute', 1224), + ('Bathroom accessory', 1223), + ('Earrings', 1221), + ('Headphones', 1218), + ('Skirt', 1198), + ('Skateboard', 1190), + ('Sandwich', 1148), + ('Saxophone', 1141), + ('Goldfish', 1136), + ('Stool', 1104), + ('Traffic light', 1097), + ('Shellfish', 1081), + ('Backpack', 1079), + ('Sea turtle', 1078), + ('Cucumber', 1075), + ('Tea', 1051), + ('Toilet', 1047), + ('Roller skates', 1040), + ('Mule', 1039), + ('Bust', 1031), + ('Broccoli', 1030), + ('Crab', 1020), + ('Oyster', 1019), + ('Cannon', 1012), + ('Zebra', 1012), + ('French horn', 1008), + ('Grapefruit', 998), + ('Whiteboard', 997), + ('Zucchini', 997), + ('Crocodile', 992), + + ('Clock', 960), + ('Wall clock', 958), + + ('Doughnut', 869), + ('Snail', 868), + + ('Baseball glove', 859), + + ('Panda', 830), + ('Tennis racket', 830), + + ('Pear', 652), + + ('Bagel', 617), + ('Oven', 616), + ('Ladybug', 615), + ('Shark', 615), + ('Polar bear', 614), + ('Ostrich', 609), + + ('Hot dog', 473), + ('Microwave oven', 467), + ('Fire hydrant', 20), + ('Stop sign', 20), + ('Parking meter', 20), + ('Bear', 20), + ('Flying disc', 20), + ('Snowboard', 20), + ('Tennis ball', 20), + ('Kite', 20), + ('Baseball bat', 20), + ('Kitchen knife', 20), + ('Knife', 20), + ('Submarine sandwich', 20), + ('Computer mouse', 20), + ('Remote control', 20), + ('Toaster', 20), + ('Sink', 20), + ('Refrigerator', 20), + ('Alarm clock', 20), + ('Wall clock', 20), + ('Scissors', 20), + ('Hair dryer', 20), + ('Toothbrush', 20), + ('Suitcase', 20) +] diff --git a/repositories/taming/data/sflckr.py b/repositories/taming/data/sflckr.py new file mode 100644 index 000000000..91101be59 --- /dev/null +++ b/repositories/taming/data/sflckr.py @@ -0,0 +1,91 @@ +import os +import numpy as np +import cv2 +import albumentations +from PIL import Image +from torch.utils.data import Dataset + + +class SegmentationBase(Dataset): + def __init__(self, + data_csv, data_root, segmentation_root, + size=None, random_crop=False, interpolation="bicubic", + n_labels=182, shift_segmentation=False, + ): + self.n_labels = n_labels + self.shift_segmentation = shift_segmentation + self.data_csv = data_csv + self.data_root = data_root + self.segmentation_root = segmentation_root + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) + for l in self.image_paths] + } + + size = None if size is not None and size<=0 else size + self.size = size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + assert segmentation.mode == "L", segmentation.mode + segmentation = np.array(segmentation).astype(np.uint8) + if self.shift_segmentation: + # used to support segmentations containing unlabeled==255 label + segmentation = segmentation+1 + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, + mask=segmentation + ) + else: + processed = {"image": image, + "mask": segmentation + } + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class Examples(SegmentationBase): + def __init__(self, size=None, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/sflckr_examples.txt", + data_root="data/sflckr_images", + segmentation_root="data/sflckr_segmentations", + size=size, random_crop=random_crop, interpolation=interpolation) diff --git a/repositories/taming/data/utils.py b/repositories/taming/data/utils.py new file mode 100644 index 000000000..2b3c3d53c --- /dev/null +++ b/repositories/taming/data/utils.py @@ -0,0 +1,169 @@ +import collections +import os +import tarfile +import urllib +import zipfile +from pathlib import Path + +import numpy as np +import torch +from taming.data.helper_types import Annotation +from torch._six import string_classes +from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format +from tqdm import tqdm + + +def unpack(path): + if path.endswith("tar.gz"): + with tarfile.open(path, "r:gz") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("tar"): + with tarfile.open(path, "r:") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("zip"): + with zipfile.ZipFile(path, "r") as f: + f.extractall(path=os.path.split(path)[0]) + else: + raise NotImplementedError( + "Unknown file extension: {}".format(os.path.splitext(path)[1]) + ) + + +def reporthook(bar): + """tqdm progress bar for downloads.""" + + def hook(b=1, bsize=1, tsize=None): + if tsize is not None: + bar.total = tsize + bar.update(b * bsize - bar.n) + + return hook + + +def get_root(name): + base = "data/" + root = os.path.join(base, name) + os.makedirs(root, exist_ok=True) + return root + + +def is_prepared(root): + return Path(root).joinpath(".ready").exists() + + +def mark_prepared(root): + Path(root).joinpath(".ready").touch() + + +def prompt_download(file_, source, target_dir, content_dir=None): + targetpath = os.path.join(target_dir, file_) + while not os.path.exists(targetpath): + if content_dir is not None and os.path.exists( + os.path.join(target_dir, content_dir) + ): + break + print( + "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) + ) + if content_dir is not None: + print( + "Or place its content into '{}'.".format( + os.path.join(target_dir, content_dir) + ) + ) + input("Press Enter when done...") + return targetpath + + +def download_url(file_, url, target_dir): + targetpath = os.path.join(target_dir, file_) + os.makedirs(target_dir, exist_ok=True) + with tqdm( + unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ + ) as bar: + urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) + return targetpath + + +def download_urls(urls, target_dir): + paths = dict() + for fname, url in urls.items(): + outpath = download_url(fname, url, target_dir) + paths[fname] = outpath + return paths + + +def quadratic_crop(x, bbox, alpha=1.0): + """bbox is xmin, ymin, xmax, ymax""" + im_h, im_w = x.shape[:2] + bbox = np.array(bbox, dtype=np.float32) + bbox = np.clip(bbox, 0, max(im_h, im_w)) + center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + l = int(alpha * max(w, h)) + l = max(l, 2) + + required_padding = -1 * min( + center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) + ) + required_padding = int(np.ceil(required_padding)) + if required_padding > 0: + padding = [ + [required_padding, required_padding], + [required_padding, required_padding], + ] + padding += [[0, 0]] * (len(x.shape) - 2) + x = np.pad(x, padding, "reflect") + center = center[0] + required_padding, center[1] + required_padding + xmin = int(center[0] - l / 2) + ymin = int(center[1] - l / 2) + return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) + + +def custom_collate(batch): + r"""source: pytorch 1.9.0, only one modification to original code """ + + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return custom_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: custom_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(custom_collate(samples) for samples in zip(*batch))) + if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added + return batch # added + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [custom_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/repositories/taming/lr_scheduler.py b/repositories/taming/lr_scheduler.py new file mode 100644 index 000000000..e598ed120 --- /dev/null +++ b/repositories/taming/lr_scheduler.py @@ -0,0 +1,34 @@ +import numpy as np + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0. + self.verbosity_interval = verbosity_interval + + def schedule(self, n): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + self.last_lr = lr + return lr + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( + 1 + np.cos(t * np.pi)) + self.last_lr = lr + return lr + + def __call__(self, n): + return self.schedule(n) + diff --git a/repositories/taming/models/cond_transformer.py b/repositories/taming/models/cond_transformer.py new file mode 100644 index 000000000..e4c63730f --- /dev/null +++ b/repositories/taming/models/cond_transformer.py @@ -0,0 +1,352 @@ +import os, math +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from main import instantiate_from_config +from taming.modules.util import SOSProvider + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class Net2NetTransformer(pl.LightningModule): + def __init__(self, + transformer_config, + first_stage_config, + cond_stage_config, + permuter_config=None, + ckpt_path=None, + ignore_keys=[], + first_stage_key="image", + cond_stage_key="depth", + downsample_cond_size=-1, + pkeep=1.0, + sos_token=0, + unconditional=False, + ): + super().__init__() + self.be_unconditional = unconditional + self.sos_token = sos_token + self.first_stage_key = first_stage_key + self.cond_stage_key = cond_stage_key + self.init_first_stage_from_ckpt(first_stage_config) + self.init_cond_stage_from_ckpt(cond_stage_config) + if permuter_config is None: + permuter_config = {"target": "taming.modules.transformer.permuter.Identity"} + self.permuter = instantiate_from_config(config=permuter_config) + self.transformer = instantiate_from_config(config=transformer_config) + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.downsample_cond_size = downsample_cond_size + self.pkeep = pkeep + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + for k in sd.keys(): + for ik in ignore_keys: + if k.startswith(ik): + self.print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def init_first_stage_from_ckpt(self, config): + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + self.first_stage_model = model + + def init_cond_stage_from_ckpt(self, config): + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__" or self.be_unconditional: + print(f"Using no cond stage. Assuming the training is intended to be unconditional. " + f"Prepending {self.sos_token} as a sos token.") + self.be_unconditional = True + self.cond_stage_key = self.first_stage_key + self.cond_stage_model = SOSProvider(self.sos_token) + else: + model = instantiate_from_config(config) + model = model.eval() + model.train = disabled_train + self.cond_stage_model = model + + def forward(self, x, c): + # one step to produce the logits + _, z_indices = self.encode_to_z(x) + _, c_indices = self.encode_to_c(c) + + if self.training and self.pkeep < 1.0: + mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape, + device=z_indices.device)) + mask = mask.round().to(dtype=torch.int64) + r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size) + a_indices = mask*z_indices+(1-mask)*r_indices + else: + a_indices = z_indices + + cz_indices = torch.cat((c_indices, a_indices), dim=1) + + # target includes all sequence elements (no need to handle first one + # differently because we are conditioning) + target = z_indices + # make the prediction + logits, _ = self.transformer(cz_indices[:, :-1]) + # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1: + c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size)) + quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c) + if len(indices.shape) > 2: + indices = indices.view(c.shape[0], -1) + return quant_c, indices + + @torch.no_grad() + def decode_to_img(self, index, zshape): + index = self.permuter(index, reverse=True) + bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) + quant_z = self.first_stage_model.quantize.get_codebook_entry( + index.reshape(-1), shape=bhwc) + x = self.first_stage_model.decode(quant_z) + return x + + @torch.no_grad() + def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs): + log = dict() + + N = 4 + if lr_interface: + x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8) + else: + x, c = self.get_xc(batch, N) + x = x.to(device=self.device) + c = c.to(device=self.device) + + quant_z, z_indices = self.encode_to_z(x) + quant_c, c_indices = self.encode_to_c(c) + + # create a "half"" sample + z_start_indices = z_indices[:,:z_indices.shape[1]//2] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1]-z_start_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample = self.decode_to_img(index_sample, quant_z.shape) + + # sample + z_start_indices = z_indices[:, :0] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1], + temperature=temperature if temperature is not None else 1.0, + sample=True, + top_k=top_k if top_k is not None else 100, + callback=callback if callback is not None else lambda k: None) + x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape) + + # det sample + z_start_indices = z_indices[:, :0] + index_sample = self.sample(z_start_indices, c_indices, + steps=z_indices.shape[1], + sample=False, + callback=callback if callback is not None else lambda k: None) + x_sample_det = self.decode_to_img(index_sample, quant_z.shape) + + # reconstruction + x_rec = self.decode_to_img(z_indices, quant_z.shape) + + log["inputs"] = x + log["reconstructions"] = x_rec + + if self.cond_stage_key in ["objects_bbox", "objects_center_points"]: + figure_size = (x_rec.shape[2], x_rec.shape[3]) + dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"] + label_for_category_no = dataset.get_textual_label_for_category_no + plotter = dataset.conditional_builders[self.cond_stage_key].plot + log["conditioning"] = torch.zeros_like(log["reconstructions"]) + for i in range(quant_c.shape[0]): + log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size) + log["conditioning_rec"] = log["conditioning"] + elif self.cond_stage_key != "image": + cond_rec = self.cond_stage_model.decode(quant_c) + if self.cond_stage_key == "segmentation": + # get image from segmentation mask + num_classes = cond_rec.shape[1] + + c = torch.argmax(c, dim=1, keepdim=True) + c = F.one_hot(c, num_classes=num_classes) + c = c.squeeze(1).permute(0, 3, 1, 2).float() + c = self.cond_stage_model.to_rgb(c) + + cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True) + cond_rec = F.one_hot(cond_rec, num_classes=num_classes) + cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float() + cond_rec = self.cond_stage_model.to_rgb(cond_rec) + log["conditioning_rec"] = cond_rec + log["conditioning"] = c + + log["samples_half"] = x_sample + log["samples_nopix"] = x_sample_nopix + log["samples_det"] = x_sample_det + return log + + def get_input(self, key, batch): + x = batch[key] + if len(x.shape) == 3: + x = x[..., None] + if len(x.shape) == 4: + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + if x.dtype == torch.double: + x = x.float() + return x + + def get_xc(self, batch, N=None): + x = self.get_input(self.first_stage_key, batch) + c = self.get_input(self.cond_stage_key, batch) + if N is not None: + x = x[:N] + c = c[:N] + return x, c + + def shared_step(self, batch, batch_idx): + x, c = self.get_xc(batch) + logits, target = self(x, c) + loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) + return loss + + def training_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + loss = self.shared_step(batch, batch_idx) + self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + return loss + + def configure_optimizers(self): + """ + Following minGPT: + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + whitelist_weight_modules = (torch.nn.Linear, ) + blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) + for mn, m in self.transformer.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + + if pn.endswith('bias'): + # all biases will not be decayed + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + # special case the position embedding parameter in the root GPT module as not decayed + no_decay.add('pos_emb') + + # validate that we considered every parameter + param_dict = {pn: p for pn, p in self.transformer.named_parameters()} + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) + assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ + % (str(param_dict.keys() - union_params), ) + + # create the pytorch optimizer object + optim_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95)) + return optimizer diff --git a/repositories/taming/models/dummy_cond_stage.py b/repositories/taming/models/dummy_cond_stage.py new file mode 100644 index 000000000..6e1993807 --- /dev/null +++ b/repositories/taming/models/dummy_cond_stage.py @@ -0,0 +1,22 @@ +from torch import Tensor + + +class DummyCondStage: + def __init__(self, conditional_key): + self.conditional_key = conditional_key + self.train = None + + def eval(self): + return self + + @staticmethod + def encode(c: Tensor): + return c, None, (None, None, c) + + @staticmethod + def decode(c: Tensor): + return c + + @staticmethod + def to_rgb(c: Tensor): + return c diff --git a/repositories/taming/models/vqgan.py b/repositories/taming/models/vqgan.py new file mode 100644 index 000000000..a6950baa5 --- /dev/null +++ b/repositories/taming/models/vqgan.py @@ -0,0 +1,404 @@ +import torch +import torch.nn.functional as F +import pytorch_lightning as pl + +from main import instantiate_from_config + +from taming.modules.diffusionmodules.model import Encoder, Decoder +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from taming.modules.vqvae.quantize import GumbelQuantize +from taming.modules.vqvae.quantize import EMAVectorQuantizer + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.image_key = image_key + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input): + quant, diff, _ = self.encode(input) + dec = self.decode(quant) + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) + return x.float() + + def training_step(self, batch, batch_idx, optimizer_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQSegmentationModel(VQModel): + def __init__(self, n_labels, *args, **kwargs): + super().__init__(*args, **kwargs) + self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1)) + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + return opt_ae + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val") + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + total_loss = log_dict_ae["val/total_loss"] + self.log("val/total_loss", total_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True) + return aeloss + + @torch.no_grad() + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + # convert logits to indices + xrec = torch.argmax(xrec, dim=1, keepdim=True) + xrec = F.one_hot(xrec, num_classes=x.shape[1]) + xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float() + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + return log + + +class VQNoDiscModel(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None + ): + super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim, + ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key, + colorize_nlabels=colorize_nlabels) + + def training_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train") + output = pl.TrainResult(minimize=aeloss) + output.log("train/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return output + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val") + rec_loss = log_dict_ae["val/rec_loss"] + output = pl.EvalResult(checkpoint_on=rec_loss) + output.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + output.log_dict(log_dict_ae) + + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=self.learning_rate, betas=(0.5, 0.9)) + return optimizer + + +class GumbelVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + temperature_scheduler_config, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + kl_weight=1e-8, + remap=None, + ): + + z_channels = ddconfig["z_channels"] + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + + self.loss.n_classes = n_embed + self.vocab_size = n_embed + + self.quantize = GumbelQuantize(z_channels, embed_dim, + n_embed=n_embed, + kl_weight=kl_weight, temp_init=1.0, + remap=remap) + + self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def temperature_scheduling(self): + self.quantize.temperature = self.temperature_scheduler(self.global_step) + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode_code(self, code_b): + raise NotImplementedError + + def training_step(self, batch, batch_idx, optimizer_idx): + self.temperature_scheduling() + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + x = self.get_input(batch, self.image_key) + xrec, qloss = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + rec_loss = log_dict_ae["val/rec_loss"] + self.log("val/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log("val/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def log_images(self, batch, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + # encode + h = self.encoder(x) + h = self.quant_conv(h) + quant, _, _ = self.quantize(h) + # decode + x_rec = self.decode(quant) + log["inputs"] = x + log["reconstructions"] = x_rec + return log + + +class EMAVQ(VQModel): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + ): + super().__init__(ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=ignore_keys, + image_key=image_key, + colorize_nlabels=colorize_nlabels, + monitor=monitor, + ) + self.quantize = EMAVectorQuantizer(n_embed=n_embed, + embedding_dim=embed_dim, + beta=0.25, + remap=remap) + def configure_optimizers(self): + lr = self.learning_rate + #Remove self.quantize from parameter list since it is updated via EMA + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] \ No newline at end of file diff --git a/repositories/taming/modules/diffusionmodules/model.py b/repositories/taming/modules/diffusionmodules/model.py new file mode 100644 index 000000000..d3a5db6aa --- /dev/null +++ b/repositories/taming/modules/diffusionmodules/model.py @@ -0,0 +1,776 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + diff --git a/repositories/taming/modules/discriminator/model.py b/repositories/taming/modules/discriminator/model.py new file mode 100644 index 000000000..2aaa3110d --- /dev/null +++ b/repositories/taming/modules/discriminator/model.py @@ -0,0 +1,67 @@ +import functools +import torch.nn as nn + + +from taming.modules.util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) diff --git a/repositories/taming/modules/losses/__init__.py b/repositories/taming/modules/losses/__init__.py new file mode 100644 index 000000000..d09caf9eb --- /dev/null +++ b/repositories/taming/modules/losses/__init__.py @@ -0,0 +1,2 @@ +from taming.modules.losses.vqperceptual import DummyLoss + diff --git a/repositories/taming/modules/losses/lpips.py b/repositories/taming/modules/losses/lpips.py new file mode 100644 index 000000000..a72804476 --- /dev/null +++ b/repositories/taming/modules/losses/lpips.py @@ -0,0 +1,123 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + +from taming.util import get_ckpt_path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) + diff --git a/repositories/taming/modules/losses/segmentation.py b/repositories/taming/modules/losses/segmentation.py new file mode 100644 index 000000000..4ba77deb5 --- /dev/null +++ b/repositories/taming/modules/losses/segmentation.py @@ -0,0 +1,22 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class BCELoss(nn.Module): + def forward(self, prediction, target): + loss = F.binary_cross_entropy_with_logits(prediction,target) + return loss, {} + + +class BCELossWithQuant(nn.Module): + def __init__(self, codebook_weight=1.): + super().__init__() + self.codebook_weight = codebook_weight + + def forward(self, qloss, target, prediction, split): + bce_loss = F.binary_cross_entropy_with_logits(prediction,target) + loss = bce_loss + self.codebook_weight*qloss + return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/bce_loss".format(split): bce_loss.detach().mean(), + "{}/quant_loss".format(split): qloss.detach().mean() + } diff --git a/repositories/taming/modules/losses/vqperceptual.py b/repositories/taming/modules/losses/vqperceptual.py new file mode 100644 index 000000000..c2febd445 --- /dev/null +++ b/repositories/taming/modules/losses/vqperceptual.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from taming.modules.losses.lpips import LPIPS +from taming.modules.discriminator.model import NLayerDiscriminator, weights_init + + +class DummyLoss(nn.Module): + def __init__(self): + super().__init__() + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train"): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/repositories/taming/modules/misc/coord.py b/repositories/taming/modules/misc/coord.py new file mode 100644 index 000000000..ee69b0c89 --- /dev/null +++ b/repositories/taming/modules/misc/coord.py @@ -0,0 +1,31 @@ +import torch + +class CoordStage(object): + def __init__(self, n_embed, down_factor): + self.n_embed = n_embed + self.down_factor = down_factor + + def eval(self): + return self + + def encode(self, c): + """fake vqmodel interface""" + assert 0.0 <= c.min() and c.max() <= 1.0 + b,ch,h,w = c.shape + assert ch == 1 + + c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, + mode="area") + c = c.clamp(0.0, 1.0) + c = self.n_embed*c + c_quant = c.round() + c_ind = c_quant.to(dtype=torch.long) + + info = None, None, c_ind + return c_quant, None, info + + def decode(self, c): + c = c/self.n_embed + c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, + mode="nearest") + return c diff --git a/repositories/taming/modules/transformer/mingpt.py b/repositories/taming/modules/transformer/mingpt.py new file mode 100644 index 000000000..d14b7b681 --- /dev/null +++ b/repositories/taming/modules/transformer/mingpt.py @@ -0,0 +1,415 @@ +""" +taken from: https://github.com/karpathy/minGPT/ +GPT model: +- the initial stem consists of a combination of token encoding and a positional encoding +- the meat of it is a uniform sequence of Transformer blocks + - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block + - all blocks feed into a central residual pathway similar to resnets +- the final decoder is a linear projection into a vanilla Softmax classifier +""" + +import math +import logging + +import torch +import torch.nn as nn +from torch.nn import functional as F +from transformers import top_k_top_p_filtering + +logger = logging.getLogger(__name__) + + +class GPTConfig: + """ base GPT config, params common to all GPT versions """ + embd_pdrop = 0.1 + resid_pdrop = 0.1 + attn_pdrop = 0.1 + + def __init__(self, vocab_size, block_size, **kwargs): + self.vocab_size = vocab_size + self.block_size = block_size + for k,v in kwargs.items(): + setattr(self, k, v) + + +class GPT1Config(GPTConfig): + """ GPT-1 like network roughly 125M params """ + n_layer = 12 + n_head = 12 + n_embd = 768 + + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + It is possible to use torch.nn.MultiheadAttention here but I am including an + explicit implementation here to show that there is nothing too scary here. + """ + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(config.n_embd, config.n_embd) + self.query = nn.Linear(config.n_embd, config.n_embd) + self.value = nn.Linear(config.n_embd, config.n_embd) + # regularization + self.attn_drop = nn.Dropout(config.attn_pdrop) + self.resid_drop = nn.Dropout(config.resid_pdrop) + # output projection + self.proj = nn.Linear(config.n_embd, config.n_embd) + # causal mask to ensure that attention is only applied to the left in the input sequence + mask = torch.tril(torch.ones(config.block_size, + config.block_size)) + if hasattr(config, "n_unmasked"): + mask[:config.n_unmasked, :config.n_unmasked] = 1 + self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size)) + self.n_head = config.n_head + + def forward(self, x, layer_past=None): + B, T, C = x.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + present = torch.stack((k, v)) + if layer_past is not None: + past_key, past_value = layer_past + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + if layer_past is None: + att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) + + att = F.softmax(att, dim=-1) + att = self.attn_drop(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_drop(self.proj(y)) + return y, present # TODO: check that this does not break anything + + +class Block(nn.Module): + """ an unassuming Transformer block """ + def __init__(self, config): + super().__init__() + self.ln1 = nn.LayerNorm(config.n_embd) + self.ln2 = nn.LayerNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.mlp = nn.Sequential( + nn.Linear(config.n_embd, 4 * config.n_embd), + nn.GELU(), # nice + nn.Linear(4 * config.n_embd, config.n_embd), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x, layer_past=None, return_present=False): + # TODO: check that training still works + if return_present: assert not self.training + # layer past: tuple of length two with B, nh, T, hs + attn, present = self.attn(self.ln1(x), layer_past=layer_past) + + x = x + attn + x = x + self.mlp(self.ln2(x)) + if layer_past is not None or return_present: + return x, present + return x + + +class GPT(nn.Module): + """ the full GPT language model, with a context size of block_size """ + def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256, + embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + super().__init__() + config = GPTConfig(vocab_size=vocab_size, block_size=block_size, + embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, + n_layer=n_layer, n_head=n_head, n_embd=n_embd, + n_unmasked=n_unmasked) + # input embedding stem + self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.block_size = config.block_size + self.apply(self._init_weights) + self.config = config + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, embeddings=None, targets=None): + # forward the GPT model + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + t = token_embeddings.shape[1] + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.ln_f(x) + logits = self.head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss + + def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None): + # inference only + assert not self.training + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + if past is not None: + assert past_length is not None + past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head + past_shape = list(past.shape) + expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head] + assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}" + position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector + else: + position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :] + + x = self.drop(token_embeddings + position_embeddings) + presents = [] # accumulate over layers + for i, block in enumerate(self.blocks): + x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True) + presents.append(present) + + x = self.ln_f(x) + logits = self.head(x) + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head + + +class DummyGPT(nn.Module): + # for debugging + def __init__(self, add_value=1): + super().__init__() + self.add_value = add_value + + def forward(self, idx): + return idx + self.add_value, None + + +class CodeGPT(nn.Module): + """Takes in semi-embeddings""" + def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256, + embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0): + super().__init__() + config = GPTConfig(vocab_size=vocab_size, block_size=block_size, + embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop, + n_layer=n_layer, n_head=n_head, n_embd=n_embd, + n_unmasked=n_unmasked) + # input embedding stem + self.tok_emb = nn.Linear(in_channels, config.n_embd) + self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd)) + self.drop = nn.Dropout(config.embd_pdrop) + # transformer + self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)]) + # decoder head + self.ln_f = nn.LayerNorm(config.n_embd) + self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.block_size = config.block_size + self.apply(self._init_weights) + self.config = config + logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) + + def get_block_size(self): + return self.block_size + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, idx, embeddings=None, targets=None): + # forward the GPT model + token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector + + if embeddings is not None: # prepend explicit embeddings + token_embeddings = torch.cat((embeddings, token_embeddings), dim=1) + + t = token_embeddings.shape[1] + assert t <= self.block_size, "Cannot forward, model block size is exhausted." + position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector + x = self.drop(token_embeddings + position_embeddings) + x = self.blocks(x) + x = self.taming_cinln_f(x) + logits = self.head(x) + + # if we are given some desired targets also calculate the loss + loss = None + if targets is not None: + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) + + return logits, loss + + + +#### sampling utils + +def top_k_logits(logits, k): + v, ix = torch.topk(logits, k) + out = logits.clone() + out[out < v[:, [-1]]] = -float('Inf') + return out + +@torch.no_grad() +def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): + """ + take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in + the sequence, feeding the predictions back into the model each time. Clearly the sampling + has quadratic complexity unlike an RNN that is only linear, and has a finite context window + of block_size, unlike an RNN that has an infinite context window. + """ + block_size = model.get_block_size() + model.eval() + for k in range(steps): + x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed + logits, _ = model(x_cond) + # pluck the logits at the final step and scale by temperature + logits = logits[:, -1, :] / temperature + # optionally crop probabilities to only the top k options + if top_k is not None: + logits = top_k_logits(logits, top_k) + # apply softmax to convert to probabilities + probs = F.softmax(logits, dim=-1) + # sample from the distribution or take the most likely + if sample: + ix = torch.multinomial(probs, num_samples=1) + else: + _, ix = torch.topk(probs, k=1, dim=-1) + # append to the sequence and continue + x = torch.cat((x, ix), dim=1) + + return x + + +@torch.no_grad() +def sample_with_past(x, model, steps, temperature=1., sample_logits=True, + top_k=None, top_p=None, callback=None): + # x is conditioning + sample = x + cond_len = x.shape[1] + past = None + for n in range(steps): + if callback is not None: + callback(n) + logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1)) + if past is None: + past = [present] + else: + past.append(present) + logits = logits[:, -1, :] / temperature + if top_k is not None: + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + + probs = F.softmax(logits, dim=-1) + if not sample_logits: + _, x = torch.topk(probs, k=1, dim=-1) + else: + x = torch.multinomial(probs, num_samples=1) + # append to the sequence and continue + sample = torch.cat((sample, x), dim=1) + del past + sample = sample[:, cond_len:] # cut conditioning off + return sample + + +#### clustering utils + +class KMeans(nn.Module): + def __init__(self, ncluster=512, nc=3, niter=10): + super().__init__() + self.ncluster = ncluster + self.nc = nc + self.niter = niter + self.shape = (3,32,32) + self.register_buffer("C", torch.zeros(self.ncluster,nc)) + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def is_initialized(self): + return self.initialized.item() == 1 + + @torch.no_grad() + def initialize(self, x): + N, D = x.shape + assert D == self.nc, D + c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random + for i in range(self.niter): + # assign all pixels to the closest codebook element + a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1) + # move each codebook element to be the mean of the pixels that assigned to it + c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)]) + # re-assign any poorly positioned codebook elements + nanix = torch.any(torch.isnan(c), dim=1) + ndead = nanix.sum().item() + print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead)) + c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters + + self.C.copy_(c) + self.initialized.fill_(1) + + + def forward(self, x, reverse=False, shape=None): + if not reverse: + # flatten + bs,c,h,w = x.shape + assert c == self.nc + x = x.reshape(bs,c,h*w,1) + C = self.C.permute(1,0) + C = C.reshape(1,c,1,self.ncluster) + a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices + return a + else: + # flatten + bs, HW = x.shape + """ + c = self.C.reshape( 1, self.nc, 1, self.ncluster) + c = c[bs*[0],:,:,:] + c = c[:,:,HW*[0],:] + x = x.reshape(bs, 1, HW, 1) + x = x[:,3*[0],:,:] + x = torch.gather(c, dim=3, index=x) + """ + x = self.C[x] + x = x.permute(0,2,1) + shape = shape if shape is not None else self.shape + x = x.reshape(bs, *shape) + + return x diff --git a/repositories/taming/modules/transformer/permuter.py b/repositories/taming/modules/transformer/permuter.py new file mode 100644 index 000000000..0d43bb135 --- /dev/null +++ b/repositories/taming/modules/transformer/permuter.py @@ -0,0 +1,248 @@ +import torch +import torch.nn as nn +import numpy as np + + +class AbstractPermuter(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + def forward(self, x, reverse=False): + raise NotImplementedError + + +class Identity(AbstractPermuter): + def __init__(self): + super().__init__() + + def forward(self, x, reverse=False): + return x + + +class Subsample(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + C = 1 + indices = np.arange(H*W).reshape(C,H,W) + while min(H, W) > 1: + indices = indices.reshape(C,H//2,2,W//2,2) + indices = indices.transpose(0,2,4,1,3) + indices = indices.reshape(C*4,H//2, W//2) + H = H//2 + W = W//2 + C = C*4 + assert H == W == 1 + idx = torch.tensor(indices.ravel()) + self.register_buffer('forward_shuffle_idx', + nn.Parameter(idx, requires_grad=False)) + self.register_buffer('backward_shuffle_idx', + nn.Parameter(torch.argsort(idx), requires_grad=False)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +def mortonify(i, j): + """(i,j) index to linear morton code""" + i = np.uint64(i) + j = np.uint64(j) + + z = np.uint(0) + + for pos in range(32): + z = (z | + ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | + ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) + ) + return z + + +class ZCurve(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] + idx = np.argsort(reverseidx) + idx = torch.tensor(idx) + reverseidx = torch.tensor(reverseidx) + self.register_buffer('forward_shuffle_idx', + idx) + self.register_buffer('backward_shuffle_idx', + reverseidx) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class SpiralOut(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + assert H == W + size = W + indices = np.arange(size*size).reshape(size,size) + + i0 = size//2 + j0 = size//2-1 + + i = i0 + j = j0 + + idx = [indices[i0, j0]] + step_mult = 0 + for c in range(1, size//2+1): + step_mult += 1 + # steps left + for k in range(step_mult): + i = i - 1 + j = j + idx.append(indices[i, j]) + + # step down + for k in range(step_mult): + i = i + j = j + 1 + idx.append(indices[i, j]) + + step_mult += 1 + if c < size//2: + # step right + for k in range(step_mult): + i = i + 1 + j = j + idx.append(indices[i, j]) + + # step up + for k in range(step_mult): + i = i + j = j - 1 + idx.append(indices[i, j]) + else: + # end reached + for k in range(step_mult-1): + i = i + 1 + idx.append(indices[i, j]) + + assert len(idx) == size*size + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class SpiralIn(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + assert H == W + size = W + indices = np.arange(size*size).reshape(size,size) + + i0 = size//2 + j0 = size//2-1 + + i = i0 + j = j0 + + idx = [indices[i0, j0]] + step_mult = 0 + for c in range(1, size//2+1): + step_mult += 1 + # steps left + for k in range(step_mult): + i = i - 1 + j = j + idx.append(indices[i, j]) + + # step down + for k in range(step_mult): + i = i + j = j + 1 + idx.append(indices[i, j]) + + step_mult += 1 + if c < size//2: + # step right + for k in range(step_mult): + i = i + 1 + j = j + idx.append(indices[i, j]) + + # step up + for k in range(step_mult): + i = i + j = j - 1 + idx.append(indices[i, j]) + else: + # end reached + for k in range(step_mult-1): + i = i + 1 + idx.append(indices[i, j]) + + assert len(idx) == size*size + idx = idx[::-1] + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class Random(nn.Module): + def __init__(self, H, W): + super().__init__() + indices = np.random.RandomState(1).permutation(H*W) + idx = torch.tensor(indices.ravel()) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +class AlternateParsing(AbstractPermuter): + def __init__(self, H, W): + super().__init__() + indices = np.arange(W*H).reshape(H,W) + for i in range(1, H, 2): + indices[i, :] = indices[i, ::-1] + idx = indices.flatten() + assert len(idx) == H*W + idx = torch.tensor(idx) + self.register_buffer('forward_shuffle_idx', idx) + self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) + + def forward(self, x, reverse=False): + if not reverse: + return x[:, self.forward_shuffle_idx] + else: + return x[:, self.backward_shuffle_idx] + + +if __name__ == "__main__": + p0 = AlternateParsing(16, 16) + print(p0.forward_shuffle_idx) + print(p0.backward_shuffle_idx) + + x = torch.randint(0, 768, size=(11, 256)) + y = p0(x) + xre = p0(y, reverse=True) + assert torch.equal(x, xre) + + p1 = SpiralOut(2, 2) + print(p1.forward_shuffle_idx) + print(p1.backward_shuffle_idx) diff --git a/repositories/taming/modules/util.py b/repositories/taming/modules/util.py new file mode 100644 index 000000000..9ee16385d --- /dev/null +++ b/repositories/taming/modules/util.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn + + +def count_params(model): + total_params = sum(p.numel() for p in model.parameters()) + return total_params + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class Labelator(AbstractEncoder): + """Net2Net Interface for Class-Conditional Model""" + def __init__(self, n_classes, quantize_interface=True): + super().__init__() + self.n_classes = n_classes + self.quantize_interface = quantize_interface + + def encode(self, c): + c = c[:,None] + if self.quantize_interface: + return c, None, [None, None, c.long()] + return c + + +class SOSProvider(AbstractEncoder): + # for unconditional training + def __init__(self, sos_token, quantize_interface=True): + super().__init__() + self.sos_token = sos_token + self.quantize_interface = quantize_interface + + def encode(self, x): + # get batch size from data and replicate sos_token + c = torch.ones(x.shape[0], 1)*self.sos_token + c = c.long().to(x.device) + if self.quantize_interface: + return c, None, [None, None, c] + return c diff --git a/repositories/taming/modules/vqvae/quantize.py b/repositories/taming/modules/vqvae/quantize.py new file mode 100644 index 000000000..d75544e41 --- /dev/null +++ b/repositories/taming/modules/vqvae/quantize.py @@ -0,0 +1,445 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:,self.used,...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:,self.used,...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b*h*w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + #EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + #EMA embedding average + embed_sum = encodings.transpose(0,1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + #normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/repositories/taming/util.py b/repositories/taming/util.py new file mode 100644 index 000000000..06053e5de --- /dev/null +++ b/repositories/taming/util.py @@ -0,0 +1,157 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = {"keya": "a", + "keyb": "b", + "keyc": + {"cc1": 1, + "cc2": 2, + } + } + from omegaconf import OmegaConf + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") +