diff --git a/.gitignore b/.gitignore
index 362e9c257..481271499 100644
--- a/.gitignore
+++ b/.gitignore
@@ -43,7 +43,6 @@ cache
!package.json
# all dynamic stuff
-/repositories/**/*
/extensions/**/*
/outputs/**/*
/embeddings/**/*
diff --git a/.gitmodules b/.gitmodules
index 000ad0e26..58a8cd7d7 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -35,3 +35,4 @@
[submodule "modules/k-diffusion"]
path = modules/k-diffusion
url = https://github.com/crowsonkb/k-diffusion
+ ignore = dirty
diff --git a/configs/v2-1-stable-unclip-h-inference.yaml b/configs/v2-1-stable-unclip-h-inference.yaml
new file mode 100644
index 000000000..1bd0c64d3
--- /dev/null
+++ b/configs/v2-1-stable-unclip-h-inference.yaml
@@ -0,0 +1,80 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
+ params:
+ embedding_dropout: 0.25
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 96
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: crossattn-adm
+ scale_factor: 0.18215
+ monitor: val/loss_simple_ema
+ use_ema: False
+
+ embedder_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
+
+ noise_aug_config:
+ target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
+ params:
+ timestep_dim: 1024
+ noise_schedule_config:
+ timesteps: 1000
+ beta_schedule: squaredcos_cap_v2
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ num_classes: "sequential"
+ adm_in_channels: 2048
+ use_checkpoint: True
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: "vanilla-xformers"
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
diff --git a/configs/v2-1-stable-unclip-l-inference.yaml b/configs/v2-1-stable-unclip-l-inference.yaml
new file mode 100644
index 000000000..335fd61f3
--- /dev/null
+++ b/configs/v2-1-stable-unclip-l-inference.yaml
@@ -0,0 +1,83 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion
+ params:
+ embedding_dropout: 0.25
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.0120
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 96
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: crossattn-adm
+ scale_factor: 0.18215
+ monitor: val/loss_simple_ema
+ use_ema: False
+
+ embedder_config:
+ target: ldm.modules.encoders.modules.ClipImageEmbedder
+ params:
+ model: "ViT-L/14"
+
+ noise_aug_config:
+ target: ldm.modules.encoders.modules.CLIPEmbeddingNoiseAugmentation
+ params:
+ clip_stats_path: "checkpoints/karlo_models/ViT-L-14_stats.th"
+ timestep_dim: 768
+ noise_schedule_config:
+ timesteps: 1000
+ beta_schedule: squaredcos_cap_v2
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ num_classes: "sequential"
+ adm_in_channels: 1536
+ use_checkpoint: True
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ attn_type: "vanilla-xformers"
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
\ No newline at end of file
diff --git a/configs/v2-midas-inference.yaml b/configs/v2-midas-inference.yaml
new file mode 100644
index 000000000..f20c30f61
--- /dev/null
+++ b/configs/v2-midas-inference.yaml
@@ -0,0 +1,74 @@
+model:
+ base_learning_rate: 5.0e-07
+ target: ldm.models.diffusion.ddpm.LatentDepth2ImageDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false
+ conditioning_key: hybrid
+ scale_factor: 0.18215
+ monitor: val/loss_simple_ema
+ finetune_keys: null
+ use_ema: False
+
+ depth_stage_config:
+ target: ldm.modules.midas.api.MiDaSInference
+ params:
+ model_type: "dpt_hybrid"
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ use_checkpoint: True
+ image_size: 32 # unused
+ in_channels: 5
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [ 4, 2, 1 ]
+ num_res_blocks: 2
+ channel_mult: [ 1, 2, 4, 4 ]
+ num_head_channels: 64 # need to fix for flash-attn
+ use_spatial_transformer: True
+ use_linear_in_transformer: True
+ transformer_depth: 1
+ context_dim: 1024
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ #attn_type: "vanilla-xformers"
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: [ ]
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+
diff --git a/installer.py b/installer.py
index 205f16cee..b5fcd8335 100644
--- a/installer.py
+++ b/installer.py
@@ -591,6 +591,7 @@ def install_packages():
# clone required repositories
def install_repositories():
+ """
if args.profile:
pr = cProfile.Profile()
pr.enable()
@@ -615,6 +616,7 @@ def d(name):
clone(blip_repo, d('BLIP'), blip_commit)
if args.profile:
print_profile(pr, 'Repositories')
+ """
# run extension installer
diff --git a/modules/paths.py b/modules/paths.py
index 0fd7b87b5..d3cce3f4a 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -17,22 +17,13 @@
# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
-# search for directory of stable diffusion in following places
-sd_path = None
-possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)]
-for possible_sd_path in possible_sd_paths:
- if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')):
- sd_path = os.path.abspath(possible_sd_path)
- break
-
-assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
-
+sd_path = os.path.join(script_path, 'repositories')
path_dirs = [
- (sd_path, 'ldm', 'Stable Diffusion', []),
- (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
- (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
- (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
- (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
+ (sd_path, 'ldm', 'ldm', []),
+ (sd_path, 'taming', 'Taming Transformers', []),
+ (os.path.join(sd_path, 'blip'), 'models/blip.py', 'BLIP', []),
+ (os.path.join(sd_path, 'codeformer'), 'inference_codeformer.py', 'CodeFormer', []),
+ (os.path.join('modules', 'k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
paths = {}
diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 163390ca6..3df76caa5 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -4,10 +4,10 @@
from modules import paths, sd_disable_initialization, devices
-sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
+sd_repo_configs_path = 'configs'
config_default = paths.sd_default_config
-config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
-config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
+config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference-512-base.yaml")
+config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-768-v.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")
diff --git a/repositories/blip b/repositories/blip
new file mode 160000
index 000000000..3a29b7410
--- /dev/null
+++ b/repositories/blip
@@ -0,0 +1 @@
+Subproject commit 3a29b7410476bf5f2ba0955827390eb6ea1f4f9d
diff --git a/repositories/codeformer/.gitignore b/repositories/codeformer/.gitignore
new file mode 100644
index 000000000..15f690d16
--- /dev/null
+++ b/repositories/codeformer/.gitignore
@@ -0,0 +1,128 @@
+.vscode
+
+# ignored files
+version.py
+
+# ignored files with suffix
+*.html
+# *.png
+# *.jpeg
+# *.jpg
+*.pt
+*.gif
+*.pth
+*.dat
+*.zip
+
+# template
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+# project
+results/
+dlib/
+*_old*
+
diff --git a/repositories/codeformer/LICENSE b/repositories/codeformer/LICENSE
new file mode 100644
index 000000000..44bf750a2
--- /dev/null
+++ b/repositories/codeformer/LICENSE
@@ -0,0 +1,35 @@
+S-Lab License 1.0
+
+Copyright 2022 S-Lab
+
+Redistribution and use for non-commercial purpose in source and
+binary forms, with or without modification, are permitted provided
+that the following conditions are met:
+
+1. Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+2. Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in
+ the documentation and/or other materials provided with the
+ distribution.
+
+3. Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+In the event that redistribution and/or use for commercial purpose in
+source or binary forms, with or without modification is required,
+please contact the contributor(s) of the work.
\ No newline at end of file
diff --git a/repositories/codeformer/README.md b/repositories/codeformer/README.md
new file mode 100644
index 000000000..f75785393
--- /dev/null
+++ b/repositories/codeformer/README.md
@@ -0,0 +1,149 @@
+
+
+
+
+## Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022)
+
+[Paper](https://arxiv.org/abs/2206.11253) | [Project Page](https://shangchenzhou.com/projects/CodeFormer/) | [Video](https://youtu.be/d3VDpkXlueI)
+
+
+ [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer) [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer) ![visitors](https://visitor-badge-sczhou.glitch.me/badge?page_id=sczhou/CodeFormer)
+
+
+
+[Shangchen Zhou](https://shangchenzhou.com/), [Kelvin C.K. Chan](https://ckkelvinchan.github.io/), [Chongyi Li](https://li-chongyi.github.io/), [Chen Change Loy](https://www.mmlab-ntu.com/person/ccloy/)
+
+S-Lab, Nanyang Technological University
+
+
+
+
+:star: If CodeFormer is helpful to your images or projects, please help star this repo. Thanks! :hugs:
+
+**[News]**: :whale: *Due to copyright issues, we have to delay the release of the training code (expected by the end of this year). Please star and stay tuned for our future updates!*
+### Update
+- **2022.10.05**: Support video input `--input_path [YOUR_VIDOE.mp4]`. Try it to enhance your videos! :clapper:
+- **2022.09.14**: Integrated to :hugs: [Hugging Face](https://huggingface.co/spaces). Try out online demo! [![Hugging Face](https://img.shields.io/badge/Demo-%F0%9F%A4%97%20Hugging%20Face-blue)](https://huggingface.co/spaces/sczhou/CodeFormer)
+- **2022.09.09**: Integrated to :rocket: [Replicate](https://replicate.com/explore). Try out online demo! [![Replicate](https://img.shields.io/badge/Demo-%F0%9F%9A%80%20Replicate-blue)](https://replicate.com/sczhou/codeformer)
+- **2022.09.04**: Add face upsampling `--face_upsample` for high-resolution AI-created face enhancement.
+- **2022.08.23**: Some modifications on face detection and fusion for better AI-created face enhancement.
+- **2022.08.07**: Integrate [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement.
+- **2022.07.29**: Integrate new face detectors of `['RetinaFace'(default), 'YOLOv5']`.
+- **2022.07.17**: Add Colab demo of CodeFormer.
+- **2022.07.16**: Release inference code for face restoration. :blush:
+- **2022.06.21**: This repo is created.
+
+### TODO
+- [ ] Add checkpoint for face inpainting
+- [ ] Add checkpoint for face colorization
+- [ ] Add training code and config files
+- [x] ~~Add background image enhancement~~
+
+#### :panda_face: Try Enhancing Old Photos / Fixing AI-arts
+[](https://imgsli.com/MTI3NTE2) [](https://imgsli.com/MTI3NTE1) [](https://imgsli.com/MTI3NTIw)
+
+#### Face Restoration
+
+
+
+
+#### Face Color Enhancement and Restoration
+
+
+
+#### Face Inpainting
+
+
+
+
+
+### Dependencies and Installation
+
+- Pytorch >= 1.7.1
+- CUDA >= 10.1
+- Other required packages in `requirements.txt`
+```
+# git clone this repository
+git clone https://github.com/sczhou/CodeFormer
+cd CodeFormer
+
+# create new anaconda env
+conda create -n codeformer python=3.8 -y
+conda activate codeformer
+
+# install python dependencies
+pip3 install -r requirements.txt
+python basicsr/setup.py develop
+```
+
+
+### Quick Inference
+
+#### Download Pre-trained Models:
+Download the facelib pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1b_3qwrzY_kTQh0-SnBoGBgOrJ_PLZSKm?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EvDxR7FcAbZMp_MA9ouq7aQB8XTppMb3-T0uGZ_2anI2mg?e=DXsJFo)] to the `weights/facelib` folder. You can manually download the pretrained models OR download by running the following command.
+```
+python scripts/download_pretrained_models.py facelib
+```
+
+Download the CodeFormer pretrained models from [[Google Drive](https://drive.google.com/drive/folders/1CNNByjHDFt0b95q54yMVp6Ifo5iuU6QS?usp=sharing) | [OneDrive](https://entuedu-my.sharepoint.com/:f:/g/personal/s200094_e_ntu_edu_sg/EoKFj4wo8cdIn2-TY2IV6CYBhZ0pIG4kUOeHdPR_A5nlbg?e=AO8UN9)] to the `weights/CodeFormer` folder. You can manually download the pretrained models OR download by running the following command.
+```
+python scripts/download_pretrained_models.py CodeFormer
+```
+
+#### Prepare Testing Data:
+You can put the testing images in the `inputs/TestWhole` folder. If you would like to test on cropped and aligned faces, you can put them in the `inputs/cropped_faces` folder.
+
+
+#### Testing on Face Restoration:
+[Note] If you want to compare CodeFormer in your paper, please run the following command indicating `--has_aligned` (for cropped and aligned face), as the command for the whole image will involve a process of face-background fusion that may damage hair texture on the boundary, which leads to unfair comparison.
+
+🧑🏻 Face Restoration (cropped and aligned face)
+```
+# For cropped and aligned faces
+python inference_codeformer.py -w 0.5 --has_aligned --input_path [image folder]|[image path]
+```
+
+:framed_picture: Whole Image Enhancement
+```
+# For whole image
+# Add '--bg_upsampler realesrgan' to enhance the background regions with Real-ESRGAN
+# Add '--face_upsample' to further upsample restorated face with Real-ESRGAN
+python inference_codeformer.py -w 0.7 --input_path [image folder]|[image path]
+```
+
+:clapper: Video Enhancement
+```
+# For Windows/Mac users, please install ffmpeg first
+conda install -c conda-forge ffmpeg
+```
+```
+# For video clips
+# video path should end with '.mp4'|'.mov'|'.avi'
+python inference_codeformer.py --bg_upsampler realesrgan --face_upsample -w 1.0 --input_path [video path]
+```
+
+
+Fidelity weight *w* lays in [0, 1]. Generally, smaller *w* tends to produce a higher-quality result, while larger *w* yields a higher-fidelity result.
+
+The results will be saved in the `results` folder.
+
+### Citation
+If our work is useful for your research, please consider citing:
+
+ @inproceedings{zhou2022codeformer,
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
+ booktitle = {NeurIPS},
+ year = {2022}
+ }
+
+### License
+
+This project is licensed under NTU S-Lab License 1.0. Redistribution and use should follow this license.
+
+### Acknowledgement
+
+This project is based on [BasicSR](https://github.com/XPixelGroup/BasicSR). Some codes are brought from [Unleashing Transformers](https://github.com/samb-t/unleashing-transformers), [YOLOv5-face](https://github.com/deepcam-cn/yolov5-face), and [FaceXLib](https://github.com/xinntao/facexlib). We also adopt [Real-ESRGAN](https://github.com/xinntao/Real-ESRGAN) to support background image enhancement. Thanks for their awesome works.
+
+### Contact
+If you have any question, please feel free to reach me out at `shangchenzhou@gmail.com`.
diff --git a/repositories/codeformer/basicsr/VERSION b/repositories/codeformer/basicsr/VERSION
new file mode 100644
index 000000000..1892b9267
--- /dev/null
+++ b/repositories/codeformer/basicsr/VERSION
@@ -0,0 +1 @@
+1.3.2
diff --git a/repositories/codeformer/basicsr/__init__.py b/repositories/codeformer/basicsr/__init__.py
new file mode 100644
index 000000000..c7ffcccd7
--- /dev/null
+++ b/repositories/codeformer/basicsr/__init__.py
@@ -0,0 +1,11 @@
+# https://github.com/xinntao/BasicSR
+# flake8: noqa
+from .archs import *
+from .data import *
+from .losses import *
+from .metrics import *
+from .models import *
+from .ops import *
+from .train import *
+from .utils import *
+from .version import __gitsha__, __version__
diff --git a/repositories/codeformer/basicsr/archs/__init__.py b/repositories/codeformer/basicsr/archs/__init__.py
new file mode 100644
index 000000000..cfb1e4d7b
--- /dev/null
+++ b/repositories/codeformer/basicsr/archs/__init__.py
@@ -0,0 +1,25 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import ARCH_REGISTRY
+
+__all__ = ['build_network']
+
+# automatically scan and import arch modules for registry
+# scan all the files under the 'archs' folder and collect files ending with
+# '_arch.py'
+arch_folder = osp.dirname(osp.abspath(__file__))
+arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
+# import all the arch modules
+_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
+
+
+def build_network(opt):
+ opt = deepcopy(opt)
+ network_type = opt.pop('type')
+ net = ARCH_REGISTRY.get(network_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
+ return net
diff --git a/repositories/codeformer/basicsr/archs/arcface_arch.py b/repositories/codeformer/basicsr/archs/arcface_arch.py
new file mode 100644
index 000000000..fe5afb7bd
--- /dev/null
+++ b/repositories/codeformer/basicsr/archs/arcface_arch.py
@@ -0,0 +1,245 @@
+import torch.nn as nn
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+def conv3x3(inplanes, outplanes, stride=1):
+ """A simple wrapper for 3x3 convolution with padding.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ outplanes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ """
+ return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+ """Basic residual block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class IRBlock(nn.Module):
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+ expansion = 1 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
+ super(IRBlock, self).__init__()
+ self.bn0 = nn.BatchNorm2d(inplanes)
+ self.conv1 = conv3x3(inplanes, inplanes)
+ self.bn1 = nn.BatchNorm2d(inplanes)
+ self.prelu = nn.PReLU()
+ self.conv2 = conv3x3(inplanes, planes, stride)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.use_se = use_se
+ if self.use_se:
+ self.se = SEBlock(planes)
+
+ def forward(self, x):
+ residual = x
+ out = self.bn0(x)
+ out = self.conv1(out)
+ out = self.bn1(out)
+ out = self.prelu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.use_se:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.prelu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ """Bottleneck block used in the ResNetArcFace architecture.
+
+ Args:
+ inplanes (int): Channel number of inputs.
+ planes (int): Channel number of outputs.
+ stride (int): Stride in convolution. Default: 1.
+ downsample (nn.Module): The downsample module. Default: None.
+ """
+ expansion = 4 # output channel expansion ratio
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class SEBlock(nn.Module):
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
+
+ Args:
+ channel (int): Channel number of inputs.
+ reduction (int): Channel reduction ration. Default: 16.
+ """
+
+ def __init__(self, channel, reduction=16):
+ super(SEBlock, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel),
+ nn.Sigmoid())
+
+ def forward(self, x):
+ b, c, _, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1, 1)
+ return x * y
+
+
+@ARCH_REGISTRY.register()
+class ResNetArcFace(nn.Module):
+ """ArcFace with ResNet architectures.
+
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
+
+ Args:
+ block (str): Block used in the ArcFace architecture.
+ layers (tuple(int)): Block numbers in each layer.
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
+ """
+
+ def __init__(self, block, layers, use_se=True):
+ if block == 'IRBlock':
+ block = IRBlock
+ self.inplanes = 64
+ self.use_se = use_se
+ super(ResNetArcFace, self).__init__()
+
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.prelu = nn.PReLU()
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.bn4 = nn.BatchNorm2d(512)
+ self.dropout = nn.Dropout()
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
+ self.bn5 = nn.BatchNorm1d(512)
+
+ # initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.xavier_normal_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, block, planes, num_blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se))
+ self.inplanes = planes
+ for _ in range(1, num_blocks):
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.prelu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.bn4(x)
+ x = self.dropout(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc5(x)
+ x = self.bn5(x)
+
+ return x
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/archs/arch_util.py b/repositories/codeformer/basicsr/archs/arch_util.py
new file mode 100644
index 000000000..bad45ab34
--- /dev/null
+++ b/repositories/codeformer/basicsr/archs/arch_util.py
@@ -0,0 +1,318 @@
+import collections.abc
+import math
+import torch
+import torchvision
+import warnings
+from distutils.version import LooseVersion
+from itertools import repeat
+from torch import nn as nn
+from torch.nn import functional as F
+from torch.nn import init as init
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
+from basicsr.utils import get_root_logger
+
+
+@torch.no_grad()
+def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
+ """Initialize network weights.
+
+ Args:
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
+ scale (float): Scale initialized weights, especially for residual
+ blocks. Default: 1.
+ bias_fill (float): The value to fill bias. Default: 0
+ kwargs (dict): Other arguments for initialization function.
+ """
+ if not isinstance(module_list, list):
+ module_list = [module_list]
+ for module in module_list:
+ for m in module.modules():
+ if isinstance(m, nn.Conv2d):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, nn.Linear):
+ init.kaiming_normal_(m.weight, **kwargs)
+ m.weight.data *= scale
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+ elif isinstance(m, _BatchNorm):
+ init.constant_(m.weight, 1)
+ if m.bias is not None:
+ m.bias.data.fill_(bias_fill)
+
+
+def make_layer(basic_block, num_basic_block, **kwarg):
+ """Make layers by stacking the same blocks.
+
+ Args:
+ basic_block (nn.module): nn.module class for basic block.
+ num_basic_block (int): number of blocks.
+
+ Returns:
+ nn.Sequential: Stacked blocks in nn.Sequential.
+ """
+ layers = []
+ for _ in range(num_basic_block):
+ layers.append(basic_block(**kwarg))
+ return nn.Sequential(*layers)
+
+
+class ResidualBlockNoBN(nn.Module):
+ """Residual block without BN.
+
+ It has a style of:
+ ---Conv-ReLU-Conv-+-
+ |________________|
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ Default: 64.
+ res_scale (float): Residual scale. Default: 1.
+ pytorch_init (bool): If set to True, use pytorch default init,
+ otherwise, use default_init_weights. Default: False.
+ """
+
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
+ super(ResidualBlockNoBN, self).__init__()
+ self.res_scale = res_scale
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
+ self.relu = nn.ReLU(inplace=True)
+
+ if not pytorch_init:
+ default_init_weights([self.conv1, self.conv2], 0.1)
+
+ def forward(self, x):
+ identity = x
+ out = self.conv2(self.relu(self.conv1(x)))
+ return identity + out * self.res_scale
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
+ """Warp an image or feature map with optical flow.
+
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Before pytorch 1.3, the default value is
+ align_corners=True. After pytorch 1.3, the default value is
+ align_corners=False. Here, we use the True as default.
+
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ assert x.size()[-2:] == flow.size()[1:3]
+ _, _, h, w = x.size()
+ # create mesh grid
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
+ grid.requires_grad = False
+
+ vgrid = grid + flow
+ # scale grid to [-1,1]
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
+
+ # TODO, what if align_corners=False
+ return output
+
+
+def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
+ """Resize a flow according to ratio or shape.
+
+ Args:
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
+ size_type (str): 'ratio' or 'shape'.
+ sizes (list[int | float]): the ratio for resizing or the final output
+ shape.
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
+ ratio > 1.0).
+ 2) The order of output_size should be [out_h, out_w].
+ interp_mode (str): The mode of interpolation for resizing.
+ Default: 'bilinear'.
+ align_corners (bool): Whether align corners. Default: False.
+
+ Returns:
+ Tensor: Resized flow.
+ """
+ _, _, flow_h, flow_w = flow.size()
+ if size_type == 'ratio':
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
+ elif size_type == 'shape':
+ output_h, output_w = sizes[0], sizes[1]
+ else:
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
+
+ input_flow = flow.clone()
+ ratio_h = output_h / flow_h
+ ratio_w = output_w / flow_w
+ input_flow[:, 0, :, :] *= ratio_w
+ input_flow[:, 1, :, :] *= ratio_h
+ resized_flow = F.interpolate(
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+ return resized_flow
+
+
+# TODO: may write a cpp file
+def pixel_unshuffle(x, scale):
+ """ Pixel unshuffle.
+
+ Args:
+ x (Tensor): Input feature with shape (b, c, hh, hw).
+ scale (int): Downsample ratio.
+
+ Returns:
+ Tensor: the pixel unshuffled feature.
+ """
+ b, c, hh, hw = x.size()
+ out_channel = c * (scale**2)
+ assert hh % scale == 0 and hw % scale == 0
+ h = hh // scale
+ w = hw // scale
+ x_view = x.view(b, c, h, scale, w, scale)
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
+
+
+class DCNv2Pack(ModulatedDeformConvPack):
+ """Modulated deformable conv for deformable alignment.
+
+ Different from the official DCNv2Pack, which generates offsets and masks
+ from the preceding features, this DCNv2Pack takes another different
+ features to generate offsets and masks.
+
+ Ref:
+ Delving Deep into Deformable Alignment in Video Super-Resolution.
+ """
+
+ def forward(self, x, feat):
+ out = self.conv_offset(feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+
+ offset_absmean = torch.mean(torch.abs(offset))
+ if offset_absmean > 50:
+ logger = get_root_logger()
+ logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
+
+ if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, mask)
+ else:
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
+ 'The distribution of values may be incorrect.',
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ low = norm_cdf((a - mean) / std)
+ up = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [low, up], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ r"""Fills the input Tensor with values drawn from a truncated
+ normal distribution.
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
+
+ The values are effectively drawn from the
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
+ with values outside :math:`[a, b]` redrawn until they are within
+ the bounds. The method used for generating the random values works
+ best when :math:`a \leq \text{mean} \leq b`.
+
+ Args:
+ tensor: an n-dimensional `torch.Tensor`
+ mean: the mean of the normal distribution
+ std: the standard deviation of the normal distribution
+ a: the minimum cutoff value
+ b: the maximum cutoff value
+
+ Examples:
+ >>> w = torch.empty(3, 5)
+ >>> nn.init.trunc_normal_(w)
+ """
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+
+# From PyTorch
+def _ntuple(n):
+
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+to_1tuple = _ntuple(1)
+to_2tuple = _ntuple(2)
+to_3tuple = _ntuple(3)
+to_4tuple = _ntuple(4)
+to_ntuple = _ntuple
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/archs/codeformer_arch.py b/repositories/codeformer/basicsr/archs/codeformer_arch.py
new file mode 100644
index 000000000..4d0d8027c
--- /dev/null
+++ b/repositories/codeformer/basicsr/archs/codeformer_arch.py
@@ -0,0 +1,276 @@
+import math
+import numpy as np
+import torch
+from torch import nn, Tensor
+import torch.nn.functional as F
+from typing import Optional, List
+
+from basicsr.archs.vqgan_arch import *
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def calc_mean_std(feat, eps=1e-5):
+ """Calculate mean and std for adaptive_instance_normalization.
+
+ Args:
+ feat (Tensor): 4D tensor.
+ eps (float): A small value added to the variance to avoid
+ divide-by-zero. Default: 1e-5.
+ """
+ size = feat.size()
+ assert len(size) == 4, 'The input feature should be 4D tensor.'
+ b, c = size[:2]
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
+ return feat_mean, feat_std
+
+
+def adaptive_instance_normalization(content_feat, style_feat):
+ """Adaptive instance normalization.
+
+ Adjust the reference features to have the similar color and illuminations
+ as those in the degradate features.
+
+ Args:
+ content_feat (Tensor): The reference feature.
+ style_feat (Tensor): The degradate features.
+ """
+ size = content_feat.size()
+ style_mean, style_std = calc_mean_std(style_feat)
+ content_mean, content_std = calc_mean_std(content_feat)
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
+
+
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+def _get_activation_fn(activation):
+ """Return an activation function given a string"""
+ if activation == "relu":
+ return F.relu
+ if activation == "gelu":
+ return F.gelu
+ if activation == "glu":
+ return F.glu
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+
+
+class TransformerSALayer(nn.Module):
+ def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
+ super().__init__()
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
+ # Implementation of Feedforward model - MLP
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
+
+ self.norm1 = nn.LayerNorm(embed_dim)
+ self.norm2 = nn.LayerNorm(embed_dim)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.activation = _get_activation_fn(activation)
+
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+ return tensor if pos is None else tensor + pos
+
+ def forward(self, tgt,
+ tgt_mask: Optional[Tensor] = None,
+ tgt_key_padding_mask: Optional[Tensor] = None,
+ query_pos: Optional[Tensor] = None):
+
+ # self attention
+ tgt2 = self.norm1(tgt)
+ q = k = self.with_pos_embed(tgt2, query_pos)
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)[0]
+ tgt = tgt + self.dropout1(tgt2)
+
+ # ffn
+ tgt2 = self.norm2(tgt)
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+ tgt = tgt + self.dropout2(tgt2)
+ return tgt
+
+class Fuse_sft_block(nn.Module):
+ def __init__(self, in_ch, out_ch):
+ super().__init__()
+ self.encode_enc = ResBlock(2*in_ch, out_ch)
+
+ self.scale = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ self.shift = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, True),
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
+
+ def forward(self, enc_feat, dec_feat, w=1):
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
+ scale = self.scale(enc_feat)
+ shift = self.shift(enc_feat)
+ residual = w * (dec_feat * scale + shift)
+ out = dec_feat + residual
+ return out
+
+
+@ARCH_REGISTRY.register()
+class CodeFormer(VQAutoEncoder):
+ def __init__(self, dim_embd=512, n_head=8, n_layers=9,
+ codebook_size=1024, latent_size=256,
+ connect_list=['32', '64', '128', '256'],
+ fix_modules=['quantize','generator']):
+ super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
+
+ if fix_modules is not None:
+ for module in fix_modules:
+ for param in getattr(self, module).parameters():
+ param.requires_grad = False
+
+ self.connect_list = connect_list
+ self.n_layers = n_layers
+ self.dim_embd = dim_embd
+ self.dim_mlp = dim_embd*2
+
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
+ self.feat_emb = nn.Linear(256, self.dim_embd)
+
+ # transformer
+ self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
+ for _ in range(self.n_layers)])
+
+ # logits_predict head
+ self.idx_pred_layer = nn.Sequential(
+ nn.LayerNorm(dim_embd),
+ nn.Linear(dim_embd, codebook_size, bias=False))
+
+ self.channels = {
+ '16': 512,
+ '32': 256,
+ '64': 256,
+ '128': 128,
+ '256': 128,
+ '512': 64,
+ }
+
+ # after second residual block for > 16, before attn layer for ==16
+ self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
+ # after first residual block for > 16, before attn layer for ==16
+ self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
+
+ # fuse_convs_dict
+ self.fuse_convs_dict = nn.ModuleDict()
+ for f_size in self.connect_list:
+ in_ch = self.channels[f_size]
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
+ # ################### Encoder #####################
+ enc_feat_dict = {}
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
+ for i, block in enumerate(self.encoder.blocks):
+ x = block(x)
+ if i in out_list:
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
+
+ lq_feat = x
+ # ################# Transformer ###################
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
+ # BCHW -> BC(HW) -> (HW)BC
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
+ query_emb = feat_emb
+ # Transformer encoder
+ for layer in self.ft_layers:
+ query_emb = layer(query_emb, query_pos=pos_emb)
+
+ # output logits
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
+ logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
+
+ if code_only: # for training stage II
+ # logits doesn't need softmax before cross_entropy loss
+ return logits, lq_feat
+
+ # ################# Quantization ###################
+ # if self.training:
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
+ # # b(hw)c -> bc(hw) -> bchw
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
+ # ------------
+ soft_one_hot = F.softmax(logits, dim=2)
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
+ quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
+ # preserve gradients
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
+
+ if detach_16:
+ quant_feat = quant_feat.detach() # for training stage III
+ if adain:
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
+
+ # ################## Generator ####################
+ x = quant_feat
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
+
+ for i, block in enumerate(self.generator.blocks):
+ x = block(x)
+ if i in fuse_list: # fuse after i-th block
+ f_size = str(x.shape[-1])
+ if w>0:
+ x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
+ out = x
+ # logits doesn't need softmax before cross_entropy loss
+ return out, logits, lq_feat
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/archs/rrdbnet_arch.py b/repositories/codeformer/basicsr/archs/rrdbnet_arch.py
new file mode 100644
index 000000000..49a2d6c20
--- /dev/null
+++ b/repositories/codeformer/basicsr/archs/rrdbnet_arch.py
@@ -0,0 +1,119 @@
+import torch
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from .arch_util import default_init_weights, make_layer, pixel_unshuffle
+
+
+class ResidualDenseBlock(nn.Module):
+ """Residual Dense Block.
+
+ Used in RRDB block in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat=64, num_grow_ch=32):
+ super(ResidualDenseBlock, self).__init__()
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ # initialization
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
+
+ def forward(self, x):
+ x1 = self.lrelu(self.conv1(x))
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return x5 * 0.2 + x
+
+
+class RRDB(nn.Module):
+ """Residual in Residual Dense Block.
+
+ Used in RRDB-Net in ESRGAN.
+
+ Args:
+ num_feat (int): Channel number of intermediate features.
+ num_grow_ch (int): Channels for each growth.
+ """
+
+ def __init__(self, num_feat, num_grow_ch=32):
+ super(RRDB, self).__init__()
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
+
+ def forward(self, x):
+ out = self.rdb1(x)
+ out = self.rdb2(out)
+ out = self.rdb3(out)
+ # Emperically, we use 0.2 to scale the residual for better performance
+ return out * 0.2 + x
+
+
+@ARCH_REGISTRY.register()
+class RRDBNet(nn.Module):
+ """Networks consisting of Residual in Residual Dense Block, which is used
+ in ESRGAN.
+
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
+
+ We extend ESRGAN for scale x2 and scale x1.
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
+
+ Args:
+ num_in_ch (int): Channel number of inputs.
+ num_out_ch (int): Channel number of outputs.
+ num_feat (int): Channel number of intermediate features.
+ Default: 64
+ num_block (int): Block number in the trunk network. Defaults: 23
+ num_grow_ch (int): Channels for each growth. Default: 32.
+ """
+
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
+ super(RRDBNet, self).__init__()
+ self.scale = scale
+ if scale == 2:
+ num_in_ch = num_in_ch * 4
+ elif scale == 1:
+ num_in_ch = num_in_ch * 16
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ # upsample
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+
+ def forward(self, x):
+ if self.scale == 2:
+ feat = pixel_unshuffle(x, scale=2)
+ elif self.scale == 1:
+ feat = pixel_unshuffle(x, scale=4)
+ else:
+ feat = x
+ feat = self.conv_first(feat)
+ body_feat = self.conv_body(self.body(feat))
+ feat = feat + body_feat
+ # upsample
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
+ return out
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/archs/vgg_arch.py b/repositories/codeformer/basicsr/archs/vgg_arch.py
new file mode 100644
index 000000000..23bb0103c
--- /dev/null
+++ b/repositories/codeformer/basicsr/archs/vgg_arch.py
@@ -0,0 +1,161 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+ 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+
+@ARCH_REGISTRY.register()
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = {}
+
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/repositories/codeformer/basicsr/archs/vqgan_arch.py b/repositories/codeformer/basicsr/archs/vqgan_arch.py
new file mode 100644
index 000000000..5ac692633
--- /dev/null
+++ b/repositories/codeformer/basicsr/archs/vqgan_arch.py
@@ -0,0 +1,434 @@
+'''
+VQGAN code, adapted from the original created by the Unleashing Transformers authors:
+https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
+
+'''
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import ARCH_REGISTRY
+
+def normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+@torch.jit.script
+def swish(x):
+ return x*torch.sigmoid(x)
+
+
+# Define VQVAE classes
+class VectorQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.emb_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
+ 2 * torch.matmul(z_flattened, self.embedding.weight.t())
+
+ mean_distance = torch.mean(d)
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+ # min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
+ # [0-1], higher score, higher confidence
+ # min_encoding_scores = torch.exp(-min_encoding_scores/10)
+
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, {
+ "perplexity": perplexity,
+ "min_encodings": min_encodings,
+ "min_encoding_indices": min_encoding_indices,
+ "mean_distance": mean_distance
+ }
+
+ def get_codebook_feat(self, indices, shape):
+ # input indices: batch*token_num -> (batch*token_num)*1
+ # shape: batch, height, width, channel
+ indices = indices.view(-1,1)
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
+ min_encodings.scatter_(1, indices, 1)
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None: # reshape back to match original input shape
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantizer(nn.Module):
+ def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
+ super().__init__()
+ self.codebook_size = codebook_size # number of embeddings
+ self.emb_dim = emb_dim # dimension of embedding
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+ self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
+ self.embed = nn.Embedding(codebook_size, emb_dim)
+
+ def forward(self, z):
+ hard = self.straight_through if self.training else True
+
+ logits = self.proj(z)
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
+
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
+
+ return z_q, diff, {
+ "min_encoding_indices": min_encoding_indices
+ }
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
+
+ def forward(self, x):
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ return x
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
+ x = self.conv(x)
+
+ return x
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels=None):
+ super(ResBlock, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = normalize(in_channels)
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ self.norm2 = normalize(out_channels)
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+ if self.in_channels != self.out_channels:
+ self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, x_in):
+ x = x_in
+ x = self.norm1(x)
+ x = swish(x)
+ x = self.conv1(x)
+ x = self.norm2(x)
+ x = swish(x)
+ x = self.conv2(x)
+ if self.in_channels != self.out_channels:
+ x_in = self.conv_out(x_in)
+
+ return x + x_in
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = normalize(in_channels)
+ self.q = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.k = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.v = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+ self.proj_out = torch.nn.Conv2d(
+ in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0
+ )
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h*w)
+ q = q.permute(0, 2, 1)
+ k = k.reshape(b, c, h*w)
+ w_ = torch.bmm(q, k)
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = F.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h*w)
+ w_ = w_.permute(0, 2, 1)
+ h_ = torch.bmm(v, w_)
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Encoder(nn.Module):
+ def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.attn_resolutions = attn_resolutions
+
+ curr_res = self.resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+
+ blocks = []
+ # initial convultion
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
+
+ # residual and downsampling blocks, with attention on smaller res (16x16)
+ for i in range(self.num_resolutions):
+ block_in_ch = nf * in_ch_mult[i]
+ block_out_ch = nf * ch_mult[i]
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+ if curr_res in attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != self.num_resolutions - 1:
+ blocks.append(Downsample(block_in_ch))
+ curr_res = curr_res // 2
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ # normalise and convert to latent size
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+class Generator(nn.Module):
+ def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
+ super().__init__()
+ self.nf = nf
+ self.ch_mult = ch_mult
+ self.num_resolutions = len(self.ch_mult)
+ self.num_res_blocks = res_blocks
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.in_channels = emb_dim
+ self.out_channels = 3
+ block_in_ch = self.nf * self.ch_mult[-1]
+ curr_res = self.resolution // 2 ** (self.num_resolutions-1)
+
+ blocks = []
+ # initial conv
+ blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
+
+ # non-local attention block
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+ blocks.append(AttnBlock(block_in_ch))
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
+
+ for i in reversed(range(self.num_resolutions)):
+ block_out_ch = self.nf * self.ch_mult[i]
+
+ for _ in range(self.num_res_blocks):
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
+ block_in_ch = block_out_ch
+
+ if curr_res in self.attn_resolutions:
+ blocks.append(AttnBlock(block_in_ch))
+
+ if i != 0:
+ blocks.append(Upsample(block_in_ch))
+ curr_res = curr_res * 2
+
+ blocks.append(normalize(block_in_ch))
+ blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
+
+ self.blocks = nn.ModuleList(blocks)
+
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+
+ return x
+
+
+@ARCH_REGISTRY.register()
+class VQAutoEncoder(nn.Module):
+ def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
+ beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
+ super().__init__()
+ logger = get_root_logger()
+ self.in_channels = 3
+ self.nf = nf
+ self.n_blocks = res_blocks
+ self.codebook_size = codebook_size
+ self.embed_dim = emb_dim
+ self.ch_mult = ch_mult
+ self.resolution = img_size
+ self.attn_resolutions = attn_resolutions
+ self.quantizer_type = quantizer
+ self.encoder = Encoder(
+ self.in_channels,
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+ if self.quantizer_type == "nearest":
+ self.beta = beta #0.25
+ self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
+ elif self.quantizer_type == "gumbel":
+ self.gumbel_num_hiddens = emb_dim
+ self.straight_through = gumbel_straight_through
+ self.kl_weight = gumbel_kl_weight
+ self.quantize = GumbelQuantizer(
+ self.codebook_size,
+ self.embed_dim,
+ self.gumbel_num_hiddens,
+ self.straight_through,
+ self.kl_weight
+ )
+ self.generator = Generator(
+ self.nf,
+ self.embed_dim,
+ self.ch_mult,
+ self.n_blocks,
+ self.resolution,
+ self.attn_resolutions
+ )
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_ema' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
+ logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ logger.info(f'vqgan is loaded from: {model_path} [params]')
+ else:
+ raise ValueError(f'Wrong params!')
+
+
+ def forward(self, x):
+ x = self.encoder(x)
+ quant, codebook_loss, quant_stats = self.quantize(x)
+ x = self.generator(quant)
+ return x, codebook_loss, quant_stats
+
+
+
+# patch based discriminator
+@ARCH_REGISTRY.register()
+class VQGANDiscriminator(nn.Module):
+ def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
+ super().__init__()
+
+ layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
+ ndf_mult = 1
+ ndf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n, 8)
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ ndf_mult_prev = ndf_mult
+ ndf_mult = min(2 ** n_layers, 8)
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(ndf * ndf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ layers += [
+ nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
+ self.main = nn.Sequential(*layers)
+
+ if model_path is not None:
+ chkpt = torch.load(model_path, map_location='cpu')
+ if 'params_d' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
+ elif 'params' in chkpt:
+ self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
+ else:
+ raise ValueError(f'Wrong params!')
+
+ def forward(self, x):
+ return self.main(x)
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/data/__init__.py b/repositories/codeformer/basicsr/data/__init__.py
new file mode 100644
index 000000000..c6adb4bb6
--- /dev/null
+++ b/repositories/codeformer/basicsr/data/__init__.py
@@ -0,0 +1,100 @@
+import importlib
+import numpy as np
+import random
+import torch
+import torch.utils.data
+from copy import deepcopy
+from functools import partial
+from os import path as osp
+
+from basicsr.data.prefetch_dataloader import PrefetchDataLoader
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.dist_util import get_dist_info
+from basicsr.utils.registry import DATASET_REGISTRY
+
+__all__ = ['build_dataset', 'build_dataloader']
+
+# automatically scan and import dataset modules for registry
+# scan all the files under the data folder with '_dataset' in file names
+data_folder = osp.dirname(osp.abspath(__file__))
+dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
+# import all the dataset modules
+_dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
+
+
+def build_dataset(dataset_opt):
+ """Build dataset from options.
+
+ Args:
+ dataset_opt (dict): Configuration for dataset. It must constain:
+ name (str): Dataset name.
+ type (str): Dataset type.
+ """
+ dataset_opt = deepcopy(dataset_opt)
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
+ logger = get_root_logger()
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
+ return dataset
+
+
+def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
+ """Build dataloader.
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset.
+ dataset_opt (dict): Dataset options. It contains the following keys:
+ phase (str): 'train' or 'val'.
+ num_worker_per_gpu (int): Number of workers for each GPU.
+ batch_size_per_gpu (int): Training batch size for each GPU.
+ num_gpu (int): Number of GPUs. Used only in the train phase.
+ Default: 1.
+ dist (bool): Whether in distributed training. Used only in the train
+ phase. Default: False.
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
+ seed (int | None): Seed. Default: None
+ """
+ phase = dataset_opt['phase']
+ rank, _ = get_dist_info()
+ if phase == 'train':
+ if dist: # distributed training
+ batch_size = dataset_opt['batch_size_per_gpu']
+ num_workers = dataset_opt['num_worker_per_gpu']
+ else: # non-distributed training
+ multiplier = 1 if num_gpu == 0 else num_gpu
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
+ dataloader_args = dict(
+ dataset=dataset,
+ batch_size=batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=sampler,
+ drop_last=True)
+ if sampler is None:
+ dataloader_args['shuffle'] = True
+ dataloader_args['worker_init_fn'] = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ elif phase in ['val', 'test']: # validation
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
+ else:
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
+
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
+
+ prefetch_mode = dataset_opt.get('prefetch_mode')
+ if prefetch_mode == 'cpu': # CPUPrefetcher
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
+ logger = get_root_logger()
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: ' f'num_prefetch_queue = {num_prefetch_queue}')
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
+ else:
+ # prefetch_mode=None: Normal dataloader
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
+ return torch.utils.data.DataLoader(**dataloader_args)
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # Set the worker seed to num_workers * rank + worker_id + seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/repositories/codeformer/basicsr/data/data_sampler.py b/repositories/codeformer/basicsr/data/data_sampler.py
new file mode 100644
index 000000000..575452d9f
--- /dev/null
+++ b/repositories/codeformer/basicsr/data/data_sampler.py
@@ -0,0 +1,48 @@
+import math
+import torch
+from torch.utils.data.sampler import Sampler
+
+
+class EnlargedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ Modified from torch.utils.data.distributed.DistributedSampler
+ Support enlarging the dataset for iteration-based training, for saving
+ time when restart the dataloader after each epoch
+
+ Args:
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
+ num_replicas (int | None): Number of processes participating in
+ the training. It is usually the world_size.
+ rank (int | None): Rank of the current process within num_replicas.
+ ratio (int): Enlarging ratio. Default: 1.
+ """
+
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(self.total_size, generator=g).tolist()
+
+ dataset_size = len(self.dataset)
+ indices = [v % dataset_size for v in indices]
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/repositories/codeformer/basicsr/data/data_util.py b/repositories/codeformer/basicsr/data/data_util.py
new file mode 100644
index 000000000..63b1bce8e
--- /dev/null
+++ b/repositories/codeformer/basicsr/data/data_util.py
@@ -0,0 +1,305 @@
+import cv2
+import numpy as np
+import torch
+from os import path as osp
+from torch.nn import functional as F
+
+from basicsr.data.transforms import mod_crop
+from basicsr.utils import img2tensor, scandir
+
+
+def read_img_seq(path, require_mod_crop=False, scale=1):
+ """Read a sequence of images from a given folder path.
+
+ Args:
+ path (list[str] | str): List of image paths or image folder path.
+ require_mod_crop (bool): Require mod crop for each image.
+ Default: False.
+ scale (int): Scale factor for mod_crop. Default: 1.
+
+ Returns:
+ Tensor: size (t, c, h, w), RGB, [0, 1].
+ """
+ if isinstance(path, list):
+ img_paths = path
+ else:
+ img_paths = sorted(list(scandir(path, full_path=True)))
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+ if require_mod_crop:
+ imgs = [mod_crop(img, scale) for img in imgs]
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
+ imgs = torch.stack(imgs, dim=0)
+ return imgs
+
+
+def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
+ """Generate an index list for reading `num_frames` frames from a sequence
+ of images.
+
+ Args:
+ crt_idx (int): Current center index.
+ max_frame_num (int): Max number of the sequence of images (from 1).
+ num_frames (int): Reading num_frames frames.
+ padding (str): Padding mode, one of
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
+ Examples: current_idx = 0, num_frames = 5
+ The generated frame indices under different padding mode:
+ replicate: [0, 0, 0, 1, 2]
+ reflection: [2, 1, 0, 1, 2]
+ reflection_circle: [4, 3, 0, 1, 2]
+ circle: [3, 4, 0, 1, 2]
+
+ Returns:
+ list[int]: A list of indices.
+ """
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
+
+ max_frame_num = max_frame_num - 1 # start from 0
+ num_pad = num_frames // 2
+
+ indices = []
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
+ if i < 0:
+ if padding == 'replicate':
+ pad_idx = 0
+ elif padding == 'reflection':
+ pad_idx = -i
+ elif padding == 'reflection_circle':
+ pad_idx = crt_idx + num_pad - i
+ else:
+ pad_idx = num_frames + i
+ elif i > max_frame_num:
+ if padding == 'replicate':
+ pad_idx = max_frame_num
+ elif padding == 'reflection':
+ pad_idx = max_frame_num * 2 - i
+ elif padding == 'reflection_circle':
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
+ else:
+ pad_idx = i - num_frames
+ else:
+ pad_idx = i
+ indices.append(pad_idx)
+ return indices
+
+
+def paired_paths_from_lmdb(folders, keys):
+ """Generate paired paths from lmdb files.
+
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
+
+ lq.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records
+ 1)image name (with extension),
+ 2)image shape,
+ 3)compression level, separated by a white space.
+ Example: `baboon.png (120,125,3) 1`
+
+ We use the image name without extension as the lmdb key.
+ Note that we use the same key for the corresponding lq and gt images.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ Note that this key is different from lmdb keys.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
+ f'formats. But received {input_key}: {input_folder}; '
+ f'{gt_key}: {gt_folder}')
+ # ensure that the two meta_info files are the same
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
+ else:
+ paths = []
+ for lmdb_key in sorted(input_lmdb_keys):
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
+ return paths
+
+
+def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
+ """Generate paired paths from an meta information file.
+
+ Each line in the meta information file contains the image names and
+ image shape (usually for gt), separated by a white space.
+
+ Example of an meta information file:
+ ```
+ 0001_s001.png (480,480,3)
+ 0001_s002.png (480,480,3)
+ ```
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ meta_info_file (str): Path to the meta information file.
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ with open(meta_info_file, 'r') as fin:
+ gt_names = [line.split(' ')[0] for line in fin]
+
+ paths = []
+ for gt_name in gt_names:
+ basename, ext = osp.splitext(osp.basename(gt_name))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ gt_path = osp.join(gt_folder, gt_name)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paired_paths_from_folder(folders, keys, filename_tmpl):
+ """Generate paired paths from folders.
+
+ Args:
+ folders (list[str]): A list of folder path. The order of list should
+ be [input_folder, gt_folder].
+ keys (list[str]): A list of keys identifying folders. The order should
+ be in consistent with folders, e.g., ['lq', 'gt'].
+ filename_tmpl (str): Template for each filename. Note that the
+ template excludes the file extension. Usually the filename_tmpl is
+ for files in the input folder.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
+ f'But got {len(folders)}')
+ assert len(keys) == 2, ('The len of keys should be 2 with [input_key, gt_key]. ' f'But got {len(keys)}')
+ input_folder, gt_folder = folders
+ input_key, gt_key = keys
+
+ input_paths = list(scandir(input_folder))
+ gt_paths = list(scandir(gt_folder))
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
+ f'{len(input_paths)}, {len(gt_paths)}.')
+ paths = []
+ for gt_path in gt_paths:
+ basename, ext = osp.splitext(osp.basename(gt_path))
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
+ input_path = osp.join(input_folder, input_name)
+ assert input_name in input_paths, (f'{input_name} is not in ' f'{input_key}_paths.')
+ gt_path = osp.join(gt_folder, gt_path)
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
+ return paths
+
+
+def paths_from_folder(folder):
+ """Generate paths from folder.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+
+ paths = list(scandir(folder))
+ paths = [osp.join(folder, path) for path in paths]
+ return paths
+
+
+def paths_from_lmdb(folder):
+ """Generate paths from lmdb.
+
+ Args:
+ folder (str): Folder path.
+
+ Returns:
+ list[str]: Returned path list.
+ """
+ if not folder.endswith('.lmdb'):
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
+ paths = [line.split('.')[0] for line in fin]
+ return paths
+
+
+def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
+ """Generate Gaussian kernel used in `duf_downsample`.
+
+ Args:
+ kernel_size (int): Kernel size. Default: 13.
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
+
+ Returns:
+ np.array: The Gaussian kernel.
+ """
+ from scipy.ndimage import filters as filters
+ kernel = np.zeros((kernel_size, kernel_size))
+ # set element at the middle to one, a dirac delta
+ kernel[kernel_size // 2, kernel_size // 2] = 1
+ # gaussian-smooth the dirac, resulting in a gaussian filter
+ return filters.gaussian_filter(kernel, sigma)
+
+
+def duf_downsample(x, kernel_size=13, scale=4):
+ """Downsamping with Gaussian kernel used in the DUF official code.
+
+ Args:
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
+ kernel_size (int): Kernel size. Default: 13.
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
+ Default: 4.
+
+ Returns:
+ Tensor: DUF downsampled frames.
+ """
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
+
+ squeeze_flag = False
+ if x.ndim == 4:
+ squeeze_flag = True
+ x = x.unsqueeze(0)
+ b, t, c, h, w = x.size()
+ x = x.view(-1, 1, h, w)
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
+
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
+ x = F.conv2d(x, gaussian_filter, stride=scale)
+ x = x[:, :, 2:-2, 2:-2]
+ x = x.view(b, t, c, x.size(2), x.size(3))
+ if squeeze_flag:
+ x = x.squeeze(0)
+ return x
diff --git a/repositories/codeformer/basicsr/data/prefetch_dataloader.py b/repositories/codeformer/basicsr/data/prefetch_dataloader.py
new file mode 100644
index 000000000..508842505
--- /dev/null
+++ b/repositories/codeformer/basicsr/data/prefetch_dataloader.py
@@ -0,0 +1,125 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Ref:
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Ref:
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+ """CUDA prefetcher.
+
+ Ref:
+ https://github.com/NVIDIA/apex/issues/304#
+
+ It may consums more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return None
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/repositories/codeformer/basicsr/data/transforms.py b/repositories/codeformer/basicsr/data/transforms.py
new file mode 100644
index 000000000..aead9dc73
--- /dev/null
+++ b/repositories/codeformer/basicsr/data/transforms.py
@@ -0,0 +1,165 @@
+import cv2
+import random
+
+
+def mod_crop(img, scale):
+ """Mod crop images, used during testing.
+
+ Args:
+ img (ndarray): Input image.
+ scale (int): Scale factor.
+
+ Returns:
+ ndarray: Result image.
+ """
+ img = img.copy()
+ if img.ndim in (2, 3):
+ h, w = img.shape[0], img.shape[1]
+ h_remainder, w_remainder = h % scale, w % scale
+ img = img[:h - h_remainder, :w - w_remainder, ...]
+ else:
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
+ return img
+
+
+def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path):
+ """Paired random crop.
+
+ It crops lists of lq and gt images with corresponding locations.
+
+ Args:
+ img_gts (list[ndarray] | ndarray): GT images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
+ should have the same shape. If the input is an ndarray, it will
+ be transformed to a list containing itself.
+ gt_patch_size (int): GT patch size.
+ scale (int): Scale factor.
+ gt_path (str): Path to ground-truth.
+
+ Returns:
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
+ only have one element, just return ndarray.
+ """
+
+ if not isinstance(img_gts, list):
+ img_gts = [img_gts]
+ if not isinstance(img_lqs, list):
+ img_lqs = [img_lqs]
+
+ h_lq, w_lq, _ = img_lqs[0].shape
+ h_gt, w_gt, _ = img_gts[0].shape
+ lq_patch_size = gt_patch_size // scale
+
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.')
+
+ # randomly choose top and left coordinates for lq patch
+ top = random.randint(0, h_lq - lq_patch_size)
+ left = random.randint(0, w_lq - lq_patch_size)
+
+ # crop lq patch
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+
+ # crop corresponding gt patch
+ top_gt, left_gt = int(top * scale), int(left * scale)
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ if len(img_gts) == 1:
+ img_gts = img_gts[0]
+ if len(img_lqs) == 1:
+ img_lqs = img_lqs[0]
+ return img_gts, img_lqs
+
+
+def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
+
+ We use vertical flip and transpose for rotation implementation.
+ All the images in the list use the same augmentation.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
+ is an ndarray, it will be transformed to a list.
+ hflip (bool): Horizontal flip. Default: True.
+ rotation (bool): Ratotation. Default: True.
+ flows (list[ndarray]: Flows to be augmented. If the input is an
+ ndarray, it will be transformed to a list.
+ Dimension is (h, w, 2). Default: None.
+ return_status (bool): Return the status of flip and rotation.
+ Default: False.
+
+ Returns:
+ list[ndarray] | ndarray: Augmented images and flows. If returned
+ results only have one element, just return ndarray.
+
+ """
+ hflip = hflip and random.random() < 0.5
+ vflip = rotation and random.random() < 0.5
+ rot90 = rotation and random.random() < 0.5
+
+ def _augment(img):
+ if hflip: # horizontal
+ cv2.flip(img, 1, img)
+ if vflip: # vertical
+ cv2.flip(img, 0, img)
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ def _augment_flow(flow):
+ if hflip: # horizontal
+ cv2.flip(flow, 1, flow)
+ flow[:, :, 0] *= -1
+ if vflip: # vertical
+ cv2.flip(flow, 0, flow)
+ flow[:, :, 1] *= -1
+ if rot90:
+ flow = flow.transpose(1, 0, 2)
+ flow = flow[:, :, [1, 0]]
+ return flow
+
+ if not isinstance(imgs, list):
+ imgs = [imgs]
+ imgs = [_augment(img) for img in imgs]
+ if len(imgs) == 1:
+ imgs = imgs[0]
+
+ if flows is not None:
+ if not isinstance(flows, list):
+ flows = [flows]
+ flows = [_augment_flow(flow) for flow in flows]
+ if len(flows) == 1:
+ flows = flows[0]
+ return imgs, flows
+ else:
+ if return_status:
+ return imgs, (hflip, vflip, rot90)
+ else:
+ return imgs
+
+
+def img_rotate(img, angle, center=None, scale=1.0):
+ """Rotate image.
+
+ Args:
+ img (ndarray): Image to be rotated.
+ angle (float): Rotation angle in degrees. Positive values mean
+ counter-clockwise rotation.
+ center (tuple[int]): Rotation center. If the center is None,
+ initialize it as the center of the image. Default: None.
+ scale (float): Isotropic scale factor. Default: 1.0.
+ """
+ (h, w) = img.shape[:2]
+
+ if center is None:
+ center = (w // 2, h // 2)
+
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
+ rotated_img = cv2.warpAffine(img, matrix, (w, h))
+ return rotated_img
diff --git a/repositories/codeformer/basicsr/losses/__init__.py b/repositories/codeformer/basicsr/losses/__init__.py
new file mode 100644
index 000000000..2b184e74c
--- /dev/null
+++ b/repositories/codeformer/basicsr/losses/__init__.py
@@ -0,0 +1,26 @@
+from copy import deepcopy
+
+from basicsr.utils import get_root_logger
+from basicsr.utils.registry import LOSS_REGISTRY
+from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
+ gradient_penalty_loss, r1_penalty)
+
+__all__ = [
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
+ 'r1_penalty', 'g_path_regularize'
+]
+
+
+def build_loss(opt):
+ """Build loss from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ loss_type = opt.pop('type')
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
+ logger = get_root_logger()
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
+ return loss
diff --git a/repositories/codeformer/basicsr/losses/loss_util.py b/repositories/codeformer/basicsr/losses/loss_util.py
new file mode 100644
index 000000000..744eeb46d
--- /dev/null
+++ b/repositories/codeformer/basicsr/losses/loss_util.py
@@ -0,0 +1,95 @@
+import functools
+from torch.nn import functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are 'none', 'mean' and 'sum'.
+
+ Returns:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ else:
+ return loss.sum()
+
+
+def weight_reduce_loss(loss, weight=None, reduction='mean'):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights. Default: None.
+ reduction (str): Same as built-in losses of PyTorch. Options are
+ 'none', 'mean' and 'sum'. Default: 'mean'.
+
+ Returns:
+ Tensor: Loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ assert weight.dim() == loss.dim()
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
+ loss = loss * weight
+
+ # if weight is not specified or reduction is sum, just reduce the loss
+ if weight is None or reduction == 'sum':
+ loss = reduce_loss(loss, reduction)
+ # if reduction is mean, then compute mean over weight region
+ elif reduction == 'mean':
+ if weight.size(1) > 1:
+ weight = weight.sum()
+ else:
+ weight = weight.sum() * loss.size(1)
+ loss = loss.sum() / weight
+
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.5000)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, reduction='sum')
+ tensor(3.)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction)
+ return loss
+
+ return wrapper
diff --git a/repositories/codeformer/basicsr/losses/losses.py b/repositories/codeformer/basicsr/losses/losses.py
new file mode 100644
index 000000000..1bcf272cf
--- /dev/null
+++ b/repositories/codeformer/basicsr/losses/losses.py
@@ -0,0 +1,455 @@
+import math
+import lpips
+import torch
+from torch import autograd as autograd
+from torch import nn as nn
+from torch.nn import functional as F
+
+from basicsr.archs.vgg_arch import VGGFeatureExtractor
+from basicsr.utils.registry import LOSS_REGISTRY
+from .loss_util import weighted_loss
+
+_reduction_modes = ['none', 'mean', 'sum']
+
+
+@weighted_loss
+def l1_loss(pred, target):
+ return F.l1_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@weighted_loss
+def charbonnier_loss(pred, target, eps=1e-12):
+ return torch.sqrt((pred - target)**2 + eps)
+
+
+@LOSS_REGISTRY.register()
+class L1Loss(nn.Module):
+ """L1 (mean absolute error, MAE) loss.
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(L1Loss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class MSELoss(nn.Module):
+ """MSE (L2) loss.
+
+ Args:
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean'):
+ super(MSELoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class CharbonnierLoss(nn.Module):
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
+ variant of L1Loss).
+
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
+ Super-Resolution".
+
+ Args:
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
+ reduction (str): Specifies the reduction to apply to the output.
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
+ eps (float): A value used to control the curvature near zero.
+ Default: 1e-12.
+ """
+
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
+ super(CharbonnierLoss, self).__init__()
+ if reduction not in ['none', 'mean', 'sum']:
+ raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
+
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, pred, target, weight=None, **kwargs):
+ """
+ Args:
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
+ weights. Default: None.
+ """
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
+
+
+@LOSS_REGISTRY.register()
+class WeightedTVLoss(L1Loss):
+ """Weighted TV loss.
+
+ Args:
+ loss_weight (float): Loss weight. Default: 1.0.
+ """
+
+ def __init__(self, loss_weight=1.0):
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
+
+ def forward(self, pred, weight=None):
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=weight[:, :, :-1, :])
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=weight[:, :, :, :-1])
+
+ loss = x_diff + y_diff
+
+ return loss
+
+
+@LOSS_REGISTRY.register()
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculting losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.,
+ criterion='l1'):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm)
+
+ self.criterion_type = criterion
+ if self.criterion_type == 'l1':
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == 'l2':
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == 'mse':
+ self.criterion = torch.nn.MSELoss(reduction='mean')
+ elif self.criterion_type == 'fro':
+ self.criterion = None
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ style_loss += torch.norm(
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+ else:
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+ gt_features[k])) * self.layer_weights[k]
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
+
+
+@LOSS_REGISTRY.register()
+class LPIPSLoss(nn.Module):
+ def __init__(self,
+ loss_weight=1.0,
+ use_input_norm=True,
+ range_norm=False,):
+ super(LPIPSLoss, self).__init__()
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
+ self.loss_weight = loss_weight
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, pred, target):
+ if self.range_norm:
+ pred = (pred + 1) / 2
+ target = (target + 1) / 2
+ if self.use_input_norm:
+ pred = (pred - self.mean) / self.std
+ target = (target - self.mean) / self.std
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+ return self.loss_weight * lpips_loss.mean()
+
+
+@LOSS_REGISTRY.register()
+class GANLoss(nn.Module):
+ """Define GAN loss.
+
+ Args:
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
+ real_label_val (float): The value for real label. Default: 1.0.
+ fake_label_val (float): The value for fake label. Default: 0.0.
+ loss_weight (float): Loss weight. Default: 1.0.
+ Note that loss_weight is only for generators; and it is always 1.0
+ for discriminators.
+ """
+
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
+ super(GANLoss, self).__init__()
+ self.gan_type = gan_type
+ self.loss_weight = loss_weight
+ self.real_label_val = real_label_val
+ self.fake_label_val = fake_label_val
+
+ if self.gan_type == 'vanilla':
+ self.loss = nn.BCEWithLogitsLoss()
+ elif self.gan_type == 'lsgan':
+ self.loss = nn.MSELoss()
+ elif self.gan_type == 'wgan':
+ self.loss = self._wgan_loss
+ elif self.gan_type == 'wgan_softplus':
+ self.loss = self._wgan_softplus_loss
+ elif self.gan_type == 'hinge':
+ self.loss = nn.ReLU()
+ else:
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
+
+ def _wgan_loss(self, input, target):
+ """wgan loss.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return -input.mean() if target else input.mean()
+
+ def _wgan_softplus_loss(self, input, target):
+ """wgan loss with soft plus. softplus is a smooth approximation to the
+ ReLU function.
+
+ In StyleGAN2, it is called:
+ Logistic loss for discriminator;
+ Non-saturating loss for generator.
+
+ Args:
+ input (Tensor): Input tensor.
+ target (bool): Target label.
+
+ Returns:
+ Tensor: wgan loss.
+ """
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
+
+ def get_target_label(self, input, target_is_real):
+ """Get target label.
+
+ Args:
+ input (Tensor): Input tensor.
+ target_is_real (bool): Whether the target is real or fake.
+
+ Returns:
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
+ return Tensor.
+ """
+
+ if self.gan_type in ['wgan', 'wgan_softplus']:
+ return target_is_real
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+ return input.new_ones(input.size()) * target_val
+
+ def forward(self, input, target_is_real, is_disc=False):
+ """
+ Args:
+ input (Tensor): The input for the loss module, i.e., the network
+ prediction.
+ target_is_real (bool): Whether the targe is real or fake.
+ is_disc (bool): Whether the loss for discriminators or not.
+ Default: False.
+
+ Returns:
+ Tensor: GAN loss value.
+ """
+ if self.gan_type == 'hinge':
+ if is_disc: # for discriminators in hinge-gan
+ input = -input if target_is_real else input
+ loss = self.loss(1 + input).mean()
+ else: # for generators in hinge-gan
+ loss = -input.mean()
+ else: # other gan types
+ target_label = self.get_target_label(input, target_is_real)
+ loss = self.loss(input, target_label)
+
+ # loss_weight is always 1.0 for discriminators
+ return loss if is_disc else loss * self.loss_weight
+
+
+def r1_penalty(real_pred, real_img):
+ """R1 regularization for discriminator. The core idea is to
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Ref:
+ Eq. 9 in Which training methods for GANs do actually converge.
+ """
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
+ return grad_penalty
+
+
+def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
+
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
+
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
+
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
+
+
+def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
+ """Calculate gradient penalty for wgan-gp.
+
+ Args:
+ discriminator (nn.Module): Network for the discriminator.
+ real_data (Tensor): Real input data.
+ fake_data (Tensor): Fake input data.
+ weight (Tensor): Weight tensor. Default: None.
+
+ Returns:
+ Tensor: A tensor for gradient penalty.
+ """
+
+ batch_size = real_data.size(0)
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
+
+ # interpolate between real_data and fake_data
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
+
+ disc_interpolates = discriminator(interpolates)
+ gradients = autograd.grad(
+ outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones_like(disc_interpolates),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True)[0]
+
+ if weight is not None:
+ gradients = gradients * weight
+
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+ if weight is not None:
+ gradients_penalty /= torch.mean(weight)
+
+ return gradients_penalty
diff --git a/repositories/codeformer/basicsr/metrics/__init__.py b/repositories/codeformer/basicsr/metrics/__init__.py
new file mode 100644
index 000000000..19d55cc83
--- /dev/null
+++ b/repositories/codeformer/basicsr/metrics/__init__.py
@@ -0,0 +1,19 @@
+from copy import deepcopy
+
+from basicsr.utils.registry import METRIC_REGISTRY
+from .psnr_ssim import calculate_psnr, calculate_ssim
+
+__all__ = ['calculate_psnr', 'calculate_ssim']
+
+
+def calculate_metric(data, opt):
+ """Calculate metric from data and options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ metric_type = opt.pop('type')
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
+ return metric
diff --git a/repositories/codeformer/basicsr/metrics/metric_util.py b/repositories/codeformer/basicsr/metrics/metric_util.py
new file mode 100644
index 000000000..4d18f0f78
--- /dev/null
+++ b/repositories/codeformer/basicsr/metrics/metric_util.py
@@ -0,0 +1,45 @@
+import numpy as np
+
+from basicsr.utils.matlab_functions import bgr2ycbcr
+
+
+def reorder_image(img, input_order='HWC'):
+ """Reorder images to 'HWC' order.
+
+ If the input_order is (h, w), return (h, w, 1);
+ If the input_order is (c, h, w), return (h, w, c);
+ If the input_order is (h, w, c), return as it is.
+
+ Args:
+ img (ndarray): Input image.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ If the input image shape is (h, w), input_order will not have
+ effects. Default: 'HWC'.
+
+ Returns:
+ ndarray: reordered image.
+ """
+
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
+ if len(img.shape) == 2:
+ img = img[..., None]
+ if input_order == 'CHW':
+ img = img.transpose(1, 2, 0)
+ return img
+
+
+def to_y_channel(img):
+ """Change to Y channel of YCbCr.
+
+ Args:
+ img (ndarray): Images with range [0, 255].
+
+ Returns:
+ (ndarray): Images with range [0, 255] (float type) without round.
+ """
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 3 and img.shape[2] == 3:
+ img = bgr2ycbcr(img, y_only=True)
+ img = img[..., None]
+ return img * 255.
diff --git a/repositories/codeformer/basicsr/metrics/psnr_ssim.py b/repositories/codeformer/basicsr/metrics/psnr_ssim.py
new file mode 100644
index 000000000..bbd950699
--- /dev/null
+++ b/repositories/codeformer/basicsr/metrics/psnr_ssim.py
@@ -0,0 +1,128 @@
+import cv2
+import numpy as np
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.registry import METRIC_REGISTRY
+
+
+@METRIC_REGISTRY.register()
+def calculate_psnr(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the PSNR calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: psnr result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20. * np.log10(255. / np.sqrt(mse))
+
+
+def _ssim(img1, img2):
+ """Calculate SSIM (structural similarity) for one channel images.
+
+ It is called by func:`calculate_ssim`.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255] with order 'HWC'.
+ img2 (ndarray): Images with range [0, 255] with order 'HWC'.
+
+ Returns:
+ float: ssim result.
+ """
+
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+@METRIC_REGISTRY.register()
+def calculate_ssim(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
+ """Calculate SSIM (structural similarity).
+
+ Ref:
+ Image quality assessment: From error visibility to structural similarity
+
+ The results are the same as that of the official released MATLAB code in
+ https://ece.uwaterloo.ca/~z70wang/research/ssim/.
+
+ For three-channel images, SSIM is calculated for each channel and then
+ averaged.
+
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the SSIM calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: ssim result.
+ """
+
+ assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img1 = reorder_image(img1, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img1 = to_y_channel(img1)
+ img2 = to_y_channel(img2)
+
+ ssims = []
+ for i in range(img1.shape[2]):
+ ssims.append(_ssim(img1[..., i], img2[..., i]))
+ return np.array(ssims).mean()
diff --git a/repositories/codeformer/basicsr/models/__init__.py b/repositories/codeformer/basicsr/models/__init__.py
new file mode 100644
index 000000000..00bde45f0
--- /dev/null
+++ b/repositories/codeformer/basicsr/models/__init__.py
@@ -0,0 +1,30 @@
+import importlib
+from copy import deepcopy
+from os import path as osp
+
+from basicsr.utils import get_root_logger, scandir
+from basicsr.utils.registry import MODEL_REGISTRY
+
+__all__ = ['build_model']
+
+# automatically scan and import model modules for registry
+# scan all the files under the 'models' folder and collect files ending with
+# '_model.py'
+model_folder = osp.dirname(osp.abspath(__file__))
+model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
+# import all the model modules
+_model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
+
+
+def build_model(opt):
+ """Build model from options.
+
+ Args:
+ opt (dict): Configuration. It must constain:
+ model_type (str): Model type.
+ """
+ opt = deepcopy(opt)
+ model = MODEL_REGISTRY.get(opt['model_type'])(opt)
+ logger = get_root_logger()
+ logger.info(f'Model [{model.__class__.__name__}] is created.')
+ return model
diff --git a/repositories/.placeholder b/repositories/codeformer/basicsr/ops/__init__.py
similarity index 100%
rename from repositories/.placeholder
rename to repositories/codeformer/basicsr/ops/__init__.py
diff --git a/repositories/codeformer/basicsr/ops/dcn/__init__.py b/repositories/codeformer/basicsr/ops/dcn/__init__.py
new file mode 100644
index 000000000..32e3592f8
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/dcn/__init__.py
@@ -0,0 +1,7 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
+ modulated_deform_conv)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
+ 'modulated_deform_conv'
+]
diff --git a/repositories/codeformer/basicsr/ops/dcn/deform_conv.py b/repositories/codeformer/basicsr/ops/dcn/deform_conv.py
new file mode 100644
index 000000000..734154f9e
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/dcn/deform_conv.py
@@ -0,0 +1,377 @@
+import math
+import torch
+from torch import nn as nn
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn import functional as F
+from torch.nn.modules.utils import _pair, _single
+
+try:
+ from . import deform_conv_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ deform_conv_ext = load(
+ 'deform_conv',
+ sources=[
+ os.path.join(module_path, 'src', 'deform_conv_ext.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda.cpp'),
+ os.path.join(module_path, 'src', 'deform_conv_cuda_kernel.cu'),
+ ],
+ )
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError(f'Expected 4D tensor as input, got {input.dim()}' 'D tensor instead.')
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(DeformConvFunction._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_ext.deform_conv_forward(input, weight,
+ offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
+ ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
+ ctx.deformable_groups, cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
+ ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+ cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError('convolution input is too small (output would be ' f'{"x".join(map(str, output_size))})')
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+ or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, \
+ f'in_channels {in_channels} is not divisible by groups {groups}'
+ assert out_channels % groups == 0, \
+ f'out_channels {out_channels} is not divisible ' \
+ f'by groups {groups}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+ return out
+
+
+class DeformConvPack(DeformConv):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
+ self.deformable_groups)
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ dilation=_pair(self.dilation),
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ super(ModulatedDeformConvPack, self).init_weights()
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
diff --git a/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 000000000..5d9424908
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,685 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ TORCH_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ TORCH_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ TORCH_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ TORCH_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ TORCH_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ TORCH_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ TORCH_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ TORCH_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ TORCH_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ TORCH_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ TORCH_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ TORCH_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ TORCH_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}
diff --git a/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
new file mode 100644
index 000000000..98752dccf
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -0,0 +1,867 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const scalar_t map_h = i * dilation_h + offset_h;
+ //const scalar_t map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+void deformable_im2col(
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h, const int ksize_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ // todo: check parallel_imgs is correctly passed in
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+ height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_gpu_kernel(
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+ 2 * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index];
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+void deformable_col2im(
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im)
+{
+
+ // todo: make sure parallel_imgs is passed in correctly
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+ const scalar_t *data_im, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col, scalar_t *grad_offset)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ const scalar_t weight = get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos];
+ cnt += 1;
+ }
+
+ grad_offset[index] = val;
+ }
+}
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+
+ deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+ height_col, width_col, grad_offset_);
+ }));
+}
+
+template
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const float map_h = i * dilation_h + offset_h;
+ //const float map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val * mask;
+ data_col_ptr += batch_size * height_col * width_col;
+ //data_col_ptr += height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_im,
+ const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_offset, scalar_t *grad_mask)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ else
+ {
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+ }
+ const scalar_t weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+ }
+}
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *data_col_ = data_col.data_ptr();
+
+ modulated_deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor grad_im)
+{
+
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_im_ = grad_im.data_ptr();
+
+ modulated_deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ at::Tensor grad_offset, at::Tensor grad_mask)
+{
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data_ptr();
+ const scalar_t *data_im_ = data_im.data_ptr();
+ const scalar_t *data_offset_ = data_offset.data_ptr();
+ const scalar_t *data_mask_ = data_mask.data_ptr();
+ scalar_t *grad_offset_ = grad_offset.data_ptr();
+ scalar_t *grad_mask_ = grad_mask.data_ptr();
+
+ modulated_deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+ grad_offset_, grad_mask_);
+ }));
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
diff --git a/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp
new file mode 100644
index 000000000..41c6df6f7
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/dcn/src/deform_conv_ext.cpp
@@ -0,0 +1,164 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+#define WITH_CUDA // always use cuda
+#ifdef WITH_CUDA
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step);
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step);
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias);
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias);
+#endif
+
+int deform_conv_forward(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_forward_cuda(input, weight, offset, output, columns,
+ ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group,
+ deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_input(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_input_cuda(input, offset, gradOutput,
+ gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH,
+ dilationW, dilationH, group, deformable_group, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+int deform_conv_backward_parameters(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return deform_conv_backward_parameters_cuda(input, offset, gradOutput,
+ gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW,
+ dilationH, group, deformable_group, scale, im2col_step);
+#else
+ AT_ERROR("deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_forward(input, weight, bias, ones,
+ offset, mask, output, columns, kernel_h, kernel_w, stride_h,
+ stride_w, pad_h, pad_w, dilation_h, dilation_w, group,
+ deformable_group, with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+void modulated_deform_conv_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ if (input.device().is_cuda()) {
+#ifdef WITH_CUDA
+ return modulated_deform_conv_cuda_backward(input, weight, bias, ones,
+ offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset,
+ grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w,
+ pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
+ with_bias);
+#else
+ AT_ERROR("modulated deform conv is not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("modulated deform conv is not implemented on CPU");
+}
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_conv_forward", &deform_conv_forward,
+ "deform forward");
+ m.def("deform_conv_backward_input", &deform_conv_backward_input,
+ "deform_conv_backward_input");
+ m.def("deform_conv_backward_parameters",
+ &deform_conv_backward_parameters,
+ "deform_conv_backward_parameters");
+ m.def("modulated_deform_conv_forward",
+ &modulated_deform_conv_forward,
+ "modulated deform conv forward");
+ m.def("modulated_deform_conv_backward",
+ &modulated_deform_conv_backward,
+ "modulated deform conv backward");
+}
diff --git a/repositories/codeformer/basicsr/ops/fused_act/__init__.py b/repositories/codeformer/basicsr/ops/fused_act/__init__.py
new file mode 100644
index 000000000..241dc0754
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/fused_act/__init__.py
@@ -0,0 +1,3 @@
+from .fused_act import FusedLeakyReLU, fused_leaky_relu
+
+__all__ = ['FusedLeakyReLU', 'fused_leaky_relu']
diff --git a/repositories/codeformer/basicsr/ops/fused_act/fused_act.py b/repositories/codeformer/basicsr/ops/fused_act/fused_act.py
new file mode 100644
index 000000000..588f815e5
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/fused_act/fused_act.py
@@ -0,0 +1,89 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
+
+import torch
+from torch import nn
+from torch.autograd import Function
+
+try:
+ from . import fused_act_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ fused_act_ext = load(
+ 'fused',
+ sources=[
+ os.path.join(module_path, 'src', 'fused_bias_act.cpp'),
+ os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'),
+ ],
+ )
+
+
+class FusedLeakyReLUFunctionBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, out, negative_slope, scale):
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ empty = grad_output.new_empty(0)
+
+ grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale)
+
+ dim = [0]
+
+ if grad_input.ndim > 2:
+ dim += list(range(2, grad_input.ndim))
+
+ grad_bias = grad_input.sum(dim).detach()
+
+ return grad_input, grad_bias
+
+ @staticmethod
+ def backward(ctx, gradgrad_input, gradgrad_bias):
+ out, = ctx.saved_tensors
+ gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
+ ctx.scale)
+
+ return gradgrad_out, None, None, None
+
+
+class FusedLeakyReLUFunction(Function):
+
+ @staticmethod
+ def forward(ctx, input, bias, negative_slope, scale):
+ empty = input.new_empty(0)
+ out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
+ ctx.save_for_backward(out)
+ ctx.negative_slope = negative_slope
+ ctx.scale = scale
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ out, = ctx.saved_tensors
+
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
+
+ return grad_input, grad_bias, None, None
+
+
+class FusedLeakyReLU(nn.Module):
+
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
+ super().__init__()
+
+ self.bias = nn.Parameter(torch.zeros(channel))
+ self.negative_slope = negative_slope
+ self.scale = scale
+
+ def forward(self, input):
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
+
+
+def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp
new file mode 100644
index 000000000..85ed0a79f
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act.cpp
@@ -0,0 +1,26 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp
+#include
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor fused_bias_act(const torch::Tensor& input,
+ const torch::Tensor& bias,
+ const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(bias);
+
+ return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
+}
diff --git a/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
new file mode 100644
index 000000000..54c7ff53c
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/fused_act/src/fused_bias_act_kernel.cu
@@ -0,0 +1,100 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+
+template
+static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
+ int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
+ int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
+
+ scalar_t zero = 0.0;
+
+ for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
+ scalar_t x = p_x[xi];
+
+ if (use_bias) {
+ x += p_b[(xi / step_b) % size_b];
+ }
+
+ scalar_t ref = use_ref ? p_ref[xi] : zero;
+
+ scalar_t y;
+
+ switch (act * 10 + grad) {
+ default:
+ case 10: y = x; break;
+ case 11: y = x; break;
+ case 12: y = 0.0; break;
+
+ case 30: y = (x > 0.0) ? x : x * alpha; break;
+ case 31: y = (ref > 0.0) ? x : x * alpha; break;
+ case 32: y = 0.0; break;
+ }
+
+ out[xi] = y * scale;
+ }
+}
+
+
+torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
+ int act, int grad, float alpha, float scale) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ auto x = input.contiguous();
+ auto b = bias.contiguous();
+ auto ref = refer.contiguous();
+
+ int use_bias = b.numel() ? 1 : 0;
+ int use_ref = ref.numel() ? 1 : 0;
+
+ int size_x = x.numel();
+ int size_b = b.numel();
+ int step_b = 1;
+
+ for (int i = 1 + 1; i < x.dim(); i++) {
+ step_b *= x.size(i);
+ }
+
+ int loop_x = 4;
+ int block_size = 4 * 32;
+ int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
+
+ auto y = torch::empty_like(x);
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
+ fused_bias_act_kernel<<>>(
+ y.data_ptr(),
+ x.data_ptr(),
+ b.data_ptr(),
+ ref.data_ptr(),
+ act,
+ grad,
+ alpha,
+ scale,
+ loop_x,
+ size_x,
+ step_b,
+ size_b,
+ use_bias,
+ use_ref
+ );
+ });
+
+ return y;
+}
diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/__init__.py b/repositories/codeformer/basicsr/ops/upfirdn2d/__init__.py
new file mode 100644
index 000000000..397e85bea
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/upfirdn2d/__init__.py
@@ -0,0 +1,3 @@
+from .upfirdn2d import upfirdn2d
+
+__all__ = ['upfirdn2d']
diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
new file mode 100644
index 000000000..43d0b6783
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d.cpp
@@ -0,0 +1,24 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp
+#include
+
+
+torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1);
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
+#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
+#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
+
+torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
+ int up_x, int up_y, int down_x, int down_y,
+ int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
+ CHECK_CUDA(input);
+ CHECK_CUDA(kernel);
+
+ return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
+}
diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
new file mode 100644
index 000000000..8870063ba
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/upfirdn2d/src/upfirdn2d_kernel.cu
@@ -0,0 +1,370 @@
+// from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d_kernel.cu
+// Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
+//
+// This work is made available under the Nvidia Source Code License-NC.
+// To view a copy of this license, visit
+// https://nvlabs.github.io/stylegan2/license.html
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
+ int c = a / b;
+
+ if (c * b > a) {
+ c--;
+ }
+
+ return c;
+}
+
+struct UpFirDn2DKernelParams {
+ int up_x;
+ int up_y;
+ int down_x;
+ int down_y;
+ int pad_x0;
+ int pad_x1;
+ int pad_y0;
+ int pad_y1;
+
+ int major_dim;
+ int in_h;
+ int in_w;
+ int minor_dim;
+ int kernel_h;
+ int kernel_w;
+ int out_h;
+ int out_w;
+ int loop_major;
+ int loop_x;
+};
+
+template
+__global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ int minor_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ int out_y = minor_idx / p.minor_dim;
+ minor_idx -= out_y * p.minor_dim;
+ int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (out_x_base >= p.out_w || out_y >= p.out_h ||
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0;
+ int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h);
+ int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y;
+ int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y;
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major && major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, out_x = out_x_base;
+ loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) {
+ int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0;
+ int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w);
+ int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x;
+ int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x;
+
+ const scalar_t *x_p =
+ &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim +
+ minor_idx];
+ const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x];
+ int x_px = p.minor_dim;
+ int k_px = -p.up_x;
+ int x_py = p.in_w * p.minor_dim;
+ int k_py = -p.up_y * p.kernel_w;
+
+ scalar_t v = 0.0f;
+
+ for (int y = 0; y < h; y++) {
+ for (int x = 0; x < w; x++) {
+ v += static_cast(*x_p) * static_cast(*k_p);
+ x_p += x_px;
+ k_p += k_px;
+ }
+
+ x_p += x_py - w * x_px;
+ k_p += k_py - w * k_px;
+ }
+
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+}
+
+template
+__global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input,
+ const scalar_t *kernel,
+ const UpFirDn2DKernelParams p) {
+ const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
+ const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
+
+ __shared__ volatile float sk[kernel_h][kernel_w];
+ __shared__ volatile float sx[tile_in_h][tile_in_w];
+
+ int minor_idx = blockIdx.x;
+ int tile_out_y = minor_idx / p.minor_dim;
+ minor_idx -= tile_out_y * p.minor_dim;
+ tile_out_y *= tile_out_h;
+ int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
+ int major_idx_base = blockIdx.z * p.loop_major;
+
+ if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h |
+ major_idx_base >= p.major_dim) {
+ return;
+ }
+
+ for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w;
+ tap_idx += blockDim.x) {
+ int ky = tap_idx / kernel_w;
+ int kx = tap_idx - ky * kernel_w;
+ scalar_t v = 0.0;
+
+ if (kx < p.kernel_w & ky < p.kernel_h) {
+ v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
+ }
+
+ sk[ky][kx] = v;
+ }
+
+ for (int loop_major = 0, major_idx = major_idx_base;
+ loop_major < p.loop_major & major_idx < p.major_dim;
+ loop_major++, major_idx++) {
+ for (int loop_x = 0, tile_out_x = tile_out_x_base;
+ loop_x < p.loop_x & tile_out_x < p.out_w;
+ loop_x++, tile_out_x += tile_out_w) {
+ int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
+ int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
+ int tile_in_x = floor_div(tile_mid_x, up_x);
+ int tile_in_y = floor_div(tile_mid_y, up_y);
+
+ __syncthreads();
+
+ for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w;
+ in_idx += blockDim.x) {
+ int rel_in_y = in_idx / tile_in_w;
+ int rel_in_x = in_idx - rel_in_y * tile_in_w;
+ int in_x = rel_in_x + tile_in_x;
+ int in_y = rel_in_y + tile_in_y;
+
+ scalar_t v = 0.0;
+
+ if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
+ v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) *
+ p.minor_dim +
+ minor_idx];
+ }
+
+ sx[rel_in_y][rel_in_x] = v;
+ }
+
+ __syncthreads();
+ for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w;
+ out_idx += blockDim.x) {
+ int rel_out_y = out_idx / tile_out_w;
+ int rel_out_x = out_idx - rel_out_y * tile_out_w;
+ int out_x = rel_out_x + tile_out_x;
+ int out_y = rel_out_y + tile_out_y;
+
+ int mid_x = tile_mid_x + rel_out_x * down_x;
+ int mid_y = tile_mid_y + rel_out_y * down_y;
+ int in_x = floor_div(mid_x, up_x);
+ int in_y = floor_div(mid_y, up_y);
+ int rel_in_x = in_x - tile_in_x;
+ int rel_in_y = in_y - tile_in_y;
+ int kernel_x = (in_x + 1) * up_x - mid_x - 1;
+ int kernel_y = (in_y + 1) * up_y - mid_y - 1;
+
+ scalar_t v = 0.0;
+
+#pragma unroll
+ for (int y = 0; y < kernel_h / up_y; y++)
+#pragma unroll
+ for (int x = 0; x < kernel_w / up_x; x++)
+ v += sx[rel_in_y + y][rel_in_x + x] *
+ sk[kernel_y + y * up_y][kernel_x + x * up_x];
+
+ if (out_x < p.out_w & out_y < p.out_h) {
+ out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim +
+ minor_idx] = v;
+ }
+ }
+ }
+ }
+}
+
+torch::Tensor upfirdn2d_op(const torch::Tensor &input,
+ const torch::Tensor &kernel, int up_x, int up_y,
+ int down_x, int down_y, int pad_x0, int pad_x1,
+ int pad_y0, int pad_y1) {
+ int curDevice = -1;
+ cudaGetDevice(&curDevice);
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
+
+ UpFirDn2DKernelParams p;
+
+ auto x = input.contiguous();
+ auto k = kernel.contiguous();
+
+ p.major_dim = x.size(0);
+ p.in_h = x.size(1);
+ p.in_w = x.size(2);
+ p.minor_dim = x.size(3);
+ p.kernel_h = k.size(0);
+ p.kernel_w = k.size(1);
+ p.up_x = up_x;
+ p.up_y = up_y;
+ p.down_x = down_x;
+ p.down_y = down_y;
+ p.pad_x0 = pad_x0;
+ p.pad_x1 = pad_x1;
+ p.pad_y0 = pad_y0;
+ p.pad_y1 = pad_y1;
+
+ p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) /
+ p.down_y;
+ p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) /
+ p.down_x;
+
+ auto out =
+ at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
+
+ int mode = -1;
+
+ int tile_out_h = -1;
+ int tile_out_w = -1;
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 1;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 3 && p.kernel_w <= 3) {
+ mode = 2;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 3;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 4;
+ tile_out_h = 16;
+ tile_out_w = 64;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 4 && p.kernel_w <= 4) {
+ mode = 5;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 &&
+ p.kernel_h <= 2 && p.kernel_w <= 2) {
+ mode = 6;
+ tile_out_h = 8;
+ tile_out_w = 32;
+ }
+
+ dim3 block_size;
+ dim3 grid_size;
+
+ if (tile_out_h > 0 && tile_out_w > 0) {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 1;
+ block_size = dim3(32 * 8, 1, 1);
+ grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
+ (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ } else {
+ p.loop_major = (p.major_dim - 1) / 16384 + 1;
+ p.loop_x = 4;
+ block_size = dim3(4, 32, 1);
+ grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1,
+ (p.out_w - 1) / (p.loop_x * block_size.y) + 1,
+ (p.major_dim - 1) / p.loop_major + 1);
+ }
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
+ switch (mode) {
+ case 1:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 2:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 3:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 4:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 5:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ case 6:
+ upfirdn2d_kernel
+ <<>>(out.data_ptr(),
+ x.data_ptr(),
+ k.data_ptr(), p);
+
+ break;
+
+ default:
+ upfirdn2d_kernel_large<<>>(
+ out.data_ptr(), x.data_ptr(),
+ k.data_ptr(), p);
+ }
+ });
+
+ return out;
+}
diff --git a/repositories/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py b/repositories/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py
new file mode 100644
index 000000000..667f96e1d
--- /dev/null
+++ b/repositories/codeformer/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -0,0 +1,186 @@
+# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
+
+import torch
+from torch.autograd import Function
+from torch.nn import functional as F
+
+try:
+ from . import upfirdn2d_ext
+except ImportError:
+ import os
+ BASICSR_JIT = os.getenv('BASICSR_JIT')
+ if BASICSR_JIT == 'True':
+ from torch.utils.cpp_extension import load
+ module_path = os.path.dirname(__file__)
+ upfirdn2d_ext = load(
+ 'upfirdn2d',
+ sources=[
+ os.path.join(module_path, 'src', 'upfirdn2d.cpp'),
+ os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'),
+ ],
+ )
+
+
+class UpFirDn2dBackward(Function):
+
+ @staticmethod
+ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
+
+ up_x, up_y = up
+ down_x, down_y = down
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
+
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
+
+ grad_input = upfirdn2d_ext.upfirdn2d(
+ grad_output,
+ grad_kernel,
+ down_x,
+ down_y,
+ up_x,
+ up_y,
+ g_pad_x0,
+ g_pad_x1,
+ g_pad_y0,
+ g_pad_y1,
+ )
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
+
+ ctx.save_for_backward(kernel)
+
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ ctx.up_x = up_x
+ ctx.up_y = up_y
+ ctx.down_x = down_x
+ ctx.down_y = down_y
+ ctx.pad_x0 = pad_x0
+ ctx.pad_x1 = pad_x1
+ ctx.pad_y0 = pad_y0
+ ctx.pad_y1 = pad_y1
+ ctx.in_size = in_size
+ ctx.out_size = out_size
+
+ return grad_input
+
+ @staticmethod
+ def backward(ctx, gradgrad_input):
+ kernel, = ctx.saved_tensors
+
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
+
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
+ gradgrad_input,
+ kernel,
+ ctx.up_x,
+ ctx.up_y,
+ ctx.down_x,
+ ctx.down_y,
+ ctx.pad_x0,
+ ctx.pad_x1,
+ ctx.pad_y0,
+ ctx.pad_y1,
+ )
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
+ # ctx.out_size[1], ctx.in_size[3])
+ gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1])
+
+ return gradgrad_out, None, None, None, None, None, None, None, None
+
+
+class UpFirDn2d(Function):
+
+ @staticmethod
+ def forward(ctx, input, kernel, up, down, pad):
+ up_x, up_y = up
+ down_x, down_y = down
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
+
+ kernel_h, kernel_w = kernel.shape
+ batch, channel, in_h, in_w = input.shape
+ ctx.in_size = input.shape
+
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ ctx.out_size = (out_h, out_w)
+
+ ctx.up = (up_x, up_y)
+ ctx.down = (down_x, down_y)
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
+
+ g_pad_x0 = kernel_w - pad_x0 - 1
+ g_pad_y0 = kernel_h - pad_y0 - 1
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
+
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
+
+ out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1)
+ # out = out.view(major, out_h, out_w, minor)
+ out = out.view(-1, channel, out_h, out_w)
+
+ return out
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ kernel, grad_kernel = ctx.saved_tensors
+
+ grad_input = UpFirDn2dBackward.apply(
+ grad_output,
+ kernel,
+ grad_kernel,
+ ctx.up,
+ ctx.down,
+ ctx.pad,
+ ctx.g_pad,
+ ctx.in_size,
+ ctx.out_size,
+ )
+
+ return grad_input, None, None, None, None
+
+
+def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
+ if input.device.type == 'cpu':
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+ else:
+ out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]))
+
+ return out
+
+
+def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
+ _, in_h, in_w, minor = input.shape
+ kernel_h, kernel_w = kernel.shape
+
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
+
+ out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
+ out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+
+ out = out.permute(0, 3, 1, 2)
+ out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
+ out = F.conv2d(out, w)
+ out = out.reshape(
+ -1,
+ minor,
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
+ )
+ out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+
+ return out.view(-1, channel, out_h, out_w)
diff --git a/repositories/codeformer/basicsr/setup.py b/repositories/codeformer/basicsr/setup.py
new file mode 100644
index 000000000..382a2aa10
--- /dev/null
+++ b/repositories/codeformer/basicsr/setup.py
@@ -0,0 +1,165 @@
+#!/usr/bin/env python
+
+from setuptools import find_packages, setup
+
+import os
+import subprocess
+import sys
+import time
+import torch
+from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
+
+version_file = './basicsr/version.py'
+
+
+def readme():
+ with open('README.md', encoding='utf-8') as f:
+ content = f.read()
+ return content
+
+
+def get_git_hash():
+
+ def _minimal_ext_cmd(cmd):
+ # construct minimal environment
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH', 'HOME']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ # LANGUAGE is used on win32
+ env['LANGUAGE'] = 'C'
+ env['LANG'] = 'C'
+ env['LC_ALL'] = 'C'
+ out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ return out
+
+ try:
+ out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
+ sha = out.strip().decode('ascii')
+ except OSError:
+ sha = 'unknown'
+
+ return sha
+
+
+def get_hash():
+ if os.path.exists('.git'):
+ sha = get_git_hash()[:7]
+ elif os.path.exists(version_file):
+ try:
+ from version import __version__
+ sha = __version__.split('+')[-1]
+ except ImportError:
+ raise ImportError('Unable to get git version')
+ else:
+ sha = 'unknown'
+
+ return sha
+
+
+def write_version_py():
+ content = """# GENERATED VERSION FILE
+# TIME: {}
+__version__ = '{}'
+__gitsha__ = '{}'
+version_info = ({})
+"""
+ sha = get_hash()
+ with open('./basicsr/VERSION', 'r') as f:
+ SHORT_VERSION = f.read().strip()
+ VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
+
+ version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
+ with open(version_file, 'w') as f:
+ f.write(version_file_str)
+
+
+def get_version():
+ with open(version_file, 'r') as f:
+ exec(compile(f.read(), version_file, 'exec'))
+ return locals()['__version__']
+
+
+def make_cuda_ext(name, module, sources, sources_cuda=None):
+ if sources_cuda is None:
+ sources_cuda = []
+ define_macros = []
+ extra_compile_args = {'cxx': []}
+
+ if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
+ define_macros += [('WITH_CUDA', None)]
+ extension = CUDAExtension
+ extra_compile_args['nvcc'] = [
+ '-D__CUDA_NO_HALF_OPERATORS__',
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
+ '-D__CUDA_NO_HALF2_OPERATORS__',
+ ]
+ sources += sources_cuda
+ else:
+ print(f'Compiling {name} without CUDA')
+ extension = CppExtension
+
+ return extension(
+ name=f'{module}.{name}',
+ sources=[os.path.join(*module.split('.'), p) for p in sources],
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args)
+
+
+def get_requirements(filename='requirements.txt'):
+ with open(os.path.join('.', filename), 'r') as f:
+ requires = [line.replace('\n', '') for line in f.readlines()]
+ return requires
+
+
+if __name__ == '__main__':
+ if '--cuda_ext' in sys.argv:
+ ext_modules = [
+ make_cuda_ext(
+ name='deform_conv_ext',
+ module='ops.dcn',
+ sources=['src/deform_conv_ext.cpp'],
+ sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
+ make_cuda_ext(
+ name='fused_act_ext',
+ module='ops.fused_act',
+ sources=['src/fused_bias_act.cpp'],
+ sources_cuda=['src/fused_bias_act_kernel.cu']),
+ make_cuda_ext(
+ name='upfirdn2d_ext',
+ module='ops.upfirdn2d',
+ sources=['src/upfirdn2d.cpp'],
+ sources_cuda=['src/upfirdn2d_kernel.cu']),
+ ]
+ sys.argv.remove('--cuda_ext')
+ else:
+ ext_modules = []
+
+ write_version_py()
+ setup(
+ name='basicsr',
+ version=get_version(),
+ description='Open Source Image and Video Super-Resolution Toolbox',
+ long_description=readme(),
+ long_description_content_type='text/markdown',
+ author='Xintao Wang',
+ author_email='xintao.wang@outlook.com',
+ keywords='computer vision, restoration, super resolution',
+ url='https://github.com/xinntao/BasicSR',
+ include_package_data=True,
+ packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'License :: OSI Approved :: Apache Software License',
+ 'Operating System :: OS Independent',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ ],
+ license='Apache License 2.0',
+ setup_requires=['cython', 'numpy'],
+ install_requires=get_requirements(),
+ ext_modules=ext_modules,
+ cmdclass={'build_ext': BuildExtension},
+ zip_safe=False)
diff --git a/repositories/codeformer/basicsr/train.py b/repositories/codeformer/basicsr/train.py
new file mode 100644
index 000000000..a01c0dfcc
--- /dev/null
+++ b/repositories/codeformer/basicsr/train.py
@@ -0,0 +1,225 @@
+import argparse
+import datetime
+import logging
+import math
+import copy
+import random
+import time
+import torch
+from os import path as osp
+
+from basicsr.data import build_dataloader, build_dataset
+from basicsr.data.data_sampler import EnlargedSampler
+from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
+from basicsr.models import build_model
+from basicsr.utils import (MessageLogger, check_resume, get_env_info, get_root_logger, init_tb_logger,
+ init_wandb_logger, make_exp_dirs, mkdir_and_rename, set_random_seed)
+from basicsr.utils.dist_util import get_dist_info, init_dist
+from basicsr.utils.options import dict2str, parse
+
+import warnings
+# ignore UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`.
+warnings.filterwarnings("ignore", category=UserWarning)
+
+def parse_options(root_path, is_train=True):
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ opt = parse(args.opt, root_path, is_train=is_train)
+
+ # distributed settings
+ if args.launcher == 'none':
+ opt['dist'] = False
+ print('Disable distributed.', flush=True)
+ else:
+ opt['dist'] = True
+ if args.launcher == 'slurm' and 'dist_params' in opt:
+ init_dist(args.launcher, **opt['dist_params'])
+ else:
+ init_dist(args.launcher)
+
+ opt['rank'], opt['world_size'] = get_dist_info()
+
+ # random seed
+ seed = opt.get('manual_seed')
+ if seed is None:
+ seed = random.randint(1, 10000)
+ opt['manual_seed'] = seed
+ set_random_seed(seed + opt['rank'])
+
+ return opt
+
+
+def init_loggers(opt):
+ log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
+ logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
+ logger.info(get_env_info())
+ logger.info(dict2str(opt))
+
+ # initialize wandb logger before tensorboard logger to allow proper sync:
+ if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') is not None):
+ assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
+ init_wandb_logger(opt)
+ tb_logger = None
+ if opt['logger'].get('use_tb_logger'):
+ tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
+ return logger, tb_logger
+
+
+def create_train_val_dataloader(opt, logger):
+ # create train and val dataloaders
+ train_loader, val_loader = None, None
+ for phase, dataset_opt in opt['datasets'].items():
+ if phase == 'train':
+ dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
+ train_set = build_dataset(dataset_opt)
+ train_sampler = EnlargedSampler(train_set, opt['world_size'], opt['rank'], dataset_enlarge_ratio)
+ train_loader = build_dataloader(
+ train_set,
+ dataset_opt,
+ num_gpu=opt['num_gpu'],
+ dist=opt['dist'],
+ sampler=train_sampler,
+ seed=opt['manual_seed'])
+
+ num_iter_per_epoch = math.ceil(
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
+ total_iters = int(opt['train']['total_iter'])
+ total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
+ logger.info('Training statistics:'
+ f'\n\tNumber of train images: {len(train_set)}'
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+
+ elif phase == 'val':
+ val_set = build_dataset(dataset_opt)
+ val_loader = build_dataloader(
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+ logger.info(f'Number of val images/folders in {dataset_opt["name"]}: ' f'{len(val_set)}')
+ else:
+ raise ValueError(f'Dataset phase {phase} is not recognized.')
+
+ return train_loader, train_sampler, val_loader, total_epochs, total_iters
+
+
+def train_pipeline(root_path):
+ # parse options, set distributed setting, set ramdom seed
+ opt = parse_options(root_path, is_train=True)
+
+ torch.backends.cudnn.benchmark = True
+ # torch.backends.cudnn.deterministic = True
+
+ # load resume states if necessary
+ if opt['path'].get('resume_state'):
+ device_id = torch.cuda.current_device()
+ resume_state = torch.load(
+ opt['path']['resume_state'], map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ resume_state = None
+
+ # mkdir for experiments and logger
+ if resume_state is None:
+ make_exp_dirs(opt)
+ if opt['logger'].get('use_tb_logger') and opt['rank'] == 0:
+ mkdir_and_rename(osp.join('tb_logger', opt['name']))
+
+ # initialize loggers
+ logger, tb_logger = init_loggers(opt)
+
+ # create train and validation dataloaders
+ result = create_train_val_dataloader(opt, logger)
+ train_loader, train_sampler, val_loader, total_epochs, total_iters = result
+
+ # create model
+ if resume_state: # resume training
+ check_resume(opt, resume_state['iter'])
+ model = build_model(opt)
+ model.resume_training(resume_state) # handle optimizers and schedulers
+ logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " f"iter: {resume_state['iter']}.")
+ start_epoch = resume_state['epoch']
+ current_iter = resume_state['iter']
+ else:
+ model = build_model(opt)
+ start_epoch = 0
+ current_iter = 0
+
+ # create message logger (formatted outputs)
+ msg_logger = MessageLogger(opt, current_iter, tb_logger)
+
+ # dataloader prefetcher
+ prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
+ if prefetch_mode is None or prefetch_mode == 'cpu':
+ prefetcher = CPUPrefetcher(train_loader)
+ elif prefetch_mode == 'cuda':
+ prefetcher = CUDAPrefetcher(train_loader, opt)
+ logger.info(f'Use {prefetch_mode} prefetch dataloader')
+ if opt['datasets']['train'].get('pin_memory') is not True:
+ raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
+ else:
+ raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' "Supported ones are: None, 'cuda', 'cpu'.")
+
+ # training
+ logger.info(f'Start training from epoch: {start_epoch}, iter: {current_iter+1}')
+ data_time, iter_time = time.time(), time.time()
+ start_time = time.time()
+
+ for epoch in range(start_epoch, total_epochs + 1):
+ train_sampler.set_epoch(epoch)
+ prefetcher.reset()
+ train_data = prefetcher.next()
+
+ while train_data is not None:
+ data_time = time.time() - data_time
+
+ current_iter += 1
+ if current_iter > total_iters:
+ break
+ # update learning rate
+ model.update_learning_rate(current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
+ # training
+ model.feed_data(train_data)
+ model.optimize_parameters(current_iter)
+ iter_time = time.time() - iter_time
+ # log
+ if current_iter % opt['logger']['print_freq'] == 0:
+ log_vars = {'epoch': epoch, 'iter': current_iter}
+ log_vars.update({'lrs': model.get_current_learning_rate()})
+ log_vars.update({'time': iter_time, 'data_time': data_time})
+ log_vars.update(model.get_current_log())
+ msg_logger(log_vars)
+
+ # save models and training states
+ if current_iter % opt['logger']['save_checkpoint_freq'] == 0:
+ logger.info('Saving models and training states.')
+ model.save(epoch, current_iter)
+
+ # validation
+ if opt.get('val') is not None and opt['datasets'].get('val') is not None \
+ and (current_iter % opt['val']['val_freq'] == 0):
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+
+ data_time = time.time()
+ iter_time = time.time()
+ train_data = prefetcher.next()
+ # end of iter
+
+ # end of epoch
+
+ consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
+ logger.info(f'End of training. Time consumed: {consumed_time}')
+ logger.info('Save the latest model.')
+ model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
+ if opt.get('val') is not None and opt['datasets'].get('val'):
+ model.validation(val_loader, current_iter, tb_logger, opt['val']['save_img'])
+ if tb_logger:
+ tb_logger.close()
+
+
+if __name__ == '__main__':
+ root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
+ train_pipeline(root_path)
diff --git a/repositories/codeformer/basicsr/utils/__init__.py b/repositories/codeformer/basicsr/utils/__init__.py
new file mode 100644
index 000000000..5fcc1d540
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/__init__.py
@@ -0,0 +1,29 @@
+from .file_client import FileClient
+from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
+from .logger import MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
+from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
+
+__all__ = [
+ # file_client.py
+ 'FileClient',
+ # img_util.py
+ 'img2tensor',
+ 'tensor2img',
+ 'imfrombytes',
+ 'imwrite',
+ 'crop_border',
+ # logger.py
+ 'MessageLogger',
+ 'init_tb_logger',
+ 'init_wandb_logger',
+ 'get_root_logger',
+ 'get_env_info',
+ # misc.py
+ 'set_random_seed',
+ 'get_time_str',
+ 'mkdir_and_rename',
+ 'make_exp_dirs',
+ 'scandir',
+ 'check_resume',
+ 'sizeof_fmt'
+]
diff --git a/repositories/codeformer/basicsr/utils/dist_util.py b/repositories/codeformer/basicsr/utils/dist_util.py
new file mode 100644
index 000000000..0fab887b2
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/dist_util.py
@@ -0,0 +1,82 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
+import functools
+import os
+import subprocess
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+
+def init_dist(launcher, backend='nccl', **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method('spawn')
+ if launcher == 'pytorch':
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == 'slurm':
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError(f'Invalid launcher type: {launcher}')
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ rank = int(os.environ['RANK'])
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_slurm(backend, port=None):
+ """Initialize slurm distributed training environment.
+
+ If argument ``port`` is not specified, then the master port will be system
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
+ environment variable, then a default port ``29500`` will be used.
+
+ Args:
+ backend (str): Backend of torch.distributed.
+ port (int, optional): Master port. Defaults to None.
+ """
+ proc_id = int(os.environ['SLURM_PROCID'])
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
+ # specify master port
+ if port is not None:
+ os.environ['MASTER_PORT'] = str(port)
+ elif 'MASTER_PORT' in os.environ:
+ pass # use MASTER_PORT in the environment variable
+ else:
+ # 29500 is torch.distributed default port
+ os.environ['MASTER_PORT'] = '29500'
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['RANK'] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def get_dist_info():
+ if dist.is_available():
+ initialized = dist.is_initialized()
+ else:
+ initialized = False
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/repositories/codeformer/basicsr/utils/download_util.py b/repositories/codeformer/basicsr/utils/download_util.py
new file mode 100644
index 000000000..2a2679157
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/download_util.py
@@ -0,0 +1,95 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+from .misc import sizeof_fmt
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ print(response_file_size)
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/utils/file_client.py b/repositories/codeformer/basicsr/utils/file_client.py
new file mode 100644
index 000000000..7f38d9796
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/file_client.py
@@ -0,0 +1,167 @@
+# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing differnet lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/repositories/codeformer/basicsr/utils/img_util.py b/repositories/codeformer/basicsr/utils/img_util.py
new file mode 100644
index 000000000..5aba82ce0
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/img_util.py
@@ -0,0 +1,171 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+ """This implementation is slightly faster than tensor2img.
+ It now only supports torch tensor with shape (1, c, h, w).
+
+ Args:
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+ min_max (tuple[int]): min and max values for clamp.
+ """
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+ output = output.type(torch.uint8).cpu().numpy()
+ if rgb2bgr:
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+ return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
+
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/utils/lmdb_util.py b/repositories/codeformer/basicsr/utils/lmdb_util.py
new file mode 100644
index 000000000..e0a10f60f
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/lmdb_util.py
@@ -0,0 +1,196 @@
+import cv2
+import lmdb
+import sys
+from multiprocessing import Pool
+from os import path as osp
+from tqdm import tqdm
+
+
+def make_lmdb_from_imgs(data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None):
+ """Make lmdb from images.
+
+ Contents of lmdb. The file structure is:
+ example.lmdb
+ ├── data.mdb
+ ├── lock.mdb
+ ├── meta_info.txt
+
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
+ https://lmdb.readthedocs.io/en/release/ for more details.
+
+ The meta_info.txt is a specified txt file to record the meta information
+ of our datasets. It will be automatically created when preparing
+ datasets by our provided dataset tools.
+ Each line in the txt file records 1)image name (with extension),
+ 2)image shape, and 3)compression level, separated by a white space.
+
+ For example, the meta information could be:
+ `000_00000000.png (720,1280,3) 1`, which means:
+ 1) image name (with extension): 000_00000000.png;
+ 2) image shape: (720,1280,3);
+ 3) compression level: 1
+
+ We use the image name without extension as the lmdb key.
+
+ If `multiprocessing_read` is True, it will read all the images to memory
+ using multiprocessing. Thus, your server needs to have enough memory.
+
+ Args:
+ data_path (str): Data path for reading images.
+ lmdb_path (str): Lmdb save path.
+ img_path_list (str): Image path list.
+ keys (str): Used for lmdb keys.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ multiprocessing_read (bool): Whether use multiprocessing to read all
+ the images to memory. Default: False.
+ n_thread (int): For multiprocessing.
+ map_size (int | None): Map size for lmdb env. If None, use the
+ estimated size from images. Default: None
+ """
+
+ assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
+ f'but got {len(img_path_list)} and {len(keys)}')
+ print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
+ print(f'Totoal images: {len(img_path_list)}')
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ if multiprocessing_read:
+ # read all the images to memory (multiprocessing)
+ dataset = {} # use dict to keep the order for multiprocessing
+ shapes = {}
+ print(f'Read images with multiprocessing, #thread: {n_thread} ...')
+ pbar = tqdm(total=len(img_path_list), unit='image')
+
+ def callback(arg):
+ """get the image data and update pbar."""
+ key, dataset[key], shapes[key] = arg
+ pbar.update(1)
+ pbar.set_description(f'Read {key}')
+
+ pool = Pool(n_thread)
+ for path, key in zip(img_path_list, keys):
+ pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback)
+ pool.close()
+ pool.join()
+ pbar.close()
+ print(f'Finish reading {len(img_path_list)} images.')
+
+ # create lmdb environment
+ if map_size is None:
+ # obtain data size for one image
+ img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ data_size_per_img = img_byte.nbytes
+ print('Data size per image is: ', data_size_per_img)
+ data_size = data_size_per_img * len(img_path_list)
+ map_size = data_size * 10
+
+ env = lmdb.open(lmdb_path, map_size=map_size)
+
+ # write data to lmdb
+ pbar = tqdm(total=len(img_path_list), unit='chunk')
+ txn = env.begin(write=True)
+ txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ for idx, (path, key) in enumerate(zip(img_path_list, keys)):
+ pbar.update(1)
+ pbar.set_description(f'Write {key}')
+ key_byte = key.encode('ascii')
+ if multiprocessing_read:
+ img_byte = dataset[key]
+ h, w, c = shapes[key]
+ else:
+ _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level)
+ h, w, c = img_shape
+
+ txn.put(key_byte, img_byte)
+ # write meta information
+ txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
+ if idx % batch == 0:
+ txn.commit()
+ txn = env.begin(write=True)
+ pbar.close()
+ txn.commit()
+ env.close()
+ txt_file.close()
+ print('\nFinish writing lmdb.')
+
+
+def read_img_worker(path, key, compress_level):
+ """Read image worker.
+
+ Args:
+ path (str): Image path.
+ key (str): Image key.
+ compress_level (int): Compress level when encoding images.
+
+ Returns:
+ str: Image key.
+ byte: Image byte.
+ tuple[int]: Image shape.
+ """
+
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
+ if img.ndim == 2:
+ h, w = img.shape
+ c = 1
+ else:
+ h, w, c = img.shape
+ _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
+ return (key, img_byte, (h, w, c))
+
+
+class LmdbMaker():
+ """LMDB Maker.
+
+ Args:
+ lmdb_path (str): Lmdb save path.
+ map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
+ batch (int): After processing batch images, lmdb commits.
+ Default: 5000.
+ compress_level (int): Compress level when encoding images. Default: 1.
+ """
+
+ def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1):
+ if not lmdb_path.endswith('.lmdb'):
+ raise ValueError("lmdb_path must end with '.lmdb'.")
+ if osp.exists(lmdb_path):
+ print(f'Folder {lmdb_path} already exists. Exit.')
+ sys.exit(1)
+
+ self.lmdb_path = lmdb_path
+ self.batch = batch
+ self.compress_level = compress_level
+ self.env = lmdb.open(lmdb_path, map_size=map_size)
+ self.txn = self.env.begin(write=True)
+ self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
+ self.counter = 0
+
+ def put(self, img_byte, key, img_shape):
+ self.counter += 1
+ key_byte = key.encode('ascii')
+ self.txn.put(key_byte, img_byte)
+ # write meta information
+ h, w, c = img_shape
+ self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
+ if self.counter % self.batch == 0:
+ self.txn.commit()
+ self.txn = self.env.begin(write=True)
+
+ def close(self):
+ self.txn.commit()
+ self.env.close()
+ self.txt_file.close()
diff --git a/repositories/codeformer/basicsr/utils/logger.py b/repositories/codeformer/basicsr/utils/logger.py
new file mode 100644
index 000000000..9714bf59c
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/logger.py
@@ -0,0 +1,169 @@
+import datetime
+import logging
+import time
+
+from .dist_util import get_dist_info, master_only
+
+initialized_logger = {}
+
+
+class MessageLogger():
+ """Message logger for printing.
+ Args:
+ opt (dict): Config. It contains the following keys:
+ name (str): Exp name.
+ logger (dict): Contains 'print_freq' (str) for logger interval.
+ train (dict): Contains 'total_iter' (int) for total iters.
+ use_tb_logger (bool): Use tensorboard logger.
+ start_iter (int): Start iter. Default: 1.
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
+ """
+
+ def __init__(self, opt, start_iter=1, tb_logger=None):
+ self.exp_name = opt['name']
+ self.interval = opt['logger']['print_freq']
+ self.start_iter = start_iter
+ self.max_iters = opt['train']['total_iter']
+ self.use_tb_logger = opt['logger']['use_tb_logger']
+ self.tb_logger = tb_logger
+ self.start_time = time.time()
+ self.logger = get_root_logger()
+
+ @master_only
+ def __call__(self, log_vars):
+ """Format logging message.
+ Args:
+ log_vars (dict): It contains the following keys:
+ epoch (int): Epoch number.
+ iter (int): Current iter.
+ lrs (list): List for learning rates.
+ time (float): Iter time.
+ data_time (float): Data time for each iter.
+ """
+ # epoch, iter, learning rates
+ epoch = log_vars.pop('epoch')
+ current_iter = log_vars.pop('iter')
+ lrs = log_vars.pop('lrs')
+
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
+ for v in lrs:
+ message += f'{v:.3e},'
+ message += ')] '
+
+ # time and estimated time
+ if 'time' in log_vars.keys():
+ iter_time = log_vars.pop('time')
+ data_time = log_vars.pop('data_time')
+
+ total_time = time.time() - self.start_time
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ message += f'[eta: {eta_str}, '
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
+
+ # other items, especially losses
+ for k, v in log_vars.items():
+ message += f'{k}: {v:.4e} '
+ # tensorboard logger
+ if self.use_tb_logger:
+ if k.startswith('l_'):
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
+ else:
+ self.tb_logger.add_scalar(k, v, current_iter)
+ self.logger.info(message)
+
+
+@master_only
+def init_tb_logger(log_dir):
+ from torch.utils.tensorboard import SummaryWriter
+ tb_logger = SummaryWriter(log_dir=log_dir)
+ return tb_logger
+
+
+@master_only
+def init_wandb_logger(opt):
+ """We now only use wandb to sync tensorboard log."""
+ import wandb
+ logger = logging.getLogger('basicsr')
+
+ project = opt['logger']['wandb']['project']
+ resume_id = opt['logger']['wandb'].get('resume_id')
+ if resume_id:
+ wandb_id = resume_id
+ resume = 'allow'
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
+ else:
+ wandb_id = wandb.util.generate_id()
+ resume = 'never'
+
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
+
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
+
+
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+ rank, _ = get_dist_info()
+ if rank != 0:
+ logger.setLevel('ERROR')
+ elif log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ # file_handler = logging.FileHandler(log_file, 'w')
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+def get_env_info():
+ """Get environment information.
+ Currently, only log the software version.
+ """
+ import torch
+ import torchvision
+
+ from basicsr.version import __version__
+ msg = r"""
+ ____ _ _____ ____
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
+ ______ __ __ __ __
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
+ """
+ msg += ('\nVersion Information: '
+ f'\n\tBasicSR: {__version__}'
+ f'\n\tPyTorch: {torch.__version__}'
+ f'\n\tTorchVision: {torchvision.__version__}')
+ return msg
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/utils/matlab_functions.py b/repositories/codeformer/basicsr/utils/matlab_functions.py
new file mode 100644
index 000000000..c6ce1004a
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/matlab_functions.py
@@ -0,0 +1,347 @@
+import math
+import numpy as np
+import torch
+
+
+def cubic(x):
+ """cubic function used for calculate_weights_indices."""
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
+ (absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ """Calculate weights and indices, used for imresize function.
+
+ Args:
+ in_length (int): Input length.
+ out_length (int): Output length.
+ scale (float): Scale factor.
+ kernel_width (int): Kernel width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ """
+
+ if (scale < 1) and antialiasing:
+ # Use a modified kernel (larger kernel width) to simultaneously
+ # interpolate and antialias
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ p = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
+ out_length, p)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
+
+ # apply cubic kernel
+ if (scale < 1) and antialiasing:
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, p)
+
+ # If a column in weights is all zero, get rid of it. only consider the
+ # first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, p - 2)
+ weights = weights.narrow(1, 1, p - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, p - 2)
+ weights = weights.narrow(1, 0, p - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+@torch.no_grad()
+def imresize(img, scale, antialiasing=True):
+ """imresize function same as MATLAB.
+
+ It now only supports bicubic.
+ The same scale applies for both height and width.
+
+ Args:
+ img (Tensor | Numpy array):
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
+ scale (float): Scale factor. The same scale applies for both height
+ and width.
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
+ Default: True.
+
+ Returns:
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
+ """
+ if type(img).__module__ == np.__name__: # numpy type
+ numpy_type = True
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
+ else:
+ numpy_type = False
+
+ in_c, in_h, in_w = img.size()
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # get weights and indices
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
+ antialiasing)
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
+ antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
+
+ sym_patch = img[:, :sym_len_hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_he:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
+ kernel_width = weights_h.size(1)
+ for i in range(out_h):
+ idx = int(indices_h[i][0])
+ for j in range(in_c):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_we:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
+ kernel_width = weights_w.size(1)
+ for i in range(out_w):
+ idx = int(indices_w[i][0])
+ for j in range(in_c):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+
+ if numpy_type:
+ out_2 = out_2.numpy().transpose(1, 2, 0)
+ return out_2
+
+
+def rgb2ycbcr(img, y_only=False):
+ """Convert a RGB image to YCbCr image.
+
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def bgr2ycbcr(img, y_only=False):
+ """Convert a BGR image to YCbCr image.
+
+ The bgr version of rgb2ycbcr.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+ y_only (bool): Whether to only return Y channel. Default: False.
+
+ Returns:
+ ndarray: The converted YCbCr image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img)
+ if y_only:
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
+ else:
+ out_img = np.matmul(
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2rgb(img):
+ """Convert a YCbCr image to RGB image.
+
+ This function produces the same results as Matlab's ycbcr2rgb function.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted RGB image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def ycbcr2bgr(img):
+ """Convert a YCbCr image to BGR image.
+
+ The bgr version of ycbcr2rgb.
+ It implements the ITU-R BT.601 conversion for standard-definition
+ television. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
+
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
+ In OpenCV, it implements a JPEG conversion. See more details in
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ ndarray: The converted BGR image. The output image has the same type
+ and range as input image.
+ """
+ img_type = img.dtype
+ img = _convert_input_type_range(img) * 255
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
+ out_img = _convert_output_type_range(out_img, img_type)
+ return out_img
+
+
+def _convert_input_type_range(img):
+ """Convert the type and range of the input image.
+
+ It converts the input image to np.float32 type and range of [0, 1].
+ It is mainly used for pre-processing the input image in colorspace
+ convertion functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The input image. It accepts:
+ 1. np.uint8 type with range [0, 255];
+ 2. np.float32 type with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with type of np.float32 and range of
+ [0, 1].
+ """
+ img_type = img.dtype
+ img = img.astype(np.float32)
+ if img_type == np.float32:
+ pass
+ elif img_type == np.uint8:
+ img /= 255.
+ else:
+ raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
+ return img
+
+
+def _convert_output_type_range(img, dst_type):
+ """Convert the type and range of the image according to dst_type.
+
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
+ images will be converted to np.uint8 type with range [0, 255]. If
+ `dst_type` is np.float32, it converts the image to np.float32 type with
+ range [0, 1].
+ It is mainly used for post-processing images in colorspace convertion
+ functions such as rgb2ycbcr and ycbcr2rgb.
+
+ Args:
+ img (ndarray): The image to be converted with np.float32 type and
+ range [0, 255].
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
+ converts the image to np.uint8 type with range [0, 255]. If
+ dst_type is np.float32, it converts the image to np.float32 type
+ with range [0, 1].
+
+ Returns:
+ (ndarray): The converted image with desired type and range.
+ """
+ if dst_type not in (np.uint8, np.float32):
+ raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
+ if dst_type == np.uint8:
+ img = img.round()
+ else:
+ img /= 255.
+ return img.astype(dst_type)
diff --git a/repositories/codeformer/basicsr/utils/misc.py b/repositories/codeformer/basicsr/utils/misc.py
new file mode 100644
index 000000000..3b444ff3b
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/misc.py
@@ -0,0 +1,134 @@
+import numpy as np
+import os
+import random
+import time
+import torch
+from os import path as osp
+
+from .dist_util import master_only
+from .logger import get_root_logger
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def mkdir_and_rename(path):
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
+
+ Args:
+ path (str): Folder path.
+ """
+ if osp.exists(path):
+ new_name = path + '_archived_' + get_time_str()
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
+ os.rename(path, new_name)
+ os.makedirs(path, exist_ok=True)
+
+
+@master_only
+def make_exp_dirs(opt):
+ """Make dirs for experiments."""
+ path_opt = opt['path'].copy()
+ if opt['is_train']:
+ mkdir_and_rename(path_opt.pop('experiments_root'))
+ else:
+ mkdir_and_rename(path_opt.pop('results_root'))
+ for key, path in path_opt.items():
+ if ('strict_load' not in key) and ('pretrain_network' not in key) and ('resume' not in key):
+ os.makedirs(path, exist_ok=True)
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative pathes.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def check_resume(opt, resume_iter):
+ """Check resume states and pretrain_network paths.
+
+ Args:
+ opt (dict): Options.
+ resume_iter (int): Resume iteration.
+ """
+ logger = get_root_logger()
+ if opt['path']['resume_state']:
+ # get all the networks
+ networks = [key for key in opt.keys() if key.startswith('network_')]
+ flag_pretrain = False
+ for network in networks:
+ if opt['path'].get(f'pretrain_{network}') is not None:
+ flag_pretrain = True
+ if flag_pretrain:
+ logger.warning('pretrain_network path will be ignored during resuming.')
+ # set pretrained model paths
+ for network in networks:
+ name = f'pretrain_{network}'
+ basename = network.replace('network_', '')
+ if opt['path'].get('ignore_resume_networks') is None or (basename
+ not in opt['path']['ignore_resume_networks']):
+ opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
+ logger.info(f"Set {name} to {opt['path'][name]}")
+
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
diff --git a/repositories/codeformer/basicsr/utils/options.py b/repositories/codeformer/basicsr/utils/options.py
new file mode 100644
index 000000000..db490e4aa
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/options.py
@@ -0,0 +1,108 @@
+import yaml
+import time
+from collections import OrderedDict
+from os import path as osp
+from basicsr.utils.misc import get_time_str
+
+def ordered_yaml():
+ """Support OrderedDict for yaml.
+
+ Returns:
+ yaml Loader and Dumper.
+ """
+ try:
+ from yaml import CDumper as Dumper
+ from yaml import CLoader as Loader
+ except ImportError:
+ from yaml import Dumper, Loader
+
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
+
+ def dict_representer(dumper, data):
+ return dumper.represent_dict(data.items())
+
+ def dict_constructor(loader, node):
+ return OrderedDict(loader.construct_pairs(node))
+
+ Dumper.add_representer(OrderedDict, dict_representer)
+ Loader.add_constructor(_mapping_tag, dict_constructor)
+ return Loader, Dumper
+
+
+def parse(opt_path, root_path, is_train=True):
+ """Parse option file.
+
+ Args:
+ opt_path (str): Option file path.
+ is_train (str): Indicate whether in training or not. Default: True.
+
+ Returns:
+ (dict): Options.
+ """
+ with open(opt_path, mode='r') as f:
+ Loader, _ = ordered_yaml()
+ opt = yaml.load(f, Loader=Loader)
+
+ opt['is_train'] = is_train
+
+ # opt['name'] = f"{get_time_str()}_{opt['name']}"
+ if opt['path'].get('resume_state', None): # Shangchen added
+ resume_state_path = opt['path'].get('resume_state')
+ opt['name'] = resume_state_path.split("/")[-3]
+ else:
+ opt['name'] = f"{get_time_str()}_{opt['name']}"
+
+
+ # datasets
+ for phase, dataset in opt['datasets'].items():
+ # for several datasets, e.g., test_1, test_2
+ phase = phase.split('_')[0]
+ dataset['phase'] = phase
+ if 'scale' in opt:
+ dataset['scale'] = opt['scale']
+ if dataset.get('dataroot_gt') is not None:
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
+ if dataset.get('dataroot_lq') is not None:
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
+
+ # paths
+ for key, val in opt['path'].items():
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
+ opt['path'][key] = osp.expanduser(val)
+
+ if is_train:
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
+ opt['path']['experiments_root'] = experiments_root
+ opt['path']['models'] = osp.join(experiments_root, 'models')
+ opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
+ opt['path']['log'] = experiments_root
+ opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
+
+ else: # test
+ results_root = osp.join(root_path, 'results', opt['name'])
+ opt['path']['results_root'] = results_root
+ opt['path']['log'] = results_root
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
+
+ return opt
+
+
+def dict2str(opt, indent_level=1):
+ """dict to string for printing options.
+
+ Args:
+ opt (dict): Option dict.
+ indent_level (int): Indent level. Default: 1.
+
+ Return:
+ (str): Option string for printing.
+ """
+ msg = '\n'
+ for k, v in opt.items():
+ if isinstance(v, dict):
+ msg += ' ' * (indent_level * 2) + k + ':['
+ msg += dict2str(v, indent_level + 1)
+ msg += ' ' * (indent_level * 2) + ']\n'
+ else:
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
+ return msg
diff --git a/repositories/codeformer/basicsr/utils/realesrgan_utils.py b/repositories/codeformer/basicsr/utils/realesrgan_utils.py
new file mode 100644
index 000000000..6b7a8b460
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/realesrgan_utils.py
@@ -0,0 +1,299 @@
+import cv2
+import math
+import numpy as np
+import os
+import queue
+import threading
+import torch
+from basicsr.utils.download_util import load_file_from_url
+from torch.nn import functional as F
+
+# ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+
+
+class RealESRGANer():
+ """A helper class for upsampling images with RealESRGAN.
+
+ Args:
+ scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
+ model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
+ model (nn.Module): The defined network. Default: None.
+ tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
+ input images into tiles, and then process each of them. Finally, they will be merged into one image.
+ 0 denotes for do not use tile. Default: 0.
+ tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
+ pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
+ half (float): Whether to use half precision during inference. Default: False.
+ """
+
+ def __init__(self,
+ scale,
+ model_path,
+ model=None,
+ tile=0,
+ tile_pad=10,
+ pre_pad=10,
+ half=False,
+ device=None,
+ gpu_id=None):
+ self.scale = scale
+ self.tile_size = tile
+ self.tile_pad = tile_pad
+ self.pre_pad = pre_pad
+ self.mod_scale = None
+ self.half = half
+
+ # initialize model
+ if gpu_id:
+ self.device = torch.device(
+ f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device
+ else:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
+ # if the model_path starts with https, it will first download models to the folder: realesrgan/weights
+ if model_path.startswith('https://'):
+ model_path = load_file_from_url(
+ url=model_path, model_dir=os.path.join('weights/realesrgan'), progress=True, file_name=None)
+ loadnet = torch.load(model_path, map_location=torch.device('cpu'))
+ # prefer to use params_ema
+ if 'params_ema' in loadnet:
+ keyname = 'params_ema'
+ else:
+ keyname = 'params'
+ model.load_state_dict(loadnet[keyname], strict=True)
+ model.eval()
+ self.model = model.to(self.device)
+ if self.half:
+ self.model = self.model.half()
+
+ def pre_process(self, img):
+ """Pre-process, such as pre-pad and mod pad, so that the images can be divisible
+ """
+ img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
+ self.img = img.unsqueeze(0).to(self.device)
+ if self.half:
+ self.img = self.img.half()
+
+ # pre_pad
+ if self.pre_pad != 0:
+ self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect')
+ # mod pad for divisible borders
+ if self.scale == 2:
+ self.mod_scale = 2
+ elif self.scale == 1:
+ self.mod_scale = 4
+ if self.mod_scale is not None:
+ self.mod_pad_h, self.mod_pad_w = 0, 0
+ _, _, h, w = self.img.size()
+ if (h % self.mod_scale != 0):
+ self.mod_pad_h = (self.mod_scale - h % self.mod_scale)
+ if (w % self.mod_scale != 0):
+ self.mod_pad_w = (self.mod_scale - w % self.mod_scale)
+ self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect')
+
+ def process(self):
+ # model inference
+ self.output = self.model(self.img)
+
+ def tile_process(self):
+ """It will first crop input images to tiles, and then process each tile.
+ Finally, all the processed tiles are merged into one images.
+
+ Modified from: https://github.com/ata4/esrgan-launcher
+ """
+ batch, channel, height, width = self.img.shape
+ output_height = height * self.scale
+ output_width = width * self.scale
+ output_shape = (batch, channel, output_height, output_width)
+
+ # start with black image
+ self.output = self.img.new_zeros(output_shape)
+ tiles_x = math.ceil(width / self.tile_size)
+ tiles_y = math.ceil(height / self.tile_size)
+
+ # loop over all tiles
+ for y in range(tiles_y):
+ for x in range(tiles_x):
+ # extract tile from input image
+ ofs_x = x * self.tile_size
+ ofs_y = y * self.tile_size
+ # input tile area on total image
+ input_start_x = ofs_x
+ input_end_x = min(ofs_x + self.tile_size, width)
+ input_start_y = ofs_y
+ input_end_y = min(ofs_y + self.tile_size, height)
+
+ # input tile area on total image with padding
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
+
+ # input tile dimensions
+ input_tile_width = input_end_x - input_start_x
+ input_tile_height = input_end_y - input_start_y
+ tile_idx = y * tiles_x + x + 1
+ input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
+
+ # upscale tile
+ try:
+ with torch.no_grad():
+ output_tile = self.model(input_tile)
+ except RuntimeError as error:
+ print('Error', error)
+ # print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
+
+ # output tile area on total image
+ output_start_x = input_start_x * self.scale
+ output_end_x = input_end_x * self.scale
+ output_start_y = input_start_y * self.scale
+ output_end_y = input_end_y * self.scale
+
+ # output tile area without padding
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
+
+ # put tile into output image
+ self.output[:, :, output_start_y:output_end_y,
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
+ output_start_x_tile:output_end_x_tile]
+
+ def post_process(self):
+ # remove extra pad
+ if self.mod_scale is not None:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale]
+ # remove prepad
+ if self.pre_pad != 0:
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale]
+ return self.output
+
+ @torch.no_grad()
+ def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'):
+ h_input, w_input = img.shape[0:2]
+ # img: numpy
+ img = img.astype(np.float32)
+ if np.max(img) > 256: # 16-bit image
+ max_range = 65535
+ print('\tInput is a 16-bit image')
+ else:
+ max_range = 255
+ img = img / max_range
+ if len(img.shape) == 2: # gray image
+ img_mode = 'L'
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+ elif img.shape[2] == 4: # RGBA image with alpha channel
+ img_mode = 'RGBA'
+ alpha = img[:, :, 3]
+ img = img[:, :, 0:3]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if alpha_upsampler == 'realesrgan':
+ alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
+ else:
+ img_mode = 'RGB'
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # ------------------- process image (without the alpha channel) ------------------- #
+ try:
+ with torch.no_grad():
+ self.pre_process(img)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_img_t = self.post_process()
+ output_img = output_img_t.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
+ if img_mode == 'L':
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
+ del output_img_t
+ torch.cuda.empty_cache()
+ except RuntimeError as error:
+ print(f"Failed inference for RealESRGAN: {error}")
+
+ # ------------------- process the alpha channel if necessary ------------------- #
+ if img_mode == 'RGBA':
+ if alpha_upsampler == 'realesrgan':
+ self.pre_process(alpha)
+ if self.tile_size > 0:
+ self.tile_process()
+ else:
+ self.process()
+ output_alpha = self.post_process()
+ output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy()
+ output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
+ output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
+ else: # use the cv2 resize for alpha channel
+ h, w = alpha.shape[0:2]
+ output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR)
+
+ # merge the alpha channel
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
+ output_img[:, :, 3] = output_alpha
+
+ # ------------------------------ return ------------------------------ #
+ if max_range == 65535: # 16-bit image
+ output = (output_img * 65535.0).round().astype(np.uint16)
+ else:
+ output = (output_img * 255.0).round().astype(np.uint8)
+
+ if outscale is not None and outscale != float(self.scale):
+ output = cv2.resize(
+ output, (
+ int(w_input * outscale),
+ int(h_input * outscale),
+ ), interpolation=cv2.INTER_LANCZOS4)
+
+ return output, img_mode
+
+
+class PrefetchReader(threading.Thread):
+ """Prefetch images.
+
+ Args:
+ img_list (list[str]): A image list of image paths to be read.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, img_list, num_prefetch_queue):
+ super().__init__()
+ self.que = queue.Queue(num_prefetch_queue)
+ self.img_list = img_list
+
+ def run(self):
+ for img_path in self.img_list:
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
+ self.que.put(img)
+
+ self.que.put(None)
+
+ def __next__(self):
+ next_item = self.que.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class IOConsumer(threading.Thread):
+
+ def __init__(self, opt, que, qid):
+ super().__init__()
+ self._queue = que
+ self.qid = qid
+ self.opt = opt
+
+ def run(self):
+ while True:
+ msg = self._queue.get()
+ if isinstance(msg, str) and msg == 'quit':
+ break
+
+ output = msg['output']
+ save_path = msg['save_path']
+ cv2.imwrite(save_path, output)
+ print(f'IO worker {self.qid} is done.')
\ No newline at end of file
diff --git a/repositories/codeformer/basicsr/utils/registry.py b/repositories/codeformer/basicsr/utils/registry.py
new file mode 100644
index 000000000..655753b3b
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/registry.py
@@ -0,0 +1,82 @@
+# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
+
+
+class Registry():
+ """
+ The registry that provides name -> object mapping, to support third-party
+ users' custom modules.
+
+ To create a registry (e.g. a backbone registry):
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY = Registry('BACKBONE')
+
+ To register an object:
+
+ .. code-block:: python
+
+ @BACKBONE_REGISTRY.register()
+ class MyBackbone():
+ ...
+
+ Or:
+
+ .. code-block:: python
+
+ BACKBONE_REGISTRY.register(MyBackbone)
+ """
+
+ def __init__(self, name):
+ """
+ Args:
+ name (str): the name of this registry
+ """
+ self._name = name
+ self._obj_map = {}
+
+ def _do_register(self, name, obj):
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
+ f"in '{self._name}' registry!")
+ self._obj_map[name] = obj
+
+ def register(self, obj=None):
+ """
+ Register the given object under the the name `obj.__name__`.
+ Can be used as either a decorator or not.
+ See docstring of this class for usage.
+ """
+ if obj is None:
+ # used as a decorator
+ def deco(func_or_class):
+ name = func_or_class.__name__
+ self._do_register(name, func_or_class)
+ return func_or_class
+
+ return deco
+
+ # used as a function call
+ name = obj.__name__
+ self._do_register(name, obj)
+
+ def get(self, name):
+ ret = self._obj_map.get(name)
+ if ret is None:
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
+ return ret
+
+ def __contains__(self, name):
+ return name in self._obj_map
+
+ def __iter__(self):
+ return iter(self._obj_map.items())
+
+ def keys(self):
+ return self._obj_map.keys()
+
+
+DATASET_REGISTRY = Registry('dataset')
+ARCH_REGISTRY = Registry('arch')
+MODEL_REGISTRY = Registry('model')
+LOSS_REGISTRY = Registry('loss')
+METRIC_REGISTRY = Registry('metric')
diff --git a/repositories/codeformer/basicsr/utils/video_util.py b/repositories/codeformer/basicsr/utils/video_util.py
new file mode 100644
index 000000000..20a2ff14c
--- /dev/null
+++ b/repositories/codeformer/basicsr/utils/video_util.py
@@ -0,0 +1,125 @@
+'''
+The code is modified from the Real-ESRGAN:
+https://github.com/xinntao/Real-ESRGAN/blob/master/inference_realesrgan_video.py
+
+'''
+import cv2
+import sys
+import numpy as np
+
+try:
+ import ffmpeg
+except ImportError:
+ import pip
+ pip.main(['install', '--user', 'ffmpeg-python'])
+ import ffmpeg
+
+def get_video_meta_info(video_path):
+ ret = {}
+ probe = ffmpeg.probe(video_path)
+ video_streams = [stream for stream in probe['streams'] if stream['codec_type'] == 'video']
+ has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
+ ret['width'] = video_streams[0]['width']
+ ret['height'] = video_streams[0]['height']
+ ret['fps'] = eval(video_streams[0]['avg_frame_rate'])
+ ret['audio'] = ffmpeg.input(video_path).audio if has_audio else None
+ ret['nb_frames'] = int(video_streams[0]['nb_frames'])
+ return ret
+
+class VideoReader:
+ def __init__(self, video_path):
+ self.paths = [] # for image&folder type
+ self.audio = None
+ try:
+ self.stream_reader = (
+ ffmpeg.input(video_path).output('pipe:', format='rawvideo', pix_fmt='bgr24',
+ loglevel='error').run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
+ except FileNotFoundError:
+ print('Please install ffmpeg (not ffmpeg-python) by running\n',
+ '\t$ conda install -c conda-forge ffmpeg')
+ sys.exit(0)
+
+ meta = get_video_meta_info(video_path)
+ self.width = meta['width']
+ self.height = meta['height']
+ self.input_fps = meta['fps']
+ self.audio = meta['audio']
+ self.nb_frames = meta['nb_frames']
+
+ self.idx = 0
+
+ def get_resolution(self):
+ return self.height, self.width
+
+ def get_fps(self):
+ if self.input_fps is not None:
+ return self.input_fps
+ return 24
+
+ def get_audio(self):
+ return self.audio
+
+ def __len__(self):
+ return self.nb_frames
+
+ def get_frame_from_stream(self):
+ img_bytes = self.stream_reader.stdout.read(self.width * self.height * 3) # 3 bytes for one pixel
+ if not img_bytes:
+ return None
+ img = np.frombuffer(img_bytes, np.uint8).reshape([self.height, self.width, 3])
+ return img
+
+ def get_frame_from_list(self):
+ if self.idx >= self.nb_frames:
+ return None
+ img = cv2.imread(self.paths[self.idx])
+ self.idx += 1
+ return img
+
+ def get_frame(self):
+ return self.get_frame_from_stream()
+
+
+ def close(self):
+ self.stream_reader.stdin.close()
+ self.stream_reader.wait()
+
+
+class VideoWriter:
+ def __init__(self, video_save_path, height, width, fps, audio):
+ if height > 2160:
+ print('You are generating video that is larger than 4K, which will be very slow due to IO speed.',
+ 'We highly recommend to decrease the outscale(aka, -s).')
+ if audio is not None:
+ self.stream_writer = (
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
+ framerate=fps).output(
+ audio,
+ video_save_path,
+ pix_fmt='yuv420p',
+ vcodec='libx264',
+ loglevel='error',
+ acodec='copy').overwrite_output().run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
+ else:
+ self.stream_writer = (
+ ffmpeg.input('pipe:', format='rawvideo', pix_fmt='bgr24', s=f'{width}x{height}',
+ framerate=fps).output(
+ video_save_path, pix_fmt='yuv420p', vcodec='libx264',
+ loglevel='error').overwrite_output().run_async(
+ pipe_stdin=True, pipe_stdout=True, cmd='ffmpeg'))
+
+ def write_frame(self, frame):
+ try:
+ frame = frame.astype(np.uint8).tobytes()
+ self.stream_writer.stdin.write(frame)
+ except BrokenPipeError:
+ print('Please re-install ffmpeg and libx264 by running\n',
+ '\t$ conda install -c conda-forge ffmpeg\n',
+ '\t$ conda install -c conda-forge x264')
+ sys.exit(0)
+
+ def close(self):
+ self.stream_writer.stdin.close()
+ self.stream_writer.wait()
\ No newline at end of file
diff --git a/repositories/codeformer/facelib/detection/__init__.py b/repositories/codeformer/facelib/detection/__init__.py
new file mode 100644
index 000000000..5d1f8fc21
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/__init__.py
@@ -0,0 +1,100 @@
+import os
+import torch
+from torch import nn
+from copy import deepcopy
+
+from facelib.utils import load_file_from_url
+from facelib.utils import download_pretrained_models
+from facelib.detection.yolov5face.models.common import Conv
+
+from .retinaface.retinaface import RetinaFace
+from .yolov5face.face_detector import YoloDetector
+
+
+def init_detection_model(model_name, half=False, device='cuda'):
+ if 'retinaface' in model_name:
+ model = init_retinaface_model(model_name, half, device)
+ elif 'YOLOv5' in model_name:
+ model = init_yolov5face_model(model_name, device)
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ return model
+
+
+def init_retinaface_model(model_name, half=False, device='cuda'):
+ if model_name == 'retinaface_resnet50':
+ model = RetinaFace(network_name='resnet50', half=half)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth'
+ elif model_name == 'retinaface_mobile0.25':
+ model = RetinaFace(network_name='mobile0.25', half=half)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_mobilenet0.25_Final.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ # remove unnecessary 'module.'
+ for k, v in deepcopy(load_net).items():
+ if k.startswith('module.'):
+ load_net[k[7:]] = v
+ load_net.pop(k)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+
+ return model
+
+
+def init_yolov5face_model(model_name, device='cuda'):
+ if model_name == 'YOLOv5l':
+ model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth'
+ elif model_name == 'YOLOv5n':
+ model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5n-face.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ model.detector.load_state_dict(load_net, strict=True)
+ model.detector.eval()
+ model.detector = model.detector.to(device).float()
+
+ for m in model.detector.modules():
+ if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+ m.inplace = True # pytorch 1.7.0 compatibility
+ elif isinstance(m, Conv):
+ m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
+
+ return model
+
+
+# Download from Google Drive
+# def init_yolov5face_model(model_name, device='cuda'):
+# if model_name == 'YOLOv5l':
+# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5l.yaml', device=device)
+# f_id = {'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV'}
+# elif model_name == 'YOLOv5n':
+# model = YoloDetector(config_name='facelib/detection/yolov5face/models/yolov5n.yaml', device=device)
+# f_id = {'yolov5n-face.pth': '1fhcpFvWZqghpGXjYPIne2sw1Fy4yhw6o'}
+# else:
+# raise NotImplementedError(f'{model_name} is not implemented.')
+
+# model_path = os.path.join('weights/facelib', list(f_id.keys())[0])
+# if not os.path.exists(model_path):
+# download_pretrained_models(file_ids=f_id, save_path_root='weights/facelib')
+
+# load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+# model.detector.load_state_dict(load_net, strict=True)
+# model.detector.eval()
+# model.detector = model.detector.to(device).float()
+
+# for m in model.detector.modules():
+# if type(m) in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
+# m.inplace = True # pytorch 1.7.0 compatibility
+# elif isinstance(m, Conv):
+# m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatibility
+
+# return model
\ No newline at end of file
diff --git a/repositories/codeformer/facelib/detection/align_trans.py b/repositories/codeformer/facelib/detection/align_trans.py
new file mode 100644
index 000000000..07f1eb365
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/align_trans.py
@@ -0,0 +1,219 @@
+import cv2
+import numpy as np
+
+from .matlab_cp2tform import get_similarity_transform_for_cv2
+
+# reference facial points, a list of coordinates (x,y)
+REFERENCE_FACIAL_POINTS = [[30.29459953, 51.69630051], [65.53179932, 51.50139999], [48.02519989, 71.73660278],
+ [33.54930115, 92.3655014], [62.72990036, 92.20410156]]
+
+DEFAULT_CROP_SIZE = (96, 112)
+
+
+class FaceWarpException(Exception):
+
+ def __str__(self):
+ return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def get_reference_facial_points(output_size=None, inner_padding_factor=0.0, outer_padding=(0, 0), default_square=False):
+ """
+ Function:
+ ----------
+ get reference 5 key points according to crop settings:
+ 0. Set default crop_size:
+ if default_square:
+ crop_size = (112, 112)
+ else:
+ crop_size = (96, 112)
+ 1. Pad the crop_size by inner_padding_factor in each side;
+ 2. Resize crop_size into (output_size - outer_padding*2),
+ pad into output_size with outer_padding;
+ 3. Output reference_5point;
+ Parameters:
+ ----------
+ @output_size: (w, h) or None
+ size of aligned face image
+ @inner_padding_factor: (w_factor, h_factor)
+ padding factor for inner (w, h)
+ @outer_padding: (w_pad, h_pad)
+ each row is a pair of coordinates (x, y)
+ @default_square: True or False
+ if True:
+ default crop_size = (112, 112)
+ else:
+ default crop_size = (96, 112);
+ !!! make sure, if output_size is not None:
+ (output_size - outer_padding)
+ = some_scale * (default crop_size * (1.0 +
+ inner_padding_factor))
+ Returns:
+ ----------
+ @reference_5point: 5x2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+
+ tmp_5pts = np.array(REFERENCE_FACIAL_POINTS)
+ tmp_crop_size = np.array(DEFAULT_CROP_SIZE)
+
+ # 0) make the inner region a square
+ if default_square:
+ size_diff = max(tmp_crop_size) - tmp_crop_size
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += size_diff
+
+ if (output_size and output_size[0] == tmp_crop_size[0] and output_size[1] == tmp_crop_size[1]):
+
+ return tmp_5pts
+
+ if (inner_padding_factor == 0 and outer_padding == (0, 0)):
+ if output_size is None:
+ return tmp_5pts
+ else:
+ raise FaceWarpException('No paddings to do, output_size must be None or {}'.format(tmp_crop_size))
+
+ # check output size
+ if not (0 <= inner_padding_factor <= 1.0):
+ raise FaceWarpException('Not (0 <= inner_padding_factor <= 1.0)')
+
+ if ((inner_padding_factor > 0 or outer_padding[0] > 0 or outer_padding[1] > 0) and output_size is None):
+ output_size = tmp_crop_size * \
+ (1 + inner_padding_factor * 2).astype(np.int32)
+ output_size += np.array(outer_padding)
+ if not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1]):
+ raise FaceWarpException('Not (outer_padding[0] < output_size[0] and outer_padding[1] < output_size[1])')
+
+ # 1) pad the inner region according inner_padding_factor
+ if inner_padding_factor > 0:
+ size_diff = tmp_crop_size * inner_padding_factor * 2
+ tmp_5pts += size_diff / 2
+ tmp_crop_size += np.round(size_diff).astype(np.int32)
+
+ # 2) resize the padded inner region
+ size_bf_outer_pad = np.array(output_size) - np.array(outer_padding) * 2
+
+ if size_bf_outer_pad[0] * tmp_crop_size[1] != size_bf_outer_pad[1] * tmp_crop_size[0]:
+ raise FaceWarpException('Must have (output_size - outer_padding)'
+ '= some_scale * (crop_size * (1.0 + inner_padding_factor)')
+
+ scale_factor = size_bf_outer_pad[0].astype(np.float32) / tmp_crop_size[0]
+ tmp_5pts = tmp_5pts * scale_factor
+ # size_diff = tmp_crop_size * (scale_factor - min(scale_factor))
+ # tmp_5pts = tmp_5pts + size_diff / 2
+ tmp_crop_size = size_bf_outer_pad
+
+ # 3) add outer_padding to make output_size
+ reference_5point = tmp_5pts + np.array(outer_padding)
+ tmp_crop_size = output_size
+
+ return reference_5point
+
+
+def get_affine_transform_matrix(src_pts, dst_pts):
+ """
+ Function:
+ ----------
+ get affine transform matrix 'tfm' from src_pts to dst_pts
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points matrix, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points matrix, each row is a pair of coordinates (x, y)
+ Returns:
+ ----------
+ @tfm: 2x3 np.array
+ transform matrix from src_pts to dst_pts
+ """
+
+ tfm = np.float32([[1, 0, 0], [0, 1, 0]])
+ n_pts = src_pts.shape[0]
+ ones = np.ones((n_pts, 1), src_pts.dtype)
+ src_pts_ = np.hstack([src_pts, ones])
+ dst_pts_ = np.hstack([dst_pts, ones])
+
+ A, res, rank, s = np.linalg.lstsq(src_pts_, dst_pts_)
+
+ if rank == 3:
+ tfm = np.float32([[A[0, 0], A[1, 0], A[2, 0]], [A[0, 1], A[1, 1], A[2, 1]]])
+ elif rank == 2:
+ tfm = np.float32([[A[0, 0], A[1, 0], 0], [A[0, 1], A[1, 1], 0]])
+
+ return tfm
+
+
+def warp_and_crop_face(src_img, facial_pts, reference_pts=None, crop_size=(96, 112), align_type='smilarity'):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+ Parameters:
+ ----------
+ @src_img: 3x3 np.array
+ input image
+ @facial_pts: could be
+ 1)a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ @reference_pts: could be
+ 1) a list of K coordinates (x,y)
+ or
+ 2) Kx2 or 2xK np.array
+ each row or col is a pair of coordinates (x, y)
+ or
+ 3) None
+ if None, use default reference facial points
+ @crop_size: (w, h)
+ output face image size
+ @align_type: transform type, could be one of
+ 1) 'similarity': use similarity transform
+ 2) 'cv2_affine': use the first 3 points to do affine transform,
+ by calling cv2.getAffineTransform()
+ 3) 'affine': use all points to do affine transform
+ Returns:
+ ----------
+ @face_img: output face image with size (w, h) = @crop_size
+ """
+
+ if reference_pts is None:
+ if crop_size[0] == 96 and crop_size[1] == 112:
+ reference_pts = REFERENCE_FACIAL_POINTS
+ else:
+ default_square = False
+ inner_padding_factor = 0
+ outer_padding = (0, 0)
+ output_size = crop_size
+
+ reference_pts = get_reference_facial_points(output_size, inner_padding_factor, outer_padding,
+ default_square)
+
+ ref_pts = np.float32(reference_pts)
+ ref_pts_shp = ref_pts.shape
+ if max(ref_pts_shp) < 3 or min(ref_pts_shp) != 2:
+ raise FaceWarpException('reference_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if ref_pts_shp[0] == 2:
+ ref_pts = ref_pts.T
+
+ src_pts = np.float32(facial_pts)
+ src_pts_shp = src_pts.shape
+ if max(src_pts_shp) < 3 or min(src_pts_shp) != 2:
+ raise FaceWarpException('facial_pts.shape must be (K,2) or (2,K) and K>2')
+
+ if src_pts_shp[0] == 2:
+ src_pts = src_pts.T
+
+ if src_pts.shape != ref_pts.shape:
+ raise FaceWarpException('facial_pts and reference_pts must have the same shape')
+
+ if align_type == 'cv2_affine':
+ tfm = cv2.getAffineTransform(src_pts[0:3], ref_pts[0:3])
+ elif align_type == 'affine':
+ tfm = get_affine_transform_matrix(src_pts, ref_pts)
+ else:
+ tfm = get_similarity_transform_for_cv2(src_pts, ref_pts)
+
+ face_img = cv2.warpAffine(src_img, tfm, (crop_size[0], crop_size[1]))
+
+ return face_img
diff --git a/repositories/codeformer/facelib/detection/matlab_cp2tform.py b/repositories/codeformer/facelib/detection/matlab_cp2tform.py
new file mode 100644
index 000000000..b2a8b54a9
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/matlab_cp2tform.py
@@ -0,0 +1,317 @@
+import numpy as np
+from numpy.linalg import inv, lstsq
+from numpy.linalg import matrix_rank as rank
+from numpy.linalg import norm
+
+
+class MatlabCp2tormException(Exception):
+
+ def __str__(self):
+ return 'In File {}:{}'.format(__file__, super.__str__(self))
+
+
+def tformfwd(trans, uv):
+ """
+ Function:
+ ----------
+ apply affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of transformed coordinates (x, y)
+ """
+ uv = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy = np.dot(uv, trans)
+ xy = xy[:, 0:-1]
+ return xy
+
+
+def tforminv(trans, uv):
+ """
+ Function:
+ ----------
+ apply the inverse of affine transform 'trans' to uv
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix
+ @uv: Kx2 np.array
+ each row is a pair of coordinates (x, y)
+
+ Returns:
+ ----------
+ @xy: Kx2 np.array
+ each row is a pair of inverse-transformed coordinates (x, y)
+ """
+ Tinv = inv(trans)
+ xy = tformfwd(Tinv, uv)
+ return xy
+
+
+def findNonreflectiveSimilarity(uv, xy, options=None):
+ options = {'K': 2}
+
+ K = options['K']
+ M = xy.shape[0]
+ x = xy[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ y = xy[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+
+ tmp1 = np.hstack((x, y, np.ones((M, 1)), np.zeros((M, 1))))
+ tmp2 = np.hstack((y, -x, np.zeros((M, 1)), np.ones((M, 1))))
+ X = np.vstack((tmp1, tmp2))
+
+ u = uv[:, 0].reshape((-1, 1)) # use reshape to keep a column vector
+ v = uv[:, 1].reshape((-1, 1)) # use reshape to keep a column vector
+ U = np.vstack((u, v))
+
+ # We know that X * r = U
+ if rank(X) >= 2 * K:
+ r, _, _, _ = lstsq(X, U, rcond=-1)
+ r = np.squeeze(r)
+ else:
+ raise Exception('cp2tform:twoUniquePointsReq')
+ sc = r[0]
+ ss = r[1]
+ tx = r[2]
+ ty = r[3]
+
+ Tinv = np.array([[sc, -ss, 0], [ss, sc, 0], [tx, ty, 1]])
+ T = inv(Tinv)
+ T[:, 2] = np.array([0, 0, 1])
+
+ return T, Tinv
+
+
+def findSimilarity(uv, xy, options=None):
+ options = {'K': 2}
+
+ # uv = np.array(uv)
+ # xy = np.array(xy)
+
+ # Solve for trans1
+ trans1, trans1_inv = findNonreflectiveSimilarity(uv, xy, options)
+
+ # Solve for trans2
+
+ # manually reflect the xy data across the Y-axis
+ xyR = xy
+ xyR[:, 0] = -1 * xyR[:, 0]
+
+ trans2r, trans2r_inv = findNonreflectiveSimilarity(uv, xyR, options)
+
+ # manually reflect the tform to undo the reflection done on xyR
+ TreflectY = np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]])
+
+ trans2 = np.dot(trans2r, TreflectY)
+
+ # Figure out if trans1 or trans2 is better
+ xy1 = tformfwd(trans1, uv)
+ norm1 = norm(xy1 - xy)
+
+ xy2 = tformfwd(trans2, uv)
+ norm2 = norm(xy2 - xy)
+
+ if norm1 <= norm2:
+ return trans1, trans1_inv
+ else:
+ trans2_inv = inv(trans2)
+ return trans2, trans2_inv
+
+
+def get_similarity_transform(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'trans':
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y, 1] = [u, v, 1] * trans
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ @reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+ trans_inv: 3x3 np.array
+ inverse of trans, transform matrix from xy to uv
+ """
+
+ if reflective:
+ trans, trans_inv = findSimilarity(src_pts, dst_pts)
+ else:
+ trans, trans_inv = findNonreflectiveSimilarity(src_pts, dst_pts)
+
+ return trans, trans_inv
+
+
+def cvt_tform_mat_for_cv2(trans):
+ """
+ Function:
+ ----------
+ Convert Transform Matrix 'trans' into 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @trans: 3x3 np.array
+ transform matrix from uv to xy
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ cv2_trans = trans[:, 0:2].T
+
+ return cv2_trans
+
+
+def get_similarity_transform_for_cv2(src_pts, dst_pts, reflective=True):
+ """
+ Function:
+ ----------
+ Find Similarity Transform Matrix 'cv2_trans' which could be
+ directly used by cv2.warpAffine():
+ u = src_pts[:, 0]
+ v = src_pts[:, 1]
+ x = dst_pts[:, 0]
+ y = dst_pts[:, 1]
+ [x, y].T = cv_trans * [u, v, 1].T
+
+ Parameters:
+ ----------
+ @src_pts: Kx2 np.array
+ source points, each row is a pair of coordinates (x, y)
+ @dst_pts: Kx2 np.array
+ destination points, each row is a pair of transformed
+ coordinates (x, y)
+ reflective: True or False
+ if True:
+ use reflective similarity transform
+ else:
+ use non-reflective similarity transform
+
+ Returns:
+ ----------
+ @cv2_trans: 2x3 np.array
+ transform matrix from src_pts to dst_pts, could be directly used
+ for cv2.warpAffine()
+ """
+ trans, trans_inv = get_similarity_transform(src_pts, dst_pts, reflective)
+ cv2_trans = cvt_tform_mat_for_cv2(trans)
+
+ return cv2_trans
+
+
+if __name__ == '__main__':
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ # In Matlab, run:
+ #
+ # uv = [u'; v'];
+ # xy = [x'; y'];
+ # tform_sim=cp2tform(uv,xy,'similarity');
+ #
+ # trans = tform_sim.tdata.T
+ # ans =
+ # -0.0764 -1.6190 0
+ # 1.6190 -0.0764 0
+ # -3.2156 0.0290 1.0000
+ # trans_inv = tform_sim.tdata.Tinv
+ # ans =
+ #
+ # -0.0291 0.6163 0
+ # -0.6163 -0.0291 0
+ # -0.0756 1.9826 1.0000
+ # xy_m=tformfwd(tform_sim, u,v)
+ #
+ # xy_m =
+ #
+ # -3.2156 0.0290
+ # 1.1833 -9.9143
+ # 5.0323 2.8853
+ # uv_m=tforminv(tform_sim, x,y)
+ #
+ # uv_m =
+ #
+ # 0.5698 1.3953
+ # 6.0872 2.2733
+ # -2.6570 4.3314
+ """
+ u = [0, 6, -2]
+ v = [0, 3, 5]
+ x = [-1, 0, 4]
+ y = [-1, -10, 4]
+
+ uv = np.array((u, v)).T
+ xy = np.array((x, y)).T
+
+ print('\n--->uv:')
+ print(uv)
+ print('\n--->xy:')
+ print(xy)
+
+ trans, trans_inv = get_similarity_transform(uv, xy)
+
+ print('\n--->trans matrix:')
+ print(trans)
+
+ print('\n--->trans_inv matrix:')
+ print(trans_inv)
+
+ print('\n---> apply transform to uv')
+ print('\nxy_m = uv_augmented * trans')
+ uv_aug = np.hstack((uv, np.ones((uv.shape[0], 1))))
+ xy_m = np.dot(uv_aug, trans)
+ print(xy_m)
+
+ print('\nxy_m = tformfwd(trans, uv)')
+ xy_m = tformfwd(trans, uv)
+ print(xy_m)
+
+ print('\n---> apply inverse transform to xy')
+ print('\nuv_m = xy_augmented * trans_inv')
+ xy_aug = np.hstack((xy, np.ones((xy.shape[0], 1))))
+ uv_m = np.dot(xy_aug, trans_inv)
+ print(uv_m)
+
+ print('\nuv_m = tformfwd(trans_inv, xy)')
+ uv_m = tformfwd(trans_inv, xy)
+ print(uv_m)
+
+ uv_m = tforminv(trans, xy)
+ print('\nuv_m = tforminv(trans, xy)')
+ print(uv_m)
diff --git a/repositories/codeformer/facelib/detection/retinaface/retinaface.py b/repositories/codeformer/facelib/detection/retinaface/retinaface.py
new file mode 100644
index 000000000..02593556d
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/retinaface/retinaface.py
@@ -0,0 +1,370 @@
+import cv2
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.models._utils import IntermediateLayerGetter as IntermediateLayerGetter
+
+from facelib.detection.align_trans import get_reference_facial_points, warp_and_crop_face
+from facelib.detection.retinaface.retinaface_net import FPN, SSH, MobileNetV1, make_bbox_head, make_class_head, make_landmark_head
+from facelib.detection.retinaface.retinaface_utils import (PriorBox, batched_decode, batched_decode_landm, decode, decode_landm,
+ py_cpu_nms)
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def generate_config(network_name):
+
+ cfg_mnet = {
+ 'name': 'mobilenet0.25',
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
+ 'steps': [8, 16, 32],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 32,
+ 'ngpu': 1,
+ 'epoch': 250,
+ 'decay1': 190,
+ 'decay2': 220,
+ 'image_size': 640,
+ 'return_layers': {
+ 'stage1': 1,
+ 'stage2': 2,
+ 'stage3': 3
+ },
+ 'in_channel': 32,
+ 'out_channel': 64
+ }
+
+ cfg_re50 = {
+ 'name': 'Resnet50',
+ 'min_sizes': [[16, 32], [64, 128], [256, 512]],
+ 'steps': [8, 16, 32],
+ 'variance': [0.1, 0.2],
+ 'clip': False,
+ 'loc_weight': 2.0,
+ 'gpu_train': True,
+ 'batch_size': 24,
+ 'ngpu': 4,
+ 'epoch': 100,
+ 'decay1': 70,
+ 'decay2': 90,
+ 'image_size': 840,
+ 'return_layers': {
+ 'layer2': 1,
+ 'layer3': 2,
+ 'layer4': 3
+ },
+ 'in_channel': 256,
+ 'out_channel': 256
+ }
+
+ if network_name == 'mobile0.25':
+ return cfg_mnet
+ elif network_name == 'resnet50':
+ return cfg_re50
+ else:
+ raise NotImplementedError(f'network_name={network_name}')
+
+
+class RetinaFace(nn.Module):
+
+ def __init__(self, network_name='resnet50', half=False, phase='test'):
+ super(RetinaFace, self).__init__()
+ self.half_inference = half
+ cfg = generate_config(network_name)
+ self.backbone = cfg['name']
+
+ self.model_name = f'retinaface_{network_name}'
+ self.cfg = cfg
+ self.phase = phase
+ self.target_size, self.max_size = 1600, 2150
+ self.resize, self.scale, self.scale1 = 1., None, None
+ self.mean_tensor = torch.tensor([[[[104.]], [[117.]], [[123.]]]]).to(device)
+ self.reference = get_reference_facial_points(default_square=True)
+ # Build network.
+ backbone = None
+ if cfg['name'] == 'mobilenet0.25':
+ backbone = MobileNetV1()
+ self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+ elif cfg['name'] == 'Resnet50':
+ import torchvision.models as models
+ backbone = models.resnet50(pretrained=False)
+ self.body = IntermediateLayerGetter(backbone, cfg['return_layers'])
+
+ in_channels_stage2 = cfg['in_channel']
+ in_channels_list = [
+ in_channels_stage2 * 2,
+ in_channels_stage2 * 4,
+ in_channels_stage2 * 8,
+ ]
+
+ out_channels = cfg['out_channel']
+ self.fpn = FPN(in_channels_list, out_channels)
+ self.ssh1 = SSH(out_channels, out_channels)
+ self.ssh2 = SSH(out_channels, out_channels)
+ self.ssh3 = SSH(out_channels, out_channels)
+
+ self.ClassHead = make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
+ self.BboxHead = make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
+ self.LandmarkHead = make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
+
+ self.to(device)
+ self.eval()
+ if self.half_inference:
+ self.half()
+
+ def forward(self, inputs):
+ out = self.body(inputs)
+
+ if self.backbone == 'mobilenet0.25' or self.backbone == 'Resnet50':
+ out = list(out.values())
+ # FPN
+ fpn = self.fpn(out)
+
+ # SSH
+ feature1 = self.ssh1(fpn[0])
+ feature2 = self.ssh2(fpn[1])
+ feature3 = self.ssh3(fpn[2])
+ features = [feature1, feature2, feature3]
+
+ bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)], dim=1)
+ tmp = [self.LandmarkHead[i](feature) for i, feature in enumerate(features)]
+ ldm_regressions = (torch.cat(tmp, dim=1))
+
+ if self.phase == 'train':
+ output = (bbox_regressions, classifications, ldm_regressions)
+ else:
+ output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
+ return output
+
+ def __detect_faces(self, inputs):
+ # get scale
+ height, width = inputs.shape[2:]
+ self.scale = torch.tensor([width, height, width, height], dtype=torch.float32).to(device)
+ tmp = [width, height, width, height, width, height, width, height, width, height]
+ self.scale1 = torch.tensor(tmp, dtype=torch.float32).to(device)
+
+ # forawrd
+ inputs = inputs.to(device)
+ if self.half_inference:
+ inputs = inputs.half()
+ loc, conf, landmarks = self(inputs)
+
+ # get priorbox
+ priorbox = PriorBox(self.cfg, image_size=inputs.shape[2:])
+ priors = priorbox.forward().to(device)
+
+ return loc, conf, landmarks, priors
+
+ # single image detection
+ def transform(self, image, use_origin_size):
+ # convert to opencv format
+ if isinstance(image, Image.Image):
+ image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
+ image = image.astype(np.float32)
+
+ # testing scale
+ im_size_min = np.min(image.shape[0:2])
+ im_size_max = np.max(image.shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ image = cv2.resize(image, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+
+ # convert to torch.tensor format
+ # image -= (104, 117, 123)
+ image = image.transpose(2, 0, 1)
+ image = torch.from_numpy(image).unsqueeze(0)
+
+ return image, resize
+
+ def detect_faces(
+ self,
+ image,
+ conf_threshold=0.8,
+ nms_threshold=0.4,
+ use_origin_size=True,
+ ):
+ """
+ Params:
+ imgs: BGR image
+ """
+ image, self.resize = self.transform(image, use_origin_size)
+ image = image.to(device)
+ if self.half_inference:
+ image = image.half()
+ image = image - self.mean_tensor
+
+ loc, conf, landmarks, priors = self.__detect_faces(image)
+
+ boxes = decode(loc.data.squeeze(0), priors.data, self.cfg['variance'])
+ boxes = boxes * self.scale / self.resize
+ boxes = boxes.cpu().numpy()
+
+ scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
+
+ landmarks = decode_landm(landmarks.squeeze(0), priors, self.cfg['variance'])
+ landmarks = landmarks * self.scale1 / self.resize
+ landmarks = landmarks.cpu().numpy()
+
+ # ignore low scores
+ inds = np.where(scores > conf_threshold)[0]
+ boxes, landmarks, scores = boxes[inds], landmarks[inds], scores[inds]
+
+ # sort
+ order = scores.argsort()[::-1]
+ boxes, landmarks, scores = boxes[order], landmarks[order], scores[order]
+
+ # do NMS
+ bounding_boxes = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landmarks[keep]
+ # self.t['forward_pass'].toc()
+ # print(self.t['forward_pass'].average_time)
+ # import sys
+ # sys.stdout.flush()
+ return np.concatenate((bounding_boxes, landmarks), axis=1)
+
+ def __align_multi(self, image, boxes, landmarks, limit=None):
+
+ if len(boxes) < 1:
+ return [], []
+
+ if limit:
+ boxes = boxes[:limit]
+ landmarks = landmarks[:limit]
+
+ faces = []
+ for landmark in landmarks:
+ facial5points = [[landmark[2 * j], landmark[2 * j + 1]] for j in range(5)]
+
+ warped_face = warp_and_crop_face(np.array(image), facial5points, self.reference, crop_size=(112, 112))
+ faces.append(warped_face)
+
+ return np.concatenate((boxes, landmarks), axis=1), faces
+
+ def align_multi(self, img, conf_threshold=0.8, limit=None):
+
+ rlt = self.detect_faces(img, conf_threshold=conf_threshold)
+ boxes, landmarks = rlt[:, 0:5], rlt[:, 5:]
+
+ return self.__align_multi(img, boxes, landmarks, limit)
+
+ # batched detection
+ def batched_transform(self, frames, use_origin_size):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or torch.Tensor(shape=[n, h, w, c],
+ type=np.float32, BGR format).
+ use_origin_size: whether to use origin size.
+ """
+ from_PIL = True if isinstance(frames[0], Image.Image) else False
+
+ # convert to opencv format
+ if from_PIL:
+ frames = [cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2BGR) for frame in frames]
+ frames = np.asarray(frames, dtype=np.float32)
+
+ # testing scale
+ im_size_min = np.min(frames[0].shape[0:2])
+ im_size_max = np.max(frames[0].shape[0:2])
+ resize = float(self.target_size) / float(im_size_min)
+
+ # prevent bigger axis from being more than max_size
+ if np.round(resize * im_size_max) > self.max_size:
+ resize = float(self.max_size) / float(im_size_max)
+ resize = 1 if use_origin_size else resize
+
+ # resize
+ if resize != 1:
+ if not from_PIL:
+ frames = F.interpolate(frames, scale_factor=resize)
+ else:
+ frames = [
+ cv2.resize(frame, None, None, fx=resize, fy=resize, interpolation=cv2.INTER_LINEAR)
+ for frame in frames
+ ]
+
+ # convert to torch.tensor format
+ if not from_PIL:
+ frames = frames.transpose(1, 2).transpose(1, 3).contiguous()
+ else:
+ frames = frames.transpose((0, 3, 1, 2))
+ frames = torch.from_numpy(frames)
+
+ return frames, resize
+
+ def batched_detect_faces(self, frames, conf_threshold=0.8, nms_threshold=0.4, use_origin_size=True):
+ """
+ Arguments:
+ frames: a list of PIL.Image, or np.array(shape=[n, h, w, c],
+ type=np.uint8, BGR format).
+ conf_threshold: confidence threshold.
+ nms_threshold: nms threshold.
+ use_origin_size: whether to use origin size.
+ Returns:
+ final_bounding_boxes: list of np.array ([n_boxes, 5],
+ type=np.float32).
+ final_landmarks: list of np.array ([n_boxes, 10], type=np.float32).
+ """
+ # self.t['forward_pass'].tic()
+ frames, self.resize = self.batched_transform(frames, use_origin_size)
+ frames = frames.to(device)
+ frames = frames - self.mean_tensor
+
+ b_loc, b_conf, b_landmarks, priors = self.__detect_faces(frames)
+
+ final_bounding_boxes, final_landmarks = [], []
+
+ # decode
+ priors = priors.unsqueeze(0)
+ b_loc = batched_decode(b_loc, priors, self.cfg['variance']) * self.scale / self.resize
+ b_landmarks = batched_decode_landm(b_landmarks, priors, self.cfg['variance']) * self.scale1 / self.resize
+ b_conf = b_conf[:, :, 1]
+
+ # index for selection
+ b_indice = b_conf > conf_threshold
+
+ # concat
+ b_loc_and_conf = torch.cat((b_loc, b_conf.unsqueeze(-1)), dim=2).float()
+
+ for pred, landm, inds in zip(b_loc_and_conf, b_landmarks, b_indice):
+
+ # ignore low scores
+ pred, landm = pred[inds, :], landm[inds, :]
+ if pred.shape[0] == 0:
+ final_bounding_boxes.append(np.array([], dtype=np.float32))
+ final_landmarks.append(np.array([], dtype=np.float32))
+ continue
+
+ # sort
+ # order = score.argsort(descending=True)
+ # box, landm, score = box[order], landm[order], score[order]
+
+ # to CPU
+ bounding_boxes, landm = pred.cpu().numpy(), landm.cpu().numpy()
+
+ # NMS
+ keep = py_cpu_nms(bounding_boxes, nms_threshold)
+ bounding_boxes, landmarks = bounding_boxes[keep, :], landm[keep]
+
+ # append
+ final_bounding_boxes.append(bounding_boxes)
+ final_landmarks.append(landmarks)
+ # self.t['forward_pass'].toc(average=True)
+ # self.batch_time += self.t['forward_pass'].diff
+ # self.total_frame += len(frames)
+ # print(self.batch_time / self.total_frame)
+
+ return final_bounding_boxes, final_landmarks
diff --git a/repositories/codeformer/facelib/detection/retinaface/retinaface_net.py b/repositories/codeformer/facelib/detection/retinaface/retinaface_net.py
new file mode 100644
index 000000000..ab6aa82d3
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/retinaface/retinaface_net.py
@@ -0,0 +1,196 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_bn(inp, oup, stride=1, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_bn_no_relu(inp, oup, stride):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
+ nn.BatchNorm2d(oup),
+ )
+
+
+def conv_bn1X1(inp, oup, stride, leaky=0):
+ return nn.Sequential(
+ nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True))
+
+
+def conv_dw(inp, oup, stride, leaky=0.1):
+ return nn.Sequential(
+ nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
+ nn.BatchNorm2d(inp),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
+ nn.BatchNorm2d(oup),
+ nn.LeakyReLU(negative_slope=leaky, inplace=True),
+ )
+
+
+class SSH(nn.Module):
+
+ def __init__(self, in_channel, out_channel):
+ super(SSH, self).__init__()
+ assert out_channel % 4 == 0
+ leaky = 0
+ if (out_channel <= 64):
+ leaky = 0.1
+ self.conv3X3 = conv_bn_no_relu(in_channel, out_channel // 2, stride=1)
+
+ self.conv5X5_1 = conv_bn(in_channel, out_channel // 4, stride=1, leaky=leaky)
+ self.conv5X5_2 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ self.conv7X7_2 = conv_bn(out_channel // 4, out_channel // 4, stride=1, leaky=leaky)
+ self.conv7x7_3 = conv_bn_no_relu(out_channel // 4, out_channel // 4, stride=1)
+
+ def forward(self, input):
+ conv3X3 = self.conv3X3(input)
+
+ conv5X5_1 = self.conv5X5_1(input)
+ conv5X5 = self.conv5X5_2(conv5X5_1)
+
+ conv7X7_2 = self.conv7X7_2(conv5X5_1)
+ conv7X7 = self.conv7x7_3(conv7X7_2)
+
+ out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
+ out = F.relu(out)
+ return out
+
+
+class FPN(nn.Module):
+
+ def __init__(self, in_channels_list, out_channels):
+ super(FPN, self).__init__()
+ leaky = 0
+ if (out_channels <= 64):
+ leaky = 0.1
+ self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride=1, leaky=leaky)
+ self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride=1, leaky=leaky)
+ self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride=1, leaky=leaky)
+
+ self.merge1 = conv_bn(out_channels, out_channels, leaky=leaky)
+ self.merge2 = conv_bn(out_channels, out_channels, leaky=leaky)
+
+ def forward(self, input):
+ # names = list(input.keys())
+ # input = list(input.values())
+
+ output1 = self.output1(input[0])
+ output2 = self.output2(input[1])
+ output3 = self.output3(input[2])
+
+ up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode='nearest')
+ output2 = output2 + up3
+ output2 = self.merge2(output2)
+
+ up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode='nearest')
+ output1 = output1 + up2
+ output1 = self.merge1(output1)
+
+ out = [output1, output2, output3]
+ return out
+
+
+class MobileNetV1(nn.Module):
+
+ def __init__(self):
+ super(MobileNetV1, self).__init__()
+ self.stage1 = nn.Sequential(
+ conv_bn(3, 8, 2, leaky=0.1), # 3
+ conv_dw(8, 16, 1), # 7
+ conv_dw(16, 32, 2), # 11
+ conv_dw(32, 32, 1), # 19
+ conv_dw(32, 64, 2), # 27
+ conv_dw(64, 64, 1), # 43
+ )
+ self.stage2 = nn.Sequential(
+ conv_dw(64, 128, 2), # 43 + 16 = 59
+ conv_dw(128, 128, 1), # 59 + 32 = 91
+ conv_dw(128, 128, 1), # 91 + 32 = 123
+ conv_dw(128, 128, 1), # 123 + 32 = 155
+ conv_dw(128, 128, 1), # 155 + 32 = 187
+ conv_dw(128, 128, 1), # 187 + 32 = 219
+ )
+ self.stage3 = nn.Sequential(
+ conv_dw(128, 256, 2), # 219 +3 2 = 241
+ conv_dw(256, 256, 1), # 241 + 64 = 301
+ )
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = nn.Linear(256, 1000)
+
+ def forward(self, x):
+ x = self.stage1(x)
+ x = self.stage2(x)
+ x = self.stage3(x)
+ x = self.avg(x)
+ # x = self.model(x)
+ x = x.view(-1, 256)
+ x = self.fc(x)
+ return x
+
+
+class ClassHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(ClassHead, self).__init__()
+ self.num_anchors = num_anchors
+ self.conv1x1 = nn.Conv2d(inchannels, self.num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 2)
+
+
+class BboxHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(BboxHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 4)
+
+
+class LandmarkHead(nn.Module):
+
+ def __init__(self, inchannels=512, num_anchors=3):
+ super(LandmarkHead, self).__init__()
+ self.conv1x1 = nn.Conv2d(inchannels, num_anchors * 10, kernel_size=(1, 1), stride=1, padding=0)
+
+ def forward(self, x):
+ out = self.conv1x1(x)
+ out = out.permute(0, 2, 3, 1).contiguous()
+
+ return out.view(out.shape[0], -1, 10)
+
+
+def make_class_head(fpn_num=3, inchannels=64, anchor_num=2):
+ classhead = nn.ModuleList()
+ for i in range(fpn_num):
+ classhead.append(ClassHead(inchannels, anchor_num))
+ return classhead
+
+
+def make_bbox_head(fpn_num=3, inchannels=64, anchor_num=2):
+ bboxhead = nn.ModuleList()
+ for i in range(fpn_num):
+ bboxhead.append(BboxHead(inchannels, anchor_num))
+ return bboxhead
+
+
+def make_landmark_head(fpn_num=3, inchannels=64, anchor_num=2):
+ landmarkhead = nn.ModuleList()
+ for i in range(fpn_num):
+ landmarkhead.append(LandmarkHead(inchannels, anchor_num))
+ return landmarkhead
diff --git a/repositories/codeformer/facelib/detection/retinaface/retinaface_utils.py b/repositories/codeformer/facelib/detection/retinaface/retinaface_utils.py
new file mode 100644
index 000000000..8c3577577
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/retinaface/retinaface_utils.py
@@ -0,0 +1,421 @@
+import numpy as np
+import torch
+import torchvision
+from itertools import product as product
+from math import ceil
+
+
+class PriorBox(object):
+
+ def __init__(self, cfg, image_size=None, phase='train'):
+ super(PriorBox, self).__init__()
+ self.min_sizes = cfg['min_sizes']
+ self.steps = cfg['steps']
+ self.clip = cfg['clip']
+ self.image_size = image_size
+ self.feature_maps = [[ceil(self.image_size[0] / step), ceil(self.image_size[1] / step)] for step in self.steps]
+ self.name = 's'
+
+ def forward(self):
+ anchors = []
+ for k, f in enumerate(self.feature_maps):
+ min_sizes = self.min_sizes[k]
+ for i, j in product(range(f[0]), range(f[1])):
+ for min_size in min_sizes:
+ s_kx = min_size / self.image_size[1]
+ s_ky = min_size / self.image_size[0]
+ dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
+ dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
+ for cy, cx in product(dense_cy, dense_cx):
+ anchors += [cx, cy, s_kx, s_ky]
+
+ # back to torch land
+ output = torch.Tensor(anchors).view(-1, 4)
+ if self.clip:
+ output.clamp_(max=1, min=0)
+ return output
+
+
+def py_cpu_nms(dets, thresh):
+ """Pure Python NMS baseline."""
+ keep = torchvision.ops.nms(
+ boxes=torch.Tensor(dets[:, :4]),
+ scores=torch.Tensor(dets[:, 4]),
+ iou_threshold=thresh,
+ )
+
+ return list(keep)
+
+
+def point_form(boxes):
+ """ Convert prior_boxes to (xmin, ymin, xmax, ymax)
+ representation for comparison to point form ground truth data.
+ Args:
+ boxes: (tensor) center-size default boxes from priorbox layers.
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat(
+ (
+ boxes[:, :2] - boxes[:, 2:] / 2, # xmin, ymin
+ boxes[:, :2] + boxes[:, 2:] / 2),
+ 1) # xmax, ymax
+
+
+def center_size(boxes):
+ """ Convert prior_boxes to (cx, cy, w, h)
+ representation for comparison to center-size form ground truth data.
+ Args:
+ boxes: (tensor) point_form boxes
+ Return:
+ boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
+ """
+ return torch.cat(
+ (boxes[:, 2:] + boxes[:, :2]) / 2, # cx, cy
+ boxes[:, 2:] - boxes[:, :2],
+ 1) # w, h
+
+
+def intersect(box_a, box_b):
+ """ We resize both tensors to [A,B,2] without new malloc:
+ [A,2] -> [A,1,2] -> [A,B,2]
+ [B,2] -> [1,B,2] -> [A,B,2]
+ Then we compute the area of intersect between box_a and box_b.
+ Args:
+ box_a: (tensor) bounding boxes, Shape: [A,4].
+ box_b: (tensor) bounding boxes, Shape: [B,4].
+ Return:
+ (tensor) intersection area, Shape: [A,B].
+ """
+ A = box_a.size(0)
+ B = box_b.size(0)
+ max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
+ min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
+ inter = torch.clamp((max_xy - min_xy), min=0)
+ return inter[:, :, 0] * inter[:, :, 1]
+
+
+def jaccard(box_a, box_b):
+ """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
+ is simply the intersection over union of two boxes. Here we operate on
+ ground truth boxes and default boxes.
+ E.g.:
+ A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
+ Args:
+ box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
+ box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
+ Return:
+ jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
+ """
+ inter = intersect(box_a, box_b)
+ area_a = ((box_a[:, 2] - box_a[:, 0]) * (box_a[:, 3] - box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
+ area_b = ((box_b[:, 2] - box_b[:, 0]) * (box_b[:, 3] - box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
+ union = area_a + area_b - inter
+ return inter / union # [A,B]
+
+
+def matrix_iou(a, b):
+ """
+ return iou of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
+ return area_i / (area_a[:, np.newaxis] + area_b - area_i)
+
+
+def matrix_iof(a, b):
+ """
+ return iof of a and b, numpy version for data augenmentation
+ """
+ lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
+ rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
+
+ area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
+ area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
+ return area_i / np.maximum(area_a[:, np.newaxis], 1)
+
+
+def match(threshold, truths, priors, variances, labels, landms, loc_t, conf_t, landm_t, idx):
+ """Match each prior box with the ground truth box of the highest jaccard
+ overlap, encode the bounding boxes, then return the matched indices
+ corresponding to both confidence and location preds.
+ Args:
+ threshold: (float) The overlap threshold used when matching boxes.
+ truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
+ priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
+ variances: (tensor) Variances corresponding to each prior coord,
+ Shape: [num_priors, 4].
+ labels: (tensor) All the class labels for the image, Shape: [num_obj].
+ landms: (tensor) Ground truth landms, Shape [num_obj, 10].
+ loc_t: (tensor) Tensor to be filled w/ encoded location targets.
+ conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
+ landm_t: (tensor) Tensor to be filled w/ encoded landm targets.
+ idx: (int) current batch index
+ Return:
+ The matched indices corresponding to 1)location 2)confidence
+ 3)landm preds.
+ """
+ # jaccard index
+ overlaps = jaccard(truths, point_form(priors))
+ # (Bipartite Matching)
+ # [1,num_objects] best prior for each ground truth
+ best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
+
+ # ignore hard gt
+ valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
+ best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
+ if best_prior_idx_filter.shape[0] <= 0:
+ loc_t[idx] = 0
+ conf_t[idx] = 0
+ return
+
+ # [1,num_priors] best ground truth for each prior
+ best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
+ best_truth_idx.squeeze_(0)
+ best_truth_overlap.squeeze_(0)
+ best_prior_idx.squeeze_(1)
+ best_prior_idx_filter.squeeze_(1)
+ best_prior_overlap.squeeze_(1)
+ best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
+ # TODO refactor: index best_prior_idx with long tensor
+ # ensure every gt matches with its prior of max overlap
+ for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
+ best_truth_idx[best_prior_idx[j]] = j
+ matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
+ conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
+ conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
+ loc = encode(matches, priors, variances)
+
+ matches_landm = landms[best_truth_idx]
+ landm = encode_landm(matches_landm, priors, variances)
+ loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
+ conf_t[idx] = conf # [num_priors] top class label for each prior
+ landm_t[idx] = landm
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, 2:])
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+
+def encode_landm(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 10].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded landm (tensor), Shape: [num_priors, 10]
+ """
+
+ # dist b/t match center and prior's center
+ matched = torch.reshape(matched, (matched.size(0), 5, 2))
+ priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_cy = priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_w = priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors_h = priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
+ priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
+ g_cxcy = matched[:, :, :2] - priors[:, :, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, :, 2:])
+ # g_cxcy /= priors[:, :, 2:]
+ g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
+ # return target for smooth_l1_loss
+ return g_cxcy
+
+
+# Adapted from https://github.com/Hakuyume/chainer-ssd
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+
+def decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ tmp = (
+ priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
+ priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
+ )
+ landms = torch.cat(tmp, dim=1)
+ return landms
+
+
+def batched_decode(b_loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ b_loc (tensor): location predictions for loc layers,
+ Shape: [num_batches,num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+ boxes = (
+ priors[:, :, :2] + b_loc[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, 2:] * torch.exp(b_loc[:, :, 2:] * variances[1]),
+ )
+ boxes = torch.cat(boxes, dim=2)
+
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
+ boxes[:, :, 2:] += boxes[:, :, :2]
+ return boxes
+
+
+def batched_decode_landm(pre, priors, variances):
+ """Decode landm from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ pre (tensor): landm predictions for loc layers,
+ Shape: [num_batches,num_priors,10]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [1,num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded landm predictions
+ """
+ landms = (
+ priors[:, :, :2] + pre[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 2:4] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 4:6] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 6:8] * variances[0] * priors[:, :, 2:],
+ priors[:, :, :2] + pre[:, :, 8:10] * variances[0] * priors[:, :, 2:],
+ )
+ landms = torch.cat(landms, dim=2)
+ return landms
+
+
+def log_sum_exp(x):
+ """Utility function for computing log_sum_exp while determining
+ This will be used to determine unaveraged confidence loss across
+ all examples in a batch.
+ Args:
+ x (Variable(tensor)): conf_preds from conf layers
+ """
+ x_max = x.data.max()
+ return torch.log(torch.sum(torch.exp(x - x_max), 1, keepdim=True)) + x_max
+
+
+# Original author: Francisco Massa:
+# https://github.com/fmassa/object-detection.torch
+# Ported to PyTorch by Max deGroot (02/01/2017)
+def nms(boxes, scores, overlap=0.5, top_k=200):
+ """Apply non-maximum suppression at test time to avoid detecting too many
+ overlapping bounding boxes for a given object.
+ Args:
+ boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
+ scores: (tensor) The class predscores for the img, Shape:[num_priors].
+ overlap: (float) The overlap thresh for suppressing unnecessary boxes.
+ top_k: (int) The Maximum number of box preds to consider.
+ Return:
+ The indices of the kept boxes with respect to num_priors.
+ """
+
+ keep = torch.Tensor(scores.size(0)).fill_(0).long()
+ if boxes.numel() == 0:
+ return keep
+ x1 = boxes[:, 0]
+ y1 = boxes[:, 1]
+ x2 = boxes[:, 2]
+ y2 = boxes[:, 3]
+ area = torch.mul(x2 - x1, y2 - y1)
+ v, idx = scores.sort(0) # sort in ascending order
+ # I = I[v >= 0.01]
+ idx = idx[-top_k:] # indices of the top-k largest vals
+ xx1 = boxes.new()
+ yy1 = boxes.new()
+ xx2 = boxes.new()
+ yy2 = boxes.new()
+ w = boxes.new()
+ h = boxes.new()
+
+ # keep = torch.Tensor()
+ count = 0
+ while idx.numel() > 0:
+ i = idx[-1] # index of current largest val
+ # keep.append(i)
+ keep[count] = i
+ count += 1
+ if idx.size(0) == 1:
+ break
+ idx = idx[:-1] # remove kept element from view
+ # load bboxes of next highest vals
+ torch.index_select(x1, 0, idx, out=xx1)
+ torch.index_select(y1, 0, idx, out=yy1)
+ torch.index_select(x2, 0, idx, out=xx2)
+ torch.index_select(y2, 0, idx, out=yy2)
+ # store element-wise max with next highest score
+ xx1 = torch.clamp(xx1, min=x1[i])
+ yy1 = torch.clamp(yy1, min=y1[i])
+ xx2 = torch.clamp(xx2, max=x2[i])
+ yy2 = torch.clamp(yy2, max=y2[i])
+ w.resize_as_(xx2)
+ h.resize_as_(yy2)
+ w = xx2 - xx1
+ h = yy2 - yy1
+ # check sizes of xx1 and xx2.. after each iteration
+ w = torch.clamp(w, min=0.0)
+ h = torch.clamp(h, min=0.0)
+ inter = w * h
+ # IoU = i / (area(a) + area(b) - i)
+ rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
+ union = (rem_areas - inter) + area[i]
+ IoU = inter / union # store result in iou
+ # keep only elements with an IoU <= overlap
+ idx = idx[IoU.le(overlap)]
+ return keep, count
diff --git a/repositories/codeformer/facelib/detection/yolov5face/__init__.py b/repositories/codeformer/facelib/detection/yolov5face/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/codeformer/facelib/detection/yolov5face/face_detector.py b/repositories/codeformer/facelib/detection/yolov5face/face_detector.py
new file mode 100644
index 000000000..79fdba0c9
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/face_detector.py
@@ -0,0 +1,142 @@
+import copy
+import os
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import Conv
+from facelib.detection.yolov5face.models.yolo import Model
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+ check_img_size,
+ non_max_suppression_face,
+ scale_coords,
+ scale_coords_landmarks,
+)
+
+IS_HIGH_VERSION = tuple(map(int, torch.__version__.split('+')[0].split('.')[:2])) >= (1, 9, 0)
+
+
+def isListempty(inList):
+ if isinstance(inList, list): # Is a list
+ return all(map(isListempty, inList))
+ return False # Not a list
+
+class YoloDetector:
+ def __init__(
+ self,
+ config_name,
+ min_face=10,
+ target_size=None,
+ device='cuda',
+ ):
+ """
+ config_name: name of .yaml config with network configuration from models/ folder.
+ min_face : minimal face size in pixels.
+ target_size : target size of smaller image axis (choose lower for faster work). e.g. 480, 720, 1080.
+ None for original resolution.
+ """
+ self._class_path = Path(__file__).parent.absolute()
+ self.target_size = target_size
+ self.min_face = min_face
+ self.detector = Model(cfg=config_name)
+ self.device = device
+
+
+ def _preprocess(self, imgs):
+ """
+ Preprocessing image before passing through the network. Resize and conversion to torch tensor.
+ """
+ pp_imgs = []
+ for img in imgs:
+ h0, w0 = img.shape[:2] # orig hw
+ if self.target_size:
+ r = self.target_size / min(h0, w0) # resize image to img_size
+ if r < 1:
+ img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=cv2.INTER_LINEAR)
+
+ imgsz = check_img_size(max(img.shape[:2]), s=self.detector.stride.max()) # check img_size
+ img = letterbox(img, new_shape=imgsz)[0]
+ pp_imgs.append(img)
+ pp_imgs = np.array(pp_imgs)
+ pp_imgs = pp_imgs.transpose(0, 3, 1, 2)
+ pp_imgs = torch.from_numpy(pp_imgs).to(self.device)
+ pp_imgs = pp_imgs.float() # uint8 to fp16/32
+ return pp_imgs / 255.0 # 0 - 255 to 0.0 - 1.0
+
+ def _postprocess(self, imgs, origimgs, pred, conf_thres, iou_thres):
+ """
+ Postprocessing of raw pytorch model output.
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ bboxes = [[] for _ in range(len(origimgs))]
+ landmarks = [[] for _ in range(len(origimgs))]
+
+ pred = non_max_suppression_face(pred, conf_thres, iou_thres)
+
+ for image_id, origimg in enumerate(origimgs):
+ img_shape = origimg.shape
+ image_height, image_width = img_shape[:2]
+ gn = torch.tensor(img_shape)[[1, 0, 1, 0]] # normalization gain whwh
+ gn_lks = torch.tensor(img_shape)[[1, 0, 1, 0, 1, 0, 1, 0, 1, 0]] # normalization gain landmarks
+ det = pred[image_id].cpu()
+ scale_coords(imgs[image_id].shape[1:], det[:, :4], img_shape).round()
+ scale_coords_landmarks(imgs[image_id].shape[1:], det[:, 5:15], img_shape).round()
+
+ for j in range(det.size()[0]):
+ box = (det[j, :4].view(1, 4) / gn).view(-1).tolist()
+ box = list(
+ map(int, [box[0] * image_width, box[1] * image_height, box[2] * image_width, box[3] * image_height])
+ )
+ if box[3] - box[1] < self.min_face:
+ continue
+ lm = (det[j, 5:15].view(1, 10) / gn_lks).view(-1).tolist()
+ lm = list(map(int, [i * image_width if j % 2 == 0 else i * image_height for j, i in enumerate(lm)]))
+ lm = [lm[i : i + 2] for i in range(0, len(lm), 2)]
+ bboxes[image_id].append(box)
+ landmarks[image_id].append(lm)
+ return bboxes, landmarks
+
+ def detect_faces(self, imgs, conf_thres=0.7, iou_thres=0.5):
+ """
+ Get bbox coordinates and keypoints of faces on original image.
+ Params:
+ imgs: image or list of images to detect faces on with BGR order (convert to RGB order for inference)
+ conf_thres: confidence threshold for each prediction
+ iou_thres: threshold for NMS (filter of intersecting bboxes)
+ Returns:
+ bboxes: list of arrays with 4 coordinates of bounding boxes with format x1,y1,x2,y2.
+ points: list of arrays with coordinates of 5 facial keypoints (eyes, nose, lips corners).
+ """
+ # Pass input images through face detector
+ images = imgs if isinstance(imgs, list) else [imgs]
+ images = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in images]
+ origimgs = copy.deepcopy(images)
+
+ images = self._preprocess(images)
+
+ if IS_HIGH_VERSION:
+ with torch.inference_mode(): # for pytorch>=1.9
+ pred = self.detector(images)[0]
+ else:
+ with torch.no_grad(): # for pytorch<1.9
+ pred = self.detector(images)[0]
+
+ bboxes, points = self._postprocess(images, origimgs, pred, conf_thres, iou_thres)
+
+ # return bboxes, points
+ if not isListempty(points):
+ bboxes = np.array(bboxes).reshape(-1,4)
+ points = np.array(points).reshape(-1,10)
+ padding = bboxes[:,0].reshape(-1,1)
+ return np.concatenate((bboxes, padding, points), axis=1)
+ else:
+ return None
+
+ def __call__(self, *args):
+ return self.predict(*args)
diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/__init__.py b/repositories/codeformer/facelib/detection/yolov5face/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/common.py b/repositories/codeformer/facelib/detection/yolov5face/models/common.py
new file mode 100644
index 000000000..497a00444
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/models/common.py
@@ -0,0 +1,299 @@
+# This file contains modules common to various models
+
+import math
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.utils.datasets import letterbox
+from facelib.detection.yolov5face.utils.general import (
+ make_divisible,
+ non_max_suppression,
+ scale_coords,
+ xyxy2xywh,
+)
+
+
+def autopad(k, p=None): # kernel, padding
+ # Pad to 'same'
+ if p is None:
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
+ return p
+
+
+def channel_shuffle(x, groups):
+ batchsize, num_channels, height, width = x.data.size()
+ channels_per_group = torch.div(num_channels, groups, rounding_mode="trunc")
+
+ # reshape
+ x = x.view(batchsize, groups, channels_per_group, height, width)
+ x = torch.transpose(x, 1, 2).contiguous()
+
+ # flatten
+ return x.view(batchsize, -1, height, width)
+
+
+def DWConv(c1, c2, k=1, s=1, act=True):
+ # Depthwise convolution
+ return Conv(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
+
+
+class Conv(nn.Module):
+ # Standard convolution
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
+
+ def forward(self, x):
+ return self.act(self.bn(self.conv(x)))
+
+ def fuseforward(self, x):
+ return self.act(self.conv(x))
+
+
+class StemBlock(nn.Module):
+ def __init__(self, c1, c2, k=3, s=2, p=None, g=1, act=True):
+ super().__init__()
+ self.stem_1 = Conv(c1, c2, k, s, p, g, act)
+ self.stem_2a = Conv(c2, c2 // 2, 1, 1, 0)
+ self.stem_2b = Conv(c2 // 2, c2, 3, 2, 1)
+ self.stem_2p = nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)
+ self.stem_3 = Conv(c2 * 2, c2, 1, 1, 0)
+
+ def forward(self, x):
+ stem_1_out = self.stem_1(x)
+ stem_2a_out = self.stem_2a(stem_1_out)
+ stem_2b_out = self.stem_2b(stem_2a_out)
+ stem_2p_out = self.stem_2p(stem_1_out)
+ return self.stem_3(torch.cat((stem_2b_out, stem_2p_out), 1))
+
+
+class Bottleneck(nn.Module):
+ # Standard bottleneck
+ def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_, c2, 3, 1, g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class BottleneckCSP(nn.Module):
+ # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
+ self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
+ self.cv4 = Conv(2 * c_, c2, 1, 1)
+ self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
+ self.act = nn.LeakyReLU(0.1, inplace=True)
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ y1 = self.cv3(self.m(self.cv1(x)))
+ y2 = self.cv2(x)
+ return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
+
+
+class C3(nn.Module):
+ # CSP Bottleneck with 3 convolutions
+ def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c1, c_, 1, 1)
+ self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
+ self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
+
+ def forward(self, x):
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
+
+
+class ShuffleV2Block(nn.Module):
+ def __init__(self, inp, oup, stride):
+ super().__init__()
+
+ if not 1 <= stride <= 3:
+ raise ValueError("illegal stride value")
+ self.stride = stride
+
+ branch_features = oup // 2
+
+ if self.stride > 1:
+ self.branch1 = nn.Sequential(
+ self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(inp),
+ nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ )
+ else:
+ self.branch1 = nn.Sequential()
+
+ self.branch2 = nn.Sequential(
+ nn.Conv2d(
+ inp if (self.stride > 1) else branch_features,
+ branch_features,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False,
+ ),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
+ nn.BatchNorm2d(branch_features),
+ nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(branch_features),
+ nn.SiLU(),
+ )
+
+ @staticmethod
+ def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
+ return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
+
+ def forward(self, x):
+ if self.stride == 1:
+ x1, x2 = x.chunk(2, dim=1)
+ out = torch.cat((x1, self.branch2(x2)), dim=1)
+ else:
+ out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
+ out = channel_shuffle(out, 2)
+ return out
+
+
+class SPP(nn.Module):
+ # Spatial pyramid pooling layer used in YOLOv3-SPP
+ def __init__(self, c1, c2, k=(5, 9, 13)):
+ super().__init__()
+ c_ = c1 // 2 # hidden channels
+ self.cv1 = Conv(c1, c_, 1, 1)
+ self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
+ self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
+
+ def forward(self, x):
+ x = self.cv1(x)
+ return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
+
+
+class Focus(nn.Module):
+ # Focus wh information into c-space
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
+ super().__init__()
+ self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
+
+ def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
+ return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
+
+
+class Concat(nn.Module):
+ # Concatenate a list of tensors along dimension
+ def __init__(self, dimension=1):
+ super().__init__()
+ self.d = dimension
+
+ def forward(self, x):
+ return torch.cat(x, self.d)
+
+
+class NMS(nn.Module):
+ # Non-Maximum Suppression (NMS) module
+ conf = 0.25 # confidence threshold
+ iou = 0.45 # IoU threshold
+ classes = None # (optional list) filter by class
+
+ def forward(self, x):
+ return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes)
+
+
+class AutoShape(nn.Module):
+ # input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
+ img_size = 640 # inference size (pixels)
+ conf = 0.25 # NMS confidence threshold
+ iou = 0.45 # NMS IoU threshold
+ classes = None # (optional list) filter by class
+
+ def __init__(self, model):
+ super().__init__()
+ self.model = model.eval()
+
+ def autoshape(self):
+ print("autoShape already enabled, skipping... ") # model already converted to model.autoshape()
+ return self
+
+ def forward(self, imgs, size=640, augment=False, profile=False):
+ # Inference from various sources. For height=720, width=1280, RGB images example inputs are:
+ # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(720,1280,3)
+ # PIL: = Image.open('image.jpg') # HWC x(720,1280,3)
+ # numpy: = np.zeros((720,1280,3)) # HWC
+ # torch: = torch.zeros(16,3,720,1280) # BCHW
+ # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
+
+ p = next(self.model.parameters()) # for device and type
+ if isinstance(imgs, torch.Tensor): # torch
+ return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
+
+ # Pre-process
+ n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
+ shape0, shape1 = [], [] # image and inference shapes
+ for i, im in enumerate(imgs):
+ im = np.array(im) # to numpy
+ if im.shape[0] < 5: # image in CHW
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
+ im = im[:, :, :3] if im.ndim == 3 else np.tile(im[:, :, None], 3) # enforce 3ch input
+ s = im.shape[:2] # HWC
+ shape0.append(s) # image shape
+ g = size / max(s) # gain
+ shape1.append([y * g for y in s])
+ imgs[i] = im # update
+ shape1 = [make_divisible(x, int(self.stride.max())) for x in np.stack(shape1, 0).max(0)] # inference shape
+ x = [letterbox(im, new_shape=shape1, auto=False)[0] for im in imgs] # pad
+ x = np.stack(x, 0) if n > 1 else x[0][None] # stack
+ x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255.0 # uint8 to fp16/32
+
+ # Inference
+ with torch.no_grad():
+ y = self.model(x, augment, profile)[0] # forward
+ y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
+
+ # Post-process
+ for i in range(n):
+ scale_coords(shape1, y[i][:, :4], shape0[i])
+
+ return Detections(imgs, y, self.names)
+
+
+class Detections:
+ # detections class for YOLOv5 inference results
+ def __init__(self, imgs, pred, names=None):
+ super().__init__()
+ d = pred[0].device # device
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1.0, 1.0], device=d) for im in imgs] # normalizations
+ self.imgs = imgs # list of images as numpy arrays
+ self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
+ self.names = names # class names
+ self.xyxy = pred # xyxy pixels
+ self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
+ self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
+ self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
+ self.n = len(self.pred)
+
+ def __len__(self):
+ return self.n
+
+ def tolist(self):
+ # return a list of Detections objects, i.e. 'for result in results.tolist():'
+ x = [Detections([self.imgs[i]], [self.pred[i]], self.names) for i in range(self.n)]
+ for d in x:
+ for k in ["imgs", "pred", "xyxy", "xyxyn", "xywh", "xywhn"]:
+ setattr(d, k, getattr(d, k)[0]) # pop out of list
+ return x
diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/experimental.py b/repositories/codeformer/facelib/detection/yolov5face/models/experimental.py
new file mode 100644
index 000000000..37ba4c442
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/models/experimental.py
@@ -0,0 +1,45 @@
+# # This file contains experimental modules
+
+import numpy as np
+import torch
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import Conv
+
+
+class CrossConv(nn.Module):
+ # Cross Convolution Downsample
+ def __init__(self, c1, c2, k=3, s=1, g=1, e=1.0, shortcut=False):
+ # ch_in, ch_out, kernel, stride, groups, expansion, shortcut
+ super().__init__()
+ c_ = int(c2 * e) # hidden channels
+ self.cv1 = Conv(c1, c_, (1, k), (1, s))
+ self.cv2 = Conv(c_, c2, (k, 1), (s, 1), g=g)
+ self.add = shortcut and c1 == c2
+
+ def forward(self, x):
+ return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
+
+
+class MixConv2d(nn.Module):
+ # Mixed Depthwise Conv https://arxiv.org/abs/1907.09595
+ def __init__(self, c1, c2, k=(1, 3), s=1, equal_ch=True):
+ super().__init__()
+ groups = len(k)
+ if equal_ch: # equal c_ per group
+ i = torch.linspace(0, groups - 1e-6, c2).floor() # c2 indices
+ c_ = [(i == g).sum() for g in range(groups)] # intermediate channels
+ else: # equal weight.numel() per group
+ b = [c2] + [0] * groups
+ a = np.eye(groups + 1, groups, k=-1)
+ a -= np.roll(a, 1, axis=1)
+ a *= np.array(k) ** 2
+ a[0] = 1
+ c_ = np.linalg.lstsq(a, b, rcond=None)[0].round() # solve for equal weight indices, ax = b
+
+ self.m = nn.ModuleList([nn.Conv2d(c1, int(c_[g]), k[g], s, k[g] // 2, bias=False) for g in range(groups)])
+ self.bn = nn.BatchNorm2d(c2)
+ self.act = nn.LeakyReLU(0.1, inplace=True)
+
+ def forward(self, x):
+ return x + self.act(self.bn(torch.cat([m(x) for m in self.m], 1)))
diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/yolo.py b/repositories/codeformer/facelib/detection/yolov5face/models/yolo.py
new file mode 100644
index 000000000..70845d972
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/models/yolo.py
@@ -0,0 +1,235 @@
+import math
+from copy import deepcopy
+from pathlib import Path
+
+import torch
+import yaml # for torch hub
+from torch import nn
+
+from facelib.detection.yolov5face.models.common import (
+ C3,
+ NMS,
+ SPP,
+ AutoShape,
+ Bottleneck,
+ BottleneckCSP,
+ Concat,
+ Conv,
+ DWConv,
+ Focus,
+ ShuffleV2Block,
+ StemBlock,
+)
+from facelib.detection.yolov5face.models.experimental import CrossConv, MixConv2d
+from facelib.detection.yolov5face.utils.autoanchor import check_anchor_order
+from facelib.detection.yolov5face.utils.general import make_divisible
+from facelib.detection.yolov5face.utils.torch_utils import copy_attr, fuse_conv_and_bn
+
+
+class Detect(nn.Module):
+ stride = None # strides computed during build
+ export = False # onnx export
+
+ def __init__(self, nc=80, anchors=(), ch=()): # detection layer
+ super().__init__()
+ self.nc = nc # number of classes
+ self.no = nc + 5 + 10 # number of outputs per anchor
+
+ self.nl = len(anchors) # number of detection layers
+ self.na = len(anchors[0]) // 2 # number of anchors
+ self.grid = [torch.zeros(1)] * self.nl # init grid
+ a = torch.tensor(anchors).float().view(self.nl, -1, 2)
+ self.register_buffer("anchors", a) # shape(nl,na,2)
+ self.register_buffer("anchor_grid", a.clone().view(self.nl, 1, -1, 1, 1, 2)) # shape(nl,1,na,1,1,2)
+ self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
+
+ def forward(self, x):
+ z = [] # inference output
+ if self.export:
+ for i in range(self.nl):
+ x[i] = self.m[i](x[i])
+ return x
+ for i in range(self.nl):
+ x[i] = self.m[i](x[i]) # conv
+ bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
+ x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
+
+ if not self.training: # inference
+ if self.grid[i].shape[2:4] != x[i].shape[2:4]:
+ self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
+
+ y = torch.full_like(x[i], 0)
+ y[..., [0, 1, 2, 3, 4, 15]] = x[i][..., [0, 1, 2, 3, 4, 15]].sigmoid()
+ y[..., 5:15] = x[i][..., 5:15]
+
+ y[..., 0:2] = (y[..., 0:2] * 2.0 - 0.5 + self.grid[i].to(x[i].device)) * self.stride[i] # xy
+ y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
+
+ y[..., 5:7] = (
+ y[..., 5:7] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x1 y1
+ y[..., 7:9] = (
+ y[..., 7:9] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x2 y2
+ y[..., 9:11] = (
+ y[..., 9:11] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x3 y3
+ y[..., 11:13] = (
+ y[..., 11:13] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x4 y4
+ y[..., 13:15] = (
+ y[..., 13:15] * self.anchor_grid[i] + self.grid[i].to(x[i].device) * self.stride[i]
+ ) # landmark x5 y5
+
+ z.append(y.view(bs, -1, self.no))
+
+ return x if self.training else (torch.cat(z, 1), x)
+
+ @staticmethod
+ def _make_grid(nx=20, ny=20):
+ # yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)], indexing="ij") # for pytorch>=1.10
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
+ return torch.stack((xv, yv), 2).view((1, 1, ny, nx, 2)).float()
+
+
+class Model(nn.Module):
+ def __init__(self, cfg="yolov5s.yaml", ch=3, nc=None): # model, input channels, number of classes
+ super().__init__()
+ self.yaml_file = Path(cfg).name
+ with Path(cfg).open(encoding="utf8") as f:
+ self.yaml = yaml.safe_load(f) # model dict
+
+ # Define model
+ ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
+ if nc and nc != self.yaml["nc"]:
+ self.yaml["nc"] = nc # override yaml value
+
+ self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch]) # model, savelist
+ self.names = [str(i) for i in range(self.yaml["nc"])] # default names
+
+ # Build strides, anchors
+ m = self.model[-1] # Detect()
+ if isinstance(m, Detect):
+ s = 128 # 2x min stride
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
+ m.anchors /= m.stride.view(-1, 1, 1)
+ check_anchor_order(m)
+ self.stride = m.stride
+ self._initialize_biases() # only run once
+
+ def forward(self, x):
+ return self.forward_once(x) # single-scale inference, train
+
+ def forward_once(self, x):
+ y = [] # outputs
+ for m in self.model:
+ if m.f != -1: # if not from previous layer
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
+
+ x = m(x) # run
+ y.append(x if m.i in self.save else None) # save output
+
+ return x
+
+ def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
+ # https://arxiv.org/abs/1708.02002 section 3.3
+ m = self.model[-1] # Detect() module
+ for mi, s in zip(m.m, m.stride): # from
+ b = mi.bias.view(m.na, -1) # conv.bias(255) to (3,85)
+ b.data[:, 4] += math.log(8 / (640 / s) ** 2) # obj (8 objects per 640 image)
+ b.data[:, 5:] += math.log(0.6 / (m.nc - 0.99)) if cf is None else torch.log(cf / cf.sum()) # cls
+ mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+ def _print_biases(self):
+ m = self.model[-1] # Detect() module
+ for mi in m.m: # from
+ b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
+ print(("%6g Conv2d.bias:" + "%10.3g" * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
+
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
+ print("Fusing layers... ")
+ for m in self.model.modules():
+ if isinstance(m, Conv) and hasattr(m, "bn"):
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
+ delattr(m, "bn") # remove batchnorm
+ m.forward = m.fuseforward # update forward
+ elif type(m) is nn.Upsample:
+ m.recompute_scale_factor = None # torch 1.11.0 compatibility
+ return self
+
+ def nms(self, mode=True): # add or remove NMS module
+ present = isinstance(self.model[-1], NMS) # last layer is NMS
+ if mode and not present:
+ print("Adding NMS... ")
+ m = NMS() # module
+ m.f = -1 # from
+ m.i = self.model[-1].i + 1 # index
+ self.model.add_module(name=str(m.i), module=m) # add
+ self.eval()
+ elif not mode and present:
+ print("Removing NMS... ")
+ self.model = self.model[:-1] # remove
+ return self
+
+ def autoshape(self): # add autoShape module
+ print("Adding autoShape... ")
+ m = AutoShape(self) # wrap model
+ copy_attr(m, self, include=("yaml", "nc", "hyp", "names", "stride"), exclude=()) # copy attributes
+ return m
+
+
+def parse_model(d, ch): # model_dict, input_channels(3)
+ anchors, nc, gd, gw = d["anchors"], d["nc"], d["depth_multiple"], d["width_multiple"]
+ na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
+ no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
+
+ layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
+ for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
+ m = eval(m) if isinstance(m, str) else m # eval strings
+ for j, a in enumerate(args):
+ try:
+ args[j] = eval(a) if isinstance(a, str) else a # eval strings
+ except:
+ pass
+
+ n = max(round(n * gd), 1) if n > 1 else n # depth gain
+ if m in [
+ Conv,
+ Bottleneck,
+ SPP,
+ DWConv,
+ MixConv2d,
+ Focus,
+ CrossConv,
+ BottleneckCSP,
+ C3,
+ ShuffleV2Block,
+ StemBlock,
+ ]:
+ c1, c2 = ch[f], args[0]
+
+ c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
+
+ args = [c1, c2, *args[1:]]
+ if m in [BottleneckCSP, C3]:
+ args.insert(2, n)
+ n = 1
+ elif m is nn.BatchNorm2d:
+ args = [ch[f]]
+ elif m is Concat:
+ c2 = sum(ch[-1 if x == -1 else x + 1] for x in f)
+ elif m is Detect:
+ args.append([ch[x + 1] for x in f])
+ if isinstance(args[1], int): # number of anchors
+ args[1] = [list(range(args[1] * 2))] * len(f)
+ else:
+ c2 = ch[f]
+
+ m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
+ t = str(m)[8:-2].replace("__main__.", "") # module type
+ np = sum(x.numel() for x in m_.parameters()) # number params
+ m_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number params
+ save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
+ layers.append(m_)
+ ch.append(c2)
+ return nn.Sequential(*layers), sorted(save)
diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/yolov5l.yaml b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5l.yaml
new file mode 100644
index 000000000..0532b0e22
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5l.yaml
@@ -0,0 +1,47 @@
+# parameters
+nc: 1 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# anchors
+anchors:
+ - [4,5, 8,10, 13,16] # P3/8
+ - [23,29, 43,55, 73,105] # P4/16
+ - [146,217, 231,300, 335,433] # P5/32
+
+# YOLOv5 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, StemBlock, [64, 3, 2]], # 0-P1/2
+ [-1, 3, C3, [128]],
+ [-1, 1, Conv, [256, 3, 2]], # 2-P3/8
+ [-1, 9, C3, [256]],
+ [-1, 1, Conv, [512, 3, 2]], # 4-P4/16
+ [-1, 9, C3, [512]],
+ [-1, 1, Conv, [1024, 3, 2]], # 6-P5/32
+ [-1, 1, SPP, [1024, [3,5,7]]],
+ [-1, 3, C3, [1024, False]], # 8
+ ]
+
+# YOLOv5 head
+head:
+ [[-1, 1, Conv, [512, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 5], 1, Concat, [1]], # cat backbone P4
+ [-1, 3, C3, [512, False]], # 12
+
+ [-1, 1, Conv, [256, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 3], 1, Concat, [1]], # cat backbone P3
+ [-1, 3, C3, [256, False]], # 16 (P3/8-small)
+
+ [-1, 1, Conv, [256, 3, 2]],
+ [[-1, 13], 1, Concat, [1]], # cat head P4
+ [-1, 3, C3, [512, False]], # 19 (P4/16-medium)
+
+ [-1, 1, Conv, [512, 3, 2]],
+ [[-1, 9], 1, Concat, [1]], # cat head P5
+ [-1, 3, C3, [1024, False]], # 22 (P5/32-large)
+
+ [[16, 19, 22], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
+ ]
\ No newline at end of file
diff --git a/repositories/codeformer/facelib/detection/yolov5face/models/yolov5n.yaml b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5n.yaml
new file mode 100644
index 000000000..caba6bed6
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/models/yolov5n.yaml
@@ -0,0 +1,45 @@
+# parameters
+nc: 1 # number of classes
+depth_multiple: 1.0 # model depth multiple
+width_multiple: 1.0 # layer channel multiple
+
+# anchors
+anchors:
+ - [4,5, 8,10, 13,16] # P3/8
+ - [23,29, 43,55, 73,105] # P4/16
+ - [146,217, 231,300, 335,433] # P5/32
+
+# YOLOv5 backbone
+backbone:
+ # [from, number, module, args]
+ [[-1, 1, StemBlock, [32, 3, 2]], # 0-P2/4
+ [-1, 1, ShuffleV2Block, [128, 2]], # 1-P3/8
+ [-1, 3, ShuffleV2Block, [128, 1]], # 2
+ [-1, 1, ShuffleV2Block, [256, 2]], # 3-P4/16
+ [-1, 7, ShuffleV2Block, [256, 1]], # 4
+ [-1, 1, ShuffleV2Block, [512, 2]], # 5-P5/32
+ [-1, 3, ShuffleV2Block, [512, 1]], # 6
+ ]
+
+# YOLOv5 head
+head:
+ [[-1, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 4], 1, Concat, [1]], # cat backbone P4
+ [-1, 1, C3, [128, False]], # 10
+
+ [-1, 1, Conv, [128, 1, 1]],
+ [-1, 1, nn.Upsample, [None, 2, 'nearest']],
+ [[-1, 2], 1, Concat, [1]], # cat backbone P3
+ [-1, 1, C3, [128, False]], # 14 (P3/8-small)
+
+ [-1, 1, Conv, [128, 3, 2]],
+ [[-1, 11], 1, Concat, [1]], # cat head P4
+ [-1, 1, C3, [128, False]], # 17 (P4/16-medium)
+
+ [-1, 1, Conv, [128, 3, 2]],
+ [[-1, 7], 1, Concat, [1]], # cat head P5
+ [-1, 1, C3, [128, False]], # 20 (P5/32-large)
+
+ [[14, 17, 20], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
+ ]
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/__init__.py b/repositories/codeformer/facelib/detection/yolov5face/utils/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/autoanchor.py b/repositories/codeformer/facelib/detection/yolov5face/utils/autoanchor.py
new file mode 100644
index 000000000..a4eba3e94
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/utils/autoanchor.py
@@ -0,0 +1,12 @@
+# Auto-anchor utils
+
+
+def check_anchor_order(m):
+ # Check anchor order against stride order for YOLOv5 Detect() module m, and correct if necessary
+ a = m.anchor_grid.prod(-1).view(-1) # anchor area
+ da = a[-1] - a[0] # delta a
+ ds = m.stride[-1] - m.stride[0] # delta s
+ if da.sign() != ds.sign(): # same order
+ print("Reversing anchor order")
+ m.anchors[:] = m.anchors.flip(0)
+ m.anchor_grid[:] = m.anchor_grid.flip(0)
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py b/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py
new file mode 100755
index 000000000..e672b136f
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/utils/datasets.py
@@ -0,0 +1,35 @@
+import cv2
+import numpy as np
+
+
+def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scale_fill=False, scaleup=True):
+ # Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
+ shape = img.shape[:2] # current shape [height, width]
+ if isinstance(new_shape, int):
+ new_shape = (new_shape, new_shape)
+
+ # Scale ratio (new / old)
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
+ if not scaleup: # only scale down, do not scale up (for better test mAP)
+ r = min(r, 1.0)
+
+ # Compute padding
+ ratio = r, r # width, height ratios
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
+ if auto: # minimum rectangle
+ dw, dh = np.mod(dw, 64), np.mod(dh, 64) # wh padding
+ elif scale_fill: # stretch
+ dw, dh = 0.0, 0.0
+ new_unpad = (new_shape[1], new_shape[0])
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
+
+ dw /= 2 # divide padding into 2 sides
+ dh /= 2
+
+ if shape[::-1] != new_unpad: # resize
+ img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
+ img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
+ return img, ratio, (dw, dh)
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py b/repositories/codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py
new file mode 100644
index 000000000..4b8b63134
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/utils/extract_ckpt.py
@@ -0,0 +1,5 @@
+import torch
+import sys
+sys.path.insert(0,'./facelib/detection/yolov5face')
+model = torch.load('facelib/detection/yolov5face/yolov5n-face.pt', map_location='cpu')['model']
+torch.save(model.state_dict(),'weights/facelib/yolov5n-face.pth')
\ No newline at end of file
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/general.py b/repositories/codeformer/facelib/detection/yolov5face/utils/general.py
new file mode 100755
index 000000000..1c8e14f56
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/utils/general.py
@@ -0,0 +1,271 @@
+import math
+import time
+
+import numpy as np
+import torch
+import torchvision
+
+
+def check_img_size(img_size, s=32):
+ # Verify img_size is a multiple of stride s
+ new_size = make_divisible(img_size, int(s)) # ceil gs-multiple
+ # if new_size != img_size:
+ # print(f"WARNING: --img-size {img_size:g} must be multiple of max stride {s:g}, updating to {new_size:g}")
+ return new_size
+
+
+def make_divisible(x, divisor):
+ # Returns x evenly divisible by divisor
+ return math.ceil(x / divisor) * divisor
+
+
+def xyxy2xywh(x):
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
+ y[:, 2] = x[:, 2] - x[:, 0] # width
+ y[:, 3] = x[:, 3] - x[:, 1] # height
+ return y
+
+
+def xywh2xyxy(x):
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
+ return y
+
+
+def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2]] -= pad[0] # x padding
+ coords[:, [1, 3]] -= pad[1] # y padding
+ coords[:, :4] /= gain
+ clip_coords(coords, img0_shape)
+ return coords
+
+
+def clip_coords(boxes, img_shape):
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
+
+
+def box_iou(box1, box2):
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
+ """
+ Return intersection-over-union (Jaccard index) of boxes.
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
+ Arguments:
+ box1 (Tensor[N, 4])
+ box2 (Tensor[M, 4])
+ Returns:
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
+ IoU values for every element in boxes1 and boxes2
+ """
+
+ def box_area(box):
+ return (box[2] - box[0]) * (box[3] - box[1])
+
+ area1 = box_area(box1.T)
+ area2 = box_area(box2.T)
+
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
+ return inter / (area1[:, None] + area2 - inter)
+
+
+def non_max_suppression_face(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 15 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ # Apply constraints
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label = labels[xi]
+ v = torch.zeros((len(label), nc + 15), device=x.device)
+ v[:, :4] = label[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label)), label[:, 0].long() + 15] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, landmarks, cls)
+ if multi_label:
+ i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 15, None], x[:, 5:15], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 15:].max(1, keepdim=True)
+ x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # If none remain process next image
+ n = x.shape[0] # number of boxes
+ if not n:
+ continue
+
+ # Batched NMS
+ c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ break # time limit exceeded
+
+ return output
+
+
+def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, labels=()):
+ """Performs Non-Maximum Suppression (NMS) on inference results
+
+ Returns:
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
+ """
+
+ nc = prediction.shape[2] - 5 # number of classes
+ xc = prediction[..., 4] > conf_thres # candidates
+
+ # Settings
+ # (pixels) maximum box width and height
+ max_wh = 4096
+ time_limit = 10.0 # seconds to quit after
+ redundant = True # require redundant detections
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
+ merge = False # use merge-NMS
+
+ t = time.time()
+ output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]
+ for xi, x in enumerate(prediction): # image index, image inference
+ x = x[xc[xi]] # confidence
+
+ # Cat apriori labels if autolabelling
+ if labels and len(labels[xi]):
+ label_id = labels[xi]
+ v = torch.zeros((len(label_id), nc + 5), device=x.device)
+ v[:, :4] = label_id[:, 1:5] # box
+ v[:, 4] = 1.0 # conf
+ v[range(len(label_id)), label_id[:, 0].long() + 5] = 1.0 # cls
+ x = torch.cat((x, v), 0)
+
+ # If none remain process next image
+ if not x.shape[0]:
+ continue
+
+ # Compute conf
+ x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
+
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
+ box = xywh2xyxy(x[:, :4])
+
+ # Detections matrix nx6 (xyxy, conf, cls)
+ if multi_label:
+ i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
+ x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
+ else: # best class only
+ conf, j = x[:, 5:].max(1, keepdim=True)
+ x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
+
+ # Filter by class
+ if classes is not None:
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
+
+ # Check shape
+ n = x.shape[0] # number of boxes
+ if not n: # no boxes
+ continue
+
+ x = x[x[:, 4].argsort(descending=True)] # sort by confidence
+
+ # Batched NMS
+ c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
+ if merge and (1 < n < 3e3): # Merge NMS (boxes merged using weighted mean)
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
+ weights = iou * scores[None] # box weights
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
+ if redundant:
+ i = i[iou.sum(1) > 1] # require redundancy
+
+ output[xi] = x[i]
+ if (time.time() - t) > time_limit:
+ print(f"WARNING: NMS time limit {time_limit}s exceeded")
+ break # time limit exceeded
+
+ return output
+
+
+def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
+ # Rescale coords (xyxy) from img1_shape to img0_shape
+ if ratio_pad is None: # calculate from img0_shape
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
+ else:
+ gain = ratio_pad[0][0]
+ pad = ratio_pad[1]
+
+ coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
+ coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
+ coords[:, :10] /= gain
+ coords[:, 0].clamp_(0, img0_shape[1]) # x1
+ coords[:, 1].clamp_(0, img0_shape[0]) # y1
+ coords[:, 2].clamp_(0, img0_shape[1]) # x2
+ coords[:, 3].clamp_(0, img0_shape[0]) # y2
+ coords[:, 4].clamp_(0, img0_shape[1]) # x3
+ coords[:, 5].clamp_(0, img0_shape[0]) # y3
+ coords[:, 6].clamp_(0, img0_shape[1]) # x4
+ coords[:, 7].clamp_(0, img0_shape[0]) # y4
+ coords[:, 8].clamp_(0, img0_shape[1]) # x5
+ coords[:, 9].clamp_(0, img0_shape[0]) # y5
+ return coords
diff --git a/repositories/codeformer/facelib/detection/yolov5face/utils/torch_utils.py b/repositories/codeformer/facelib/detection/yolov5face/utils/torch_utils.py
new file mode 100644
index 000000000..af2d06587
--- /dev/null
+++ b/repositories/codeformer/facelib/detection/yolov5face/utils/torch_utils.py
@@ -0,0 +1,40 @@
+import torch
+from torch import nn
+
+
+def fuse_conv_and_bn(conv, bn):
+ # Fuse convolution and batchnorm layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
+ fusedconv = (
+ nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ kernel_size=conv.kernel_size,
+ stride=conv.stride,
+ padding=conv.padding,
+ groups=conv.groups,
+ bias=True,
+ )
+ .requires_grad_(False)
+ .to(conv.weight.device)
+ )
+
+ # prepare filters
+ w_conv = conv.weight.clone().view(conv.out_channels, -1)
+ w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
+ fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
+
+ # prepare spatial bias
+ b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
+ b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
+ fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
+
+ return fusedconv
+
+
+def copy_attr(a, b, include=(), exclude=()):
+ # Copy attributes from b to a, options to only include [...] and to exclude [...]
+ for k, v in b.__dict__.items():
+ if (include and k not in include) or k.startswith("_") or k in exclude:
+ continue
+
+ setattr(a, k, v)
diff --git a/repositories/codeformer/facelib/parsing/__init__.py b/repositories/codeformer/facelib/parsing/__init__.py
new file mode 100644
index 000000000..72656e4b5
--- /dev/null
+++ b/repositories/codeformer/facelib/parsing/__init__.py
@@ -0,0 +1,23 @@
+import torch
+
+from facelib.utils import load_file_from_url
+from .bisenet import BiSeNet
+from .parsenet import ParseNet
+
+
+def init_parsing_model(model_name='bisenet', half=False, device='cuda'):
+ if model_name == 'bisenet':
+ model = BiSeNet(num_class=19)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_bisenet.pth'
+ elif model_name == 'parsenet':
+ model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
+ model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
+ else:
+ raise NotImplementedError(f'{model_name} is not implemented.')
+
+ model_path = load_file_from_url(url=model_url, model_dir='weights/facelib', progress=True, file_name=None)
+ load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
+ model.load_state_dict(load_net, strict=True)
+ model.eval()
+ model = model.to(device)
+ return model
diff --git a/repositories/codeformer/facelib/parsing/bisenet.py b/repositories/codeformer/facelib/parsing/bisenet.py
new file mode 100644
index 000000000..3898cab76
--- /dev/null
+++ b/repositories/codeformer/facelib/parsing/bisenet.py
@@ -0,0 +1,140 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .resnet import ResNet18
+
+
+class ConvBNReLU(nn.Module):
+
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1):
+ super(ConvBNReLU, self).__init__()
+ self.conv = nn.Conv2d(in_chan, out_chan, kernel_size=ks, stride=stride, padding=padding, bias=False)
+ self.bn = nn.BatchNorm2d(out_chan)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = F.relu(self.bn(x))
+ return x
+
+
+class BiSeNetOutput(nn.Module):
+
+ def __init__(self, in_chan, mid_chan, num_class):
+ super(BiSeNetOutput, self).__init__()
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
+ self.conv_out = nn.Conv2d(mid_chan, num_class, kernel_size=1, bias=False)
+
+ def forward(self, x):
+ feat = self.conv(x)
+ out = self.conv_out(feat)
+ return out, feat
+
+
+class AttentionRefinementModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(AttentionRefinementModule, self).__init__()
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size=1, bias=False)
+ self.bn_atten = nn.BatchNorm2d(out_chan)
+ self.sigmoid_atten = nn.Sigmoid()
+
+ def forward(self, x):
+ feat = self.conv(x)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv_atten(atten)
+ atten = self.bn_atten(atten)
+ atten = self.sigmoid_atten(atten)
+ out = torch.mul(feat, atten)
+ return out
+
+
+class ContextPath(nn.Module):
+
+ def __init__(self):
+ super(ContextPath, self).__init__()
+ self.resnet = ResNet18()
+ self.arm16 = AttentionRefinementModule(256, 128)
+ self.arm32 = AttentionRefinementModule(512, 128)
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
+
+ def forward(self, x):
+ feat8, feat16, feat32 = self.resnet(x)
+ h8, w8 = feat8.size()[2:]
+ h16, w16 = feat16.size()[2:]
+ h32, w32 = feat32.size()[2:]
+
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
+ avg = self.conv_avg(avg)
+ avg_up = F.interpolate(avg, (h32, w32), mode='nearest')
+
+ feat32_arm = self.arm32(feat32)
+ feat32_sum = feat32_arm + avg_up
+ feat32_up = F.interpolate(feat32_sum, (h16, w16), mode='nearest')
+ feat32_up = self.conv_head32(feat32_up)
+
+ feat16_arm = self.arm16(feat16)
+ feat16_sum = feat16_arm + feat32_up
+ feat16_up = F.interpolate(feat16_sum, (h8, w8), mode='nearest')
+ feat16_up = self.conv_head16(feat16_up)
+
+ return feat8, feat16_up, feat32_up # x8, x8, x16
+
+
+class FeatureFusionModule(nn.Module):
+
+ def __init__(self, in_chan, out_chan):
+ super(FeatureFusionModule, self).__init__()
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
+ self.conv1 = nn.Conv2d(out_chan, out_chan // 4, kernel_size=1, stride=1, padding=0, bias=False)
+ self.conv2 = nn.Conv2d(out_chan // 4, out_chan, kernel_size=1, stride=1, padding=0, bias=False)
+ self.relu = nn.ReLU(inplace=True)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, fsp, fcp):
+ fcat = torch.cat([fsp, fcp], dim=1)
+ feat = self.convblk(fcat)
+ atten = F.avg_pool2d(feat, feat.size()[2:])
+ atten = self.conv1(atten)
+ atten = self.relu(atten)
+ atten = self.conv2(atten)
+ atten = self.sigmoid(atten)
+ feat_atten = torch.mul(feat, atten)
+ feat_out = feat_atten + feat
+ return feat_out
+
+
+class BiSeNet(nn.Module):
+
+ def __init__(self, num_class):
+ super(BiSeNet, self).__init__()
+ self.cp = ContextPath()
+ self.ffm = FeatureFusionModule(256, 256)
+ self.conv_out = BiSeNetOutput(256, 256, num_class)
+ self.conv_out16 = BiSeNetOutput(128, 64, num_class)
+ self.conv_out32 = BiSeNetOutput(128, 64, num_class)
+
+ def forward(self, x, return_feat=False):
+ h, w = x.size()[2:]
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # return res3b1 feature
+ feat_sp = feat_res8 # replace spatial path feature with res3b1 feature
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
+
+ out, feat = self.conv_out(feat_fuse)
+ out16, feat16 = self.conv_out16(feat_cp8)
+ out32, feat32 = self.conv_out32(feat_cp16)
+
+ out = F.interpolate(out, (h, w), mode='bilinear', align_corners=True)
+ out16 = F.interpolate(out16, (h, w), mode='bilinear', align_corners=True)
+ out32 = F.interpolate(out32, (h, w), mode='bilinear', align_corners=True)
+
+ if return_feat:
+ feat = F.interpolate(feat, (h, w), mode='bilinear', align_corners=True)
+ feat16 = F.interpolate(feat16, (h, w), mode='bilinear', align_corners=True)
+ feat32 = F.interpolate(feat32, (h, w), mode='bilinear', align_corners=True)
+ return out, out16, out32, feat, feat16, feat32
+ else:
+ return out, out16, out32
diff --git a/repositories/codeformer/facelib/parsing/parsenet.py b/repositories/codeformer/facelib/parsing/parsenet.py
new file mode 100644
index 000000000..e178ebe43
--- /dev/null
+++ b/repositories/codeformer/facelib/parsing/parsenet.py
@@ -0,0 +1,194 @@
+"""Modified from https://github.com/chaofengc/PSFRGAN
+"""
+import numpy as np
+import torch.nn as nn
+from torch.nn import functional as F
+
+
+class NormLayer(nn.Module):
+ """Normalization Layers.
+
+ Args:
+ channels: input channels, for batch norm and instance norm.
+ input_size: input shape without batch size, for layer norm.
+ """
+
+ def __init__(self, channels, normalize_shape=None, norm_type='bn'):
+ super(NormLayer, self).__init__()
+ norm_type = norm_type.lower()
+ self.norm_type = norm_type
+ if norm_type == 'bn':
+ self.norm = nn.BatchNorm2d(channels, affine=True)
+ elif norm_type == 'in':
+ self.norm = nn.InstanceNorm2d(channels, affine=False)
+ elif norm_type == 'gn':
+ self.norm = nn.GroupNorm(32, channels, affine=True)
+ elif norm_type == 'pixel':
+ self.norm = lambda x: F.normalize(x, p=2, dim=1)
+ elif norm_type == 'layer':
+ self.norm = nn.LayerNorm(normalize_shape)
+ elif norm_type == 'none':
+ self.norm = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f'Norm type {norm_type} not support.'
+
+ def forward(self, x, ref=None):
+ if self.norm_type == 'spade':
+ return self.norm(x, ref)
+ else:
+ return self.norm(x)
+
+
+class ReluLayer(nn.Module):
+ """Relu Layer.
+
+ Args:
+ relu type: type of relu layer, candidates are
+ - ReLU
+ - LeakyReLU: default relu slope 0.2
+ - PRelu
+ - SELU
+ - none: direct pass
+ """
+
+ def __init__(self, channels, relu_type='relu'):
+ super(ReluLayer, self).__init__()
+ relu_type = relu_type.lower()
+ if relu_type == 'relu':
+ self.func = nn.ReLU(True)
+ elif relu_type == 'leakyrelu':
+ self.func = nn.LeakyReLU(0.2, inplace=True)
+ elif relu_type == 'prelu':
+ self.func = nn.PReLU(channels)
+ elif relu_type == 'selu':
+ self.func = nn.SELU(True)
+ elif relu_type == 'none':
+ self.func = lambda x: x * 1.0
+ else:
+ assert 1 == 0, f'Relu type {relu_type} not support.'
+
+ def forward(self, x):
+ return self.func(x)
+
+
+class ConvLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ scale='none',
+ norm_type='none',
+ relu_type='none',
+ use_pad=True,
+ bias=True):
+ super(ConvLayer, self).__init__()
+ self.use_pad = use_pad
+ self.norm_type = norm_type
+ if norm_type in ['bn']:
+ bias = False
+
+ stride = 2 if scale == 'down' else 1
+
+ self.scale_func = lambda x: x
+ if scale == 'up':
+ self.scale_func = lambda x: nn.functional.interpolate(x, scale_factor=2, mode='nearest')
+
+ self.reflection_pad = nn.ReflectionPad2d(int(np.ceil((kernel_size - 1.) / 2)))
+ self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=bias)
+
+ self.relu = ReluLayer(out_channels, relu_type)
+ self.norm = NormLayer(out_channels, norm_type=norm_type)
+
+ def forward(self, x):
+ out = self.scale_func(x)
+ if self.use_pad:
+ out = self.reflection_pad(out)
+ out = self.conv2d(out)
+ out = self.norm(out)
+ out = self.relu(out)
+ return out
+
+
+class ResidualBlock(nn.Module):
+ """
+ Residual block recommended in: http://torch.ch/blog/2016/02/04/resnets.html
+ """
+
+ def __init__(self, c_in, c_out, relu_type='prelu', norm_type='bn', scale='none'):
+ super(ResidualBlock, self).__init__()
+
+ if scale == 'none' and c_in == c_out:
+ self.shortcut_func = lambda x: x
+ else:
+ self.shortcut_func = ConvLayer(c_in, c_out, 3, scale)
+
+ scale_config_dict = {'down': ['none', 'down'], 'up': ['up', 'none'], 'none': ['none', 'none']}
+ scale_conf = scale_config_dict[scale]
+
+ self.conv1 = ConvLayer(c_in, c_out, 3, scale_conf[0], norm_type=norm_type, relu_type=relu_type)
+ self.conv2 = ConvLayer(c_out, c_out, 3, scale_conf[1], norm_type=norm_type, relu_type='none')
+
+ def forward(self, x):
+ identity = self.shortcut_func(x)
+
+ res = self.conv1(x)
+ res = self.conv2(res)
+ return identity + res
+
+
+class ParseNet(nn.Module):
+
+ def __init__(self,
+ in_size=128,
+ out_size=128,
+ min_feat_size=32,
+ base_ch=64,
+ parsing_ch=19,
+ res_depth=10,
+ relu_type='LeakyReLU',
+ norm_type='bn',
+ ch_range=[32, 256]):
+ super().__init__()
+ self.res_depth = res_depth
+ act_args = {'norm_type': norm_type, 'relu_type': relu_type}
+ min_ch, max_ch = ch_range
+
+ ch_clip = lambda x: max(min_ch, min(x, max_ch)) # noqa: E731
+ min_feat_size = min(in_size, min_feat_size)
+
+ down_steps = int(np.log2(in_size // min_feat_size))
+ up_steps = int(np.log2(out_size // min_feat_size))
+
+ # =============== define encoder-body-decoder ====================
+ self.encoder = []
+ self.encoder.append(ConvLayer(3, base_ch, 3, 1))
+ head_ch = base_ch
+ for i in range(down_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch * 2)
+ self.encoder.append(ResidualBlock(cin, cout, scale='down', **act_args))
+ head_ch = head_ch * 2
+
+ self.body = []
+ for i in range(res_depth):
+ self.body.append(ResidualBlock(ch_clip(head_ch), ch_clip(head_ch), **act_args))
+
+ self.decoder = []
+ for i in range(up_steps):
+ cin, cout = ch_clip(head_ch), ch_clip(head_ch // 2)
+ self.decoder.append(ResidualBlock(cin, cout, scale='up', **act_args))
+ head_ch = head_ch // 2
+
+ self.encoder = nn.Sequential(*self.encoder)
+ self.body = nn.Sequential(*self.body)
+ self.decoder = nn.Sequential(*self.decoder)
+ self.out_img_conv = ConvLayer(ch_clip(head_ch), 3)
+ self.out_mask_conv = ConvLayer(ch_clip(head_ch), parsing_ch)
+
+ def forward(self, x):
+ feat = self.encoder(x)
+ x = feat + self.body(feat)
+ x = self.decoder(x)
+ out_img = self.out_img_conv(x)
+ out_mask = self.out_mask_conv(x)
+ return out_mask, out_img
diff --git a/repositories/codeformer/facelib/parsing/resnet.py b/repositories/codeformer/facelib/parsing/resnet.py
new file mode 100644
index 000000000..fec8e82cf
--- /dev/null
+++ b/repositories/codeformer/facelib/parsing/resnet.py
@@ -0,0 +1,69 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
+
+
+class BasicBlock(nn.Module):
+
+ def __init__(self, in_chan, out_chan, stride=1):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
+ self.bn1 = nn.BatchNorm2d(out_chan)
+ self.conv2 = conv3x3(out_chan, out_chan)
+ self.bn2 = nn.BatchNorm2d(out_chan)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ if in_chan != out_chan or stride != 1:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_chan, out_chan, kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(out_chan),
+ )
+
+ def forward(self, x):
+ residual = self.conv1(x)
+ residual = F.relu(self.bn1(residual))
+ residual = self.conv2(residual)
+ residual = self.bn2(residual)
+
+ shortcut = x
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out = shortcut + residual
+ out = self.relu(out)
+ return out
+
+
+def create_layer_basic(in_chan, out_chan, bnum, stride=1):
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
+ for i in range(bnum - 1):
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
+ return nn.Sequential(*layers)
+
+
+class ResNet18(nn.Module):
+
+ def __init__(self):
+ super(ResNet18, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = F.relu(self.bn1(x))
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ feat8 = self.layer2(x) # 1/8
+ feat16 = self.layer3(feat8) # 1/16
+ feat32 = self.layer4(feat16) # 1/32
+ return feat8, feat16, feat32
diff --git a/repositories/codeformer/facelib/utils/__init__.py b/repositories/codeformer/facelib/utils/__init__.py
new file mode 100644
index 000000000..f03b1c2ba
--- /dev/null
+++ b/repositories/codeformer/facelib/utils/__init__.py
@@ -0,0 +1,7 @@
+from .face_utils import align_crop_face_landmarks, compute_increased_bbox, get_valid_bboxes, paste_face_back
+from .misc import img2tensor, load_file_from_url, download_pretrained_models, scandir
+
+__all__ = [
+ 'align_crop_face_landmarks', 'compute_increased_bbox', 'get_valid_bboxes', 'load_file_from_url',
+ 'download_pretrained_models', 'paste_face_back', 'img2tensor', 'scandir'
+]
diff --git a/repositories/codeformer/facelib/utils/face_restoration_helper.py b/repositories/codeformer/facelib/utils/face_restoration_helper.py
new file mode 100644
index 000000000..5d3fb8f3b
--- /dev/null
+++ b/repositories/codeformer/facelib/utils/face_restoration_helper.py
@@ -0,0 +1,460 @@
+import cv2
+import numpy as np
+import os
+import torch
+from torchvision.transforms.functional import normalize
+
+from facelib.detection import init_detection_model
+from facelib.parsing import init_parsing_model
+from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
+
+
+def get_largest_face(det_faces, h, w):
+
+ def get_location(val, length):
+ if val < 0:
+ return 0
+ elif val > length:
+ return length
+ else:
+ return val
+
+ face_areas = []
+ for det_face in det_faces:
+ left = get_location(det_face[0], w)
+ right = get_location(det_face[2], w)
+ top = get_location(det_face[1], h)
+ bottom = get_location(det_face[3], h)
+ face_area = (right - left) * (bottom - top)
+ face_areas.append(face_area)
+ largest_idx = face_areas.index(max(face_areas))
+ return det_faces[largest_idx], largest_idx
+
+
+def get_center_face(det_faces, h=0, w=0, center=None):
+ if center is not None:
+ center = np.array(center)
+ else:
+ center = np.array([w / 2, h / 2])
+ center_dist = []
+ for det_face in det_faces:
+ face_center = np.array([(det_face[0] + det_face[2]) / 2, (det_face[1] + det_face[3]) / 2])
+ dist = np.linalg.norm(face_center - center)
+ center_dist.append(dist)
+ center_idx = center_dist.index(min(center_dist))
+ return det_faces[center_idx], center_idx
+
+
+class FaceRestoreHelper(object):
+ """Helper for the face restoration pipeline (base class)."""
+
+ def __init__(self,
+ upscale_factor,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ template_3points=False,
+ pad_blur=False,
+ use_parse=False,
+ device=None):
+ self.template_3points = template_3points # improve robustness
+ self.upscale_factor = int(upscale_factor)
+ # the cropped face ratio based on the square face
+ self.crop_ratio = crop_ratio # (h, w)
+ assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
+ self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
+
+ if self.template_3points:
+ self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
+ else:
+ # standard 5 landmarks for FFHQ faces with 512 x 512
+ # facexlib
+ self.face_template = np.array([[192.98138, 239.94708], [318.90277, 240.1936], [256.63416, 314.01935],
+ [201.26117, 371.41043], [313.08905, 371.15118]])
+
+ # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54
+ # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
+ # [198.22603, 372.82502], [313.91018, 372.75659]])
+
+
+ self.face_template = self.face_template * (face_size / 512.0)
+ if self.crop_ratio[0] > 1:
+ self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
+ if self.crop_ratio[1] > 1:
+ self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2
+ self.save_ext = save_ext
+ self.pad_blur = pad_blur
+ if self.pad_blur is True:
+ self.template_3points = False
+
+ self.all_landmarks_5 = []
+ self.det_faces = []
+ self.affine_matrices = []
+ self.inverse_affine_matrices = []
+ self.cropped_faces = []
+ self.restored_faces = []
+ self.pad_input_imgs = []
+
+ if device is None:
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ else:
+ self.device = device
+
+ # init face detection model
+ self.face_det = init_detection_model(det_model, half=False, device=self.device)
+
+ # init face parsing model
+ self.use_parse = use_parse
+ self.face_parse = init_parsing_model(model_name='parsenet', device=self.device)
+
+ def set_upscale_factor(self, upscale_factor):
+ self.upscale_factor = upscale_factor
+
+ def read_image(self, img):
+ """img can be image path or cv2 loaded image."""
+ # self.input_img is Numpy array, (h, w, c), BGR, uint8, [0, 255]
+ if isinstance(img, str):
+ img = cv2.imread(img)
+
+ if np.max(img) > 256: # 16-bit image
+ img = img / 65535 * 255
+ if len(img.shape) == 2: # gray image
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif img.shape[2] == 4: # BGRA image with alpha channel
+ img = img[:, :, 0:3]
+
+ self.input_img = img
+ self.is_gray = is_gray(img, threshold=5)
+ if self.is_gray:
+ print('Grayscale input: True')
+
+ if min(self.input_img.shape[:2])<512:
+ f = 512.0/min(self.input_img.shape[:2])
+ self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
+
+ def get_face_landmarks_5(self,
+ only_keep_largest=False,
+ only_center_face=False,
+ resize=None,
+ blur_ratio=0.01,
+ eye_dist_threshold=None):
+ if resize is None:
+ scale = 1
+ input_img = self.input_img
+ else:
+ h, w = self.input_img.shape[0:2]
+ scale = resize / min(h, w)
+ scale = max(1, scale) # always scale up
+ h, w = int(h * scale), int(w * scale)
+ interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
+ input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
+
+ with torch.no_grad():
+ bboxes = self.face_det.detect_faces(input_img)
+
+ if bboxes is None or bboxes.shape[0] == 0:
+ return 0
+ else:
+ bboxes = bboxes / scale
+
+ for bbox in bboxes:
+ # remove faces with too small eye distance: side faces or too small faces
+ eye_dist = np.linalg.norm([bbox[6] - bbox[8], bbox[7] - bbox[9]])
+ if eye_dist_threshold is not None and (eye_dist < eye_dist_threshold):
+ continue
+
+ if self.template_3points:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 11, 2)])
+ else:
+ landmark = np.array([[bbox[i], bbox[i + 1]] for i in range(5, 15, 2)])
+ self.all_landmarks_5.append(landmark)
+ self.det_faces.append(bbox[0:5])
+
+ if len(self.det_faces) == 0:
+ return 0
+ if only_keep_largest:
+ h, w, _ = self.input_img.shape
+ self.det_faces, largest_idx = get_largest_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[largest_idx]]
+ elif only_center_face:
+ h, w, _ = self.input_img.shape
+ self.det_faces, center_idx = get_center_face(self.det_faces, h, w)
+ self.all_landmarks_5 = [self.all_landmarks_5[center_idx]]
+
+ # pad blurry images
+ if self.pad_blur:
+ self.pad_input_imgs = []
+ for landmarks in self.all_landmarks_5:
+ # get landmarks
+ eye_left = landmarks[0, :]
+ eye_right = landmarks[1, :]
+ eye_avg = (eye_left + eye_right) * 0.5
+ mouth_avg = (landmarks[3, :] + landmarks[4, :]) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1.5
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+ border = max(int(np.rint(qsize * 0.1)), 3)
+
+ # get pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = [
+ max(-pad[0] + border, 1),
+ max(-pad[1] + border, 1),
+ max(pad[2] - self.input_img.shape[0] + border, 1),
+ max(pad[3] - self.input_img.shape[1] + border, 1)
+ ]
+
+ if max(pad) > 1:
+ # pad image
+ pad_img = np.pad(self.input_img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ # modify landmark coords
+ landmarks[:, 0] += pad[0]
+ landmarks[:, 1] += pad[1]
+ # blur pad images
+ h, w, _ = pad_img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * blur_ratio)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(pad_img, 0, ksize=(blur, blur))
+ # blur_img = cv2.GaussianBlur(pad_img, (blur, blur), 0)
+
+ pad_img = pad_img.astype('float32')
+ pad_img += (blur_img - pad_img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ pad_img += (np.median(pad_img, axis=(0, 1)) - pad_img) * np.clip(mask, 0.0, 1.0)
+ pad_img = np.clip(pad_img, 0, 255) # float32, [0, 255]
+ self.pad_input_imgs.append(pad_img)
+ else:
+ self.pad_input_imgs.append(np.copy(self.input_img))
+
+ return len(self.all_landmarks_5)
+
+ def align_warp_face(self, save_cropped_path=None, border_mode='constant'):
+ """Align and warp faces with face template.
+ """
+ if self.pad_blur:
+ assert len(self.pad_input_imgs) == len(
+ self.all_landmarks_5), f'Mismatched samples: {len(self.pad_input_imgs)} and {len(self.all_landmarks_5)}'
+ for idx, landmark in enumerate(self.all_landmarks_5):
+ # use 5 landmarks to get affine matrix
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(landmark, self.face_template, method=cv2.LMEDS)[0]
+ self.affine_matrices.append(affine_matrix)
+ # warp and crop faces
+ if border_mode == 'constant':
+ border_mode = cv2.BORDER_CONSTANT
+ elif border_mode == 'reflect101':
+ border_mode = cv2.BORDER_REFLECT101
+ elif border_mode == 'reflect':
+ border_mode = cv2.BORDER_REFLECT
+ if self.pad_blur:
+ input_img = self.pad_input_imgs[idx]
+ else:
+ input_img = self.input_img
+ cropped_face = cv2.warpAffine(
+ input_img, affine_matrix, self.face_size, borderMode=border_mode, borderValue=(135, 133, 132)) # gray
+ self.cropped_faces.append(cropped_face)
+ # save the cropped face
+ if save_cropped_path is not None:
+ path = os.path.splitext(save_cropped_path)[0]
+ save_path = f'{path}_{idx:02d}.{self.save_ext}'
+ imwrite(cropped_face, save_path)
+
+ def get_inverse_affine(self, save_inverse_affine_path=None):
+ """Get inverse affine matrix."""
+ for idx, affine_matrix in enumerate(self.affine_matrices):
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ inverse_affine *= self.upscale_factor
+ self.inverse_affine_matrices.append(inverse_affine)
+ # save inverse affine matrices
+ if save_inverse_affine_path is not None:
+ path, _ = os.path.splitext(save_inverse_affine_path)
+ save_path = f'{path}_{idx:02d}.pth'
+ torch.save(inverse_affine, save_path)
+
+
+ def add_restored_face(self, face):
+ if self.is_gray:
+ face = bgr2gray(face) # convert img into grayscale
+ self.restored_faces.append(face)
+
+
+ def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
+ h, w, _ = self.input_img.shape
+ h_up, w_up = int(h * self.upscale_factor), int(w * self.upscale_factor)
+
+ if upsample_img is None:
+ # simply resize the background
+ # upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+ upsample_img = cv2.resize(self.input_img, (w_up, h_up), interpolation=cv2.INTER_LINEAR)
+ else:
+ upsample_img = cv2.resize(upsample_img, (w_up, h_up), interpolation=cv2.INTER_LANCZOS4)
+
+ assert len(self.restored_faces) == len(
+ self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.')
+
+ inv_mask_borders = []
+ for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices):
+ if face_upsampler is not None:
+ restored_face = face_upsampler.enhance(restored_face, outscale=self.upscale_factor)[0]
+ inverse_affine /= self.upscale_factor
+ inverse_affine[:, 2] *= self.upscale_factor
+ face_size = (self.face_size[0]*self.upscale_factor, self.face_size[1]*self.upscale_factor)
+ else:
+ # Add an offset to inverse affine matrix, for more precise back alignment
+ if self.upscale_factor > 1:
+ extra_offset = 0.5 * self.upscale_factor
+ else:
+ extra_offset = 0
+ inverse_affine[:, 2] += extra_offset
+ face_size = self.face_size
+ inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up))
+
+ # if draw_box or not self.use_parse: # use square parse maps
+ # mask = np.ones(face_size, dtype=np.float32)
+ # inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # # remove the black borders
+ # inv_mask_erosion = cv2.erode(
+ # inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ # pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ # total_face_area = np.sum(inv_mask_erosion) # // 3
+ # # add border
+ # if draw_box:
+ # h, w = face_size
+ # mask_border = np.ones((h, w, 3), dtype=np.float32)
+ # border = int(1400/np.sqrt(total_face_area))
+ # mask_border[border:h-border, border:w-border,:] = 0
+ # inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ # inv_mask_borders.append(inv_mask_border)
+ # if not self.use_parse:
+ # # compute the fusion edge based on the area of face
+ # w_edge = int(total_face_area**0.5) // 20
+ # erosion_radius = w_edge * 2
+ # inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ # blur_size = w_edge * 2
+ # inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ # if len(upsample_img.shape) == 2: # upsample_img is gray image
+ # upsample_img = upsample_img[:, :, None]
+ # inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # always use square mask
+ mask = np.ones(face_size, dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(
+ inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8))
+ pasted_face = inv_mask_erosion[:, :, None] * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) # // 3
+ # add border
+ if draw_box:
+ h, w = face_size
+ mask_border = np.ones((h, w, 3), dtype=np.float32)
+ border = int(1400/np.sqrt(total_face_area))
+ mask_border[border:h-border, border:w-border,:] = 0
+ inv_mask_border = cv2.warpAffine(mask_border, inverse_affine, (w_up, h_up))
+ inv_mask_borders.append(inv_mask_border)
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ if len(upsample_img.shape) == 2: # upsample_img is gray image
+ upsample_img = upsample_img[:, :, None]
+ inv_soft_mask = inv_soft_mask[:, :, None]
+
+ # parse mask
+ if self.use_parse:
+ # inference
+ face_input = cv2.resize(restored_face, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_input = img2tensor(face_input.astype('float32') / 255., bgr2rgb=True, float32=True)
+ normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ face_input = torch.unsqueeze(face_input, 0).to(self.device)
+ with torch.no_grad():
+ out = self.face_parse(face_input)[0]
+ out = out.argmax(dim=1).squeeze().cpu().numpy()
+
+ parse_mask = np.zeros(out.shape)
+ MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
+ for idx, color in enumerate(MASK_COLORMAP):
+ parse_mask[out == idx] = color
+ # blur the mask
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ parse_mask = cv2.GaussianBlur(parse_mask, (101, 101), 11)
+ # remove the black borders
+ thres = 10
+ parse_mask[:thres, :] = 0
+ parse_mask[-thres:, :] = 0
+ parse_mask[:, :thres] = 0
+ parse_mask[:, -thres:] = 0
+ parse_mask = parse_mask / 255.
+
+ parse_mask = cv2.resize(parse_mask, face_size)
+ parse_mask = cv2.warpAffine(parse_mask, inverse_affine, (w_up, h_up), flags=3)
+ inv_soft_parse_mask = parse_mask[:, :, None]
+ # pasted_face = inv_restored
+ fuse_mask = (inv_soft_parse_mask 256: # 16-bit image
+ upsample_img = upsample_img.astype(np.uint16)
+ else:
+ upsample_img = upsample_img.astype(np.uint8)
+
+ # draw bounding box
+ if draw_box:
+ # upsample_input_img = cv2.resize(input_img, (w_up, h_up))
+ img_color = np.ones([*upsample_img.shape], dtype=np.float32)
+ img_color[:,:,0] = 0
+ img_color[:,:,1] = 255
+ img_color[:,:,2] = 0
+ for inv_mask_border in inv_mask_borders:
+ upsample_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_img
+ # upsample_input_img = inv_mask_border * img_color + (1 - inv_mask_border) * upsample_input_img
+
+ if save_path is not None:
+ path = os.path.splitext(save_path)[0]
+ save_path = f'{path}.{self.save_ext}'
+ imwrite(upsample_img, save_path)
+ return upsample_img
+
+ def clean_all(self):
+ self.all_landmarks_5 = []
+ self.restored_faces = []
+ self.affine_matrices = []
+ self.cropped_faces = []
+ self.inverse_affine_matrices = []
+ self.det_faces = []
+ self.pad_input_imgs = []
\ No newline at end of file
diff --git a/repositories/codeformer/facelib/utils/face_utils.py b/repositories/codeformer/facelib/utils/face_utils.py
new file mode 100644
index 000000000..f1474a2a4
--- /dev/null
+++ b/repositories/codeformer/facelib/utils/face_utils.py
@@ -0,0 +1,248 @@
+import cv2
+import numpy as np
+import torch
+
+
+def compute_increased_bbox(bbox, increase_area, preserve_aspect=True):
+ left, top, right, bot = bbox
+ width = right - left
+ height = bot - top
+
+ if preserve_aspect:
+ width_increase = max(increase_area, ((1 + 2 * increase_area) * height - width) / (2 * width))
+ height_increase = max(increase_area, ((1 + 2 * increase_area) * width - height) / (2 * height))
+ else:
+ width_increase = height_increase = increase_area
+ left = int(left - width_increase * width)
+ top = int(top - height_increase * height)
+ right = int(right + width_increase * width)
+ bot = int(bot + height_increase * height)
+ return (left, top, right, bot)
+
+
+def get_valid_bboxes(bboxes, h, w):
+ left = max(bboxes[0], 0)
+ top = max(bboxes[1], 0)
+ right = min(bboxes[2], w)
+ bottom = min(bboxes[3], h)
+ return (left, top, right, bottom)
+
+
+def align_crop_face_landmarks(img,
+ landmarks,
+ output_size,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=False,
+ shrink_ratio=(1, 1)):
+ """Align and crop face with landmarks.
+
+ The output_size and transform_size are based on width. The height is
+ adjusted based on shrink_ratio_h/shring_ration_w.
+
+ Modified from:
+ https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
+
+ Args:
+ img (Numpy array): Input image.
+ landmarks (Numpy array): 5 or 68 or 98 landmarks.
+ output_size (int): Output face size.
+ transform_size (ing): Transform size. Usually the four time of
+ output_size.
+ enable_padding (float): Default: True.
+ shrink_ratio (float | tuple[float] | list[float]): Shring the whole
+ face for height and width (crop larger area). Default: (1, 1).
+
+ Returns:
+ (Numpy array): Cropped face.
+ """
+ lm_type = 'retinaface_5' # Options: dlib_5, retinaface_5
+
+ if isinstance(shrink_ratio, (float, int)):
+ shrink_ratio = (shrink_ratio, shrink_ratio)
+ if transform_size is None:
+ transform_size = output_size * 4
+
+ # Parse landmarks
+ lm = np.array(landmarks)
+ if lm.shape[0] == 5 and lm_type == 'retinaface_5':
+ eye_left = lm[0]
+ eye_right = lm[1]
+ mouth_avg = (lm[3] + lm[4]) * 0.5
+ elif lm.shape[0] == 5 and lm_type == 'dlib_5':
+ lm_eye_left = lm[2:4]
+ lm_eye_right = lm[0:2]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = lm[4]
+ elif lm.shape[0] == 68:
+ lm_eye_left = lm[36:42]
+ lm_eye_right = lm[42:48]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[48] + lm[54]) * 0.5
+ elif lm.shape[0] == 98:
+ lm_eye_left = lm[60:68]
+ lm_eye_right = lm[68:76]
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ mouth_avg = (lm[76] + lm[82]) * 0.5
+
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Get the oriented crop rectangle
+ # x: half width of the oriented crop rectangle
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ # - np.flipud(eye_to_mouth) * [-1, 1]: rotate 90 clockwise
+ # norm with the hypotenuse: get the direction
+ x /= np.hypot(*x) # get the hypotenuse of a right triangle
+ rect_scale = 1 # TODO: you can edit it to get larger rect
+ x *= max(np.hypot(*eye_to_eye) * 2.0 * rect_scale, np.hypot(*eye_to_mouth) * 1.8 * rect_scale)
+ # y: half height of the oriented crop rectangle
+ y = np.flipud(x) * [-1, 1]
+
+ x *= shrink_ratio[1] # width
+ y *= shrink_ratio[0] # height
+
+ # c: center
+ c = eye_avg + eye_to_mouth * 0.1
+ # quad: (left_top, left_bottom, right_bottom, right_top)
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ # qsize: side length of the square
+ qsize = np.hypot(*x) * 2
+
+ quad_ori = np.copy(quad)
+ # Shrink, for large face
+ # TODO: do we really need shrink
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ h, w = img.shape[0:2]
+ rsize = (int(np.rint(float(w) / shrink)), int(np.rint(float(h) / shrink)))
+ img = cv2.resize(img, rsize, interpolation=cv2.INTER_AREA)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop
+ h, w = img.shape[0:2]
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, w), min(crop[3] + border, h))
+ if crop[2] - crop[0] < w or crop[3] - crop[1] < h:
+ img = img[crop[1]:crop[3], crop[0]:crop[2], :]
+ quad -= crop[0:2]
+
+ # Pad
+ # pad: (width_left, height_top, width_right, height_bottom)
+ h, w = img.shape[0:2]
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+ int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - w + border, 0), max(pad[3] - h + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(img, ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+ h, w = img.shape[0:2]
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]),
+ 1.0 - np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = int(qsize * 0.02)
+ if blur % 2 == 0:
+ blur += 1
+ blur_img = cv2.boxFilter(img, 0, ksize=(blur, blur))
+
+ img = img.astype('float32')
+ img += (blur_img - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = np.clip(img, 0, 255) # float32, [0, 255]
+ quad += pad[:2]
+
+ # Transform use cv2
+ h_ratio = shrink_ratio[0] / shrink_ratio[1]
+ dst_h, dst_w = int(transform_size * h_ratio), transform_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(quad, template, method=cv2.LMEDS)[0]
+ cropped_face = cv2.warpAffine(
+ img, affine_matrix, (dst_w, dst_h), borderMode=cv2.BORDER_CONSTANT, borderValue=(135, 133, 132)) # gray
+
+ if output_size < transform_size:
+ cropped_face = cv2.resize(
+ cropped_face, (output_size, int(output_size * h_ratio)), interpolation=cv2.INTER_LINEAR)
+
+ if return_inverse_affine:
+ dst_h, dst_w = int(output_size * h_ratio), output_size
+ template = np.array([[0, 0], [0, dst_h], [dst_w, dst_h], [dst_w, 0]])
+ # use cv2.LMEDS method for the equivalence to skimage transform
+ # ref: https://blog.csdn.net/yichxi/article/details/115827338
+ affine_matrix = cv2.estimateAffinePartial2D(
+ quad_ori, np.array([[0, 0], [0, output_size], [dst_w, dst_h], [dst_w, 0]]), method=cv2.LMEDS)[0]
+ inverse_affine = cv2.invertAffineTransform(affine_matrix)
+ else:
+ inverse_affine = None
+ return cropped_face, inverse_affine
+
+
+def paste_face_back(img, face, inverse_affine):
+ h, w = img.shape[0:2]
+ face_h, face_w = face.shape[0:2]
+ inv_restored = cv2.warpAffine(face, inverse_affine, (w, h))
+ mask = np.ones((face_h, face_w, 3), dtype=np.float32)
+ inv_mask = cv2.warpAffine(mask, inverse_affine, (w, h))
+ # remove the black borders
+ inv_mask_erosion = cv2.erode(inv_mask, np.ones((2, 2), np.uint8))
+ inv_restored_remove_border = inv_mask_erosion * inv_restored
+ total_face_area = np.sum(inv_mask_erosion) // 3
+ # compute the fusion edge based on the area of face
+ w_edge = int(total_face_area**0.5) // 20
+ erosion_radius = w_edge * 2
+ inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
+ blur_size = w_edge * 2
+ inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0)
+ img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * img
+ # float32, [0, 255]
+ return img
+
+
+if __name__ == '__main__':
+ import os
+
+ from facelib.detection import init_detection_model
+ from facelib.utils.face_restoration_helper import get_largest_face
+
+ img_path = '/home/wxt/datasets/ffhq/ffhq_wild/00009.png'
+ img_name = os.splitext(os.path.basename(img_path))[0]
+
+ # initialize model
+ det_net = init_detection_model('retinaface_resnet50', half=False)
+ img_ori = cv2.imread(img_path)
+ h, w = img_ori.shape[0:2]
+ # if larger than 800, scale it
+ scale = max(h / 800, w / 800)
+ if scale > 1:
+ img = cv2.resize(img_ori, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_LINEAR)
+
+ with torch.no_grad():
+ bboxes = det_net.detect_faces(img, 0.97)
+ if scale > 1:
+ bboxes *= scale # the score is incorrect
+ bboxes = get_largest_face(bboxes, h, w)[0]
+
+ landmarks = np.array([[bboxes[i], bboxes[i + 1]] for i in range(5, 15, 2)])
+
+ cropped_face, inverse_affine = align_crop_face_landmarks(
+ img_ori,
+ landmarks,
+ output_size=512,
+ transform_size=None,
+ enable_padding=True,
+ return_inverse_affine=True,
+ shrink_ratio=(1, 1))
+
+ cv2.imwrite(f'tmp/{img_name}_cropeed_face.png', cropped_face)
+ img = paste_face_back(img_ori, cropped_face, inverse_affine)
+ cv2.imwrite(f'tmp/{img_name}_back.png', img)
diff --git a/repositories/codeformer/facelib/utils/misc.py b/repositories/codeformer/facelib/utils/misc.py
new file mode 100644
index 000000000..7f5c95506
--- /dev/null
+++ b/repositories/codeformer/facelib/utils/misc.py
@@ -0,0 +1,174 @@
+import cv2
+import os
+import os.path as osp
+import numpy as np
+from PIL import Image
+import torch
+from torch.hub import download_url_to_file, get_dir
+from urllib.parse import urlparse
+# from basicsr.utils.download_util import download_file_from_google_drive
+
+ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+
+def download_pretrained_models(file_ids, save_path_root):
+ import gdown
+
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ file_url = 'https://drive.google.com/uc?id='+file_id
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+ if user_response.lower() == 'y':
+ print(f'Covering {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == 'n':
+ print(f'Skipping {file_name}')
+ else:
+ raise ValueError('Wrong input. Only accepts Y/N.')
+ else:
+ print(f'Downloading {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ """
+ if model_dir is None:
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+ Returns:
+ A generator for all the interested files with relative paths.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
+
+
+def is_gray(img, threshold=10):
+ img = Image.fromarray(img)
+ if len(img.getbands()) == 1:
+ return True
+ img1 = np.asarray(img.getchannel(channel=0), dtype=np.int16)
+ img2 = np.asarray(img.getchannel(channel=1), dtype=np.int16)
+ img3 = np.asarray(img.getchannel(channel=2), dtype=np.int16)
+ diff1 = (img1 - img2).var()
+ diff2 = (img2 - img3).var()
+ diff3 = (img3 - img1).var()
+ diff_sum = (diff1 + diff2 + diff3) / 3.0
+ if diff_sum <= threshold:
+ return True
+ else:
+ return False
+
+def rgb2gray(img, out_channel=3):
+ r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+ if out_channel == 3:
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+ return gray
+
+def bgr2gray(img, out_channel=3):
+ b, g, r = img[:,:,0], img[:,:,1], img[:,:,2]
+ gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
+ if out_channel == 3:
+ gray = gray[:,:,np.newaxis].repeat(3, axis=2)
+ return gray
diff --git a/repositories/codeformer/inference_codeformer.py b/repositories/codeformer/inference_codeformer.py
new file mode 100644
index 000000000..cde1094af
--- /dev/null
+++ b/repositories/codeformer/inference_codeformer.py
@@ -0,0 +1,269 @@
+import os
+import cv2
+import argparse
+import glob
+import torch
+from torchvision.transforms.functional import normalize
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+from facelib.utils.misc import is_gray
+import torch.nn.functional as F
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+pretrain_model_url = {
+ 'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
+}
+
+def set_realesrgan():
+ from basicsr.archs.rrdbnet_arch import RRDBNet
+ from basicsr.utils.realesrgan_utils import RealESRGANer
+
+ cuda_is_available = torch.cuda.is_available()
+ half = True if cuda_is_available else False
+ model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=2,
+ )
+ upsampler = RealESRGANer(
+ scale=2,
+ model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
+ model=model,
+ tile=args.bg_tile,
+ tile_pad=40,
+ pre_pad=0,
+ half=half, # need to set False in CPU mode
+ )
+
+ if not cuda_is_available: # CPU
+ import warnings
+ warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
+ 'The unoptimized RealESRGAN is slow on CPU. '
+ 'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.',
+ category=RuntimeWarning)
+ return upsampler
+
+if __name__ == '__main__':
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
+ help='Input image, video or folder. Default: inputs/whole_imgs')
+ parser.add_argument('-o', '--output_path', type=str, default=None,
+ help='Output folder. Default: results/_')
+ parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
+ help='Balance the quality and fidelity. Default: 0.5')
+ parser.add_argument('-s', '--upscale', type=int, default=2,
+ help='The final upsampling scale of the image. Default: 2')
+ parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
+ parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
+ # large det_model: 'YOLOv5l', 'retinaface_resnet50'
+ # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
+ parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
+ help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n. \
+ Default: retinaface_resnet50')
+ parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
+ parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
+ parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
+ parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
+ parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None')
+
+ args = parser.parse_args()
+
+ # ------------------------ input & output ------------------------
+ w = args.fidelity_weight
+ input_video = False
+ if args.input_path.endswith(('jpg', 'png')): # input single img path
+ input_img_list = [args.input_path]
+ result_root = f'results/test_img_{w}'
+ elif args.input_path.endswith(('mp4', 'mov', 'avi')): # input video path
+ from basicsr.utils.video_util import VideoReader, VideoWriter
+ input_img_list = []
+ vidreader = VideoReader(args.input_path)
+ image = vidreader.get_frame()
+ while image is not None:
+ input_img_list.append(image)
+ image = vidreader.get_frame()
+ audio = vidreader.get_audio()
+ fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
+ video_name = os.path.basename(args.input_path)[:-4]
+ result_root = f'results/{video_name}_{w}'
+ input_video = True
+ vidreader.close()
+ else: # input img folder
+ if args.input_path.endswith('/'): # solve when path ends with /
+ args.input_path = args.input_path[:-1]
+ # scan all the jpg and png images
+ input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jp][pn]g')))
+ result_root = f'results/{os.path.basename(args.input_path)}_{w}'
+
+ if not args.output_path is None: # set output path
+ result_root = args.output_path
+
+ test_img_num = len(input_img_list)
+ if test_img_num == 0:
+ raise FileNotFoundError("\nInput file is not found.")
+
+ # ------------------ set up background upsampler ------------------
+ if args.bg_upsampler == 'realesrgan':
+ bg_upsampler = set_realesrgan()
+ else:
+ bg_upsampler = None
+
+ # ------------------ set up face upsampler ------------------
+ if args.face_upsample:
+ face_upsampler = None
+ # if bg_upsampler is not None:
+ # face_upsampler = bg_upsampler
+ # else:
+ # face_upsampler = set_realesrgan()
+ else:
+ face_upsampler = None
+
+ # ------------------ set up CodeFormer restorer -------------------
+ net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
+ connect_list=['32', '64', '128', '256']).to(device)
+
+ # ckpt_path = 'weights/CodeFormer/codeformer.pth'
+ ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'],
+ model_dir='weights/CodeFormer', progress=True, file_name=None)
+ checkpoint = torch.load(ckpt_path)['params_ema']
+ net.load_state_dict(checkpoint)
+ net.eval()
+
+ # ------------------ set up FaceRestoreHelper -------------------
+ # large det_model: 'YOLOv5l', 'retinaface_resnet50'
+ # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
+ if not args.has_aligned:
+ print(f'Face detection model: {args.detection_model}')
+ if bg_upsampler is not None:
+ print(f'Background upsampling: True, Face upsampling: {args.face_upsample}')
+ else:
+ print(f'Background upsampling: False, Face upsampling: {args.face_upsample}')
+
+ face_helper = FaceRestoreHelper(
+ args.upscale,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model = args.detection_model,
+ save_ext='png',
+ use_parse=True,
+ device=device)
+
+ # -------------------- start to processing ---------------------
+ for i, img_path in enumerate(input_img_list):
+ # clean all the intermediate results to process the next image
+ face_helper.clean_all()
+
+ if isinstance(img_path, str):
+ img_name = os.path.basename(img_path)
+ basename, ext = os.path.splitext(img_name)
+ print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
+ else: # for video processing
+ basename = str(i).zfill(6)
+ img_name = f'{video_name}_{basename}' if input_video else basename
+ print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
+ img = img_path
+
+ if args.has_aligned:
+ # the input faces are already cropped and aligned
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_helper.is_gray = is_gray(img, threshold=5)
+ if face_helper.is_gray:
+ print('Grayscale input: True')
+ face_helper.cropped_faces = [img]
+ else:
+ face_helper.read_image(img)
+ # get face landmarks for each face
+ num_det_faces = face_helper.get_face_landmarks_5(
+ only_center_face=args.only_center_face, resize=640, eye_dist_threshold=5)
+ print(f'\tdetect {num_det_faces} faces')
+ # align and warp each face
+ face_helper.align_warp_face()
+
+ # face restoration for each cropped face
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
+ # prepare data
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+ try:
+ with torch.no_grad():
+ output = net(cropped_face_t, w=w, adain=True)[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as error:
+ print(f'\tFailed inference for CodeFormer: {error}')
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
+
+ restored_face = restored_face.astype('uint8')
+ face_helper.add_restored_face(restored_face)
+
+ # paste_back
+ if not args.has_aligned:
+ # upsample the background
+ if bg_upsampler is not None:
+ # Now only support RealESRGAN for upsampling background
+ bg_img = bg_upsampler.enhance(img, outscale=args.upscale)[0]
+ else:
+ bg_img = None
+ face_helper.get_inverse_affine(None)
+ # paste each restored face to the input image
+ if args.face_upsample and face_upsampler is not None:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box, face_upsampler=face_upsampler)
+ else:
+ restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=args.draw_box)
+
+ # save faces
+ for idx, (cropped_face, restored_face) in enumerate(zip(face_helper.cropped_faces, face_helper.restored_faces)):
+ # save cropped face
+ if not args.has_aligned:
+ save_crop_path = os.path.join(result_root, 'cropped_faces', f'{basename}_{idx:02d}.png')
+ imwrite(cropped_face, save_crop_path)
+ # save restored face
+ if args.has_aligned:
+ save_face_name = f'{basename}.png'
+ else:
+ save_face_name = f'{basename}_{idx:02d}.png'
+ if args.suffix is not None:
+ save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
+ save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
+ imwrite(restored_face, save_restore_path)
+
+ # save restored img
+ if not args.has_aligned and restored_img is not None:
+ if args.suffix is not None:
+ basename = f'{basename}_{args.suffix}'
+ save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
+ imwrite(restored_img, save_restore_path)
+
+ # save enhanced video
+ if input_video:
+ print('Video Saving...')
+ # load images
+ video_frames = []
+ img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g')))
+ for img_path in img_list:
+ img = cv2.imread(img_path)
+ video_frames.append(img)
+ # write images to video
+ height, width = video_frames[0].shape[:2]
+ if args.suffix is not None:
+ video_name = f'{video_name}_{args.suffix}.png'
+ save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
+ vidwriter = VideoWriter(save_restore_path, height, width, fps, audio)
+
+ for f in video_frames:
+ vidwriter.write_frame(f)
+ vidwriter.close()
+
+ print(f'\nAll results are saved in {result_root}')
diff --git a/repositories/codeformer/requirements.txt b/repositories/codeformer/requirements.txt
new file mode 100644
index 000000000..7e1950a06
--- /dev/null
+++ b/repositories/codeformer/requirements.txt
@@ -0,0 +1,17 @@
+addict
+future
+lmdb
+numpy
+opencv-python
+Pillow
+pyyaml
+requests
+scikit-image
+scipy
+tb-nightly
+torch>=1.7.1
+torchvision
+tqdm
+yapf
+lpips
+gdown # supports downloading the large file from Google Drive
\ No newline at end of file
diff --git a/repositories/codeformer/scripts/crop_align_face.py b/repositories/codeformer/scripts/crop_align_face.py
new file mode 100755
index 000000000..31e66266a
--- /dev/null
+++ b/repositories/codeformer/scripts/crop_align_face.py
@@ -0,0 +1,192 @@
+"""
+brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
+author: lzhbrian (https://lzhbrian.me)
+link: https://gist.github.com/lzhbrian/bde87ab23b499dd02ba4f588258f57d5
+date: 2020.1.5
+note: code is heavily borrowed from
+ https://github.com/NVlabs/ffhq-dataset
+ http://dlib.net/face_landmark_detection.py.html
+requirements:
+ conda install Pillow numpy scipy
+ conda install -c conda-forge dlib
+ # download face landmark model from:
+ # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+"""
+
+import cv2
+import dlib
+import glob
+import numpy as np
+import os
+import PIL
+import PIL.Image
+import scipy
+import scipy.ndimage
+import sys
+import argparse
+
+# download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+predictor = dlib.shape_predictor('weights/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat')
+
+
+def get_landmark(filepath, only_keep_largest=True):
+ """get landmark with dlib
+ :return: np.array shape=(68, 2)
+ """
+ detector = dlib.get_frontal_face_detector()
+
+ img = dlib.load_rgb_image(filepath)
+ dets = detector(img, 1)
+
+ # Shangchen modified
+ print("Number of faces detected: {}".format(len(dets)))
+ if only_keep_largest:
+ print('Detect several faces and only keep the largest.')
+ face_areas = []
+ for k, d in enumerate(dets):
+ face_area = (d.right() - d.left()) * (d.bottom() - d.top())
+ face_areas.append(face_area)
+
+ largest_idx = face_areas.index(max(face_areas))
+ d = dets[largest_idx]
+ shape = predictor(img, d)
+ print("Part 0: {}, Part 1: {} ...".format(
+ shape.part(0), shape.part(1)))
+ else:
+ for k, d in enumerate(dets):
+ print("Detection {}: Left: {} Top: {} Right: {} Bottom: {}".format(
+ k, d.left(), d.top(), d.right(), d.bottom()))
+ # Get the landmarks/parts for the face in box d.
+ shape = predictor(img, d)
+ print("Part 0: {}, Part 1: {} ...".format(
+ shape.part(0), shape.part(1)))
+
+ t = list(shape.parts())
+ a = []
+ for tt in t:
+ a.append([tt.x, tt.y])
+ lm = np.array(a)
+ # lm is a shape=(68,2) np.array
+ return lm
+
+def align_face(filepath, out_path):
+ """
+ :param filepath: str
+ :return: PIL Image
+ """
+ try:
+ lm = get_landmark(filepath)
+ except:
+ print('No landmark ...')
+ return
+
+ lm_chin = lm[0:17] # left-right
+ lm_eyebrow_left = lm[17:22] # left-right
+ lm_eyebrow_right = lm[22:27] # left-right
+ lm_nose = lm[27:31] # top-down
+ lm_nostrils = lm[31:36] # top-down
+ lm_eye_left = lm[36:42] # left-clockwise
+ lm_eye_right = lm[42:48] # left-clockwise
+ lm_mouth_outer = lm[48:60] # left-clockwise
+ lm_mouth_inner = lm[60:68] # left-clockwise
+
+ # Calculate auxiliary vectors.
+ eye_left = np.mean(lm_eye_left, axis=0)
+ eye_right = np.mean(lm_eye_right, axis=0)
+ eye_avg = (eye_left + eye_right) * 0.5
+ eye_to_eye = eye_right - eye_left
+ mouth_left = lm_mouth_outer[0]
+ mouth_right = lm_mouth_outer[6]
+ mouth_avg = (mouth_left + mouth_right) * 0.5
+ eye_to_mouth = mouth_avg - eye_avg
+
+ # Choose oriented crop rectangle.
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
+ x /= np.hypot(*x)
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
+ y = np.flipud(x) * [-1, 1]
+ c = eye_avg + eye_to_mouth * 0.1
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
+ qsize = np.hypot(*x) * 2
+
+ # read image
+ img = PIL.Image.open(filepath)
+
+ output_size = 512
+ transform_size = 4096
+ enable_padding = False
+
+ # Shrink.
+ shrink = int(np.floor(qsize / output_size * 0.5))
+ if shrink > 1:
+ rsize = (int(np.rint(float(img.size[0]) / shrink)),
+ int(np.rint(float(img.size[1]) / shrink)))
+ img = img.resize(rsize, PIL.Image.ANTIALIAS)
+ quad /= shrink
+ qsize /= shrink
+
+ # Crop.
+ border = max(int(np.rint(qsize * 0.1)), 3)
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
+ min(crop[2] + border,
+ img.size[0]), min(crop[3] + border, img.size[1]))
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+ img = img.crop(crop)
+ quad -= crop[0:2]
+
+ # Pad.
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
+ int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
+ pad = (max(-pad[0] + border,
+ 0), max(-pad[1] + border,
+ 0), max(pad[2] - img.size[0] + border,
+ 0), max(pad[3] - img.size[1] + border, 0))
+ if enable_padding and max(pad) > border - 4:
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+ img = np.pad(
+ np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)),
+ 'reflect')
+ h, w, _ = img.shape
+ y, x, _ = np.ogrid[:h, :w, :1]
+ mask = np.maximum(
+ 1.0 -
+ np.minimum(np.float32(x) / pad[0],
+ np.float32(w - 1 - x) / pad[2]), 1.0 -
+ np.minimum(np.float32(y) / pad[1],
+ np.float32(h - 1 - y) / pad[3]))
+ blur = qsize * 0.02
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
+ img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+ img = PIL.Image.fromarray(
+ np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+ quad += pad[:2]
+
+ img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
+ (quad + 0.5).flatten(), PIL.Image.BILINEAR)
+
+ if output_size < transform_size:
+ img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
+
+ # Save aligned image.
+ print('saveing: ', out_path)
+ img.save(out_path)
+
+ return img, np.max(quad[:, 0]) - np.min(quad[:, 0])
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--in_dir', type=str, default='./inputs/whole_imgs')
+ parser.add_argument('--out_dir', type=str, default='./inputs/cropped_faces')
+ args = parser.parse_args()
+
+ img_list = sorted(glob.glob(f'{args.in_dir}/*.png'))
+ img_list = sorted(img_list)
+
+ for in_path in img_list:
+ out_path = os.path.join(args.out_dir, in_path.split("/")[-1])
+ out_path = out_path.replace('.jpg', '.png')
+ size_ = align_face(in_path, out_path)
\ No newline at end of file
diff --git a/repositories/codeformer/scripts/download_pretrained_models.py b/repositories/codeformer/scripts/download_pretrained_models.py
new file mode 100644
index 000000000..daa6e8ca1
--- /dev/null
+++ b/repositories/codeformer/scripts/download_pretrained_models.py
@@ -0,0 +1,40 @@
+import argparse
+import os
+from os import path as osp
+
+from basicsr.utils.download_util import load_file_from_url
+
+
+def download_pretrained_models(method, file_urls):
+ save_path_root = f'./weights/{method}'
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_url in file_urls.items():
+ save_path = load_file_from_url(url=file_url, model_dir=save_path_root, progress=True, file_name=file_name)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'method',
+ type=str,
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
+ args = parser.parse_args()
+
+ file_urls = {
+ 'CodeFormer': {
+ 'codeformer.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+ },
+ 'facelib': {
+ # 'yolov5l-face.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/yolov5l-face.pth',
+ 'detection_Resnet50_Final.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
+ 'parsing_parsenet.pth': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth'
+ }
+ }
+
+ if args.method == 'all':
+ for method in file_urls.keys():
+ download_pretrained_models(method, file_urls[method])
+ else:
+ download_pretrained_models(args.method, file_urls[args.method])
\ No newline at end of file
diff --git a/repositories/codeformer/scripts/download_pretrained_models_from_gdrive.py b/repositories/codeformer/scripts/download_pretrained_models_from_gdrive.py
new file mode 100644
index 000000000..7df5be6fc
--- /dev/null
+++ b/repositories/codeformer/scripts/download_pretrained_models_from_gdrive.py
@@ -0,0 +1,60 @@
+import argparse
+import os
+from os import path as osp
+
+# from basicsr.utils.download_util import download_file_from_google_drive
+import gdown
+
+
+def download_pretrained_models(method, file_ids):
+ save_path_root = f'./weights/{method}'
+ os.makedirs(save_path_root, exist_ok=True)
+
+ for file_name, file_id in file_ids.items():
+ file_url = 'https://drive.google.com/uc?id='+file_id
+ save_path = osp.abspath(osp.join(save_path_root, file_name))
+ if osp.exists(save_path):
+ user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
+ if user_response.lower() == 'y':
+ print(f'Covering {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+ elif user_response.lower() == 'n':
+ print(f'Skipping {file_name}')
+ else:
+ raise ValueError('Wrong input. Only accepts Y/N.')
+ else:
+ print(f'Downloading {file_name} to {save_path}')
+ gdown.download(file_url, save_path, quiet=False)
+ # download_file_from_google_drive(file_id, save_path)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ 'method',
+ type=str,
+ help=("Options: 'CodeFormer' 'facelib'. Set to 'all' to download all the models."))
+ args = parser.parse_args()
+
+ # file name: file id
+ # 'dlib': {
+ # 'mmod_human_face_detector-4cb19393.dat': '1qD-OqY8M6j4PWUP_FtqfwUPFPRMu6ubX',
+ # 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1vF3WBUApw4662v9Pw6wke3uk1qxnmLdg',
+ # 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1tJyIVdCHaU6IDMDx86BZCxLGZfsWB8yq'
+ # }
+ file_ids = {
+ 'CodeFormer': {
+ 'codeformer.pth': '1v_E_vZvP-dQPF55Kc5SRCjaKTQXDz-JB'
+ },
+ 'facelib': {
+ 'yolov5l-face.pth': '131578zMA6B2x8VQHyHfa6GEPtulMCNzV',
+ 'parsing_parsenet.pth': '16pkohyZZ8ViHGBk3QtVqxLZKzdo466bK'
+ }
+ }
+
+ if args.method == 'all':
+ for method in file_ids.keys():
+ download_pretrained_models(method, file_ids[method])
+ else:
+ download_pretrained_models(args.method, file_ids[args.method])
\ No newline at end of file
diff --git a/repositories/codeformer/web-demos/hugging_face/app.py b/repositories/codeformer/web-demos/hugging_face/app.py
new file mode 100644
index 000000000..7da0fc947
--- /dev/null
+++ b/repositories/codeformer/web-demos/hugging_face/app.py
@@ -0,0 +1,280 @@
+"""
+This file is used for deploying hugging face demo:
+https://huggingface.co/spaces/sczhou/CodeFormer
+"""
+
+import sys
+sys.path.append('CodeFormer')
+import os
+import cv2
+import torch
+import torch.nn.functional as F
+import gradio as gr
+
+from torchvision.transforms.functional import normalize
+
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.utils.download_util import load_file_from_url
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+from facelib.utils.misc import is_gray
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.realesrgan_utils import RealESRGANer
+
+from basicsr.utils.registry import ARCH_REGISTRY
+
+
+os.system("pip freeze")
+
+pretrain_model_url = {
+ 'codeformer': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
+ 'detection': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth',
+ 'parsing': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth',
+ 'realesrgan': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth'
+}
+# download weights
+if not os.path.exists('CodeFormer/weights/CodeFormer/codeformer.pth'):
+ load_file_from_url(url=pretrain_model_url['codeformer'], model_dir='CodeFormer/weights/CodeFormer', progress=True, file_name=None)
+if not os.path.exists('CodeFormer/weights/facelib/detection_Resnet50_Final.pth'):
+ load_file_from_url(url=pretrain_model_url['detection'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
+if not os.path.exists('CodeFormer/weights/facelib/parsing_parsenet.pth'):
+ load_file_from_url(url=pretrain_model_url['parsing'], model_dir='CodeFormer/weights/facelib', progress=True, file_name=None)
+if not os.path.exists('CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth'):
+ load_file_from_url(url=pretrain_model_url['realesrgan'], model_dir='CodeFormer/weights/realesrgan', progress=True, file_name=None)
+
+# download images
+torch.hub.download_url_to_file(
+ 'https://replicate.com/api/models/sczhou/codeformer/files/fa3fe3d1-76b0-4ca8-ac0d-0a925cb0ff54/06.png',
+ '01.png')
+torch.hub.download_url_to_file(
+ 'https://replicate.com/api/models/sczhou/codeformer/files/a1daba8e-af14-4b00-86a4-69cec9619b53/04.jpg',
+ '02.jpg')
+torch.hub.download_url_to_file(
+ 'https://replicate.com/api/models/sczhou/codeformer/files/542d64f9-1712-4de7-85f7-3863009a7c3d/03.jpg',
+ '03.jpg')
+torch.hub.download_url_to_file(
+ 'https://replicate.com/api/models/sczhou/codeformer/files/a11098b0-a18a-4c02-a19a-9a7045d68426/010.jpg',
+ '04.jpg')
+torch.hub.download_url_to_file(
+ 'https://replicate.com/api/models/sczhou/codeformer/files/7cf19c2c-e0cf-4712-9af8-cf5bdbb8d0ee/012.jpg',
+ '05.jpg')
+
+def imread(img_path):
+ img = cv2.imread(img_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+# set enhancer with RealESRGAN
+def set_realesrgan():
+ half = True if torch.cuda.is_available() else False
+ model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=2,
+ )
+ upsampler = RealESRGANer(
+ scale=2,
+ model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth",
+ model=model,
+ tile=400,
+ tile_pad=40,
+ pre_pad=0,
+ half=half,
+ )
+ return upsampler
+
+upsampler = set_realesrgan()
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+codeformer_net = ARCH_REGISTRY.get("CodeFormer")(
+ dim_embd=512,
+ codebook_size=1024,
+ n_head=8,
+ n_layers=9,
+ connect_list=["32", "64", "128", "256"],
+).to(device)
+ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth"
+checkpoint = torch.load(ckpt_path)["params_ema"]
+codeformer_net.load_state_dict(checkpoint)
+codeformer_net.eval()
+
+os.makedirs('output', exist_ok=True)
+
+def inference(image, background_enhance, face_upsample, upscale, codeformer_fidelity):
+ """Run a single prediction on the model"""
+ try: # global try
+ # take the default setting for the demo
+ has_aligned = False
+ only_center_face = False
+ draw_box = False
+ detection_model = "retinaface_resnet50"
+ print('Inp:', image, background_enhance, face_upsample, upscale, codeformer_fidelity)
+
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
+ print('\timage size:', img.shape)
+
+ upscale = int(upscale) # convert type to int
+ if upscale > 4: # avoid memory exceeded due to too large upscale
+ upscale = 4
+ if upscale > 2 and max(img.shape[:2])>1000: # avoid memory exceeded due to too large img resolution
+ upscale = 2
+ if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution
+ upscale = 1
+ background_enhance = False
+ face_upsample = False
+
+ face_helper = FaceRestoreHelper(
+ upscale,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model=detection_model,
+ save_ext="png",
+ use_parse=True,
+ device=device,
+ )
+ bg_upsampler = upsampler if background_enhance else None
+ face_upsampler = upsampler if face_upsample else None
+
+ if has_aligned:
+ # the input faces are already cropped and aligned
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ face_helper.is_gray = is_gray(img, threshold=5)
+ if face_helper.is_gray:
+ print('\tgrayscale input: True')
+ face_helper.cropped_faces = [img]
+ else:
+ face_helper.read_image(img)
+ # get face landmarks for each face
+ num_det_faces = face_helper.get_face_landmarks_5(
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
+ )
+ print(f'\tdetect {num_det_faces} faces')
+ # align and warp each face
+ face_helper.align_warp_face()
+
+ # face restoration for each cropped face
+ for idx, cropped_face in enumerate(face_helper.cropped_faces):
+ # prepare data
+ cropped_face_t = img2tensor(
+ cropped_face / 255.0, bgr2rgb=True, float32=True
+ )
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
+
+ try:
+ with torch.no_grad():
+ output = codeformer_net(
+ cropped_face_t, w=codeformer_fidelity, adain=True
+ )[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except RuntimeError as error:
+ print(f"Failed inference for CodeFormer: {error}")
+ restored_face = tensor2img(
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
+ )
+
+ restored_face = restored_face.astype("uint8")
+ face_helper.add_restored_face(restored_face)
+
+ # paste_back
+ if not has_aligned:
+ # upsample the background
+ if bg_upsampler is not None:
+ # Now only support RealESRGAN for upsampling background
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
+ else:
+ bg_img = None
+ face_helper.get_inverse_affine(None)
+ # paste each restored face to the input image
+ if face_upsample and face_upsampler is not None:
+ restored_img = face_helper.paste_faces_to_input_image(
+ upsample_img=bg_img,
+ draw_box=draw_box,
+ face_upsampler=face_upsampler,
+ )
+ else:
+ restored_img = face_helper.paste_faces_to_input_image(
+ upsample_img=bg_img, draw_box=draw_box
+ )
+
+ # save restored img
+ save_path = f'output/out.png'
+ imwrite(restored_img, str(save_path))
+
+ restored_img = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
+ return restored_img, save_path
+ except Exception as error:
+ print('Global exception', error)
+ return None, None
+
+
+title = "CodeFormer: Robust Face Restoration and Enhancement Network"
+description = r"""
+Official Gradio demo for Towards Robust Blind Face Restoration with Codebook Lookup Transformer (NeurIPS 2022).
+🔥 CodeFormer is a robust face restoration algorithm for old photos or AI-generated faces.
+🤗 Try CodeFormer for improved stable-diffusion generation!
+"""
+article = r"""
+If CodeFormer is helpful, please help to ⭐ the Github Repo. Thanks!
+[![GitHub Stars](https://img.shields.io/github/stars/sczhou/CodeFormer?style=social)](https://github.com/sczhou/CodeFormer)
+
+---
+
+📝 **Citation**
+
+If our work is useful for your research, please consider citing:
+```bibtex
+@inproceedings{zhou2022codeformer,
+ author = {Zhou, Shangchen and Chan, Kelvin C.K. and Li, Chongyi and Loy, Chen Change},
+ title = {Towards Robust Blind Face Restoration with Codebook Lookup TransFormer},
+ booktitle = {NeurIPS},
+ year = {2022}
+}
+```
+
+📋 **License**
+
+This project is licensed under S-Lab License 1.0.
+Redistribution and use for non-commercial purposes should follow this license.
+
+📧 **Contact**
+
+If you have any questions, please feel free to reach me out at shangchenzhou@gmail.com.
+
+
+ 🤗 Find Me:
+
+
+
+
+
+"""
+
+demo = gr.Interface(
+ inference, [
+ gr.inputs.Image(type="filepath", label="Input"),
+ gr.inputs.Checkbox(default=True, label="Background_Enhance"),
+ gr.inputs.Checkbox(default=True, label="Face_Upsample"),
+ gr.inputs.Number(default=2, label="Rescaling_Factor (up to 4)"),
+ gr.Slider(0, 1, value=0.5, step=0.01, label='Codeformer_Fidelity (0 for better quality, 1 for better identity)')
+ ], [
+ gr.outputs.Image(type="numpy", label="Output"),
+ gr.outputs.File(label="Download the output")
+ ],
+ title=title,
+ description=description,
+ article=article,
+ examples=[
+ ['01.png', True, True, 2, 0.7],
+ ['02.jpg', True, True, 2, 0.7],
+ ['03.jpg', True, True, 2, 0.7],
+ ['04.jpg', True, True, 2, 0.1],
+ ['05.jpg', True, True, 2, 0.1]
+ ]
+ )
+
+demo.queue(concurrency_count=2)
+demo.launch()
\ No newline at end of file
diff --git a/repositories/codeformer/web-demos/replicate/cog.yaml b/repositories/codeformer/web-demos/replicate/cog.yaml
new file mode 100644
index 000000000..3f4589690
--- /dev/null
+++ b/repositories/codeformer/web-demos/replicate/cog.yaml
@@ -0,0 +1,30 @@
+"""
+This file is used for deploying replicate demo:
+https://replicate.com/sczhou/codeformer
+"""
+
+build:
+ gpu: true
+ cuda: "11.3"
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "ipython==8.4.0"
+ - "future==0.18.2"
+ - "lmdb==1.3.0"
+ - "scikit-image==0.19.3"
+ - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
+ - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113"
+ - "scipy==1.9.0"
+ - "gdown==4.5.1"
+ - "pyyaml==6.0"
+ - "tb-nightly==2.11.0a20220906"
+ - "tqdm==4.64.1"
+ - "yapf==0.32.0"
+ - "lpips==0.1.4"
+ - "Pillow==9.2.0"
+ - "opencv-python==4.6.0.66"
+
+predict: "predict.py:Predictor"
diff --git a/repositories/codeformer/web-demos/replicate/predict.py b/repositories/codeformer/web-demos/replicate/predict.py
new file mode 100644
index 000000000..61935e9e7
--- /dev/null
+++ b/repositories/codeformer/web-demos/replicate/predict.py
@@ -0,0 +1,189 @@
+"""
+This file is used for deploying replicate demo:
+https://replicate.com/sczhou/codeformer
+running: cog predict -i image=@inputs/whole_imgs/04.jpg -i codeformer_fidelity=0.5 -i upscale=2
+push: cog push r8.im/sczhou/codeformer
+"""
+
+import tempfile
+import cv2
+import torch
+from torchvision.transforms.functional import normalize
+try:
+ from cog import BasePredictor, Input, Path
+except Exception:
+ print('please install cog package')
+
+from basicsr.utils import imwrite, img2tensor, tensor2img
+from basicsr.archs.rrdbnet_arch import RRDBNet
+from basicsr.utils.realesrgan_utils import RealESRGANer
+from basicsr.utils.registry import ARCH_REGISTRY
+from facelib.utils.face_restoration_helper import FaceRestoreHelper
+
+
+class Predictor(BasePredictor):
+ def setup(self):
+ """Load the model into memory to make running multiple predictions efficient"""
+ self.device = "cuda:0"
+ self.upsampler = set_realesrgan()
+ self.net = ARCH_REGISTRY.get("CodeFormer")(
+ dim_embd=512,
+ codebook_size=1024,
+ n_head=8,
+ n_layers=9,
+ connect_list=["32", "64", "128", "256"],
+ ).to(self.device)
+ ckpt_path = "weights/CodeFormer/codeformer.pth"
+ checkpoint = torch.load(ckpt_path)[
+ "params_ema"
+ ] # update file permission if cannot load
+ self.net.load_state_dict(checkpoint)
+ self.net.eval()
+
+ def predict(
+ self,
+ image: Path = Input(description="Input image"),
+ codeformer_fidelity: float = Input(
+ default=0.5,
+ ge=0,
+ le=1,
+ description="Balance the quality (lower number) and fidelity (higher number).",
+ ),
+ background_enhance: bool = Input(
+ description="Enhance background image with Real-ESRGAN", default=True
+ ),
+ face_upsample: bool = Input(
+ description="Upsample restored faces for high-resolution AI-created images",
+ default=True,
+ ),
+ upscale: int = Input(
+ description="The final upsampling scale of the image",
+ default=2,
+ ),
+ ) -> Path:
+ """Run a single prediction on the model"""
+
+ # take the default setting for the demo
+ has_aligned = False
+ only_center_face = False
+ draw_box = False
+ detection_model = "retinaface_resnet50"
+
+ self.face_helper = FaceRestoreHelper(
+ upscale,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model=detection_model,
+ save_ext="png",
+ use_parse=True,
+ device=self.device,
+ )
+
+ bg_upsampler = self.upsampler if background_enhance else None
+ face_upsampler = self.upsampler if face_upsample else None
+
+ img = cv2.imread(str(image), cv2.IMREAD_COLOR)
+
+ if has_aligned:
+ # the input faces are already cropped and aligned
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
+ self.face_helper.cropped_faces = [img]
+ else:
+ self.face_helper.read_image(img)
+ # get face landmarks for each face
+ num_det_faces = self.face_helper.get_face_landmarks_5(
+ only_center_face=only_center_face, resize=640, eye_dist_threshold=5
+ )
+ print(f"\tdetect {num_det_faces} faces")
+ # align and warp each face
+ self.face_helper.align_warp_face()
+
+ # face restoration for each cropped face
+ for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
+ # prepare data
+ cropped_face_t = img2tensor(
+ cropped_face / 255.0, bgr2rgb=True, float32=True
+ )
+ normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
+ cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)
+
+ try:
+ with torch.no_grad():
+ output = self.net(
+ cropped_face_t, w=codeformer_fidelity, adain=True
+ )[0]
+ restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
+ del output
+ torch.cuda.empty_cache()
+ except Exception as error:
+ print(f"\tFailed inference for CodeFormer: {error}")
+ restored_face = tensor2img(
+ cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
+ )
+
+ restored_face = restored_face.astype("uint8")
+ self.face_helper.add_restored_face(restored_face)
+
+ # paste_back
+ if not has_aligned:
+ # upsample the background
+ if bg_upsampler is not None:
+ # Now only support RealESRGAN for upsampling background
+ bg_img = bg_upsampler.enhance(img, outscale=upscale)[0]
+ else:
+ bg_img = None
+ self.face_helper.get_inverse_affine(None)
+ # paste each restored face to the input image
+ if face_upsample and face_upsampler is not None:
+ restored_img = self.face_helper.paste_faces_to_input_image(
+ upsample_img=bg_img,
+ draw_box=draw_box,
+ face_upsampler=face_upsampler,
+ )
+ else:
+ restored_img = self.face_helper.paste_faces_to_input_image(
+ upsample_img=bg_img, draw_box=draw_box
+ )
+
+ # save restored img
+ out_path = Path(tempfile.mkdtemp()) / 'output.png'
+ imwrite(restored_img, str(out_path))
+
+ return out_path
+
+
+def imread(img_path):
+ img = cv2.imread(img_path)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ return img
+
+
+def set_realesrgan():
+ if not torch.cuda.is_available(): # CPU
+ import warnings
+
+ warnings.warn(
+ "The unoptimized RealESRGAN is slow on CPU. We do not use it. "
+ "If you really want to use it, please modify the corresponding codes.",
+ category=RuntimeWarning,
+ )
+ upsampler = None
+ else:
+ model = RRDBNet(
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_block=23,
+ num_grow_ch=32,
+ scale=2,
+ )
+ upsampler = RealESRGANer(
+ scale=2,
+ model_path="./weights/realesrgan/RealESRGAN_x2plus.pth",
+ model=model,
+ tile=400,
+ tile_pad=40,
+ pre_pad=0,
+ half=True,
+ )
+ return upsampler
diff --git a/repositories/codeformer/weights/CodeFormer/.gitkeep b/repositories/codeformer/weights/CodeFormer/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/codeformer/weights/README.md b/repositories/codeformer/weights/README.md
new file mode 100644
index 000000000..67ad334bd
--- /dev/null
+++ b/repositories/codeformer/weights/README.md
@@ -0,0 +1,3 @@
+# Weights
+
+Put the downloaded pre-trained models to this folder.
\ No newline at end of file
diff --git a/repositories/codeformer/weights/facelib/.gitkeep b/repositories/codeformer/weights/facelib/.gitkeep
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/README.md b/repositories/ldm/README.md
new file mode 100644
index 000000000..2dfddca33
--- /dev/null
+++ b/repositories/ldm/README.md
@@ -0,0 +1,302 @@
+# Stable Diffusion Version 2
+![t2i](assets/stable-samples/txt2img/768/merged-0006.png)
+![t2i](assets/stable-samples/txt2img/768/merged-0002.png)
+![t2i](assets/stable-samples/txt2img/768/merged-0005.png)
+
+This repository contains [Stable Diffusion](https://github.com/CompVis/stable-diffusion) models trained from scratch and will be continuously updated with
+new checkpoints. The following list provides an overview of all currently available models. More coming soon.
+
+## News
+
+
+**March 24, 2023**
+
+*Stable UnCLIP 2.1*
+
+- New stable diffusion finetune (_Stable unCLIP 2.1_, [Hugging Face](https://huggingface.co/stabilityai/)) at 768x768 resolution, based on SD2.1-768. This model allows for image variations and mixing operations as described in [*Hierarchical Text-Conditional Image Generation with CLIP Latents*](https://arxiv.org/abs/2204.06125), and, thanks to its modularity, can be combined with other models such as [KARLO](https://github.com/kakaobrain/karlo). Comes in two variants: [*Stable unCLIP-L*](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/blob/main/sd21-unclip-l.ckpt) and [*Stable unCLIP-H*](https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/blob/main/sd21-unclip-h.ckpt), which are conditioned on CLIP ViT-L and ViT-H image embeddings, respectively. Instructions are available [here](doc/UNCLIP.MD).
+
+- A public demo of SD-unCLIP is already available at [clipdrop.co/stable-diffusion-reimagine](https://clipdrop.co/stable-diffusion-reimagine)
+
+
+**December 7, 2022**
+
+*Version 2.1*
+
+- New stable diffusion model (_Stable Diffusion 2.1-v_, [Hugging Face](https://huggingface.co/stabilityai/stable-diffusion-2-1)) at 768x768 resolution and (_Stable Diffusion 2.1-base_, [HuggingFace](https://huggingface.co/stabilityai/stable-diffusion-2-1-base)) at 512x512 resolution, both based on the same number of parameters and architecture as 2.0 and fine-tuned on 2.0, on a less restrictive NSFW filtering of the [LAION-5B](https://laion.ai/blog/laion-5b/) dataset.
+Per default, the attention operation of the model is evaluated at full precision when `xformers` is not installed. To enable fp16 (which can cause numerical instabilities with the vanilla attention module on the v2.1 model) , run your script with `ATTN_PRECISION=fp16 python `
+
+**November 24, 2022**
+
+*Version 2.0*
+
+- New stable diffusion model (_Stable Diffusion 2.0-v_) at 768x768 resolution. Same number of parameters in the U-Net as 1.5, but uses [OpenCLIP-ViT/H](https://github.com/mlfoundations/open_clip) as the text encoder and is trained from scratch. _SD 2.0-v_ is a so-called [v-prediction](https://arxiv.org/abs/2202.00512) model.
+- The above model is finetuned from _SD 2.0-base_, which was trained as a standard noise-prediction model on 512x512 images and is also made available.
+- Added a [x4 upscaling latent text-guided diffusion model](#image-upscaling-with-stable-diffusion).
+- New [depth-guided stable diffusion model](#depth-conditional-stable-diffusion), finetuned from _SD 2.0-base_. The model is conditioned on monocular depth estimates inferred via [MiDaS](https://github.com/isl-org/MiDaS) and can be used for structure-preserving img2img and shape-conditional synthesis.
+
+ ![d2i](assets/stable-samples/depth2img/depth2img01.png)
+- A [text-guided inpainting model](#image-inpainting-with-stable-diffusion), finetuned from SD _2.0-base_.
+
+We follow the [original repository](https://github.com/CompVis/stable-diffusion) and provide basic inference scripts to sample from the models.
+
+________________
+*The original Stable Diffusion model was created in a collaboration with [CompVis](https://arxiv.org/abs/2202.00512) and [RunwayML](https://runwayml.com/) and builds upon the work:*
+
+[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://ommer-lab.com/research/latent-diffusion-models/)
+[Robin Rombach](https://github.com/rromb)\*,
+[Andreas Blattmann](https://github.com/ablattmann)\*,
+[Dominik Lorenz](https://github.com/qp-qp)\,
+[Patrick Esser](https://github.com/pesser),
+[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
+_[CVPR '22 Oral](https://openaccess.thecvf.com/content/CVPR2022/html/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.html) |
+[GitHub](https://github.com/CompVis/latent-diffusion) | [arXiv](https://arxiv.org/abs/2112.10752) | [Project page](https://ommer-lab.com/research/latent-diffusion-models/)_
+
+and [many others](#shout-outs).
+
+Stable Diffusion is a latent text-to-image diffusion model.
+________________________________
+
+## Requirements
+
+You can update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
+
+```
+conda install pytorch==1.12.1 torchvision==0.13.1 -c pytorch
+pip install transformers==4.19.2 diffusers invisible-watermark
+pip install -e .
+```
+#### xformers efficient attention
+For more efficiency and speed on GPUs,
+we highly recommended installing the [xformers](https://github.com/facebookresearch/xformers)
+library.
+
+Tested on A100 with CUDA 11.4.
+Installation needs a somewhat recent version of nvcc and gcc/g++, obtain those, e.g., via
+```commandline
+export CUDA_HOME=/usr/local/cuda-11.4
+conda install -c nvidia/label/cuda-11.4.0 cuda-nvcc
+conda install -c conda-forge gcc
+conda install -c conda-forge gxx_linux-64==9.5.0
+```
+
+Then, run the following (compiling takes up to 30 min).
+
+```commandline
+cd ..
+git clone https://github.com/facebookresearch/xformers.git
+cd xformers
+git submodule update --init --recursive
+pip install -r requirements.txt
+pip install -e .
+cd ../stablediffusion
+```
+Upon successful installation, the code will automatically default to [memory efficient attention](https://github.com/facebookresearch/xformers)
+for the self- and cross-attention layers in the U-Net and autoencoder.
+
+## General Disclaimer
+Stable Diffusion models are general text-to-image diffusion models and therefore mirror biases and (mis-)conceptions that are present
+in their training data. Although efforts were made to reduce the inclusion of explicit pornographic material, **we do not recommend using the provided weights for services or products without additional safety mechanisms and considerations.
+The weights are research artifacts and should be treated as such.**
+Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](https://huggingface.co/stabilityai/stable-diffusion-2).
+The weights are available via [the StabilityAI organization at Hugging Face](https://huggingface.co/StabilityAI) under the [CreativeML Open RAIL++-M License](LICENSE-MODEL).
+
+
+
+## Stable Diffusion v2
+
+Stable Diffusion v2 refers to a specific configuration of the model
+architecture that uses a downsampling-factor 8 autoencoder with an 865M UNet
+and OpenCLIP ViT-H/14 text encoder for the diffusion model. The _SD 2-v_ model produces 768x768 px outputs.
+
+Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
+5.0, 6.0, 7.0, 8.0) and 50 DDIM sampling steps show the relative improvements of the checkpoints:
+
+![sd evaluation results](assets/model-variants.jpg)
+
+
+
+### Text-to-Image
+![txt2img-stable2](assets/stable-samples/txt2img/merged-0003.png)
+![txt2img-stable2](assets/stable-samples/txt2img/merged-0001.png)
+
+Stable Diffusion 2 is a latent diffusion model conditioned on the penultimate text embeddings of a CLIP ViT-H/14 text encoder.
+We provide a [reference script for sampling](#reference-sampling-script).
+#### Reference Sampling Script
+
+This script incorporates an [invisible watermarking](https://github.com/ShieldMnt/invisible-watermark) of the outputs, to help viewers [identify the images as machine-generated](scripts/tests/test_watermark.py).
+We provide the configs for the _SD2-v_ (768px) and _SD2-base_ (512px) model.
+
+First, download the weights for [_SD2.1-v_](https://huggingface.co/stabilityai/stable-diffusion-2-1) and [_SD2.1-base_](https://huggingface.co/stabilityai/stable-diffusion-2-1-base).
+
+To sample from the _SD2.1-v_ model, run the following:
+
+```
+python scripts/txt2img.py --prompt "a professional photograph of an astronaut riding a horse" --ckpt --config configs/stable-diffusion/v2-inference-v.yaml --H 768 --W 768
+```
+or try out the Web Demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/stabilityai/stable-diffusion).
+
+To sample from the base model, use
+```
+python scripts/txt2img.py --prompt "a professional photograph of an astronaut riding a horse" --ckpt --config
+```
+
+By default, this uses the [DDIM sampler](https://arxiv.org/abs/2010.02502), and renders images of size 768x768 (which it was trained on) in 50 steps.
+Empirically, the v-models can be sampled with higher guidance scales.
+
+Note: The inference config for all model versions is designed to be used with EMA-only checkpoints.
+For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
+non-EMA to EMA weights.
+
+#### Enable Intel® Extension for PyTorch* optimizations in Text-to-Image script
+
+If you're planning on running Text-to-Image on Intel® CPU, try to sample an image with TorchScript and Intel® Extension for PyTorch* optimizations. Intel® Extension for PyTorch* extends PyTorch by enabling up-to-date features optimizations for an extra performance boost on Intel® hardware. It can optimize memory layout of the operators to Channel Last memory format, which is generally beneficial for Intel CPUs, take advantage of the most advanced instruction set available on a machine, optimize operators and many more.
+
+**Prerequisites**
+
+Before running the script, make sure you have all needed libraries installed. (the optimization was checked on `Ubuntu 20.04`). Install [jemalloc](https://github.com/jemalloc/jemalloc), [numactl](https://linux.die.net/man/8/numactl), Intel® OpenMP and Intel® Extension for PyTorch*.
+
+```bash
+apt-get install numactl libjemalloc-dev
+pip install intel-openmp
+pip install intel_extension_for_pytorch -f https://software.intel.com/ipex-whl-stable
+```
+
+To sample from the _SD2.1-v_ model with TorchScript+IPEX optimizations, run the following. Remember to specify desired number of instances you want to run the program on ([more](https://github.com/intel/intel-extension-for-pytorch/blob/master/intel_extension_for_pytorch/cpu/launch.py#L48)).
+
+```
+MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-v-fp32.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex
+```
+
+To sample from the base model with IPEX optimizations, use
+
+```
+MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-fp32.yaml --n_samples 1 --n_iter 4 --precision full --device cpu --torchscript --ipex
+```
+
+If you're using a CPU that supports `bfloat16`, consider sample from the model with bfloat16 enabled for a performance boost, like so
+
+```bash
+# SD2.1-v
+MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-v-bf16.yaml --H 768 --W 768 --precision full --device cpu --torchscript --ipex --bf16
+# SD2.1-base
+MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000 python -m intel_extension_for_pytorch.cpu.launch --ninstance --enable_jemalloc scripts/txt2img.py --prompt \"a corgi is playing guitar, oil on canvas\" --ckpt --config configs/stable-diffusion/intel/v2-inference-bf16.yaml --precision full --device cpu --torchscript --ipex --bf16
+```
+
+### Image Modification with Stable Diffusion
+
+![depth2img-stable2](assets/stable-samples/depth2img/merged-0000.png)
+#### Depth-Conditional Stable Diffusion
+
+To augment the well-established [img2img](https://github.com/CompVis/stable-diffusion#image-modification-with-stable-diffusion) functionality of Stable Diffusion, we provide a _shape-preserving_ stable diffusion model.
+
+
+Note that the original method for image modification introduces significant semantic changes w.r.t. the initial image.
+If that is not desired, download our [depth-conditional stable diffusion](https://huggingface.co/stabilityai/stable-diffusion-2-depth) model and the `dpt_hybrid` MiDaS [model weights](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), place the latter in a folder `midas_models` and sample via
+```
+python scripts/gradio/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml
+```
+
+or
+
+```
+streamlit run scripts/streamlit/depth2img.py configs/stable-diffusion/v2-midas-inference.yaml
+```
+
+This method can be used on the samples of the base model itself.
+For example, take [this sample](assets/stable-samples/depth2img/old_man.png) generated by an anonymous discord user.
+Using the [gradio](https://gradio.app) or [streamlit](https://streamlit.io/) script `depth2img.py`, the MiDaS model first infers a monocular depth estimate given this input,
+and the diffusion model is then conditioned on the (relative) depth output.
+
+
+ depth2image
+
+
+
+This model is particularly useful for a photorealistic style; see the [examples](assets/stable-samples/depth2img).
+For a maximum strength of 1.0, the model removes all pixel-based information and only relies on the text prompt and the inferred monocular depth estimate.
+
+![depth2img-stable3](assets/stable-samples/depth2img/merged-0005.png)
+
+#### Classic Img2Img
+
+For running the "classic" img2img, use
+```
+python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img --strength 0.8 --ckpt
+```
+and adapt the checkpoint and config paths accordingly.
+
+### Image Upscaling with Stable Diffusion
+![upscaling-x4](assets/stable-samples/upscaling/merged-dog.png)
+After [downloading the weights](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler), run
+```
+python scripts/gradio/superresolution.py configs/stable-diffusion/x4-upscaling.yaml
+```
+
+or
+
+```
+streamlit run scripts/streamlit/superresolution.py -- configs/stable-diffusion/x4-upscaling.yaml
+```
+
+for a Gradio or Streamlit demo of the text-guided x4 superresolution model.
+This model can be used both on real inputs and on synthesized examples. For the latter, we recommend setting a higher
+`noise_level`, e.g. `noise_level=100`.
+
+### Image Inpainting with Stable Diffusion
+
+![inpainting-stable2](assets/stable-inpainting/merged-leopards.png)
+
+[Download the SD 2.0-inpainting checkpoint](https://huggingface.co/stabilityai/stable-diffusion-2-inpainting) and run
+
+```
+python scripts/gradio/inpainting.py configs/stable-diffusion/v2-inpainting-inference.yaml
+```
+
+or
+
+```
+streamlit run scripts/streamlit/inpainting.py -- configs/stable-diffusion/v2-inpainting-inference.yaml
+```
+
+for a Gradio or Streamlit demo of the inpainting model.
+This scripts adds invisible watermarking to the demo in the [RunwayML](https://github.com/runwayml/stable-diffusion/blob/main/scripts/inpaint_st.py) repository, but both should work interchangeably with the checkpoints/configs.
+
+
+
+## Shout-Outs
+- Thanks to [Hugging Face](https://huggingface.co/) and in particular [Apolinário](https://github.com/apolinario) for support with our model releases!
+- Stable Diffusion would not be possible without [LAION](https://laion.ai/) and their efforts to create open, large-scale datasets.
+- The [DeepFloyd team](https://twitter.com/deepfloydai) at Stability AI, for creating the subset of [LAION-5B](https://laion.ai/blog/laion-5b/) dataset used to train the model.
+- Stable Diffusion 2.0 uses [OpenCLIP](https://laion.ai/blog/large-openclip/), trained by [Romain Beaumont](https://github.com/rom1504).
+- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
+and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
+Thanks for open-sourcing!
+- [CompVis](https://github.com/CompVis/stable-diffusion) initial stable diffusion release
+- [Patrick](https://github.com/pesser)'s [implementation](https://github.com/runwayml/stable-diffusion/blob/main/scripts/inpaint_st.py) of the streamlit demo for inpainting.
+- `img2img` is an application of [SDEdit](https://arxiv.org/abs/2108.01073) by [Chenlin Meng](https://cs.stanford.edu/~chenlin/) from the [Stanford AI Lab](https://cs.stanford.edu/~ermon/website/).
+- [Kat's implementation]((https://github.com/CompVis/latent-diffusion/pull/51)) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler, and [more](https://github.com/crowsonkb/k-diffusion).
+- [DPMSolver](https://arxiv.org/abs/2206.00927) [integration](https://github.com/CompVis/stable-diffusion/pull/440) by [Cheng Lu](https://github.com/LuChengTHU).
+- Facebook's [xformers](https://github.com/facebookresearch/xformers) for efficient attention computation.
+- [MiDaS](https://github.com/isl-org/MiDaS) for monocular depth estimation.
+
+
+## License
+
+The code in this repository is released under the MIT License.
+
+The weights are available via [the StabilityAI organization at Hugging Face](https://huggingface.co/StabilityAI), and released under the [CreativeML Open RAIL++-M License](LICENSE-MODEL) License.
+
+## BibTeX
+
+```
+@misc{rombach2021highresolution,
+ title={High-Resolution Image Synthesis with Latent Diffusion Models},
+ author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
+ year={2021},
+ eprint={2112.10752},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
+
+
diff --git a/repositories/ldm/data/__init__.py b/repositories/ldm/data/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/data/util.py b/repositories/ldm/data/util.py
new file mode 100644
index 000000000..5b60ceb23
--- /dev/null
+++ b/repositories/ldm/data/util.py
@@ -0,0 +1,24 @@
+import torch
+
+from ldm.modules.midas.api import load_midas_transform
+
+
+class AddMiDaS(object):
+ def __init__(self, model_type):
+ super().__init__()
+ self.transform = load_midas_transform(model_type)
+
+ def pt2np(self, x):
+ x = ((x + 1.0) * .5).detach().cpu().numpy()
+ return x
+
+ def np2pt(self, x):
+ x = torch.from_numpy(x) * 2 - 1.
+ return x
+
+ def __call__(self, sample):
+ # sample['jpg'] is tensor hwc in [-1, 1] at this point
+ x = self.pt2np(sample['jpg'])
+ x = self.transform({"image": x})["image"]
+ sample['midas_in'] = x
+ return sample
\ No newline at end of file
diff --git a/repositories/ldm/models/autoencoder.py b/repositories/ldm/models/autoencoder.py
new file mode 100644
index 000000000..d12254999
--- /dev/null
+++ b/repositories/ldm/models/autoencoder.py
@@ -0,0 +1,219 @@
+import torch
+import pytorch_lightning as pl
+import torch.nn.functional as F
+from contextlib import contextmanager
+
+from ldm.modules.diffusionmodules.model import Encoder, Decoder
+from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
+
+from ldm.util import instantiate_from_config
+from ldm.modules.ema import LitEma
+
+
+class AutoencoderKL(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ ema_decay=None,
+ learn_logvar=False
+ ):
+ super().__init__()
+ self.learn_logvar = learn_logvar
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ self.use_ema = ema_decay is not None
+ if self.use_ema:
+ self.ema_decay = ema_decay
+ assert 0. < ema_decay < 1.
+ self.model_ema = LitEma(self, decay=ema_decay)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.parameters())
+ self.model_ema.copy_to(self)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self)
+
+ def encode(self, x):
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ log_dict = self._validation_step(batch, batch_idx)
+ with self.ema_scope():
+ log_dict_ema = self._validation_step(batch, batch_idx, postfix="_ema")
+ return log_dict
+
+ def _validation_step(self, batch, batch_idx, postfix=""):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val"+postfix)
+
+ self.log(f"val{postfix}/rec_loss", log_dict_ae[f"val{postfix}/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ ae_params_list = list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(
+ self.quant_conv.parameters()) + list(self.post_quant_conv.parameters())
+ if self.learn_logvar:
+ print(f"{self.__class__.__name__}: Learning logvar")
+ ae_params_list.append(self.loss.logvar)
+ opt_ae = torch.optim.Adam(ae_params_list,
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ if log_ema or self.use_ema:
+ with self.ema_scope():
+ xrec_ema, posterior_ema = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec_ema.shape[1] > 3
+ xrec_ema = self.to_rgb(xrec_ema)
+ log["samples_ema"] = self.decode(torch.randn_like(posterior_ema.sample()))
+ log["reconstructions_ema"] = xrec_ema
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
+
diff --git a/repositories/ldm/models/diffusion/__init__.py b/repositories/ldm/models/diffusion/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/models/diffusion/ddim.py b/repositories/ldm/models/diffusion/ddim.py
new file mode 100644
index 000000000..c6cfd5712
--- /dev/null
+++ b/repositories/ldm/models/diffusion/ddim.py
@@ -0,0 +1,337 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, extract_into_tensor
+
+
+class DDIMSampler(object):
+ def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+ self.device = device
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != self.device:
+ attr = attr.to(self.device)
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None, # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ ucg_schedule=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for DDIM sampling is {size}, eta {eta}')
+
+ samples, intermediates = self.ddim_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ ucg_schedule=ucg_schedule
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, dynamic_threshold=None,
+ ucg_schedule=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ if ucg_schedule is not None:
+ assert len(ucg_schedule) == len(time_range)
+ unconditional_guidance_scale = ucg_schedule[i]
+
+ outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0 = outs
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ if isinstance(c, dict):
+ assert isinstance(unconditional_conditioning, dict)
+ c_in = dict()
+ for k in c:
+ if isinstance(c[k], list):
+ c_in[k] = [torch.cat([
+ unconditional_conditioning[k][i],
+ c[k][i]]) for i in range(len(c[k]))]
+ else:
+ c_in[k] = torch.cat([
+ unconditional_conditioning[k],
+ c[k]])
+ elif isinstance(c, list):
+ c_in = list()
+ assert isinstance(unconditional_conditioning, list)
+ for i in range(len(c)):
+ c_in.append(torch.cat([unconditional_conditioning[i], c[i]]))
+ else:
+ c_in = torch.cat([unconditional_conditioning, c])
+ model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ if dynamic_threshold is not None:
+ raise NotImplementedError()
+
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ @torch.no_grad()
+ def encode(self, x0, c, t_enc, use_original_steps=False, return_intermediates=None,
+ unconditional_guidance_scale=1.0, unconditional_conditioning=None, callback=None):
+ num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[0]
+
+ assert t_enc <= num_reference_steps
+ num_steps = t_enc
+
+ if use_original_steps:
+ alphas_next = self.alphas_cumprod[:num_steps]
+ alphas = self.alphas_cumprod_prev[:num_steps]
+ else:
+ alphas_next = self.ddim_alphas[:num_steps]
+ alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
+
+ x_next = x0
+ intermediates = []
+ inter_steps = []
+ for i in tqdm(range(num_steps), desc='Encoding Image'):
+ t = torch.full((x0.shape[0],), i, device=self.model.device, dtype=torch.long)
+ if unconditional_guidance_scale == 1.:
+ noise_pred = self.model.apply_model(x_next, t, c)
+ else:
+ assert unconditional_conditioning is not None
+ e_t_uncond, noise_pred = torch.chunk(
+ self.model.apply_model(torch.cat((x_next, x_next)), torch.cat((t, t)),
+ torch.cat((unconditional_conditioning, c))), 2)
+ noise_pred = e_t_uncond + unconditional_guidance_scale * (noise_pred - e_t_uncond)
+
+ xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
+ weighted_noise_pred = alphas_next[i].sqrt() * (
+ (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()) * noise_pred
+ x_next = xt_weighted + weighted_noise_pred
+ if return_intermediates and i % (
+ num_steps // return_intermediates) == 0 and i < num_steps - 1:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ elif return_intermediates and i >= num_steps - 2:
+ intermediates.append(x_next)
+ inter_steps.append(i)
+ if callback: callback(i)
+
+ out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
+ if return_intermediates:
+ out.update({'intermediates': intermediates})
+ return x_next, out
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ # t serves as an index to gather the correct alphas
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)
+
+ @torch.no_grad()
+ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None,
+ use_original_steps=False, callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
\ No newline at end of file
diff --git a/repositories/ldm/models/diffusion/ddpm.py b/repositories/ldm/models/diffusion/ddpm.py
new file mode 100644
index 000000000..3350c032f
--- /dev/null
+++ b/repositories/ldm/models/diffusion/ddpm.py
@@ -0,0 +1,1873 @@
+"""
+wild mixture of
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import torch
+import torch.nn as nn
+import numpy as np
+import pytorch_lightning as pl
+from torch.optim.lr_scheduler import LambdaLR
+from einops import rearrange, repeat
+from contextlib import contextmanager, nullcontext
+from functools import partial
+import itertools
+from tqdm import tqdm
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities.distributed import rank_zero_only
+from omegaconf import ListConfig
+
+from ldm.util import log_txt_as_img, exists, default, ismap, isimage, mean_flat, count_params, instantiate_from_config
+from ldm.modules.ema import LitEma
+from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
+from ldm.models.autoencoder import IdentityFirstStage, AutoencoderKL
+from ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
+from ldm.models.diffusion.ddim import DDIMSampler
+
+
+__conditioning_keys__ = {'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def uniform_on_device(r1, r2, shape, device):
+ return (r1 - r2) * torch.rand(*shape, device=device) + r2
+
+
+class DDPM(pl.LightningModule):
+ # classic DDPM with Gaussian diffusion, in image space
+ def __init__(self,
+ unet_config,
+ timesteps=1000,
+ beta_schedule="linear",
+ loss_type="l2",
+ ckpt_path=None,
+ ignore_keys=[],
+ load_only_unet=False,
+ monitor="val/loss",
+ use_ema=True,
+ first_stage_key="image",
+ image_size=256,
+ channels=3,
+ log_every_t=100,
+ clip_denoised=True,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3,
+ given_betas=None,
+ original_elbo_weight=0.,
+ v_posterior=0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
+ l_simple_weight=1.,
+ conditioning_key=None,
+ parameterization="eps", # all assuming fixed variance schedules
+ scheduler_config=None,
+ use_positional_encodings=False,
+ learn_logvar=False,
+ logvar_init=0.,
+ make_it_fit=False,
+ ucg_training=None,
+ reset_ema=False,
+ reset_num_ema_updates=False,
+ ):
+ super().__init__()
+ assert parameterization in ["eps", "x0", "v"], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.image_size = image_size # try conv?
+ self.channels = channels
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(unet_config, conditioning_key)
+ count_params(self.model, verbose=True)
+ self.use_ema = use_ema
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+
+ self.use_scheduler = scheduler_config is not None
+ if self.use_scheduler:
+ self.scheduler_config = scheduler_config
+
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ self.make_it_fit = make_it_fit
+ if reset_ema: assert exists(ckpt_path)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys, only_model=load_only_unet)
+ if reset_ema:
+ assert self.use_ema
+ print(f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ self.register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, timesteps=timesteps,
+ linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s)
+
+ self.loss_type = loss_type
+
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ self.ucg_training = ucg_training or dict()
+ if self.ucg_training:
+ self.ucg_prng = np.random.RandomState()
+
+ def register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(self.betas ** 2 / (
+ 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context=None):
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ print(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ print(f"{context}: Restored training weights")
+
+ @torch.no_grad()
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ if self.make_it_fit:
+ n_params = len([name for name, _ in
+ itertools.chain(self.named_parameters(),
+ self.named_buffers())])
+ for name, param in tqdm(
+ itertools.chain(self.named_parameters(),
+ self.named_buffers()),
+ desc="Fitting old weights to new weights",
+ total=n_params
+ ):
+ if not name in sd:
+ continue
+ old_shape = sd[name].shape
+ new_shape = param.shape
+ assert len(old_shape) == len(new_shape)
+ if len(new_shape) > 2:
+ # we only modify first two axes
+ assert new_shape[2:] == old_shape[2:]
+ # assumes first axis corresponds to output dim
+ if not new_shape == old_shape:
+ new_param = param.clone()
+ old_param = sd[name]
+ if len(new_shape) == 1:
+ for i in range(new_param.shape[0]):
+ new_param[i] = old_param[i % old_shape[0]]
+ elif len(new_shape) >= 2:
+ for i in range(new_param.shape[0]):
+ for j in range(new_param.shape[1]):
+ new_param[i, j] = old_param[i % old_shape[0], j % old_shape[1]]
+
+ n_used_old = torch.ones(old_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_old[j % old_shape[1]] += 1
+ n_used_new = torch.zeros(new_shape[1])
+ for j in range(new_param.shape[1]):
+ n_used_new[j] = n_used_old[j % old_shape[1]]
+
+ n_used_new = n_used_new[None, :]
+ while len(n_used_new.shape) < len(new_shape):
+ n_used_new = n_used_new.unsqueeze(-1)
+ new_param /= n_used_new
+
+ sd[name] = new_param
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys:\n {missing}")
+ if len(unexpected) > 0:
+ print(f"\nUnexpected Keys:\n {unexpected}")
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def predict_start_from_z_and_v(self, x_t, t, v):
+ # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ def predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, clip_denoised: bool):
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, return_intermediates=False):
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
+ img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, batch_size=16, return_intermediates=False):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop((batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def get_v(self, x, noise, t):
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ )
+
+ def get_loss(self, pred, target, mean=True):
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(f"Parameterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x, *args, **kwargs):
+ # b, c, h, w, device, img_size, = *x.shape, x.device, self.image_size
+ # assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(self, batch):
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch, batch_idx):
+ for k in self.ucg_training:
+ p = self.ucg_training[k]["p"]
+ val = self.ucg_training[k]["val"]
+ if val is None:
+ val = ""
+ for i in range(len(batch[k])):
+ if self.ucg_prng.choice(2, p=[1 - p, p]):
+ batch[k][i] = val
+
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict, prog_bar=True,
+ logger=True, on_step=True, on_epoch=True)
+
+ self.log("global_step", self.global_step,
+ prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
+
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch, batch_idx):
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
+ self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+ self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
+
+ def on_train_batch_end(self, *args, **kwargs):
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples):
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=2, sample=True, return_keys=None, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N, return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """main class"""
+
+ def __init__(self,
+ first_stage_config,
+ cond_stage_config,
+ num_timesteps_cond=None,
+ cond_stage_key="image",
+ cond_stage_trainable=False,
+ concat_mode=True,
+ cond_stage_forward=None,
+ conditioning_key=None,
+ scale_factor=1.0,
+ scale_by_std=False,
+ force_null_conditioning=False,
+ *args, **kwargs):
+ self.force_null_conditioning = force_null_conditioning
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # for backwards compatibility after implementation of DiffusionWrapper
+ if conditioning_key is None:
+ conditioning_key = 'concat' if concat_mode else 'crossattn'
+ if cond_stage_config == '__is_unconditional__' and not self.force_null_conditioning:
+ conditioning_key = None
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ reset_ema = kwargs.pop("reset_ema", False)
+ reset_num_ema_updates = kwargs.pop("reset_num_ema_updates", False)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+ self.concat_mode = concat_mode
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ try:
+ self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.cond_stage_forward = cond_stage_forward
+ self.clip_denoised = False
+ self.bbox_tokenizer = None
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+ self.restarted_from_ckpt = True
+ if reset_ema:
+ assert self.use_ema
+ print(
+ f"Resetting ema to pure model weights. This is useful when restoring from an ema-only checkpoint.")
+ self.model_ema = LitEma(self.model)
+ if reset_num_ema_updates:
+ print(" +++++++++++ WARNING: RESETTING NUM_EMA UPDATES TO ZERO +++++++++++ ")
+ assert self.use_ema
+ self.model_ema.reset_num_updates()
+
+ def make_cond_schedule(self, ):
+ self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long)
+ ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
+ # only for very first batch
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ print("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ print(f"setting self.scale_factor to {self.scale_factor}")
+ print("### USING STD-RESCALING ###")
+
+ def register_schedule(self,
+ given_betas=None, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ super().register_schedule(given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config):
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config):
+ if not self.cond_stage_trainable:
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__":
+ print(f"Training {self.__class__.__name__} as an unconditional model.")
+ self.cond_stage_model = None
+ # self.be_unconditional = True
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ assert config != '__is_first_stage__'
+ assert config != '__is_unconditional__'
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def _get_denoise_row_from_list(self, samples, desc='', force_no_decoder_quantization=False):
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device),
+ force_not_quantize=force_no_decoder_quantization))
+ n_imgs_per_row = len(denoise_row)
+ denoise_row = torch.stack(denoise_row) # n_log_step, n_row, C, H, W
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ def get_first_stage_encoding(self, encoder_posterior):
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample()
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
+ return self.scale_factor * z
+
+ def get_learned_conditioning(self, c):
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def meshgrid(self, h, w):
+ y = torch.arange(0, h).view(h, 1, 1).repeat(1, w, 1)
+ x = torch.arange(0, w).view(1, w, 1).repeat(h, 1, 1)
+
+ arr = torch.cat([y, x], dim=-1)
+ return arr
+
+ def delta_border(self, h, w):
+ """
+ :param h: height
+ :param w: width
+ :return: normalized distance to image border,
+ wtith min distance = 0 at border and max dist = 0.5 at image center
+ """
+ lower_right_corner = torch.tensor([h - 1, w - 1]).view(1, 1, 2)
+ arr = self.meshgrid(h, w) / lower_right_corner
+ dist_left_up = torch.min(arr, dim=-1, keepdims=True)[0]
+ dist_right_down = torch.min(1 - arr, dim=-1, keepdims=True)[0]
+ edge_dist = torch.min(torch.cat([dist_left_up, dist_right_down], dim=-1), dim=-1)[0]
+ return edge_dist
+
+ def get_weighting(self, h, w, Ly, Lx, device):
+ weighting = self.delta_border(h, w)
+ weighting = torch.clip(weighting, self.split_input_params["clip_min_weight"],
+ self.split_input_params["clip_max_weight"], )
+ weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device)
+
+ if self.split_input_params["tie_braker"]:
+ L_weighting = self.delta_border(Ly, Lx)
+ L_weighting = torch.clip(L_weighting,
+ self.split_input_params["clip_min_tie_weight"],
+ self.split_input_params["clip_max_tie_weight"])
+
+ L_weighting = L_weighting.view(1, 1, Ly * Lx).to(device)
+ weighting = weighting * L_weighting
+ return weighting
+
+ def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
+ """
+ :param x: img of size (bs, c, h, w)
+ :return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
+ """
+ bs, nc, h, w = x.shape
+
+ # number of crops in image
+ Ly = (h - kernel_size[0]) // stride[0] + 1
+ Lx = (w - kernel_size[1]) // stride[1] + 1
+
+ if uf == 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold = torch.nn.Fold(output_size=x.shape[2:], **fold_params)
+
+ weighting = self.get_weighting(kernel_size[0], kernel_size[1], Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h, w) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0], kernel_size[1], Ly * Lx))
+
+ elif uf > 1 and df == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] * uf, kernel_size[0] * uf),
+ dilation=1, padding=0,
+ stride=(stride[0] * uf, stride[1] * uf))
+ fold = torch.nn.Fold(output_size=(x.shape[2] * uf, x.shape[3] * uf), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] * uf, kernel_size[1] * uf, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h * uf, w * uf) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] * uf, kernel_size[1] * uf, Ly * Lx))
+
+ elif df > 1 and uf == 1:
+ fold_params = dict(kernel_size=kernel_size, dilation=1, padding=0, stride=stride)
+ unfold = torch.nn.Unfold(**fold_params)
+
+ fold_params2 = dict(kernel_size=(kernel_size[0] // df, kernel_size[0] // df),
+ dilation=1, padding=0,
+ stride=(stride[0] // df, stride[1] // df))
+ fold = torch.nn.Fold(output_size=(x.shape[2] // df, x.shape[3] // df), **fold_params2)
+
+ weighting = self.get_weighting(kernel_size[0] // df, kernel_size[1] // df, Ly, Lx, x.device).to(x.dtype)
+ normalization = fold(weighting).view(1, 1, h // df, w // df) # normalizes the overlap
+ weighting = weighting.view((1, 1, kernel_size[0] // df, kernel_size[1] // df, Ly * Lx))
+
+ else:
+ raise NotImplementedError
+
+ return fold, unfold, normalization, weighting
+
+ @torch.no_grad()
+ def get_input(self, batch, k, return_first_stage_outputs=False, force_c_encode=False,
+ cond_key=None, return_original_cond=False, bs=None, return_x=False):
+ x = super().get_input(batch, k)
+ if bs is not None:
+ x = x[:bs]
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+
+ if self.model.conditioning_key is not None and not self.force_null_conditioning:
+ if cond_key is None:
+ cond_key = self.cond_stage_key
+ if cond_key != self.first_stage_key:
+ if cond_key in ['caption', 'coordinates_bbox', "txt"]:
+ xc = batch[cond_key]
+ elif cond_key in ['class_label', 'cls']:
+ xc = batch
+ else:
+ xc = super().get_input(batch, cond_key).to(self.device)
+ else:
+ xc = x
+ if not self.cond_stage_trainable or force_c_encode:
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ c = self.get_learned_conditioning(xc.to(self.device))
+ else:
+ c = xc
+ if bs is not None:
+ c = c[:bs]
+
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ ckey = __conditioning_keys__[self.model.conditioning_key]
+ c = {ckey: c, 'pos_x': pos_x, 'pos_y': pos_y}
+
+ else:
+ c = None
+ xc = None
+ if self.use_positional_encodings:
+ pos_x, pos_y = self.compute_latent_shifts(batch)
+ c = {'pos_x': pos_x, 'pos_y': pos_y}
+ out = [z, c]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([x, xrec])
+ if return_x:
+ out.extend([x])
+ if return_original_cond:
+ out.append(xc)
+ return out
+
+ @torch.no_grad()
+ def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
+ if predict_cids:
+ if z.dim() == 4:
+ z = torch.argmax(z.exp(), dim=1).long()
+ z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
+ z = rearrange(z, 'b h w c -> b c h w').contiguous()
+
+ z = 1. / self.scale_factor * z
+ return self.first_stage_model.decode(z)
+
+ @torch.no_grad()
+ def encode_first_stage(self, x):
+ return self.first_stage_model.encode(x)
+
+ def shared_step(self, batch, **kwargs):
+ x, c = self.get_input(batch, self.first_stage_key)
+ loss = self(x, c)
+ return loss
+
+ def forward(self, x, c, *args, **kwargs):
+ t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
+ if self.model.conditioning_key is not None:
+ assert c is not None
+ if self.cond_stage_trainable:
+ c = self.get_learned_conditioning(c)
+ if self.shorten_cond_schedule: # TODO: drop this option
+ tc = self.cond_ids[t].to(self.device)
+ c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
+ return self.p_losses(x, c, t, *args, **kwargs)
+
+ def apply_model(self, x_noisy, t, cond, return_ids=False):
+ if isinstance(cond, dict):
+ # hybrid case, cond is expected to be a dict
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ x_recon = self.model(x_noisy, t, **cond)
+
+ if isinstance(x_recon, tuple) and not return_ids:
+ return x_recon[0]
+ else:
+ return x_recon
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart) / \
+ extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = torch.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def p_losses(self, x_start, cond, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_output = self.apply_model(x_noisy, t, cond)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+
+ logvar_t = self.logvar[t].to(self.device)
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ # loss = loss_simple / torch.exp(self.logvar) + self.logvar
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+
+ loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=False, quantize_denoised=False,
+ return_x0=False, score_corrector=None, corrector_kwargs=None):
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c, **corrector_kwargs)
+
+ if return_codebook_ids:
+ model_out, logits = model_out
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ if quantize_denoised:
+ x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon)
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ if return_codebook_ids:
+ return model_mean, posterior_variance, posterior_log_variance, logits
+ elif return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, c, t, clip_denoised=False, repeat_noise=False,
+ return_codebook_ids=False, quantize_denoised=False, return_x0=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised,
+ return_codebook_ids=return_codebook_ids,
+ quantize_denoised=quantize_denoised,
+ return_x0=return_x0,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if return_codebook_ids:
+ raise DeprecationWarning("Support dropped.")
+ model_mean, _, model_log_variance, logits = outputs
+ elif return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+
+ if return_codebook_ids:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, logits.argmax(dim=1)
+ if return_x0:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def progressive_denoising(self, cond, shape, verbose=True, callback=None, quantize_denoised=False,
+ img_callback=None, mask=None, x0=None, temperature=1., noise_dropout=0.,
+ score_corrector=None, corrector_kwargs=None, batch_size=None, x_T=None, start_T=None,
+ log_every_t=None):
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ timesteps = self.num_timesteps
+ if batch_size is not None:
+ b = batch_size if batch_size is not None else shape[0]
+ shape = [batch_size] + list(shape)
+ else:
+ b = batch_size = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=self.device)
+ else:
+ img = x_T
+ intermediates = []
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Progressive Generation',
+ total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+ if type(temperature) == float:
+ temperature = [temperature] * timesteps
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=self.device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img, x0_partial = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised, return_x0=True,
+ temperature=temperature[i], noise_dropout=noise_dropout,
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs)
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(x0_partial)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_loop(self, cond, shape, return_intermediates=False,
+ x_T=None, verbose=True, callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, start_T=None,
+ log_every_t=None):
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+ iterator = tqdm(reversed(range(0, timesteps)), desc='Sampling t', total=timesteps) if verbose else reversed(
+ range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match
+
+ for i in iterator:
+ ts = torch.full((b,), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond))
+
+ img = self.p_sample(img, cond, ts,
+ clip_denoised=self.clip_denoised,
+ quantize_denoised=quantize_denoised)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self, cond, batch_size=16, return_intermediates=False, x_T=None,
+ verbose=True, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, shape=None, **kwargs):
+ if shape is None:
+ shape = (batch_size, self.channels, self.image_size, self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {key: cond[key][:batch_size] if not isinstance(cond[key], list) else
+ list(map(lambda x: x[:batch_size], cond[key])) for key in cond}
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates, x_T=x_T,
+ verbose=verbose, timesteps=timesteps, quantize_denoised=quantize_denoised,
+ mask=mask, x0=x0)
+
+ @torch.no_grad()
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.image_size, self.image_size)
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size,
+ shape, cond, verbose=False, **kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond, batch_size=batch_size,
+ return_intermediates=True, **kwargs)
+
+ return samples, intermediates
+
+ @torch.no_grad()
+ def get_unconditional_conditioning(self, batch_size, null_label=None):
+ if null_label is not None:
+ xc = null_label
+ if isinstance(xc, ListConfig):
+ xc = list(xc)
+ if isinstance(xc, dict) or isinstance(xc, list):
+ c = self.get_learned_conditioning(xc)
+ else:
+ if hasattr(xc, "to"):
+ xc = xc.to(self.device)
+ c = self.get_learned_conditioning(xc)
+ else:
+ if self.cond_stage_key in ["class_label", "cls"]:
+ xc = self.cond_stage_model.get_unconditional_conditioning(batch_size, device=self.device)
+ return self.get_learned_conditioning(xc)
+ else:
+ raise NotImplementedError("todo")
+ if isinstance(c, list): # in case the encoder gives us a list
+ for i in range(len(c)):
+ c[i] = repeat(c[i], '1 ... -> b ...', b=batch_size).to(self.device)
+ else:
+ c = repeat(c, '1 ... -> b ...', b=batch_size).to(self.device)
+ return c
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=50, ddim_eta=0., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key,
+ return_first_stage_outputs=True,
+ force_c_encode=True,
+ return_original_cond=True,
+ bs=N)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', "cls"]:
+ try:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ except KeyError:
+ # probably no "human_label" in batch
+ pass
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if quantize_denoised and not isinstance(self.first_stage_model, AutoencoderKL) and not isinstance(
+ self.first_stage_model, IdentityFirstStage):
+ # also display when quantizing x0 while sampling
+ with ema_scope("Plotting Quantized Denoised"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ quantize_denoised=True)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True,
+ # quantize_denoised=True)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_x0_quantized"] = x_samples
+
+ if unconditional_guidance_scale > 1.0:
+ uc = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ if self.model.conditioning_key == "crossattn-adm":
+ uc = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if inpaint:
+ # make a simple center square
+ b, h, w = z.shape[0], z.shape[2], z.shape[3]
+ mask = torch.ones(N, h, w).to(self.device)
+ # zeros will be filled in
+ mask[:, h // 4:3 * h // 4, w // 4:3 * w // 4] = 0.
+ mask = mask[:, None, ...]
+ with ema_scope("Plotting Inpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_inpainting"] = x_samples
+ log["mask"] = mask
+
+ # outpaint
+ mask = 1. - mask
+ with ema_scope("Plotting Outpaint"):
+ samples, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim, eta=ddim_eta,
+ ddim_steps=ddim_steps, x0=z[:N], mask=mask)
+ x_samples = self.decode_first_stage(samples.to(self.device))
+ log["samples_outpainting"] = x_samples
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.cond_stage_trainable:
+ print(f"{self.__class__.__name__}: Also optimizing conditioner params!")
+ params = params + list(self.cond_stage_model.parameters())
+ if self.learn_logvar:
+ print('Diffusion model optimizing logvar')
+ params.append(self.logvar)
+ opt = torch.optim.AdamW(params, lr=lr)
+ if self.use_scheduler:
+ assert 'target' in self.scheduler_config
+ scheduler = instantiate_from_config(self.scheduler_config)
+
+ print("Setting up LambdaLR scheduler...")
+ scheduler = [
+ {
+ 'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule),
+ 'interval': 'step',
+ 'frequency': 1
+ }]
+ return [opt], scheduler
+ return opt
+
+ @torch.no_grad()
+ def to_rgb(self, x):
+ x = x.float()
+ if not hasattr(self, "colorize"):
+ self.colorize = torch.randn(3, x.shape[1], 1, 1).to(x)
+ x = nn.functional.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class DiffusionWrapper(pl.LightningModule):
+ def __init__(self, diff_model_config, conditioning_key):
+ super().__init__()
+ self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+ assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']
+
+ def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t)
+ elif self.conditioning_key == 'crossattn':
+ if not self.sequential_cross_attn:
+ cc = torch.cat(c_crossattn, 1)
+ else:
+ cc = c_crossattn
+ if hasattr(self, "scripted_diffusion_model"):
+ # TorchScript changes names of the arguments
+ # with argument cc defined as context=cc scripted model will produce
+ # an error: RuntimeError: forward() is missing value for argument 'argument_3'.
+ out = self.scripted_diffusion_model(x, t, cc)
+ else:
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'crossattn-adm':
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ else:
+ raise NotImplementedError()
+
+ return out
+
+
+class LatentUpscaleDiffusion(LatentDiffusion):
+ def __init__(self, *args, low_scale_config, low_scale_key="LR", noise_level_key=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ # assumes that neither the cond_stage nor the low_scale_model contain trainable params
+ assert not self.cond_stage_trainable
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+ self.noise_level_key = noise_level_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, log_mode=False):
+ if not log_mode:
+ z, c = super().get_input(batch, k, force_c_encode=True, bs=bs)
+ else:
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+ x_low = batch[self.low_scale_key][:bs]
+ x_low = rearrange(x_low, 'b h w c -> b c h w')
+ x_low = x_low.to(memory_format=torch.contiguous_format).float()
+ zx, noise_level = self.low_scale_model(x_low)
+ if self.noise_level_key is not None:
+ # get noise level from batch instead, e.g. when extracting a custom noise level for bsr
+ raise NotImplementedError('TODO')
+
+ all_conds = {"c_concat": [zx], "c_crossattn": [c], "c_adm": noise_level}
+ if log_mode:
+ # TODO: maybe disable if too expensive
+ x_low_rec = self.low_scale_model.decode(zx)
+ return z, all_conds, x, xrec, xc, x_low, x_low_rec, noise_level
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ plot_denoise_rows=False, plot_progressive_rows=True, plot_diffusion_rows=True,
+ unconditional_guidance_scale=1., unconditional_guidance_label=None, use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc, x_low, x_low_rec, noise_level = self.get_input(batch, self.first_stage_key, bs=N,
+ log_mode=True)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ log["x_lr"] = x_low
+ log[f"x_lr_rec_@noise_levels{'-'.join(map(lambda x: str(x), list(noise_level.cpu().numpy())))}"] = x_low_rec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_tmp = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ # TODO explore better "unconditional" choices for the other keys
+ # maybe guide away from empty text label and highest noise level and maximally degraded zx?
+ uc = dict()
+ for k in c:
+ if k == "c_crossattn":
+ assert isinstance(c[k], list) and len(c[k]) == 1
+ uc[k] = [uc_tmp]
+ elif k == "c_adm": # todo: only run with text-based guidance?
+ assert isinstance(c[k], torch.Tensor)
+ #uc[k] = torch.ones_like(c[k]) * self.low_scale_model.max_noise_level
+ uc[k] = c[k]
+ elif isinstance(c[k], list):
+ uc[k] = [c[k][i] for i in range(len(c[k]))]
+ else:
+ uc[k] = c[k]
+
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ if plot_progressive_rows:
+ with ema_scope("Plotting Progressives"):
+ img, progressives = self.progressive_denoising(c,
+ shape=(self.channels, self.image_size, self.image_size),
+ batch_size=N)
+ prog_row = self._get_denoise_row_from_list(progressives, desc="Progressive Generation")
+ log["progressive_row"] = prog_row
+
+ return log
+
+
+class LatentFinetuneDiffusion(LatentDiffusion):
+ """
+ Basis for different finetunas, such as inpainting or depth2image
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys: tuple,
+ finetune_keys=("model.diffusion_model.input_blocks.0.0.weight",
+ "model_ema.diffusion_modelinput_blocks00weight"
+ ),
+ keep_finetune_dims=4,
+ # if model was trained without concat mode before and we would like to keep these channels
+ c_concat_log_start=None, # to log reconstruction of c_concat codes
+ c_concat_log_end=None,
+ *args, **kwargs
+ ):
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", list())
+ super().__init__(*args, **kwargs)
+ self.finetune_keys = finetune_keys
+ self.concat_keys = concat_keys
+ self.keep_dims = keep_finetune_dims
+ self.c_concat_log_start = c_concat_log_start
+ self.c_concat_log_end = c_concat_log_end
+ if exists(self.finetune_keys): assert exists(ckpt_path), 'can only finetune from a given checkpoint'
+ if exists(ckpt_path):
+ self.init_from_ckpt(ckpt_path, ignore_keys)
+
+ def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+
+ # make it explicit, finetune by including extra input channels
+ if exists(self.finetune_keys) and k in self.finetune_keys:
+ new_entry = None
+ for name, param in self.named_parameters():
+ if name in self.finetune_keys:
+ print(
+ f"modifying key '{name}' and keeping its original {self.keep_dims} (channels) dimensions only")
+ new_entry = torch.zeros_like(param) # zero init
+ assert exists(new_entry), 'did not find matching parameter to modify'
+ new_entry[:, :self.keep_dims, ...] = sd[k]
+ sd[k] = new_entry
+
+ missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
+ if len(missing) > 0:
+ print(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ print(f"Unexpected Keys: {unexpected}")
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, sample=True, ddim_steps=200, ddim_eta=1., return_keys=None,
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
+ plot_diffusion_rows=True, unconditional_guidance_scale=1., unconditional_guidance_label=None,
+ use_ema_scope=True,
+ **kwargs):
+ ema_scope = self.ema_scope if use_ema_scope else nullcontext
+ use_ddim = ddim_steps is not None
+
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True)
+ c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ if self.model.conditioning_key is not None:
+ if hasattr(self.cond_stage_model, "decode"):
+ xc = self.cond_stage_model.decode(c)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ["caption", "txt"]:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ elif self.cond_stage_key in ['class_label', 'cls']:
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"], size=x.shape[2] // 25)
+ log['conditioning'] = xc
+ elif isimage(xc):
+ log["conditioning"] = xc
+ if ismap(xc):
+ log["original_conditioning"] = self.to_rgb(xc)
+
+ if not (self.c_concat_log_start is None and self.c_concat_log_end is None):
+ log["c_concat_decoded"] = self.decode_first_stage(c_cat[:, self.c_concat_log_start:self.c_concat_log_end])
+
+ if plot_diffusion_rows:
+ # get diffusion row
+ diffusion_row = list()
+ z_start = z[:n_row]
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(z_start)
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
+ diffusion_row.append(self.decode_first_stage(z_noisy))
+
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
+ log["diffusion_row"] = diffusion_grid
+
+ if sample:
+ # get denoise row
+ with ema_scope("Sampling"):
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta)
+ # samples, z_denoise_row = self.sample(cond=c, batch_size=N, return_intermediates=True)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ if unconditional_guidance_scale > 1.0:
+ uc_cross = self.get_unconditional_conditioning(N, unconditional_guidance_label)
+ uc_cat = c_cat
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
+ with ema_scope("Sampling with classifier-free guidance"):
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
+ batch_size=N, ddim=use_ddim,
+ ddim_steps=ddim_steps, eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_full,
+ )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+
+ return log
+
+
+class LatentInpaintDiffusion(LatentFinetuneDiffusion):
+ """
+ can either run as pure inpainting model (only concat mode) or with mixed conditionings,
+ e.g. mask as concat and text via cross-attn.
+ To disable finetuning mode, set finetune_keys to None
+ """
+
+ def __init__(self,
+ concat_keys=("mask", "masked_image"),
+ masked_image_key="masked_image",
+ *args, **kwargs
+ ):
+ super().__init__(concat_keys, *args, **kwargs)
+ self.masked_image_key = masked_image_key
+ assert self.masked_image_key in concat_keys
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for inpainting'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = rearrange(batch[ck], 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ bchw = z.shape
+ if ck != self.masked_image_key:
+ cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
+ else:
+ cc = self.get_first_stage_encoding(self.encode_first_stage(cc))
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super(LatentInpaintDiffusion, self).log_images(*args, **kwargs)
+ log["masked_image"] = rearrange(args[0]["masked_image"],
+ 'b h w c -> b c h w').to(memory_format=torch.contiguous_format).float()
+ return log
+
+
+class LatentDepth2ImageDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on monocular depth estimation
+ """
+
+ def __init__(self, depth_stage_config, concat_keys=("midas_in",), *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.depth_model = instantiate_from_config(depth_stage_config)
+ self.depth_stage_key = concat_keys[0]
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for depth2img'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ c_cat = list()
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ cc = self.depth_model(cc)
+ cc = torch.nn.functional.interpolate(
+ cc,
+ size=z.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+
+ depth_min, depth_max = torch.amin(cc, dim=[1, 2, 3], keepdim=True), torch.amax(cc, dim=[1, 2, 3],
+ keepdim=True)
+ cc = 2. * (cc - depth_min) / (depth_max - depth_min + 0.001) - 1.
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ depth = self.depth_model(args[0][self.depth_stage_key])
+ depth_min, depth_max = torch.amin(depth, dim=[1, 2, 3], keepdim=True), \
+ torch.amax(depth, dim=[1, 2, 3], keepdim=True)
+ log["depth"] = 2. * (depth - depth_min) / (depth_max - depth_min) - 1.
+ return log
+
+
+class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
+ """
+ condition on low-res image (and optionally on some spatial noise augmentation)
+ """
+ def __init__(self, concat_keys=("lr",), reshuffle_patch_size=None,
+ low_scale_config=None, low_scale_key=None, *args, **kwargs):
+ super().__init__(concat_keys=concat_keys, *args, **kwargs)
+ self.reshuffle_patch_size = reshuffle_patch_size
+ self.low_scale_model = None
+ if low_scale_config is not None:
+ print("Initializing a low-scale model")
+ assert exists(low_scale_key)
+ self.instantiate_low_stage(low_scale_config)
+ self.low_scale_key = low_scale_key
+
+ def instantiate_low_stage(self, config):
+ model = instantiate_from_config(config)
+ self.low_scale_model = model.eval()
+ self.low_scale_model.train = disabled_train
+ for param in self.low_scale_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def get_input(self, batch, k, cond_key=None, bs=None, return_first_stage_outputs=False):
+ # note: restricted to non-trainable encoders currently
+ assert not self.cond_stage_trainable, 'trainable cond stages not yet supported for upscaling-ft'
+ z, c, x, xrec, xc = super().get_input(batch, self.first_stage_key, return_first_stage_outputs=True,
+ force_c_encode=True, return_original_cond=True, bs=bs)
+
+ assert exists(self.concat_keys)
+ assert len(self.concat_keys) == 1
+ # optionally make spatial noise_level here
+ c_cat = list()
+ noise_level = None
+ for ck in self.concat_keys:
+ cc = batch[ck]
+ cc = rearrange(cc, 'b h w c -> b c h w')
+ if exists(self.reshuffle_patch_size):
+ assert isinstance(self.reshuffle_patch_size, int)
+ cc = rearrange(cc, 'b c (p1 h) (p2 w) -> b (p1 p2 c) h w',
+ p1=self.reshuffle_patch_size, p2=self.reshuffle_patch_size)
+ if bs is not None:
+ cc = cc[:bs]
+ cc = cc.to(self.device)
+ if exists(self.low_scale_model) and ck == self.low_scale_key:
+ cc, noise_level = self.low_scale_model(cc)
+ c_cat.append(cc)
+ c_cat = torch.cat(c_cat, dim=1)
+ if exists(noise_level):
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c], "c_adm": noise_level}
+ else:
+ all_conds = {"c_concat": [c_cat], "c_crossattn": [c]}
+ if return_first_stage_outputs:
+ return z, all_conds, x, xrec, xc
+ return z, all_conds
+
+ @torch.no_grad()
+ def log_images(self, *args, **kwargs):
+ log = super().log_images(*args, **kwargs)
+ log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
+ return log
+
+
+class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
+ def __init__(self, embedder_config, embedding_key="jpg", embedding_dropout=0.5,
+ freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.embed_key = embedding_key
+ self.embedding_dropout = embedding_dropout
+ self._init_embedder(embedder_config, freeze_embedder)
+ self._init_noise_aug(noise_aug_config)
+
+ def _init_embedder(self, config, freeze=True):
+ embedder = instantiate_from_config(config)
+ if freeze:
+ self.embedder = embedder.eval()
+ self.embedder.train = disabled_train
+ for param in self.embedder.parameters():
+ param.requires_grad = False
+
+ def _init_noise_aug(self, config):
+ if config is not None:
+ # use the KARLO schedule for noise augmentation on CLIP image embeddings
+ noise_augmentor = instantiate_from_config(config)
+ assert isinstance(noise_augmentor, nn.Module)
+ noise_augmentor = noise_augmentor.eval()
+ noise_augmentor.train = disabled_train
+ self.noise_augmentor = noise_augmentor
+ else:
+ self.noise_augmentor = None
+
+ def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
+ outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
+ z, c = outputs[0], outputs[1]
+ img = batch[self.embed_key][:bs]
+ img = rearrange(img, 'b h w c -> b c h w')
+ c_adm = self.embedder(img)
+ if self.noise_augmentor is not None:
+ c_adm, noise_level_emb = self.noise_augmentor(c_adm)
+ # assume this gives embeddings of noise levels
+ c_adm = torch.cat((c_adm, noise_level_emb), 1)
+ if self.training:
+ c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
+ device=c_adm.device)[:, None]) * c_adm
+ all_conds = {"c_crossattn": [c], "c_adm": c_adm}
+ noutputs = [z, all_conds]
+ noutputs.extend(outputs[2:])
+ return noutputs
+
+ @torch.no_grad()
+ def log_images(self, batch, N=8, n_row=4, **kwargs):
+ log = dict()
+ z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
+ return_original_cond=True)
+ log["inputs"] = x
+ log["reconstruction"] = xrec
+ assert self.model.conditioning_key is not None
+ assert self.cond_stage_key in ["caption", "txt"]
+ xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
+ log["conditioning"] = xc
+ uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
+ unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
+
+ uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
+ ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
+ with ema_scope(f"Sampling"):
+ samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
+ ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc_, )
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
+ log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
+ return log
diff --git a/repositories/ldm/models/diffusion/dpm_solver/__init__.py b/repositories/ldm/models/diffusion/dpm_solver/__init__.py
new file mode 100644
index 000000000..7427f38c0
--- /dev/null
+++ b/repositories/ldm/models/diffusion/dpm_solver/__init__.py
@@ -0,0 +1 @@
+from .sampler import DPMSolverSampler
\ No newline at end of file
diff --git a/repositories/ldm/models/diffusion/dpm_solver/dpm_solver.py b/repositories/ldm/models/diffusion/dpm_solver/dpm_solver.py
new file mode 100644
index 000000000..da8d41f9c
--- /dev/null
+++ b/repositories/ldm/models/diffusion/dpm_solver/dpm_solver.py
@@ -0,0 +1,1163 @@
+import torch
+import torch.nn.functional as F
+import math
+from tqdm import tqdm
+
+
+class NoiseScheduleVP:
+ def __init__(
+ self,
+ schedule='discrete',
+ betas=None,
+ alphas_cumprod=None,
+ continuous_beta_0=0.1,
+ continuous_beta_1=20.,
+ ):
+ """Create a wrapper class for the forward SDE (VP type).
+ ***
+ Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
+ We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
+ ***
+ The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
+ We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
+ Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
+ log_alpha_t = self.marginal_log_mean_coeff(t)
+ sigma_t = self.marginal_std(t)
+ lambda_t = self.marginal_lambda(t)
+ Moreover, as lambda(t) is an invertible function, we also support its inverse function:
+ t = self.inverse_lambda(lambda_t)
+ ===============================================================
+ We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
+ 1. For discrete-time DPMs:
+ For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
+ t_i = (i + 1) / N
+ e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
+ We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
+ Args:
+ betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
+ alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
+ Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
+ **Important**: Please pay special attention for the args for `alphas_cumprod`:
+ The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
+ q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
+ Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
+ alpha_{t_n} = \sqrt{\hat{alpha_n}},
+ and
+ log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
+ 2. For continuous-time DPMs:
+ We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
+ schedule are the default settings in DDPM and improved-DDPM:
+ Args:
+ beta_min: A `float` number. The smallest beta for the linear schedule.
+ beta_max: A `float` number. The largest beta for the linear schedule.
+ cosine_s: A `float` number. The hyperparameter in the cosine schedule.
+ cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
+ T: A `float` number. The ending time of the forward process.
+ ===============================================================
+ Args:
+ schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
+ 'linear' or 'cosine' for continuous-time DPMs.
+ Returns:
+ A wrapper object of the forward SDE (VP type).
+
+ ===============================================================
+ Example:
+ # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', betas=betas)
+ # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
+ >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
+ # For continuous-time DPMs (VPSDE), linear schedule:
+ >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
+ """
+
+ if schedule not in ['discrete', 'linear', 'cosine']:
+ raise ValueError(
+ "Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
+ schedule))
+
+ self.schedule = schedule
+ if schedule == 'discrete':
+ if betas is not None:
+ log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
+ else:
+ assert alphas_cumprod is not None
+ log_alphas = 0.5 * torch.log(alphas_cumprod)
+ self.total_N = len(log_alphas)
+ self.T = 1.
+ self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
+ self.log_alpha_array = log_alphas.reshape((1, -1,))
+ else:
+ self.total_N = 1000
+ self.beta_0 = continuous_beta_0
+ self.beta_1 = continuous_beta_1
+ self.cosine_s = 0.008
+ self.cosine_beta_max = 999.
+ self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
+ self.schedule = schedule
+ if schedule == 'cosine':
+ # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
+ # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
+ self.T = 0.9946
+ else:
+ self.T = 1.
+
+ def marginal_log_mean_coeff(self, t):
+ """
+ Compute log(alpha_t) of a given continuous-time label t in [0, T].
+ """
+ if self.schedule == 'discrete':
+ return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device),
+ self.log_alpha_array.to(t.device)).reshape((-1))
+ elif self.schedule == 'linear':
+ return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
+ elif self.schedule == 'cosine':
+ log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
+ log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
+ return log_alpha_t
+
+ def marginal_alpha(self, t):
+ """
+ Compute alpha_t of a given continuous-time label t in [0, T].
+ """
+ return torch.exp(self.marginal_log_mean_coeff(t))
+
+ def marginal_std(self, t):
+ """
+ Compute sigma_t of a given continuous-time label t in [0, T].
+ """
+ return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
+
+ def marginal_lambda(self, t):
+ """
+ Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
+ """
+ log_mean_coeff = self.marginal_log_mean_coeff(t)
+ log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
+ return log_mean_coeff - log_std
+
+ def inverse_lambda(self, lamb):
+ """
+ Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
+ """
+ if self.schedule == 'linear':
+ tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ Delta = self.beta_0 ** 2 + tmp
+ return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
+ elif self.schedule == 'discrete':
+ log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
+ t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]),
+ torch.flip(self.t_array.to(lamb.device), [1]))
+ return t.reshape((-1,))
+ else:
+ log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
+ t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (
+ 1. + self.cosine_s) / math.pi - self.cosine_s
+ t = t_fn(log_alpha)
+ return t
+
+
+def model_wrapper(
+ model,
+ noise_schedule,
+ model_type="noise",
+ model_kwargs={},
+ guidance_type="uncond",
+ condition=None,
+ unconditional_condition=None,
+ guidance_scale=1.,
+ classifier_fn=None,
+ classifier_kwargs={},
+):
+ """Create a wrapper function for the noise prediction model.
+ DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
+ firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
+ We support four types of the diffusion model by setting `model_type`:
+ 1. "noise": noise prediction model. (Trained by predicting noise).
+ 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
+ 3. "v": velocity prediction model. (Trained by predicting the velocity).
+ The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
+ [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
+ arXiv preprint arXiv:2202.00512 (2022).
+ [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
+ arXiv preprint arXiv:2210.02303 (2022).
+
+ 4. "score": marginal score function. (Trained by denoising score matching).
+ Note that the score function and the noise prediction model follows a simple relationship:
+ ```
+ noise(x_t, t) = -sigma_t * score(x_t, t)
+ ```
+ We support three types of guided sampling by DPMs by setting `guidance_type`:
+ 1. "uncond": unconditional sampling by DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, **model_kwargs) -> noise | x_start | v | score
+ ``
+ The input `classifier_fn` has the following format:
+ ``
+ classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
+ ``
+ [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
+ in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
+ 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
+ The input `model` has the following format:
+ ``
+ model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
+ ``
+ And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
+ [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
+ arXiv preprint arXiv:2207.12598 (2022).
+
+ The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
+ or continuous-time labels (i.e. epsilon to T).
+ We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
+ ``
+ def model_fn(x, t_continuous) -> noise:
+ t_input = get_model_input_time(t_continuous)
+ return noise_pred(model, x, t_input, **model_kwargs)
+ ``
+ where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
+ ===============================================================
+ Args:
+ model: A diffusion model with the corresponding format described above.
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ model_type: A `str`. The parameterization type of the diffusion model.
+ "noise" or "x_start" or "v" or "score".
+ model_kwargs: A `dict`. A dict for the other inputs of the model function.
+ guidance_type: A `str`. The type of the guidance for sampling.
+ "uncond" or "classifier" or "classifier-free".
+ condition: A pytorch tensor. The condition for the guided sampling.
+ Only used for "classifier" or "classifier-free" guidance type.
+ unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
+ Only used for "classifier-free" guidance type.
+ guidance_scale: A `float`. The scale for the guided sampling.
+ classifier_fn: A classifier function. Only used for the classifier guidance.
+ classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
+ Returns:
+ A noise prediction model that accepts the noised data and the continuous time as the inputs.
+ """
+
+ def get_model_input_time(t_continuous):
+ """
+ Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
+ For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
+ For continuous-time DPMs, we just use `t_continuous`.
+ """
+ if noise_schedule.schedule == 'discrete':
+ return (t_continuous - 1. / noise_schedule.total_N) * 1000.
+ else:
+ return t_continuous
+
+ def noise_pred_fn(x, t_continuous, cond=None):
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ t_input = get_model_input_time(t_continuous)
+ if cond is None:
+ output = model(x, t_input, **model_kwargs)
+ else:
+ output = model(x, t_input, cond, **model_kwargs)
+ if model_type == "noise":
+ return output
+ elif model_type == "x_start":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
+ elif model_type == "v":
+ alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
+ elif model_type == "score":
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ dims = x.dim()
+ return -expand_dims(sigma_t, dims) * output
+
+ def cond_grad_fn(x, t_input):
+ """
+ Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
+ """
+ with torch.enable_grad():
+ x_in = x.detach().requires_grad_(True)
+ log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
+ return torch.autograd.grad(log_prob.sum(), x_in)[0]
+
+ def model_fn(x, t_continuous):
+ """
+ The noise predicition model function that is used for DPM-Solver.
+ """
+ if t_continuous.reshape((-1,)).shape[0] == 1:
+ t_continuous = t_continuous.expand((x.shape[0]))
+ if guidance_type == "uncond":
+ return noise_pred_fn(x, t_continuous)
+ elif guidance_type == "classifier":
+ assert classifier_fn is not None
+ t_input = get_model_input_time(t_continuous)
+ cond_grad = cond_grad_fn(x, t_input)
+ sigma_t = noise_schedule.marginal_std(t_continuous)
+ noise = noise_pred_fn(x, t_continuous)
+ return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
+ elif guidance_type == "classifier-free":
+ if guidance_scale == 1. or unconditional_condition is None:
+ return noise_pred_fn(x, t_continuous, cond=condition)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t_continuous] * 2)
+ if isinstance(condition, dict):
+ assert isinstance(unconditional_condition, dict)
+ c_in = dict()
+ for k in condition:
+ if isinstance(condition[k], list):
+ c_in[k] = [torch.cat([unconditional_condition[k][i], condition[k][i]]) for i in range(len(condition[k]))]
+ else:
+ c_in[k] = torch.cat([unconditional_condition[k], condition[k]])
+ else:
+ c_in = torch.cat([unconditional_condition, condition])
+ noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
+ return noise_uncond + guidance_scale * (noise - noise_uncond)
+
+ assert model_type in ["noise", "x_start", "v"]
+ assert guidance_type in ["uncond", "classifier", "classifier-free"]
+ return model_fn
+
+
+class DPM_Solver:
+ def __init__(self, model_fn, noise_schedule, predict_x0=False, thresholding=False, max_val=1.):
+ """Construct a DPM-Solver.
+ We support both the noise prediction model ("predicting epsilon") and the data prediction model ("predicting x0").
+ If `predict_x0` is False, we use the solver for the noise prediction model (DPM-Solver).
+ If `predict_x0` is True, we use the solver for the data prediction model (DPM-Solver++).
+ In such case, we further support the "dynamic thresholding" in [1] when `thresholding` is True.
+ The "dynamic thresholding" can greatly improve the sample quality for pixel-space DPMs with large guidance scales.
+ Args:
+ model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
+ ``
+ def model_fn(x, t_continuous):
+ return noise
+ ``
+ noise_schedule: A noise schedule object, such as NoiseScheduleVP.
+ predict_x0: A `bool`. If true, use the data prediction model; else, use the noise prediction model.
+ thresholding: A `bool`. Valid when `predict_x0` is True. Whether to use the "dynamic thresholding" in [1].
+ max_val: A `float`. Valid when both `predict_x0` and `thresholding` are True. The max value for thresholding.
+
+ [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
+ """
+ self.model = model_fn
+ self.noise_schedule = noise_schedule
+ self.predict_x0 = predict_x0
+ self.thresholding = thresholding
+ self.max_val = max_val
+
+ def noise_prediction_fn(self, x, t):
+ """
+ Return the noise prediction model.
+ """
+ return self.model(x, t)
+
+ def data_prediction_fn(self, x, t):
+ """
+ Return the data prediction model (with thresholding).
+ """
+ noise = self.noise_prediction_fn(x, t)
+ dims = x.dim()
+ alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
+ x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
+ if self.thresholding:
+ p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
+ s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
+ s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
+ x0 = torch.clamp(x0, -s, s) / s
+ return x0
+
+ def model_fn(self, x, t):
+ """
+ Convert the model to the noise prediction model or the data prediction model.
+ """
+ if self.predict_x0:
+ return self.data_prediction_fn(x, t)
+ else:
+ return self.noise_prediction_fn(x, t)
+
+ def get_time_steps(self, skip_type, t_T, t_0, N, device):
+ """Compute the intermediate time steps for sampling.
+ Args:
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ N: A `int`. The total number of the spacing of the time steps.
+ device: A torch device.
+ Returns:
+ A pytorch tensor of the time steps, with the shape (N + 1,).
+ """
+ if skip_type == 'logSNR':
+ lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
+ lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
+ logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
+ return self.noise_schedule.inverse_lambda(logSNR_steps)
+ elif skip_type == 'time_uniform':
+ return torch.linspace(t_T, t_0, N + 1).to(device)
+ elif skip_type == 'time_quadratic':
+ t_order = 2
+ t = torch.linspace(t_T ** (1. / t_order), t_0 ** (1. / t_order), N + 1).pow(t_order).to(device)
+ return t
+ else:
+ raise ValueError(
+ "Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
+
+ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
+ """
+ Get the order of each step for sampling by the singlestep DPM-Solver.
+ We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
+ Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
+ - If order == 1:
+ We take `steps` of DPM-Solver-1 (i.e. DDIM).
+ - If order == 2:
+ - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of DPM-Solver-2.
+ - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If order == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
+ ============================================
+ Args:
+ order: A `int`. The max order for the solver (2 or 3).
+ steps: A `int`. The total number of function evaluations (NFE).
+ skip_type: A `str`. The type for the spacing of the time steps. We support three types:
+ - 'logSNR': uniform logSNR for the time steps.
+ - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
+ - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ device: A torch device.
+ Returns:
+ orders: A list of the solver order of each step.
+ """
+ if order == 3:
+ K = steps // 3 + 1
+ if steps % 3 == 0:
+ orders = [3, ] * (K - 2) + [2, 1]
+ elif steps % 3 == 1:
+ orders = [3, ] * (K - 1) + [1]
+ else:
+ orders = [3, ] * (K - 1) + [2]
+ elif order == 2:
+ if steps % 2 == 0:
+ K = steps // 2
+ orders = [2, ] * K
+ else:
+ K = steps // 2 + 1
+ orders = [2, ] * (K - 1) + [1]
+ elif order == 1:
+ K = 1
+ orders = [1, ] * steps
+ else:
+ raise ValueError("'order' must be '1' or '2' or '3'.")
+ if skip_type == 'logSNR':
+ # To reproduce the results in DPM-Solver paper
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
+ else:
+ timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
+ torch.cumsum(torch.tensor([0, ] + orders)).to(device)]
+ return timesteps_outer, orders
+
+ def denoise_to_zero_fn(self, x, s):
+ """
+ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
+ """
+ return self.data_prediction_fn(x, s)
+
+ def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
+ """
+ DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_1 = torch.expm1(-h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+ else:
+ phi_1 = torch.expm1(h)
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False,
+ solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-2 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the second-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 0.5
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(
+ s1), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
+ alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_1 = torch.expm1(-h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(alpha_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r1) * expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * (
+ model_s1 - model_s)
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_1 = torch.expm1(h)
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (0.5 / r1) * expand_dims(sigma_t * phi_1, dims) * (model_s1 - model_s)
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r1) * expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * (model_s1 - model_s)
+ )
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1}
+ else:
+ return x_t
+
+ def singlestep_dpm_solver_third_update(self, x, s, t, r1=1. / 3., r2=2. / 3., model_s=None, model_s1=None,
+ return_intermediate=False, solver_type='dpm_solver'):
+ """
+ Singlestep solver DPM-Solver-3 from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ r1: A `float`. The hyperparameter of the third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ model_s: A pytorch tensor. The model function evaluated at time `s`.
+ If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
+ model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
+ If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ if r1 is None:
+ r1 = 1. / 3.
+ if r2 is None:
+ r2 = 2. / 3.
+ ns = self.noise_schedule
+ dims = x.dim()
+ lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
+ h = lambda_t - lambda_s
+ lambda_s1 = lambda_s + r1 * h
+ lambda_s2 = lambda_s + r2 * h
+ s1 = ns.inverse_lambda(lambda_s1)
+ s2 = ns.inverse_lambda(lambda_s2)
+ log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(
+ s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
+ sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(
+ s2), ns.marginal_std(t)
+ alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
+
+ if self.predict_x0:
+ phi_11 = torch.expm1(-r1 * h)
+ phi_12 = torch.expm1(-r2 * h)
+ phi_1 = torch.expm1(-h)
+ phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
+ phi_2 = phi_1 / h + 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(sigma_s1 / sigma_s, dims) * x
+ - expand_dims(alpha_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(sigma_s2 / sigma_s, dims) * x
+ - expand_dims(alpha_s2 * phi_12, dims) * model_s
+ + r2 / r1 * expand_dims(alpha_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + (1. / r2) * expand_dims(alpha_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(sigma_t / sigma_s, dims) * x
+ - expand_dims(alpha_t * phi_1, dims) * model_s
+ + expand_dims(alpha_t * phi_2, dims) * D1
+ - expand_dims(alpha_t * phi_3, dims) * D2
+ )
+ else:
+ phi_11 = torch.expm1(r1 * h)
+ phi_12 = torch.expm1(r2 * h)
+ phi_1 = torch.expm1(h)
+ phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
+ phi_2 = phi_1 / h - 1.
+ phi_3 = phi_2 / h - 0.5
+
+ if model_s is None:
+ model_s = self.model_fn(x, s)
+ if model_s1 is None:
+ x_s1 = (
+ expand_dims(torch.exp(log_alpha_s1 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s1 * phi_11, dims) * model_s
+ )
+ model_s1 = self.model_fn(x_s1, s1)
+ x_s2 = (
+ expand_dims(torch.exp(log_alpha_s2 - log_alpha_s), dims) * x
+ - expand_dims(sigma_s2 * phi_12, dims) * model_s
+ - r2 / r1 * expand_dims(sigma_s2 * phi_22, dims) * (model_s1 - model_s)
+ )
+ model_s2 = self.model_fn(x_s2, s2)
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - (1. / r2) * expand_dims(sigma_t * phi_2, dims) * (model_s2 - model_s)
+ )
+ elif solver_type == 'taylor':
+ D1_0 = (1. / r1) * (model_s1 - model_s)
+ D1_1 = (1. / r2) * (model_s2 - model_s)
+ D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
+ D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_s), dims) * x
+ - expand_dims(sigma_t * phi_1, dims) * model_s
+ - expand_dims(sigma_t * phi_2, dims) * D1
+ - expand_dims(sigma_t * phi_3, dims) * D2
+ )
+
+ if return_intermediate:
+ return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
+ else:
+ return x_t
+
+ def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpm_solver"):
+ """
+ Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if solver_type not in ['dpm_solver', 'taylor']:
+ raise ValueError("'solver_type' must be either 'dpm_solver' or 'taylor', got {}".format(solver_type))
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_1, model_prev_0 = model_prev_list
+ t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(
+ t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0 = h_0 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ if self.predict_x0:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1_0
+ )
+ else:
+ if solver_type == 'dpm_solver':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - 0.5 * expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * D1_0
+ )
+ elif solver_type == 'taylor':
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1_0
+ )
+ return x_t
+
+ def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpm_solver'):
+ """
+ Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ ns = self.noise_schedule
+ dims = x.dim()
+ model_prev_2, model_prev_1, model_prev_0 = model_prev_list
+ t_prev_2, t_prev_1, t_prev_0 = t_prev_list
+ lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(
+ t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
+ log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
+ sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
+ alpha_t = torch.exp(log_alpha_t)
+
+ h_1 = lambda_prev_1 - lambda_prev_2
+ h_0 = lambda_prev_0 - lambda_prev_1
+ h = lambda_t - lambda_prev_0
+ r0, r1 = h_0 / h, h_1 / h
+ D1_0 = expand_dims(1. / r0, dims) * (model_prev_0 - model_prev_1)
+ D1_1 = expand_dims(1. / r1, dims) * (model_prev_1 - model_prev_2)
+ D1 = D1_0 + expand_dims(r0 / (r0 + r1), dims) * (D1_0 - D1_1)
+ D2 = expand_dims(1. / (r0 + r1), dims) * (D1_0 - D1_1)
+ if self.predict_x0:
+ x_t = (
+ expand_dims(sigma_t / sigma_prev_0, dims) * x
+ - expand_dims(alpha_t * (torch.exp(-h) - 1.), dims) * model_prev_0
+ + expand_dims(alpha_t * ((torch.exp(-h) - 1.) / h + 1.), dims) * D1
+ - expand_dims(alpha_t * ((torch.exp(-h) - 1. + h) / h ** 2 - 0.5), dims) * D2
+ )
+ else:
+ x_t = (
+ expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
+ - expand_dims(sigma_t * (torch.exp(h) - 1.), dims) * model_prev_0
+ - expand_dims(sigma_t * ((torch.exp(h) - 1.) / h - 1.), dims) * D1
+ - expand_dims(sigma_t * ((torch.exp(h) - 1. - h) / h ** 2 - 0.5), dims) * D2
+ )
+ return x_t
+
+ def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpm_solver', r1=None,
+ r2=None):
+ """
+ Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ s: A pytorch tensor. The starting time, with the shape (x.shape[0],).
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ r1: A `float`. The hyperparameter of the second-order or third-order solver.
+ r2: A `float`. The hyperparameter of the third-order solver.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
+ elif order == 2:
+ return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1)
+ elif order == 3:
+ return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate,
+ solver_type=solver_type, r1=r1, r2=r2)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpm_solver'):
+ """
+ Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
+ Args:
+ x: A pytorch tensor. The initial value at time `s`.
+ model_prev_list: A list of pytorch tensor. The previous computed model values.
+ t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (x.shape[0],)
+ t: A pytorch tensor. The ending time, with the shape (x.shape[0],).
+ order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_t: A pytorch tensor. The approximated solution at time `t`.
+ """
+ if order == 1:
+ return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
+ elif order == 2:
+ return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ elif order == 3:
+ return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
+ else:
+ raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
+
+ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5,
+ solver_type='dpm_solver'):
+ """
+ The adaptive step size solver based on singlestep DPM-Solver.
+ Args:
+ x: A pytorch tensor. The initial value at time `t_T`.
+ order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
+ t_T: A `float`. The starting time of the sampling (default is T).
+ t_0: A `float`. The ending time of the sampling (default is epsilon).
+ h_init: A `float`. The initial step size (for logSNR).
+ atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
+ rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
+ theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
+ t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
+ current time and `t_0` is less than `t_err`. The default setting is 1e-5.
+ solver_type: either 'dpm_solver' or 'taylor'. The type for the high-order solvers.
+ The type slightly impacts the performance. We recommend to use 'dpm_solver' type.
+ Returns:
+ x_0: A pytorch tensor. The approximated solution at time `t_0`.
+ [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
+ """
+ ns = self.noise_schedule
+ s = t_T * torch.ones((x.shape[0],)).to(x)
+ lambda_s = ns.marginal_lambda(s)
+ lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
+ h = h_init * torch.ones_like(s).to(x)
+ x_prev = x
+ nfe = 0
+ if order == 2:
+ r1 = 0.5
+ lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ solver_type=solver_type,
+ **kwargs)
+ elif order == 3:
+ r1, r2 = 1. / 3., 2. / 3.
+ lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1,
+ return_intermediate=True,
+ solver_type=solver_type)
+ higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2,
+ solver_type=solver_type,
+ **kwargs)
+ else:
+ raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
+ while torch.abs((s - t_0)).mean() > t_err:
+ t = ns.inverse_lambda(lambda_s + h)
+ x_lower, lower_noise_kwargs = lower_update(x, s, t)
+ x_higher = higher_update(x, s, t, **lower_noise_kwargs)
+ delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
+ norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
+ E = norm_fn((x_higher - x_lower) / delta).max()
+ if torch.all(E <= 1.):
+ x = x_higher
+ s = t
+ x_prev = x_lower
+ lambda_s = ns.marginal_lambda(s)
+ h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
+ nfe += order
+ print('adaptive solver nfe', nfe)
+ return x
+
+ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform',
+ method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
+ atol=0.0078, rtol=0.05,
+ ):
+ """
+ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
+ =====================================================
+ We support the following algorithms for both noise prediction model and data prediction model:
+ - 'singlestep':
+ Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
+ We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
+ The total number of function evaluations (NFE) == `steps`.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ - If `order` == 1:
+ - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
+ - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
+ - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If `order` == 3:
+ - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
+ - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
+ - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
+ - 'multistep':
+ Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
+ We initialize the first `order` values by lower order multistep solvers.
+ Given a fixed NFE == `steps`, the sampling procedure is:
+ Denote K = steps.
+ - If `order` == 1:
+ - We use K steps of DPM-Solver-1 (i.e. DDIM).
+ - If `order` == 2:
+ - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
+ - If `order` == 3:
+ - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
+ - 'singlestep_fixed':
+ Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
+ We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
+ - 'adaptive':
+ Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
+ We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
+ You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
+ (NFE) and the sample quality.
+ - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
+ - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
+ =====================================================
+ Some advices for choosing the algorithm:
+ - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
+ Use singlestep DPM-Solver ("DPM-Solver-fast" in the paper) with `order = 3`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=False)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
+ skip_type='time_uniform', method='singlestep')
+ - For **guided sampling with large guidance scale** by DPMs:
+ Use multistep DPM-Solver with `predict_x0 = True` and `order = 2`.
+ e.g.
+ >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, predict_x0=True)
+ >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
+ skip_type='time_uniform', method='multistep')
+ We support three types of `skip_type`:
+ - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
+ - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
+ - 'time_quadratic': quadratic time for the time steps.
+ =====================================================
+ Args:
+ x: A pytorch tensor. The initial value at time `t_start`
+ e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
+ steps: A `int`. The total number of function evaluations (NFE).
+ t_start: A `float`. The starting time of the sampling.
+ If `T` is None, we use self.noise_schedule.T (default is 1.0).
+ t_end: A `float`. The ending time of the sampling.
+ If `t_end` is None, we use 1. / self.noise_schedule.total_N.
+ e.g. if total_N == 1000, we have `t_end` == 1e-3.
+ For discrete-time DPMs:
+ - We recommend `t_end` == 1. / self.noise_schedule.total_N.
+ For continuous-time DPMs:
+ - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
+ order: A `int`. The order of DPM-Solver.
+ skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
+ method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
+ denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
+ Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
+ This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
+ score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
+ for diffusion models sampling by diffusion SDEs for low-resolutional images
+ (such as CIFAR-10). However, we observed that such trick does not matter for
+ high-resolutional images. As it needs an additional NFE, we do not recommend
+ it for high-resolutional images.
+ lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
+ Only valid for `method=multistep` and `steps < 15`. We empirically find that
+ this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
+ (especially for steps <= 10). So we recommend to set it to be `True`.
+ solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`.
+ atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
+ Returns:
+ x_end: A pytorch tensor. The approximated solution at time `t_end`.
+ """
+ t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
+ t_T = self.noise_schedule.T if t_start is None else t_start
+ device = x.device
+ if method == 'adaptive':
+ with torch.no_grad():
+ x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol,
+ solver_type=solver_type)
+ elif method == 'multistep':
+ assert steps >= order
+ timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
+ assert timesteps.shape[0] - 1 == steps
+ with torch.no_grad():
+ vec_t = timesteps[0].expand((x.shape[0]))
+ model_prev_list = [self.model_fn(x, vec_t)]
+ t_prev_list = [vec_t]
+ # Init the first `order` values by lower order multistep DPM-Solver.
+ for init_order in tqdm(range(1, order), desc="DPM init order"):
+ vec_t = timesteps[init_order].expand(x.shape[0])
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, init_order,
+ solver_type=solver_type)
+ model_prev_list.append(self.model_fn(x, vec_t))
+ t_prev_list.append(vec_t)
+ # Compute the remaining values by `order`-th order multistep DPM-Solver.
+ for step in tqdm(range(order, steps + 1), desc="DPM multistep"):
+ vec_t = timesteps[step].expand(x.shape[0])
+ if lower_order_final and steps < 15:
+ step_order = min(order, steps + 1 - step)
+ else:
+ step_order = order
+ x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order,
+ solver_type=solver_type)
+ for i in range(order - 1):
+ t_prev_list[i] = t_prev_list[i + 1]
+ model_prev_list[i] = model_prev_list[i + 1]
+ t_prev_list[-1] = vec_t
+ # We do not need to evaluate the final model value.
+ if step < steps:
+ model_prev_list[-1] = self.model_fn(x, vec_t)
+ elif method in ['singlestep', 'singlestep_fixed']:
+ if method == 'singlestep':
+ timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order,
+ skip_type=skip_type,
+ t_T=t_T, t_0=t_0,
+ device=device)
+ elif method == 'singlestep_fixed':
+ K = steps // order
+ orders = [order, ] * K
+ timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
+ for i, order in enumerate(orders):
+ t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1]
+ timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(),
+ N=order, device=device)
+ lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
+ vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0])
+ h = lambda_inner[-1] - lambda_inner[0]
+ r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
+ r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
+ x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2)
+ if denoise_to_zero:
+ x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
+ return x
+
+
+#############################################################
+# other utility functions
+#############################################################
+
+def interpolate_fn(x, xp, yp):
+ """
+ A piecewise linear function y = f(x), using xp and yp as keypoints.
+ We implement f(x) in a differentiable way (i.e. applicable for autograd).
+ The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
+ Args:
+ x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
+ xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
+ yp: PyTorch tensor with shape [C, K].
+ Returns:
+ The function values f(x), with shape [N, C].
+ """
+ N, K = x.shape[0], xp.shape[1]
+ all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
+ sorted_all_x, x_indices = torch.sort(all_x, dim=2)
+ x_idx = torch.argmin(x_indices, dim=2)
+ cand_start_idx = x_idx - 1
+ start_idx = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(1, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
+ start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
+ end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
+ start_idx2 = torch.where(
+ torch.eq(x_idx, 0),
+ torch.tensor(0, device=x.device),
+ torch.where(
+ torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
+ ),
+ )
+ y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
+ start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
+ end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
+ cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
+ return cand
+
+
+def expand_dims(v, dims):
+ """
+ Expand the tensor `v` to the dim `dims`.
+ Args:
+ `v`: a PyTorch tensor with shape [N].
+ `dim`: a `int`.
+ Returns:
+ a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
+ """
+ return v[(...,) + (None,) * (dims - 1)]
\ No newline at end of file
diff --git a/repositories/ldm/models/diffusion/dpm_solver/sampler.py b/repositories/ldm/models/diffusion/dpm_solver/sampler.py
new file mode 100644
index 000000000..e4d0d0a38
--- /dev/null
+++ b/repositories/ldm/models/diffusion/dpm_solver/sampler.py
@@ -0,0 +1,96 @@
+"""SAMPLING ONLY."""
+import torch
+
+from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
+
+MODEL_TYPES = {
+ "eps": "noise",
+ "v": "v"
+}
+
+
+class DPMSolverSampler(object):
+ def __init__(self, model, device=torch.device("cuda"), **kwargs):
+ super().__init__()
+ self.model = model
+ self.device = device
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
+ self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != self.device:
+ attr = attr.to(self.device)
+ setattr(self, name, attr)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ ctmp = conditioning[list(conditioning.keys())[0]]
+ while isinstance(ctmp, list): ctmp = ctmp[0]
+ if isinstance(ctmp, torch.Tensor):
+ cbs = ctmp.shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ elif isinstance(conditioning, list):
+ for ctmp in conditioning:
+ if ctmp.shape[0] != batch_size:
+ print(f"Warning: Got {ctmp.shape[0]} conditionings but batch-size is {batch_size}")
+ else:
+ if isinstance(conditioning, torch.Tensor):
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+
+ print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
+
+ device = self.model.betas.device
+ if x_T is None:
+ img = torch.randn(size, device=device)
+ else:
+ img = x_T
+
+ ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
+
+ model_fn = model_wrapper(
+ lambda x, t, c: self.model.apply_model(x, t, c),
+ ns,
+ model_type=MODEL_TYPES[self.model.parameterization],
+ guidance_type="classifier-free",
+ condition=conditioning,
+ unconditional_condition=unconditional_conditioning,
+ guidance_scale=unconditional_guidance_scale,
+ )
+
+ dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
+ x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2,
+ lower_order_final=True)
+
+ return x.to(device), None
diff --git a/repositories/ldm/models/diffusion/plms.py b/repositories/ldm/models/diffusion/plms.py
new file mode 100644
index 000000000..9d31b3994
--- /dev/null
+++ b/repositories/ldm/models/diffusion/plms.py
@@ -0,0 +1,245 @@
+"""SAMPLING ONLY."""
+
+import torch
+import numpy as np
+from tqdm import tqdm
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
+from ldm.models.diffusion.sampling_util import norm_thresholding
+
+
+class PLMSSampler(object):
+ def __init__(self, model, schedule="linear", device=torch.device("cuda"), **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+ self.device = device
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != self.device:
+ attr = attr.to(self.device)
+ setattr(self, name, attr)
+
+ def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
+ if ddim_eta != 0:
+ raise ValueError('ddim_eta must be 0 for PLMS')
+ self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # ddim sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
+ 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
+ dynamic_threshold=None,
+ **kwargs
+ ):
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ if cbs != batch_size:
+ print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
+
+ self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
+ # sampling
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ print(f'Data shape for PLMS sampling is {size}')
+
+ samples, intermediates = self.plms_sampling(conditioning, size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask, x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ dynamic_threshold=dynamic_threshold,
+ )
+ return samples, intermediates
+
+ @torch.no_grad()
+ def plms_sampling(self, cond, shape,
+ x_T=None, ddim_use_original_steps=False,
+ callback=None, timesteps=None, quantize_denoised=False,
+ mask=None, x0=None, img_callback=None, log_every_t=100,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None,
+ dynamic_threshold=None):
+ device = self.model.betas.device
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {'x_inter': [img], 'pred_x0': [img]}
+ time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
+ print(f"Running PLMS Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
+ old_eps = []
+
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+ ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
+
+ if mask is not None:
+ assert x0 is not None
+ img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised, temperature=temperature,
+ noise_dropout=noise_dropout, score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ old_eps=old_eps, t_next=ts_next,
+ dynamic_threshold=dynamic_threshold)
+ img, pred_x0, e_t = outs
+ old_eps.append(e_t)
+ if len(old_eps) >= 4:
+ old_eps.pop(0)
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+
+ return img, intermediates
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
+ temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
+ unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,
+ dynamic_threshold=None):
+ b, *_, device = *x.shape, x.device
+
+ def get_model_output(x, t):
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ e_t = self.model.apply_model(x, t, c)
+ else:
+ x_in = torch.cat([x] * 2)
+ t_in = torch.cat([t] * 2)
+ c_in = torch.cat([unconditional_conditioning, c])
+ e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
+ e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps"
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
+
+ return e_t
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ def get_x_prev_and_pred_x0(e_t, index):
+ # select parameters corresponding to the currently considered timestep
+ a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
+ a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
+ sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
+
+ # current prediction for x_0
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+ if dynamic_threshold is not None:
+ pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
+ # direction pointing to x_t
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+ noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+ return x_prev, pred_x0
+
+ e_t = get_model_output(x, t)
+ if len(old_eps) == 0:
+ # Pseudo Improved Euler (2nd order)
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
+ e_t_next = get_model_output(x_prev, t_next)
+ e_t_prime = (e_t + e_t_next) / 2
+ elif len(old_eps) == 1:
+ # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (3 * e_t - old_eps[-1]) / 2
+ elif len(old_eps) == 2:
+ # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
+ elif len(old_eps) >= 3:
+ # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
+ e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
+
+ x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
+
+ return x_prev, pred_x0, e_t
diff --git a/repositories/ldm/models/diffusion/sampling_util.py b/repositories/ldm/models/diffusion/sampling_util.py
new file mode 100644
index 000000000..7eff02be6
--- /dev/null
+++ b/repositories/ldm/models/diffusion/sampling_util.py
@@ -0,0 +1,22 @@
+import torch
+import numpy as np
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions.
+ From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ return x[(...,) + (None,) * dims_to_append]
+
+
+def norm_thresholding(x0, value):
+ s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
+ return x0 * (value / s)
+
+
+def spatial_norm_thresholding(x0, value):
+ # b c h w
+ s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
+ return x0 * (value / s)
\ No newline at end of file
diff --git a/repositories/ldm/modules/attention.py b/repositories/ldm/modules/attention.py
new file mode 100644
index 000000000..509cd8737
--- /dev/null
+++ b/repositories/ldm/modules/attention.py
@@ -0,0 +1,341 @@
+from inspect import isfunction
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+from einops import rearrange, repeat
+from typing import Optional, Any
+
+from ldm.modules.diffusionmodules.util import checkpoint
+
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+# CrossAttn precision handling
+import os
+_ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
+
+def exists(val):
+ return val is not None
+
+
+def uniq(arr):
+ return{el: True for el in arr}.keys()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+# feedforward
+class GEGLU(nn.Module):
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(
+ nn.Linear(dim, inner_dim),
+ nn.GELU()
+ ) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(
+ project_in,
+ nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out)
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class SpatialSelfAttention(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(
+ nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout)
+ )
+
+ def forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ # force cast to fp32 to avoid overflowing
+ if _ATTN_PRECISION =="fp32":
+ with torch.autocast(enabled=False, device_type = 'cuda'):
+ q, k = q.float(), k.float()
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+ else:
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ del q, k
+
+ if exists(mask):
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', sim, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
+
+
+class MemoryEfficientCrossAttention(nn.Module):
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
+ super().__init__()
+ print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
+ f"{heads} heads.")
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x, context=None, mask=None):
+ q = self.to_q(x)
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
+ .contiguous(),
+ (q, k, v),
+ )
+
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ if exists(mask):
+ raise NotImplementedError
+ out = (
+ out.unsqueeze(0)
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
+ .permute(0, 2, 1, 3)
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
+ )
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+ ATTENTION_MODES = {
+ "softmax": CrossAttention, # vanilla attention
+ "softmax-xformers": MemoryEfficientCrossAttention
+ }
+ def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
+ disable_self_attn=False):
+ super().__init__()
+ attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
+ assert attn_mode in self.ATTENTION_MODES
+ attn_cls = self.ATTENTION_MODES[attn_mode]
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim,
+ heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None):
+ return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
+
+ def _forward(self, x, context=None):
+ x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
+ x = self.attn2(self.norm2(x), context=context) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+ def __init__(self, in_channels, n_heads, d_head,
+ depth=1, dropout=0., context_dim=None,
+ disable_self_attn=False, use_linear=False,
+ use_checkpoint=True):
+ super().__init__()
+ if exists(context_dim) and not isinstance(context_dim, list):
+ context_dim = [context_dim]
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = Normalize(in_channels)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ self.transformer_blocks = nn.ModuleList(
+ [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
+ disable_self_attn=disable_self_attn, checkpoint=use_checkpoint)
+ for d in range(depth)]
+ )
+ if not use_linear:
+ self.proj_out = zero_module(nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ # note: if no context is given, cross-attention defaults to self-attention
+ if not isinstance(context, list):
+ context = [context]
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context[i])
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
diff --git a/repositories/ldm/modules/diffusionmodules/__init__.py b/repositories/ldm/modules/diffusionmodules/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/diffusionmodules/model.py b/repositories/ldm/modules/diffusionmodules/model.py
new file mode 100644
index 000000000..b089eebbe
--- /dev/null
+++ b/repositories/ldm/modules/diffusionmodules/model.py
@@ -0,0 +1,852 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import rearrange
+from typing import Optional, Any
+
+from ldm.modules.attention import MemoryEfficientCrossAttention
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+ print("No module 'xformers'. Proceeding without it.")
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+class MemoryEfficientAttnBlock(nn.Module):
+ """
+ Uses xformers efficient implementation,
+ see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
+ Note: this is a single-head self-attention operation
+ """
+ #
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.attention_op: Optional[Any] = None
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ B, C, H, W = q.shape
+ q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v))
+
+ q, k, v = map(
+ lambda t: t.unsqueeze(3)
+ .reshape(B, t.shape[1], 1, C)
+ .permute(0, 2, 1, 3)
+ .reshape(B * 1, t.shape[1], C)
+ .contiguous(),
+ (q, k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
+
+ out = (
+ out.unsqueeze(0)
+ .reshape(B, 1, out.shape[1], C)
+ .permute(0, 2, 1, 3)
+ .reshape(B, out.shape[1], C)
+ )
+ out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
+ out = self.proj_out(out)
+ return x+out
+
+
+class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
+ def forward(self, x, context=None, mask=None):
+ b, c, h, w = x.shape
+ x = rearrange(x, 'b c h w -> b (h w) c')
+ out = super().forward(x, context=context, mask=mask)
+ out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
+ return x + out
+
+
+def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
+ assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown'
+ if XFORMERS_IS_AVAILBLE and attn_type == "vanilla":
+ attn_type = "vanilla-xformers"
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ assert attn_kwargs is None
+ return AttnBlock(in_channels)
+ elif attn_type == "vanilla-xformers":
+ print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
+ return MemoryEfficientAttnBlock(in_channels)
+ elif type == "memory-efficient-cross-attn":
+ attn_kwargs["query_dim"] = in_channels
+ return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ raise NotImplementedError()
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ attn_type="vanilla", **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+ def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)])
+
+ self.conv_out = nn.Conv2d(mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+ def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
+ z_channels=intermediate_chn, double_z=False, resolution=resolution,
+ attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
+ mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+ def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
+ dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels*ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
+ resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult, resolution=resolution, ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
+ out_channels=tmp_chn, depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+ def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size//in_size))+1
+ factor_up = 1.+ (out_size % in_size)
+ print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
+ self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
+ attn_resolutions=[], in_channels=None, ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor==1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
+ return x
diff --git a/repositories/ldm/modules/diffusionmodules/openaimodel.py b/repositories/ldm/modules/diffusionmodules/openaimodel.py
new file mode 100644
index 000000000..cc3875c63
--- /dev/null
+++ b/repositories/ldm/modules/diffusionmodules/openaimodel.py
@@ -0,0 +1,807 @@
+from abc import abstractmethod
+import math
+
+import numpy as np
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ldm.modules.diffusionmodules.util import (
+ checkpoint,
+ conv_nd,
+ linear,
+ avg_pool_nd,
+ zero_module,
+ normalization,
+ timestep_embedding,
+)
+from ldm.modules.attention import SpatialTransformer
+from ldm.util import exists
+
+
+# dummy replace
+def convert_module_to_f16(x):
+ pass
+
+def convert_module_to_f32(x):
+ pass
+
+
+## go
+class AttentionPool2d(nn.Module):
+ """
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
+ """
+
+ def __init__(
+ self,
+ spacial_dim: int,
+ embed_dim: int,
+ num_heads_channels: int,
+ output_dim: int = None,
+ ):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
+ self.num_heads = embed_dim // num_heads_channels
+ self.attention = QKVAttention(self.num_heads)
+
+ def forward(self, x):
+ b, c, *_spatial = x.shape
+ x = x.reshape(b, c, -1) # NC(HW)
+ x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
+ x = self.qkv_proj(x)
+ x = self.attention(x)
+ x = self.c_proj(x)
+ return x[:, :, 0]
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+class TransposedUpsample(nn.Module):
+ 'Learned 2x upsampling without padding'
+ def __init__(self, channels, out_channels=None, ks=5):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+
+ self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
+
+ def forward(self,x):
+ return self.up(x)
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ return checkpoint(
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
+ )
+
+
+ def _forward(self, x, emb):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ use_new_attention_order=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ if use_new_attention_order:
+ # split qkv before split heads
+ self.attention = QKVAttention(self.num_heads)
+ else:
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x):
+ return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
+ #return pt_checkpoint(self._forward, x) # pytorch
+
+ def _forward(self, x):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += th.DoubleTensor([matmul_ops])
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention and splits in a different order.
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv):
+ """
+ Apply QKV attention.
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.chunk(3, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum(
+ "bct,bcs->bts",
+ (q * scale).view(bs * self.n_heads, ch, length),
+ (k * scale).view(bs * self.n_heads, ch, length),
+ ) # More stable with f16 than dividing afterwards
+ weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
+ a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
+ return a.reshape(bs, -1, length)
+
+ @staticmethod
+ def count_flops(model, _x, y):
+ return count_flops_attn(model, _x, y)
+
+
+class Timestep(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, t):
+ return timestep_embedding(t, self.dim)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param use_new_attention_order: use a different attention pattern for potentially
+ increased efficiency.
+ """
+
+ def __init__(
+ self,
+ image_size,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ num_classes=None,
+ use_checkpoint=False,
+ use_fp16=False,
+ use_bf16=False,
+ num_heads=-1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ resblock_updown=False,
+ use_new_attention_order=False,
+ use_spatial_transformer=False, # custom transformer support
+ transformer_depth=1, # custom transformer support
+ context_dim=None, # custom transformer support
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
+ legacy=True,
+ disable_self_attentions=None,
+ num_attention_blocks=None,
+ disable_middle_self_attn=False,
+ use_linear_in_transformer=False,
+ adm_in_channels=None,
+ ):
+ super().__init__()
+ if use_spatial_transformer:
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
+
+ if context_dim is not None:
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
+ from omegaconf.listconfig import ListConfig
+ if type(context_dim) == ListConfig:
+ context_dim = list(context_dim)
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.image_size = image_size
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ if isinstance(num_res_blocks, int):
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
+ else:
+ if len(num_res_blocks) != len(channel_mult):
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
+ "as a list/tuple (per-level) with the same length as channel_mult")
+ self.num_res_blocks = num_res_blocks
+ if disable_self_attentions is not None:
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
+ assert len(disable_self_attentions) == len(channel_mult)
+ if num_attention_blocks is not None:
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
+ f"attention will still not be set.")
+
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.num_classes = num_classes
+ self.use_checkpoint = use_checkpoint
+ self.dtype = th.float16 if use_fp16 else th.float32
+ self.dtype = th.bfloat16 if use_bf16 else self.dtype
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.predict_codebook_ids = n_embed is not None
+
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.num_classes is not None:
+ if isinstance(self.num_classes, int):
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
+ elif self.num_classes == "continuous":
+ print("setting up linear c_adm embedding layer")
+ self.label_emb = nn.Linear(1, time_embed_dim)
+ elif self.num_classes == "sequential":
+ assert adm_in_channels is not None
+ self.label_emb = nn.Sequential(
+ nn.Sequential(
+ linear(adm_in_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ )
+ else:
+ raise ValueError()
+
+ self.input_blocks = nn.ModuleList(
+ [
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
+ )
+ ]
+ )
+ self._feature_size = model_channels
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for nr in range(self.num_res_blocks[level]):
+ layers = [
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ self.middle_block = TimestepEmbedSequential(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ ),
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(self.num_res_blocks[level] + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=model_channels * mult,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ if legacy:
+ #num_heads = 1
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
+ if exists(disable_self_attentions):
+ disabled_sa = disable_self_attentions[level]
+ else:
+ disabled_sa = False
+
+ if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=dim_head,
+ use_new_attention_order=use_new_attention_order,
+ ) if not use_spatial_transformer else SpatialTransformer(
+ ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
+ use_checkpoint=use_checkpoint
+ )
+ )
+ if level and i == self.num_res_blocks[level]:
+ out_ch = ch
+ layers.append(
+ ResBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+ if self.predict_codebook_ids:
+ self.id_predictor = nn.Sequential(
+ normalization(ch),
+ conv_nd(dims, model_channels, n_embed, 1),
+ #nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
+ )
+
+ def convert_to_fp16(self):
+ """
+ Convert the torso of the model to float16.
+ """
+ self.input_blocks.apply(convert_module_to_f16)
+ self.middle_block.apply(convert_module_to_f16)
+ self.output_blocks.apply(convert_module_to_f16)
+
+ def convert_to_fp32(self):
+ """
+ Convert the torso of the model to float32.
+ """
+ self.input_blocks.apply(convert_module_to_f32)
+ self.middle_block.apply(convert_module_to_f32)
+ self.output_blocks.apply(convert_module_to_f32)
+
+ def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
+ """
+ Apply the model to an input batch.
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param context: conditioning plugged in via crossattn
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.num_classes is not None
+ ), "must specify y if and only if the model is class-conditional"
+ hs = []
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
+ emb = self.time_embed(t_emb)
+
+ if self.num_classes is not None:
+ assert y.shape[0] == x.shape[0]
+ emb = emb + self.label_emb(y)
+
+ h = x.type(self.dtype)
+ for module in self.input_blocks:
+ h = module(h, emb, context)
+ hs.append(h)
+ h = self.middle_block(h, emb, context)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context)
+ h = h.type(x.dtype)
+ if self.predict_codebook_ids:
+ return self.id_predictor(h)
+ else:
+ return self.out(h)
diff --git a/repositories/ldm/modules/diffusionmodules/upscaling.py b/repositories/ldm/modules/diffusionmodules/upscaling.py
new file mode 100644
index 000000000..038166620
--- /dev/null
+++ b/repositories/ldm/modules/diffusionmodules/upscaling.py
@@ -0,0 +1,81 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+
+from ldm.modules.diffusionmodules.util import extract_into_tensor, make_beta_schedule
+from ldm.util import default
+
+
+class AbstractLowScaleModel(nn.Module):
+ # for concatenating a downsampled image to the latent representation
+ def __init__(self, noise_schedule_config=None):
+ super(AbstractLowScaleModel, self).__init__()
+ if noise_schedule_config is not None:
+ self.register_schedule(**noise_schedule_config)
+
+ def register_schedule(self, beta_schedule="linear", timesteps=1000,
+ linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
+ cosine_s=cosine_s)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def forward(self, x):
+ return x, None
+
+ def decode(self, x):
+ return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+ # no noise level conditioning
+ def __init__(self):
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+ self.max_noise_level = 0
+
+ def forward(self, x):
+ # fix to constant noise level
+ return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+ super().__init__(noise_schedule_config=noise_schedule_config)
+ self.max_noise_level = max_noise_level
+
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ z = self.q_sample(x, noise_level)
+ return z, noise_level
+
+
+
diff --git a/repositories/ldm/modules/diffusionmodules/util.py b/repositories/ldm/modules/diffusionmodules/util.py
new file mode 100644
index 000000000..daf35da7b
--- /dev/null
+++ b/repositories/ldm/modules/diffusionmodules/util.py
@@ -0,0 +1,278 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+from einops import repeat
+
+from ldm.util import instantiate_from_config
+
+
+def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
+ )
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "squaredcos_cap_v2": # used for karlo prior
+ # return early
+ return betas_for_alpha_bar(
+ n_timestep,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
+ else:
+ raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
+ # select alphas for computing the variance schedule
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
+ print(f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+ ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(),
+ "dtype": torch.get_autocast_gpu_dtype(),
+ "cache_enabled": torch.is_autocast_cache_enabled()}
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad(), \
+ torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
diff --git a/repositories/ldm/modules/distributions/__init__.py b/repositories/ldm/modules/distributions/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/distributions/distributions.py b/repositories/ldm/modules/distributions/distributions.py
new file mode 100644
index 000000000..f2b8ef901
--- /dev/null
+++ b/repositories/ldm/modules/distributions/distributions.py
@@ -0,0 +1,92 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
+
+ def sample(self):
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1,2,3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + torch.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
+ )
diff --git a/repositories/ldm/modules/ema.py b/repositories/ldm/modules/ema.py
new file mode 100644
index 000000000..bded25019
--- /dev/null
+++ b/repositories/ldm/modules/ema.py
@@ -0,0 +1,80 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
+ else torch.tensor(-1, dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ # remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def reset_num_updates(self):
+ del self.num_updates
+ self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
+ shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/repositories/ldm/modules/encoders/__init__.py b/repositories/ldm/modules/encoders/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/encoders/modules.py b/repositories/ldm/modules/encoders/modules.py
new file mode 100644
index 000000000..523a7d853
--- /dev/null
+++ b/repositories/ldm/modules/encoders/modules.py
@@ -0,0 +1,350 @@
+import torch
+import torch.nn as nn
+import kornia
+from torch.utils.checkpoint import checkpoint
+
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+import open_clip
+from ldm.util import default, count_params, autocast
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0. and not disable_dropout:
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs,), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
+ freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = [
+ "last",
+ "pooled",
+ "hidden"
+ ]
+
+ def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class ClipImageEmbedder(nn.Module):
+ def __init__(
+ self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=True,
+ ucg_rate=0.
+ ):
+ super().__init__()
+ from clip import load as load_clip
+ self.model, _ = load_clip(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+ self.ucg_rate = ucg_rate
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic', align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # re-normalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x, no_dropout=False):
+ # x is assumed to be in range [-1,1]
+ out = self.model.encode_image(self.preprocess(x))
+ out = out.to(x.dtype)
+ if self.ucg_rate > 0. and not no_dropout:
+ out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
+ return out
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate"
+ ]
+
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(text)
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
+ freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
+ pretrained=version, )
+ del model.transformer
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "penultimate":
+ raise NotImplementedError()
+ self.layer_idx = 1
+
+ self.antialias = antialias
+
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+ self.ucg_rate = ucg_rate
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic', align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ if self.ucg_rate > 0. and not no_dropout:
+ z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ img = self.preprocess(img)
+ x = self.model.visual(img)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
+ clip_max_length=77, t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+from ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
+from ldm.modules.diffusionmodules.openaimodel import Timestep
+
+
+class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
+ def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
+ super().__init__(*args, **kwargs)
+ if clip_stats_path is None:
+ clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
+ else:
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
+ self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
+ self.register_buffer("data_std", clip_std[None, :], persistent=False)
+ self.time_embed = Timestep(timestep_dim)
+
+ def scale(self, x):
+ # re-normalize to centered mean and unit variance
+ x = (x - self.data_mean) * 1. / self.data_std
+ return x
+
+ def unscale(self, x):
+ # back to original data stats
+ x = (x * self.data_std) + self.data_mean
+ return x
+
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ x = self.scale(x)
+ z = self.q_sample(x, noise_level)
+ z = self.unscale(z)
+ noise_level = self.time_embed(noise_level)
+ return z, noise_level
diff --git a/repositories/ldm/modules/image_degradation/__init__.py b/repositories/ldm/modules/image_degradation/__init__.py
new file mode 100644
index 000000000..7836cada8
--- /dev/null
+++ b/repositories/ldm/modules/image_degradation/__init__.py
@@ -0,0 +1,2 @@
+from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
+from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
diff --git a/repositories/ldm/modules/image_degradation/bsrgan.py b/repositories/ldm/modules/image_degradation/bsrgan.py
new file mode 100644
index 000000000..32ef56169
--- /dev/null
+++ b/repositories/ldm/modules/image_degradation/bsrgan.py
@@ -0,0 +1,730 @@
+# -*- coding: utf-8 -*-
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=2 * random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', 2 * random.randint(2, 11) + 3, wd * random.random())
+ img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(30, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ elif i == 1:
+ image = add_blur(image, sf=sf)
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.filters.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=2, noise_level2=25)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ example = {"image":image}
+ return example
+
+
+# TODO incase there is a pickle error one needs to replace a += x with a = a + x in add_speckle_noise etc...
+def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None):
+ """
+ This is an extended degradation model by combining
+ the degradation models of BSRGAN and Real-ESRGAN
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ use_shuffle: the degradation shuffle
+ use_sharp: sharpening the img
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ if use_sharp:
+ img = add_sharpening(img)
+ hq = img.copy()
+
+ if random.random() < shuffle_prob:
+ shuffle_order = random.sample(range(13), 13)
+ else:
+ shuffle_order = list(range(13))
+ # local shuffle for noise, JPEG is always the last one
+ shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6)))
+ shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13)))
+
+ poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1
+
+ for i in shuffle_order:
+ if i == 0:
+ img = add_blur(img, sf=sf)
+ elif i == 1:
+ img = add_resize(img, sf=sf)
+ elif i == 2:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 3:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 4:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 5:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ elif i == 6:
+ img = add_JPEG_noise(img)
+ elif i == 7:
+ img = add_blur(img, sf=sf)
+ elif i == 8:
+ img = add_resize(img, sf=sf)
+ elif i == 9:
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25)
+ elif i == 10:
+ if random.random() < poisson_prob:
+ img = add_Poisson_noise(img)
+ elif i == 11:
+ if random.random() < speckle_prob:
+ img = add_speckle_noise(img)
+ elif i == 12:
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+ else:
+ print('check the shuffle!')
+
+ # resize to desired size
+ img = cv2.resize(img, (int(1 / sf * hq.shape[1]), int(1 / sf * hq.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf, lq_patchsize)
+
+ return img, hq
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ print(img)
+ img = util.uint2single(img)
+ print(img)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_lq = deg_fn(img)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
+
+
diff --git a/repositories/ldm/modules/image_degradation/bsrgan_light.py b/repositories/ldm/modules/image_degradation/bsrgan_light.py
new file mode 100644
index 000000000..808c7f882
--- /dev/null
+++ b/repositories/ldm/modules/image_degradation/bsrgan_light.py
@@ -0,0 +1,651 @@
+# -*- coding: utf-8 -*-
+import numpy as np
+import cv2
+import torch
+
+from functools import partial
+import random
+from scipy import ndimage
+import scipy
+import scipy.stats as ss
+from scipy.interpolate import interp2d
+from scipy.linalg import orth
+import albumentations
+
+import ldm.modules.image_degradation.utils_image as util
+
+"""
+# --------------------------------------------
+# Super-Resolution
+# --------------------------------------------
+#
+# Kai Zhang (cskaizhang@gmail.com)
+# https://github.com/cszn
+# From 2019/03--2021/08
+# --------------------------------------------
+"""
+
+def modcrop_np(img, sf):
+ '''
+ Args:
+ img: numpy image, WxH or WxHxC
+ sf: scale factor
+ Return:
+ cropped image
+ '''
+ w, h = img.shape[:2]
+ im = np.copy(img)
+ return im[:w - w % sf, :h - h % sf, ...]
+
+
+"""
+# --------------------------------------------
+# anisotropic Gaussian kernels
+# --------------------------------------------
+"""
+
+
+def analytic_kernel(k):
+ """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)"""
+ k_size = k.shape[0]
+ # Calculate the big kernels size
+ big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2))
+ # Loop over the small kernel to fill the big one
+ for r in range(k_size):
+ for c in range(k_size):
+ big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k
+ # Crop the edges of the big kernel to ignore very small values and increase run time of SR
+ crop = k_size // 2
+ cropped_big_k = big_k[crop:-crop, crop:-crop]
+ # Normalize to 1
+ return cropped_big_k / cropped_big_k.sum()
+
+
+def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6):
+ """ generate an anisotropic Gaussian kernel
+ Args:
+ ksize : e.g., 15, kernel size
+ theta : [0, pi], rotation angle range
+ l1 : [0.1,50], scaling of eigenvalues
+ l2 : [0.1,l1], scaling of eigenvalues
+ If l1 = l2, will get an isotropic Gaussian kernel.
+ Returns:
+ k : kernel
+ """
+
+ v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.]))
+ V = np.array([[v[0], v[1]], [v[1], -v[0]]])
+ D = np.array([[l1, 0], [0, l2]])
+ Sigma = np.dot(np.dot(V, D), np.linalg.inv(V))
+ k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize)
+
+ return k
+
+
+def gm_blur_kernel(mean, cov, size=15):
+ center = size / 2.0 + 0.5
+ k = np.zeros([size, size])
+ for y in range(size):
+ for x in range(size):
+ cy = y - center + 1
+ cx = x - center + 1
+ k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov)
+
+ k = k / np.sum(k)
+ return k
+
+
+def shift_pixel(x, sf, upper_left=True):
+ """shift pixel for super-resolution with different scale factors
+ Args:
+ x: WxHxC or WxH
+ sf: scale factor
+ upper_left: shift direction
+ """
+ h, w = x.shape[:2]
+ shift = (sf - 1) * 0.5
+ xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
+ if upper_left:
+ x1 = xv + shift
+ y1 = yv + shift
+ else:
+ x1 = xv - shift
+ y1 = yv - shift
+
+ x1 = np.clip(x1, 0, w - 1)
+ y1 = np.clip(y1, 0, h - 1)
+
+ if x.ndim == 2:
+ x = interp2d(xv, yv, x)(x1, y1)
+ if x.ndim == 3:
+ for i in range(x.shape[-1]):
+ x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)
+
+ return x
+
+
+def blur(x, k):
+ '''
+ x: image, NxcxHxW
+ k: kernel, Nx1xhxw
+ '''
+ n, c = x.shape[:2]
+ p1, p2 = (k.shape[-2] - 1) // 2, (k.shape[-1] - 1) // 2
+ x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate')
+ k = k.repeat(1, c, 1, 1)
+ k = k.view(-1, 1, k.shape[2], k.shape[3])
+ x = x.view(1, -1, x.shape[2], x.shape[3])
+ x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n * c)
+ x = x.view(n, c, x.shape[2], x.shape[3])
+
+ return x
+
+
+def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0):
+ """"
+ # modified version of https://github.com/assafshocher/BlindSR_dataset_generator
+ # Kai Zhang
+ # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var
+ # max_var = 2.5 * sf
+ """
+ # Set random eigen-vals (lambdas) and angle (theta) for COV matrix
+ lambda_1 = min_var + np.random.rand() * (max_var - min_var)
+ lambda_2 = min_var + np.random.rand() * (max_var - min_var)
+ theta = np.random.rand() * np.pi # random theta
+ noise = -noise_level + np.random.rand(*k_size) * noise_level * 2
+
+ # Set COV matrix using Lambdas and Theta
+ LAMBDA = np.diag([lambda_1, lambda_2])
+ Q = np.array([[np.cos(theta), -np.sin(theta)],
+ [np.sin(theta), np.cos(theta)]])
+ SIGMA = Q @ LAMBDA @ Q.T
+ INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :]
+
+ # Set expectation position (shifting kernel for aligned image)
+ MU = k_size // 2 - 0.5 * (scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2)
+ MU = MU[None, None, :, None]
+
+ # Create meshgrid for Gaussian
+ [X, Y] = np.meshgrid(range(k_size[0]), range(k_size[1]))
+ Z = np.stack([X, Y], 2)[:, :, :, None]
+
+ # Calcualte Gaussian for every pixel of the kernel
+ ZZ = Z - MU
+ ZZ_t = ZZ.transpose(0, 1, 3, 2)
+ raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise)
+
+ # shift the kernel so it will be centered
+ # raw_kernel_centered = kernel_shift(raw_kernel, scale_factor)
+
+ # Normalize the kernel and return
+ # kernel = raw_kernel_centered / np.sum(raw_kernel_centered)
+ kernel = raw_kernel / np.sum(raw_kernel)
+ return kernel
+
+
+def fspecial_gaussian(hsize, sigma):
+ hsize = [hsize, hsize]
+ siz = [(hsize[0] - 1.0) / 2.0, (hsize[1] - 1.0) / 2.0]
+ std = sigma
+ [x, y] = np.meshgrid(np.arange(-siz[1], siz[1] + 1), np.arange(-siz[0], siz[0] + 1))
+ arg = -(x * x + y * y) / (2 * std * std)
+ h = np.exp(arg)
+ h[h < scipy.finfo(float).eps * h.max()] = 0
+ sumh = h.sum()
+ if sumh != 0:
+ h = h / sumh
+ return h
+
+
+def fspecial_laplacian(alpha):
+ alpha = max([0, min([alpha, 1])])
+ h1 = alpha / (alpha + 1)
+ h2 = (1 - alpha) / (alpha + 1)
+ h = [[h1, h2, h1], [h2, -4 / (alpha + 1), h2], [h1, h2, h1]]
+ h = np.array(h)
+ return h
+
+
+def fspecial(filter_type, *args, **kwargs):
+ '''
+ python code from:
+ https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py
+ '''
+ if filter_type == 'gaussian':
+ return fspecial_gaussian(*args, **kwargs)
+ if filter_type == 'laplacian':
+ return fspecial_laplacian(*args, **kwargs)
+
+
+"""
+# --------------------------------------------
+# degradation models
+# --------------------------------------------
+"""
+
+
+def bicubic_degradation(x, sf=3):
+ '''
+ Args:
+ x: HxWxC image, [0, 1]
+ sf: down-scale factor
+ Return:
+ bicubicly downsampled LR image
+ '''
+ x = util.imresize_np(x, scale=1 / sf)
+ return x
+
+
+def srmd_degradation(x, k, sf=3):
+ ''' blur + bicubic downsampling
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2018learning,
+ title={Learning a single convolutional super-resolution network for multiple degradations},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={3262--3271},
+ year={2018}
+ }
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror'
+ x = bicubic_degradation(x, sf=sf)
+ return x
+
+
+def dpsr_degradation(x, k, sf=3):
+ ''' bicubic downsampling + blur
+ Args:
+ x: HxWxC image, [0, 1]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ Reference:
+ @inproceedings{zhang2019deep,
+ title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels},
+ author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei},
+ booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
+ pages={1671--1681},
+ year={2019}
+ }
+ '''
+ x = bicubic_degradation(x, sf=sf)
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ return x
+
+
+def classical_degradation(x, k, sf=3):
+ ''' blur + downsampling
+ Args:
+ x: HxWxC image, [0, 1]/[0, 255]
+ k: hxw, double
+ sf: down-scale factor
+ Return:
+ downsampled LR image
+ '''
+ x = ndimage.convolve(x, np.expand_dims(k, axis=2), mode='wrap')
+ # x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2))
+ st = 0
+ return x[st::sf, st::sf, ...]
+
+
+def add_sharpening(img, weight=0.5, radius=50, threshold=10):
+ """USM sharpening. borrowed from real-ESRGAN
+ Input image: I; Blurry image: B.
+ 1. K = I + weight * (I - B)
+ 2. Mask = 1 if abs(I - B) > threshold, else: 0
+ 3. Blur mask:
+ 4. Out = Mask * K + (1 - Mask) * I
+ Args:
+ img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
+ weight (float): Sharp weight. Default: 1.
+ radius (float): Kernel size of Gaussian blur. Default: 50.
+ threshold (int):
+ """
+ if radius % 2 == 0:
+ radius += 1
+ blur = cv2.GaussianBlur(img, (radius, radius), 0)
+ residual = img - blur
+ mask = np.abs(residual) * 255 > threshold
+ mask = mask.astype('float32')
+ soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
+
+ K = img + weight * residual
+ K = np.clip(K, 0, 1)
+ return soft_mask * K + (1 - soft_mask) * img
+
+
+def add_blur(img, sf=4):
+ wd2 = 4.0 + sf
+ wd = 2.0 + 0.2 * sf
+
+ wd2 = wd2/4
+ wd = wd/4
+
+ if random.random() < 0.5:
+ l1 = wd2 * random.random()
+ l2 = wd2 * random.random()
+ k = anisotropic_Gaussian(ksize=random.randint(2, 11) + 3, theta=random.random() * np.pi, l1=l1, l2=l2)
+ else:
+ k = fspecial('gaussian', random.randint(2, 4) + 3, wd * random.random())
+ img = ndimage.convolve(img, np.expand_dims(k, axis=2), mode='mirror')
+
+ return img
+
+
+def add_resize(img, sf=4):
+ rnum = np.random.rand()
+ if rnum > 0.8: # up
+ sf1 = random.uniform(1, 2)
+ elif rnum < 0.7: # down
+ sf1 = random.uniform(0.5 / sf, 1)
+ else:
+ sf1 = 1.0
+ img = cv2.resize(img, (int(sf1 * img.shape[1]), int(sf1 * img.shape[0])), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ return img
+
+
+# def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+# noise_level = random.randint(noise_level1, noise_level2)
+# rnum = np.random.rand()
+# if rnum > 0.6: # add color Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+# elif rnum < 0.4: # add grayscale Gaussian noise
+# img += np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+# else: # add noise
+# L = noise_level2 / 255.
+# D = np.diag(np.random.rand(3))
+# U = orth(np.random.rand(3, 3))
+# conv = np.dot(np.dot(np.transpose(U), D), U)
+# img += np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+# img = np.clip(img, 0.0, 1.0)
+# return img
+
+def add_Gaussian_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ rnum = np.random.rand()
+ if rnum > 0.6: # add color Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4: # add grayscale Gaussian noise
+ img = img + np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else: # add noise
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img = img + np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_speckle_noise(img, noise_level1=2, noise_level2=25):
+ noise_level = random.randint(noise_level1, noise_level2)
+ img = np.clip(img, 0.0, 1.0)
+ rnum = random.random()
+ if rnum > 0.6:
+ img += img * np.random.normal(0, noise_level / 255.0, img.shape).astype(np.float32)
+ elif rnum < 0.4:
+ img += img * np.random.normal(0, noise_level / 255.0, (*img.shape[:2], 1)).astype(np.float32)
+ else:
+ L = noise_level2 / 255.
+ D = np.diag(np.random.rand(3))
+ U = orth(np.random.rand(3, 3))
+ conv = np.dot(np.dot(np.transpose(U), D), U)
+ img += img * np.random.multivariate_normal([0, 0, 0], np.abs(L ** 2 * conv), img.shape[:2]).astype(np.float32)
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_Poisson_noise(img):
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ vals = 10 ** (2 * random.random() + 2.0) # [2, 4]
+ if random.random() < 0.5:
+ img = np.random.poisson(img * vals).astype(np.float32) / vals
+ else:
+ img_gray = np.dot(img[..., :3], [0.299, 0.587, 0.114])
+ img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255.
+ noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray
+ img += noise_gray[:, :, np.newaxis]
+ img = np.clip(img, 0.0, 1.0)
+ return img
+
+
+def add_JPEG_noise(img):
+ quality_factor = random.randint(80, 95)
+ img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR)
+ result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor])
+ img = cv2.imdecode(encimg, 1)
+ img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB)
+ return img
+
+
+def random_crop(lq, hq, sf=4, lq_patchsize=64):
+ h, w = lq.shape[:2]
+ rnd_h = random.randint(0, h - lq_patchsize)
+ rnd_w = random.randint(0, w - lq_patchsize)
+ lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :]
+
+ rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf)
+ hq = hq[rnd_h_H:rnd_h_H + lq_patchsize * sf, rnd_w_H:rnd_w_H + lq_patchsize * sf, :]
+ return lq, hq
+
+
+def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf)
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = img.shape[:2]
+ img = img.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = img.shape[:2]
+
+ if h < lq_patchsize * sf or w < lq_patchsize * sf:
+ raise ValueError(f'img size ({h1}X{w1}) is too small!')
+
+ hq = img.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ img = cv2.resize(img, (int(1 / 2 * img.shape[1]), int(1 / 2 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ img = util.imresize_np(img, 1 / 2, True)
+ img = np.clip(img, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ img = add_blur(img, sf=sf)
+
+ elif i == 1:
+ img = add_blur(img, sf=sf)
+
+ elif i == 2:
+ a, b = img.shape[1], img.shape[0]
+ # downsample2
+ if random.random() < 0.75:
+ sf1 = random.uniform(1, 2 * sf)
+ img = cv2.resize(img, (int(1 / sf1 * img.shape[1]), int(1 / sf1 * img.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ img = ndimage.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ img = img[0::sf, 0::sf, ...] # nearest downsampling
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ img = cv2.resize(img, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ img = np.clip(img, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ img = add_Gaussian_noise(img, noise_level1=2, noise_level2=8)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ img = add_JPEG_noise(img)
+
+ elif i == 6:
+ # add processed camera sensor noise
+ if random.random() < isp_prob and isp_model is not None:
+ with torch.no_grad():
+ img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ img = add_JPEG_noise(img)
+
+ # random crop
+ img, hq = random_crop(img, hq, sf_ori, lq_patchsize)
+
+ return img, hq
+
+
+# todo no isp_model?
+def degradation_bsrgan_variant(image, sf=4, isp_model=None, up=False):
+ """
+ This is the degradation model of BSRGAN from the paper
+ "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution"
+ ----------
+ sf: scale factor
+ isp_model: camera ISP model
+ Returns
+ -------
+ img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1]
+ hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1]
+ """
+ image = util.uint2single(image)
+ isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25
+ sf_ori = sf
+
+ h1, w1 = image.shape[:2]
+ image = image.copy()[:w1 - w1 % sf, :h1 - h1 % sf, ...] # mod crop
+ h, w = image.shape[:2]
+
+ hq = image.copy()
+
+ if sf == 4 and random.random() < scale2_prob: # downsample1
+ if np.random.rand() < 0.5:
+ image = cv2.resize(image, (int(1 / 2 * image.shape[1]), int(1 / 2 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ image = util.imresize_np(image, 1 / 2, True)
+ image = np.clip(image, 0.0, 1.0)
+ sf = 2
+
+ shuffle_order = random.sample(range(7), 7)
+ idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3)
+ if idx1 > idx2: # keep downsample3 last
+ shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1]
+
+ for i in shuffle_order:
+
+ if i == 0:
+ image = add_blur(image, sf=sf)
+
+ # elif i == 1:
+ # image = add_blur(image, sf=sf)
+
+ if i == 0:
+ pass
+
+ elif i == 2:
+ a, b = image.shape[1], image.shape[0]
+ # downsample2
+ if random.random() < 0.8:
+ sf1 = random.uniform(1, 2 * sf)
+ image = cv2.resize(image, (int(1 / sf1 * image.shape[1]), int(1 / sf1 * image.shape[0])),
+ interpolation=random.choice([1, 2, 3]))
+ else:
+ k = fspecial('gaussian', 25, random.uniform(0.1, 0.6 * sf))
+ k_shifted = shift_pixel(k, sf)
+ k_shifted = k_shifted / k_shifted.sum() # blur with shifted kernel
+ image = ndimage.convolve(image, np.expand_dims(k_shifted, axis=2), mode='mirror')
+ image = image[0::sf, 0::sf, ...] # nearest downsampling
+
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 3:
+ # downsample3
+ image = cv2.resize(image, (int(1 / sf * a), int(1 / sf * b)), interpolation=random.choice([1, 2, 3]))
+ image = np.clip(image, 0.0, 1.0)
+
+ elif i == 4:
+ # add Gaussian noise
+ image = add_Gaussian_noise(image, noise_level1=1, noise_level2=2)
+
+ elif i == 5:
+ # add JPEG noise
+ if random.random() < jpeg_prob:
+ image = add_JPEG_noise(image)
+ #
+ # elif i == 6:
+ # # add processed camera sensor noise
+ # if random.random() < isp_prob and isp_model is not None:
+ # with torch.no_grad():
+ # img, hq = isp_model.forward(img.copy(), hq)
+
+ # add final JPEG compression noise
+ image = add_JPEG_noise(image)
+ image = util.single2uint(image)
+ if up:
+ image = cv2.resize(image, (w1, h1), interpolation=cv2.INTER_CUBIC) # todo: random, as above? want to condition on it then
+ example = {"image": image}
+ return example
+
+
+
+
+if __name__ == '__main__':
+ print("hey")
+ img = util.imread_uint('utils/test.png', 3)
+ img = img[:448, :448]
+ h = img.shape[0] // 4
+ print("resizing to", h)
+ sf = 4
+ deg_fn = partial(degradation_bsrgan_variant, sf=sf)
+ for i in range(20):
+ print(i)
+ img_hq = img
+ img_lq = deg_fn(img)["image"]
+ img_hq, img_lq = util.uint2single(img_hq), util.uint2single(img_lq)
+ print(img_lq)
+ img_lq_bicubic = albumentations.SmallestMaxSize(max_size=h, interpolation=cv2.INTER_CUBIC)(image=img_hq)["image"]
+ print(img_lq.shape)
+ print("bicubic", img_lq_bicubic.shape)
+ print(img_hq.shape)
+ lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ lq_bicubic_nearest = cv2.resize(util.single2uint(img_lq_bicubic),
+ (int(sf * img_lq.shape[1]), int(sf * img_lq.shape[0])),
+ interpolation=0)
+ img_concat = np.concatenate([lq_bicubic_nearest, lq_nearest, util.single2uint(img_hq)], axis=1)
+ util.imsave(img_concat, str(i) + '.png')
diff --git a/repositories/ldm/modules/image_degradation/utils/test.png b/repositories/ldm/modules/image_degradation/utils/test.png
new file mode 100644
index 000000000..4249b43de
Binary files /dev/null and b/repositories/ldm/modules/image_degradation/utils/test.png differ
diff --git a/repositories/ldm/modules/image_degradation/utils_image.py b/repositories/ldm/modules/image_degradation/utils_image.py
new file mode 100644
index 000000000..0175f155a
--- /dev/null
+++ b/repositories/ldm/modules/image_degradation/utils_image.py
@@ -0,0 +1,916 @@
+import os
+import math
+import random
+import numpy as np
+import torch
+import cv2
+from torchvision.utils import make_grid
+from datetime import datetime
+#import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
+
+
+os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
+
+
+'''
+# --------------------------------------------
+# Kai Zhang (github: https://github.com/cszn)
+# 03/Mar/2019
+# --------------------------------------------
+# https://github.com/twhui/SRGAN-pyTorch
+# https://github.com/xinntao/BasicSR
+# --------------------------------------------
+'''
+
+
+IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif']
+
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def get_timestamp():
+ return datetime.now().strftime('%y%m%d-%H%M%S')
+
+
+def imshow(x, title=None, cbar=False, figsize=None):
+ plt.figure(figsize=figsize)
+ plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray')
+ if title:
+ plt.title(title)
+ if cbar:
+ plt.colorbar()
+ plt.show()
+
+
+def surf(Z, cmap='rainbow', figsize=None):
+ plt.figure(figsize=figsize)
+ ax3 = plt.axes(projection='3d')
+
+ w, h = Z.shape[:2]
+ xx = np.arange(0,w,1)
+ yy = np.arange(0,h,1)
+ X, Y = np.meshgrid(xx, yy)
+ ax3.plot_surface(X,Y,Z,cmap=cmap)
+ #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap)
+ plt.show()
+
+
+'''
+# --------------------------------------------
+# get image pathes
+# --------------------------------------------
+'''
+
+
+def get_image_paths(dataroot):
+ paths = None # return None if dataroot is None
+ if dataroot is not None:
+ paths = sorted(_get_paths_from_images(dataroot))
+ return paths
+
+
+def _get_paths_from_images(path):
+ assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
+ images = []
+ for dirpath, _, fnames in sorted(os.walk(path)):
+ for fname in sorted(fnames):
+ if is_image_file(fname):
+ img_path = os.path.join(dirpath, fname)
+ images.append(img_path)
+ assert images, '{:s} has no valid image file'.format(path)
+ return images
+
+
+'''
+# --------------------------------------------
+# split large images into small images
+# --------------------------------------------
+'''
+
+
+def patches_from_image(img, p_size=512, p_overlap=64, p_max=800):
+ w, h = img.shape[:2]
+ patches = []
+ if w > p_max and h > p_max:
+ w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int))
+ h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int))
+ w1.append(w-p_size)
+ h1.append(h-p_size)
+# print(w1)
+# print(h1)
+ for i in w1:
+ for j in h1:
+ patches.append(img[i:i+p_size, j:j+p_size,:])
+ else:
+ patches.append(img)
+
+ return patches
+
+
+def imssave(imgs, img_path):
+ """
+ imgs: list, N images of size WxHxC
+ """
+ img_name, ext = os.path.splitext(os.path.basename(img_path))
+
+ for i, img in enumerate(imgs):
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png')
+ cv2.imwrite(new_path, img)
+
+
+def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000):
+ """
+ split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size),
+ and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max)
+ will be splitted.
+ Args:
+ original_dataroot:
+ taget_dataroot:
+ p_size: size of small images
+ p_overlap: patch size in training is a good choice
+ p_max: images with smaller size than (p_max)x(p_max) keep unchanged.
+ """
+ paths = get_image_paths(original_dataroot)
+ for img_path in paths:
+ # img_name, ext = os.path.splitext(os.path.basename(img_path))
+ img = imread_uint(img_path, n_channels=n_channels)
+ patches = patches_from_image(img, p_size, p_overlap, p_max)
+ imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path)))
+ #if original_dataroot == taget_dataroot:
+ #del img_path
+
+'''
+# --------------------------------------------
+# makedir
+# --------------------------------------------
+'''
+
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+
+def mkdirs(paths):
+ if isinstance(paths, str):
+ mkdir(paths)
+ else:
+ for path in paths:
+ mkdir(path)
+
+
+def mkdir_and_rename(path):
+ if os.path.exists(path):
+ new_name = path + '_archived_' + get_timestamp()
+ print('Path already exists. Rename it to [{:s}]'.format(new_name))
+ os.rename(path, new_name)
+ os.makedirs(path)
+
+
+'''
+# --------------------------------------------
+# read image from path
+# opencv is fast, but read BGR numpy image
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# get uint8 image of size HxWxn_channles (RGB)
+# --------------------------------------------
+def imread_uint(path, n_channels=3):
+ # input: path
+ # output: HxWx3(RGB or GGG), or HxWx1 (G)
+ if n_channels == 1:
+ img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE
+ img = np.expand_dims(img, axis=2) # HxWx1
+ elif n_channels == 3:
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG
+ else:
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB
+ return img
+
+
+# --------------------------------------------
+# matlab's imwrite
+# --------------------------------------------
+def imsave(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+def imwrite(img, img_path):
+ img = np.squeeze(img)
+ if img.ndim == 3:
+ img = img[:, :, [2, 1, 0]]
+ cv2.imwrite(img_path, img)
+
+
+
+# --------------------------------------------
+# get single image of size HxWxn_channles (BGR)
+# --------------------------------------------
+def read_img(path):
+ # read image by cv2
+ # return: Numpy float32, HWC, BGR, [0,1]
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE
+ img = img.astype(np.float32) / 255.
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ # some images have 4 channels
+ if img.shape[2] > 3:
+ img = img[:, :, :3]
+ return img
+
+
+'''
+# --------------------------------------------
+# image format conversion
+# --------------------------------------------
+# numpy(single) <---> numpy(unit)
+# numpy(single) <---> tensor
+# numpy(unit) <---> tensor
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# numpy(single) [0, 1] <---> numpy(unit)
+# --------------------------------------------
+
+
+def uint2single(img):
+
+ return np.float32(img/255.)
+
+
+def single2uint(img):
+
+ return np.uint8((img.clip(0, 1)*255.).round())
+
+
+def uint162single(img):
+
+ return np.float32(img/65535.)
+
+
+def single2uint16(img):
+
+ return np.uint16((img.clip(0, 1)*65535.).round())
+
+
+# --------------------------------------------
+# numpy(unit) (HxWxC or HxW) <---> tensor
+# --------------------------------------------
+
+
+# convert uint to 4-dimensional torch tensor
+def uint2tensor4(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0)
+
+
+# convert uint to 3-dimensional torch tensor
+def uint2tensor3(img):
+ if img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.)
+
+
+# convert 2/3/4-dimensional torch tensor to uint
+def tensor2uint(img):
+ img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ return np.uint8((img*255.0).round())
+
+
+# --------------------------------------------
+# numpy(single) (HxWxC) <---> tensor
+# --------------------------------------------
+
+
+# convert single (HxWxC) to 3-dimensional torch tensor
+def single2tensor3(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float()
+
+
+# convert single (HxWxC) to 4-dimensional torch tensor
+def single2tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0)
+
+
+# convert torch tensor to single
+def tensor2single(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+
+ return img
+
+# convert torch tensor to single
+def tensor2single3(img):
+ img = img.data.squeeze().float().cpu().numpy()
+ if img.ndim == 3:
+ img = np.transpose(img, (1, 2, 0))
+ elif img.ndim == 2:
+ img = np.expand_dims(img, axis=2)
+ return img
+
+
+def single2tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0)
+
+
+def single32tensor5(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0)
+
+
+def single42tensor4(img):
+ return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float()
+
+
+# from skimage.io import imread, imsave
+def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
+ '''
+ Converts a torch Tensor into an image Numpy array of BGR channel order
+ Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
+ Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
+ '''
+ tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp
+ tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
+ n_dim = tensor.dim()
+ if n_dim == 4:
+ n_img = len(tensor)
+ img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 3:
+ img_np = tensor.numpy()
+ img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
+ elif n_dim == 2:
+ img_np = tensor.numpy()
+ else:
+ raise TypeError(
+ 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
+ if out_type == np.uint8:
+ img_np = (img_np * 255.0).round()
+ # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
+ return img_np.astype(out_type)
+
+
+'''
+# --------------------------------------------
+# Augmentation, flipe and/or rotate
+# --------------------------------------------
+# The following two are enough.
+# (1) augmet_img: numpy image of WxHxC or WxH
+# (2) augment_img_tensor4: tensor image 1xCxWxH
+# --------------------------------------------
+'''
+
+
+def augment_img(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return np.flipud(np.rot90(img))
+ elif mode == 2:
+ return np.flipud(img)
+ elif mode == 3:
+ return np.rot90(img, k=3)
+ elif mode == 4:
+ return np.flipud(np.rot90(img, k=2))
+ elif mode == 5:
+ return np.rot90(img)
+ elif mode == 6:
+ return np.rot90(img, k=2)
+ elif mode == 7:
+ return np.flipud(np.rot90(img, k=3))
+
+
+def augment_img_tensor4(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.rot90(1, [2, 3]).flip([2])
+ elif mode == 2:
+ return img.flip([2])
+ elif mode == 3:
+ return img.rot90(3, [2, 3])
+ elif mode == 4:
+ return img.rot90(2, [2, 3]).flip([2])
+ elif mode == 5:
+ return img.rot90(1, [2, 3])
+ elif mode == 6:
+ return img.rot90(2, [2, 3])
+ elif mode == 7:
+ return img.rot90(3, [2, 3]).flip([2])
+
+
+def augment_img_tensor(img, mode=0):
+ '''Kai Zhang (github: https://github.com/cszn)
+ '''
+ img_size = img.size()
+ img_np = img.data.cpu().numpy()
+ if len(img_size) == 3:
+ img_np = np.transpose(img_np, (1, 2, 0))
+ elif len(img_size) == 4:
+ img_np = np.transpose(img_np, (2, 3, 1, 0))
+ img_np = augment_img(img_np, mode=mode)
+ img_tensor = torch.from_numpy(np.ascontiguousarray(img_np))
+ if len(img_size) == 3:
+ img_tensor = img_tensor.permute(2, 0, 1)
+ elif len(img_size) == 4:
+ img_tensor = img_tensor.permute(3, 2, 0, 1)
+
+ return img_tensor.type_as(img)
+
+
+def augment_img_np3(img, mode=0):
+ if mode == 0:
+ return img
+ elif mode == 1:
+ return img.transpose(1, 0, 2)
+ elif mode == 2:
+ return img[::-1, :, :]
+ elif mode == 3:
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 4:
+ return img[:, ::-1, :]
+ elif mode == 5:
+ img = img[:, ::-1, :]
+ img = img.transpose(1, 0, 2)
+ return img
+ elif mode == 6:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ return img
+ elif mode == 7:
+ img = img[:, ::-1, :]
+ img = img[::-1, :, :]
+ img = img.transpose(1, 0, 2)
+ return img
+
+
+def augment_imgs(img_list, hflip=True, rot=True):
+ # horizontal flip OR rotate
+ hflip = hflip and random.random() < 0.5
+ vflip = rot and random.random() < 0.5
+ rot90 = rot and random.random() < 0.5
+
+ def _augment(img):
+ if hflip:
+ img = img[:, ::-1, :]
+ if vflip:
+ img = img[::-1, :, :]
+ if rot90:
+ img = img.transpose(1, 0, 2)
+ return img
+
+ return [_augment(img) for img in img_list]
+
+
+'''
+# --------------------------------------------
+# modcrop and shave
+# --------------------------------------------
+'''
+
+
+def modcrop(img_in, scale):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ if img.ndim == 2:
+ H, W = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r]
+ elif img.ndim == 3:
+ H, W, C = img.shape
+ H_r, W_r = H % scale, W % scale
+ img = img[:H - H_r, :W - W_r, :]
+ else:
+ raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
+ return img
+
+
+def shave(img_in, border=0):
+ # img_in: Numpy, HWC or HW
+ img = np.copy(img_in)
+ h, w = img.shape[:2]
+ img = img[border:h-border, border:w-border]
+ return img
+
+
+'''
+# --------------------------------------------
+# image processing process on numpy image
+# channel_convert(in_c, tar_type, img_list):
+# rgb2ycbcr(img, only_y=True):
+# bgr2ycbcr(img, only_y=True):
+# ycbcr2rgb(img):
+# --------------------------------------------
+'''
+
+
+def rgb2ycbcr(img, only_y=True):
+ '''same as matlab rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
+ [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def ycbcr2rgb(img):
+ '''same as matlab ycbcr2rgb
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def bgr2ycbcr(img, only_y=True):
+ '''bgr version of rgb2ycbcr
+ only_y: only return Y channel
+ Input:
+ uint8, [0, 255]
+ float, [0, 1]
+ '''
+ in_img_type = img.dtype
+ img.astype(np.float32)
+ if in_img_type != np.uint8:
+ img *= 255.
+ # convert
+ if only_y:
+ rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
+ else:
+ rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
+ [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
+ if in_img_type == np.uint8:
+ rlt = rlt.round()
+ else:
+ rlt /= 255.
+ return rlt.astype(in_img_type)
+
+
+def channel_convert(in_c, tar_type, img_list):
+ # conversion among BGR, gray and y
+ if in_c == 3 and tar_type == 'gray': # BGR to gray
+ gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in gray_list]
+ elif in_c == 3 and tar_type == 'y': # BGR to y
+ y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
+ return [np.expand_dims(img, axis=2) for img in y_list]
+ elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
+ return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
+ else:
+ return img_list
+
+
+'''
+# --------------------------------------------
+# metric, PSNR and SSIM
+# --------------------------------------------
+'''
+
+
+# --------------------------------------------
+# PSNR
+# --------------------------------------------
+def calculate_psnr(img1, img2, border=0):
+ # img1 and img2 have range [0, 255]
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20 * math.log10(255.0 / math.sqrt(mse))
+
+
+# --------------------------------------------
+# SSIM
+# --------------------------------------------
+def calculate_ssim(img1, img2, border=0):
+ '''calculate SSIM
+ the same outputs as MATLAB's
+ img1, img2: [0, 255]
+ '''
+ #img1 = img1.squeeze()
+ #img2 = img2.squeeze()
+ if not img1.shape == img2.shape:
+ raise ValueError('Input images must have the same dimensions.')
+ h, w = img1.shape[:2]
+ img1 = img1[border:h-border, border:w-border]
+ img2 = img2[border:h-border, border:w-border]
+
+ if img1.ndim == 2:
+ return ssim(img1, img2)
+ elif img1.ndim == 3:
+ if img1.shape[2] == 3:
+ ssims = []
+ for i in range(3):
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
+ return np.array(ssims).mean()
+ elif img1.shape[2] == 1:
+ return ssim(np.squeeze(img1), np.squeeze(img2))
+ else:
+ raise ValueError('Wrong input image dimensions.')
+
+
+def ssim(img1, img2):
+ C1 = (0.01 * 255)**2
+ C2 = (0.03 * 255)**2
+
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+ kernel = cv2.getGaussianKernel(11, 1.5)
+ window = np.outer(kernel, kernel.transpose())
+
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
+ mu1_sq = mu1**2
+ mu2_sq = mu2**2
+ mu1_mu2 = mu1 * mu2
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
+ (sigma1_sq + sigma2_sq + C2))
+ return ssim_map.mean()
+
+
+'''
+# --------------------------------------------
+# matlab's bicubic imresize (numpy and torch) [0, 1]
+# --------------------------------------------
+'''
+
+
+# matlab 'imresize' function, now only support 'bicubic'
+def cubic(x):
+ absx = torch.abs(x)
+ absx2 = absx**2
+ absx3 = absx**3
+ return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \
+ (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx))
+
+
+def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
+ if (scale < 1) and (antialiasing):
+ # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
+ kernel_width = kernel_width / scale
+
+ # Output-space coordinates
+ x = torch.linspace(1, out_length, out_length)
+
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
+ # in output space maps to 0.5 in input space, and 0.5+scale in output
+ # space maps to 1.5 in input space.
+ u = x / scale + 0.5 * (1 - 1 / scale)
+
+ # What is the left-most pixel that can be involved in the computation?
+ left = torch.floor(u - kernel_width / 2)
+
+ # What is the maximum number of pixels that can be involved in the
+ # computation? Note: it's OK to use an extra pixel here; if the
+ # corresponding weights are all zero, it will be eliminated at the end
+ # of this function.
+ P = math.ceil(kernel_width) + 2
+
+ # The indices of the input pixels involved in computing the k-th output
+ # pixel are in row k of the indices matrix.
+ indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
+ 1, P).expand(out_length, P)
+
+ # The weights used to compute the k-th output pixel are in row k of the
+ # weights matrix.
+ distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
+ # apply cubic kernel
+ if (scale < 1) and (antialiasing):
+ weights = scale * cubic(distance_to_center * scale)
+ else:
+ weights = cubic(distance_to_center)
+ # Normalize the weights matrix so that each row sums to 1.
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
+ weights = weights / weights_sum.expand(out_length, P)
+
+ # If a column in weights is all zero, get rid of it. only consider the first and last column.
+ weights_zero_tmp = torch.sum((weights == 0), 0)
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 1, P - 2)
+ weights = weights.narrow(1, 1, P - 2)
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
+ indices = indices.narrow(1, 0, P - 2)
+ weights = weights.narrow(1, 0, P - 2)
+ weights = weights.contiguous()
+ indices = indices.contiguous()
+ sym_len_s = -indices.min() + 1
+ sym_len_e = indices.max() - in_length
+ indices = indices + sym_len_s - 1
+ return weights, indices, int(sym_len_s), int(sym_len_e)
+
+
+# --------------------------------------------
+# imresize for tensor image [0, 1]
+# --------------------------------------------
+def imresize(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: pytorch tensor, CHW or HW [0,1]
+ # output: CHW or HW [0,1] w/o round
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(0)
+ in_C, in_H, in_W = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
+ img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:, :sym_len_Hs, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[:, -sym_len_He:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(in_C, out_H, in_W)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
+ out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :, :sym_len_Ws]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, :, -sym_len_We:]
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
+ out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(in_C, out_H, out_W)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+ return out_2
+
+
+# --------------------------------------------
+# imresize for numpy image [0, 1]
+# --------------------------------------------
+def imresize_np(img, scale, antialiasing=True):
+ # Now the scale should be the same for H and W
+ # input: img: Numpy, HWC or HW [0,1]
+ # output: HWC or HW [0,1] w/o round
+ img = torch.from_numpy(img)
+ need_squeeze = True if img.dim() == 2 else False
+ if need_squeeze:
+ img.unsqueeze_(2)
+
+ in_H, in_W, in_C = img.size()
+ out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
+ kernel_width = 4
+ kernel = 'cubic'
+
+ # Return the desired dimension order for performing the resize. The
+ # strategy is to perform the resize first along the dimension with the
+ # smallest scale factor.
+ # Now we do not support this.
+
+ # get weights and indices
+ weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
+ in_H, out_H, scale, kernel, kernel_width, antialiasing)
+ weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
+ in_W, out_W, scale, kernel, kernel_width, antialiasing)
+ # process H dimension
+ # symmetric copying
+ img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C)
+ img_aug.narrow(0, sym_len_Hs, in_H).copy_(img)
+
+ sym_patch = img[:sym_len_Hs, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv)
+
+ sym_patch = img[-sym_len_He:, :, :]
+ inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(0, inv_idx)
+ img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)
+
+ out_1 = torch.FloatTensor(out_H, in_W, in_C)
+ kernel_width = weights_H.size(1)
+ for i in range(out_H):
+ idx = int(indices_H[i][0])
+ for j in range(out_C):
+ out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i])
+
+ # process W dimension
+ # symmetric copying
+ out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C)
+ out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1)
+
+ sym_patch = out_1[:, :sym_len_Ws, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv)
+
+ sym_patch = out_1[:, -sym_len_We:, :]
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
+ out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)
+
+ out_2 = torch.FloatTensor(out_H, out_W, in_C)
+ kernel_width = weights_W.size(1)
+ for i in range(out_W):
+ idx = int(indices_W[i][0])
+ for j in range(out_C):
+ out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i])
+ if need_squeeze:
+ out_2.squeeze_()
+
+ return out_2.numpy()
+
+
+if __name__ == '__main__':
+ print('---')
+# img = imread_uint('test.bmp', 3)
+# img = uint2single(img)
+# img_bicubic = imresize_np(img, 1/4)
\ No newline at end of file
diff --git a/repositories/ldm/modules/karlo/__init__.py b/repositories/ldm/modules/karlo/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/karlo/diffusers_pipeline.py b/repositories/ldm/modules/karlo/diffusers_pipeline.py
new file mode 100644
index 000000000..07f72b35a
--- /dev/null
+++ b/repositories/ldm/modules/karlo/diffusers_pipeline.py
@@ -0,0 +1,512 @@
+# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch.nn import functional as F
+
+from transformers import CLIPTextModelWithProjection, CLIPTokenizer
+from transformers.models.clip.modeling_clip import CLIPTextModelOutput
+
+from ...models import PriorTransformer, UNet2DConditionModel, UNet2DModel
+from ...pipelines import DiffusionPipeline, ImagePipelineOutput
+from ...schedulers import UnCLIPScheduler
+from ...utils import is_accelerate_available, logging, randn_tensor
+from .text_proj import UnCLIPTextProjModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class UnCLIPPipeline(DiffusionPipeline):
+ """
+ Pipeline for text-to-image generation using unCLIP
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+ Args:
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ prior ([`PriorTransformer`]):
+ The canonincal unCLIP prior to approximate the image embedding from the text embedding.
+ text_proj ([`UnCLIPTextProjModel`]):
+ Utility class to prepare and combine the embeddings before they are passed to the decoder.
+ decoder ([`UNet2DConditionModel`]):
+ The decoder to invert the image embedding into an image.
+ super_res_first ([`UNet2DModel`]):
+ Super resolution unet. Used in all but the last step of the super resolution diffusion process.
+ super_res_last ([`UNet2DModel`]):
+ Super resolution unet. Used in the last step of the super resolution diffusion process.
+ prior_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the prior denoising process. Just a modified DDPMScheduler.
+ decoder_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the decoder denoising process. Just a modified DDPMScheduler.
+ super_res_scheduler ([`UnCLIPScheduler`]):
+ Scheduler used in the super resolution denoising process. Just a modified DDPMScheduler.
+ """
+
+ prior: PriorTransformer
+ decoder: UNet2DConditionModel
+ text_proj: UnCLIPTextProjModel
+ text_encoder: CLIPTextModelWithProjection
+ tokenizer: CLIPTokenizer
+ super_res_first: UNet2DModel
+ super_res_last: UNet2DModel
+
+ prior_scheduler: UnCLIPScheduler
+ decoder_scheduler: UnCLIPScheduler
+ super_res_scheduler: UnCLIPScheduler
+
+ def __init__(
+ self,
+ prior: PriorTransformer,
+ decoder: UNet2DConditionModel,
+ text_encoder: CLIPTextModelWithProjection,
+ tokenizer: CLIPTokenizer,
+ text_proj: UnCLIPTextProjModel,
+ super_res_first: UNet2DModel,
+ super_res_last: UNet2DModel,
+ prior_scheduler: UnCLIPScheduler,
+ decoder_scheduler: UnCLIPScheduler,
+ super_res_scheduler: UnCLIPScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ prior=prior,
+ decoder=decoder,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ text_proj=text_proj,
+ super_res_first=super_res_first,
+ super_res_last=super_res_last,
+ prior_scheduler=prior_scheduler,
+ decoder_scheduler=decoder_scheduler,
+ super_res_scheduler=super_res_scheduler,
+ )
+
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if text_model_output is None:
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ text_mask = text_inputs.attention_mask.bool().to(device)
+
+ if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
+ removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device))
+
+ text_embeddings = text_encoder_output.text_embeds
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+
+ else:
+ batch_size = text_model_output[0].shape[0]
+ text_embeddings, text_encoder_hidden_states = text_model_output[0], text_model_output[1]
+ text_mask = text_attention_mask
+
+ text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ if do_classifier_free_guidance:
+ uncond_tokens = [""] * batch_size
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_text_mask = uncond_input.attention_mask.bool().to(device)
+ uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(device))
+
+ uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
+ uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+
+ seq_len = uncond_embeddings.shape[1]
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
+
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ uncond_text_mask = uncond_text_mask.repeat_interleave(num_images_per_prompt, dim=0)
+
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
+ text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
+
+ text_mask = torch.cat([uncond_text_mask, text_mask])
+
+ return text_embeddings, text_encoder_hidden_states, text_mask
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, the pipeline's
+ models have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only
+ when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ # TODO: self.prior.post_process_latents is not covered by the offload hooks, so it fails if added to the list
+ models = [
+ self.decoder,
+ self.text_proj,
+ self.text_encoder,
+ self.super_res_first,
+ self.super_res_last,
+ ]
+ for cpu_offloaded_model in models:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if self.device != torch.device("meta") or not hasattr(self.decoder, "_hf_hook"):
+ return self.device
+ for module in self.decoder.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ prior_num_inference_steps: int = 25,
+ decoder_num_inference_steps: int = 25,
+ super_res_num_inference_steps: int = 7,
+ generator: Optional[torch.Generator] = None,
+ prior_latents: Optional[torch.FloatTensor] = None,
+ decoder_latents: Optional[torch.FloatTensor] = None,
+ super_res_latents: Optional[torch.FloatTensor] = None,
+ text_model_output: Optional[Union[CLIPTextModelOutput, Tuple]] = None,
+ text_attention_mask: Optional[torch.Tensor] = None,
+ prior_guidance_scale: float = 4.0,
+ decoder_guidance_scale: float = 8.0,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation. This can only be left undefined if
+ `text_model_output` and `text_attention_mask` is passed.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ prior_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the prior. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ decoder_num_inference_steps (`int`, *optional*, defaults to 25):
+ The number of denoising steps for the decoder. More denoising steps usually lead to a higher quality
+ image at the expense of slower inference.
+ super_res_num_inference_steps (`int`, *optional*, defaults to 7):
+ The number of denoising steps for super resolution. More denoising steps usually lead to a higher
+ quality image at the expense of slower inference.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ prior_latents (`torch.FloatTensor` of shape (batch size, embeddings dimension), *optional*):
+ Pre-generated noisy latents to be used as inputs for the prior.
+ decoder_latents (`torch.FloatTensor` of shape (batch size, channels, height, width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ super_res_latents (`torch.FloatTensor` of shape (batch size, channels, super res height, super res width), *optional*):
+ Pre-generated noisy latents to be used as inputs for the decoder.
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ decoder_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ text_model_output (`CLIPTextModelOutput`, *optional*):
+ Pre-defined CLIPTextModel outputs that can be derived from the text encoder. Pre-defined text outputs
+ can be passed for tasks like text embedding interpolations. Make sure to also pass
+ `text_attention_mask` in this case. `prompt` can the be left to `None`.
+ text_attention_mask (`torch.Tensor`, *optional*):
+ Pre-defined CLIP text attention mask that can be derived from the tokenizer. Pre-defined text attention
+ masks are necessary when passing `text_model_output`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+ """
+ if prompt is not None:
+ if isinstance(prompt, str):
+ batch_size = 1
+ elif isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+ else:
+ batch_size = text_model_output[0].shape[0]
+
+ device = self._execution_device
+
+ batch_size = batch_size * num_images_per_prompt
+
+ do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
+
+ text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, text_model_output, text_attention_mask
+ )
+
+ # prior
+
+ self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=device)
+ prior_timesteps_tensor = self.prior_scheduler.timesteps
+
+ embedding_dim = self.prior.config.embedding_dim
+
+ prior_latents = self.prepare_latents(
+ (batch_size, embedding_dim),
+ text_embeddings.dtype,
+ device,
+ generator,
+ prior_latents,
+ self.prior_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
+
+ predicted_image_embedding = self.prior(
+ latent_model_input,
+ timestep=t,
+ proj_embedding=text_embeddings,
+ encoder_hidden_states=text_encoder_hidden_states,
+ attention_mask=text_mask,
+ ).predicted_image_embedding
+
+ if do_classifier_free_guidance:
+ predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
+ predicted_image_embedding_text - predicted_image_embedding_uncond
+ )
+
+ if i + 1 == prior_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = prior_timesteps_tensor[i + 1]
+
+ prior_latents = self.prior_scheduler.step(
+ predicted_image_embedding,
+ timestep=t,
+ sample=prior_latents,
+ generator=generator,
+ prev_timestep=prev_timestep,
+ ).prev_sample
+
+ prior_latents = self.prior.post_process_latents(prior_latents)
+
+ image_embeddings = prior_latents
+
+ # done prior
+
+ # decoder
+
+ text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
+ image_embeddings=image_embeddings,
+ text_embeddings=text_embeddings,
+ text_encoder_hidden_states=text_encoder_hidden_states,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ )
+
+ decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
+
+ self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=device)
+ decoder_timesteps_tensor = self.decoder_scheduler.timesteps
+
+ num_channels_latents = self.decoder.in_channels
+ height = self.decoder.sample_size
+ width = self.decoder.sample_size
+
+ decoder_latents = self.prepare_latents(
+ (batch_size, num_channels_latents, height, width),
+ text_encoder_hidden_states.dtype,
+ device,
+ generator,
+ decoder_latents,
+ self.decoder_scheduler,
+ )
+
+ for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
+
+ noise_pred = self.decoder(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=text_encoder_hidden_states,
+ class_labels=additive_clip_time_embeddings,
+ attention_mask=decoder_text_mask,
+ ).sample
+
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
+ noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
+ noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
+ noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
+
+ if i + 1 == decoder_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = decoder_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ decoder_latents = self.decoder_scheduler.step(
+ noise_pred, t, decoder_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ decoder_latents = decoder_latents.clamp(-1, 1)
+
+ image_small = decoder_latents
+
+ # done decoder
+
+ # super res
+
+ self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=device)
+ super_res_timesteps_tensor = self.super_res_scheduler.timesteps
+
+ channels = self.super_res_first.in_channels // 2
+ height = self.super_res_first.sample_size
+ width = self.super_res_first.sample_size
+
+ super_res_latents = self.prepare_latents(
+ (batch_size, channels, height, width),
+ image_small.dtype,
+ device,
+ generator,
+ super_res_latents,
+ self.super_res_scheduler,
+ )
+
+ interpolate_antialias = {}
+ if "antialias" in inspect.signature(F.interpolate).parameters:
+ interpolate_antialias["antialias"] = True
+
+ image_upscaled = F.interpolate(
+ image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
+ )
+
+ for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
+ # no classifier free guidance
+
+ if i == super_res_timesteps_tensor.shape[0] - 1:
+ unet = self.super_res_last
+ else:
+ unet = self.super_res_first
+
+ latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
+
+ noise_pred = unet(
+ sample=latent_model_input,
+ timestep=t,
+ ).sample
+
+ if i + 1 == super_res_timesteps_tensor.shape[0]:
+ prev_timestep = None
+ else:
+ prev_timestep = super_res_timesteps_tensor[i + 1]
+
+ # compute the previous noisy sample x_t -> x_t-1
+ super_res_latents = self.super_res_scheduler.step(
+ noise_pred, t, super_res_latents, prev_timestep=prev_timestep, generator=generator
+ ).prev_sample
+
+ image = super_res_latents
+ # done super res
+
+ # post processing
+
+ image = image * 0.5 + 0.5
+ image = image.clamp(0, 1)
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+
+ if output_type == "pil":
+ image = self.numpy_to_pil(image)
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
\ No newline at end of file
diff --git a/repositories/ldm/modules/karlo/kakao/__init__.py b/repositories/ldm/modules/karlo/kakao/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/karlo/kakao/models/__init__.py b/repositories/ldm/modules/karlo/kakao/models/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/karlo/kakao/models/clip.py b/repositories/ldm/modules/karlo/kakao/models/clip.py
new file mode 100644
index 000000000..961d81502
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/models/clip.py
@@ -0,0 +1,182 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+# ------------------------------------------------------------------------------------
+# Adapted from OpenAI's CLIP (https://github.com/openai/CLIP/)
+# ------------------------------------------------------------------------------------
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import clip
+
+from clip.model import CLIP, convert_weights
+from clip.simple_tokenizer import SimpleTokenizer, default_bpe
+
+
+"""===== Monkey-Patching original CLIP for JIT compile ====="""
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = F.layer_norm(
+ x.type(torch.float32),
+ self.normalized_shape,
+ self.weight,
+ self.bias,
+ self.eps,
+ )
+ return ret.type(orig_type)
+
+
+clip.model.LayerNorm = LayerNorm
+delattr(clip.model.CLIP, "forward")
+
+"""===== End of Monkey-Patching ====="""
+
+
+class CustomizedCLIP(CLIP):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ @torch.jit.export
+ def encode_image(self, image):
+ return self.visual(image)
+
+ @torch.jit.export
+ def encode_text(self, text):
+ # re-define this function to return unpooled text features
+
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ x_seq = x
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x_out = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
+
+ return x_out, x_seq
+
+ @torch.jit.ignore
+ def forward(self, image, text):
+ super().forward(image, text)
+
+ @classmethod
+ def load_from_checkpoint(cls, ckpt_path: str):
+ state_dict = torch.load(ckpt_path, map_location="cpu").state_dict()
+
+ vit = "visual.proj" in state_dict
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len(
+ [
+ k
+ for k in state_dict.keys()
+ if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
+ ]
+ )
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round(
+ (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
+ )
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [
+ len(
+ set(
+ k.split(".")[2]
+ for k in state_dict
+ if k.startswith(f"visual.layer{b}")
+ )
+ )
+ for b in [1, 2, 3, 4]
+ ]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round(
+ (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
+ )
+ vision_patch_size = None
+ assert (
+ output_width**2 + 1
+ == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ )
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(
+ set(
+ k.split(".")[2]
+ for k in state_dict
+ if k.startswith("transformer.resblocks")
+ )
+ )
+
+ model = cls(
+ embed_dim,
+ image_resolution,
+ vision_layers,
+ vision_width,
+ vision_patch_size,
+ context_length,
+ vocab_size,
+ transformer_width,
+ transformer_heads,
+ transformer_layers,
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ convert_weights(model)
+ model.load_state_dict(state_dict)
+ model.eval()
+ model.float()
+ return model
+
+
+class CustomizedTokenizer(SimpleTokenizer):
+ def __init__(self):
+ super().__init__(bpe_path=default_bpe())
+
+ self.sot_token = self.encoder["<|startoftext|>"]
+ self.eot_token = self.encoder["<|endoftext|>"]
+
+ def padded_tokens_and_mask(self, texts, text_ctx):
+ assert isinstance(texts, list) and all(
+ isinstance(elem, str) for elem in texts
+ ), "texts should be a list of strings"
+
+ all_tokens = [
+ [self.sot_token] + self.encode(text) + [self.eot_token] for text in texts
+ ]
+
+ mask = [
+ [True] * min(text_ctx, len(tokens))
+ + [False] * max(text_ctx - len(tokens), 0)
+ for tokens in all_tokens
+ ]
+ mask = torch.tensor(mask, dtype=torch.bool)
+ result = torch.zeros(len(all_tokens), text_ctx, dtype=torch.int)
+ for i, tokens in enumerate(all_tokens):
+ if len(tokens) > text_ctx:
+ tokens = tokens[:text_ctx]
+ tokens[-1] = self.eot_token
+ result[i, : len(tokens)] = torch.tensor(tokens)
+
+ return result, mask
diff --git a/repositories/ldm/modules/karlo/kakao/models/decoder_model.py b/repositories/ldm/modules/karlo/kakao/models/decoder_model.py
new file mode 100644
index 000000000..84e96c9b2
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/models/decoder_model.py
@@ -0,0 +1,193 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import copy
+import torch
+
+from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
+from ldm.modules.karlo.kakao.modules.unet import PLMImUNet
+
+
+class Text2ImProgressiveModel(torch.nn.Module):
+ """
+ A decoder that generates 64x64px images based on the text prompt.
+
+ :param config: yaml config to define the decoder.
+ :param tokenizer: tokenizer used in clip.
+ """
+
+ def __init__(
+ self,
+ config,
+ tokenizer,
+ ):
+ super().__init__()
+
+ self._conf = config
+ self._model_conf = config.model.hparams
+ self._diffusion_kwargs = dict(
+ steps=config.diffusion.steps,
+ learn_sigma=config.diffusion.learn_sigma,
+ sigma_small=config.diffusion.sigma_small,
+ noise_schedule=config.diffusion.noise_schedule,
+ use_kl=config.diffusion.use_kl,
+ predict_xstart=config.diffusion.predict_xstart,
+ rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
+ timestep_respacing=config.diffusion.timestep_respacing,
+ )
+ self._tokenizer = tokenizer
+
+ self.model = self.create_plm_dec_model()
+
+ cf_token, cf_mask = self.set_cf_text_tensor()
+ self.register_buffer("cf_token", cf_token, persistent=False)
+ self.register_buffer("cf_mask", cf_mask, persistent=False)
+
+ @classmethod
+ def load_from_checkpoint(cls, config, tokenizer, ckpt_path, strict: bool = True):
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+
+ model = cls(config, tokenizer)
+ model.load_state_dict(ckpt, strict=strict)
+ return model
+
+ def create_plm_dec_model(self):
+ image_size = self._model_conf.image_size
+ if self._model_conf.channel_mult == "":
+ if image_size == 256:
+ channel_mult = (1, 1, 2, 2, 4, 4)
+ elif image_size == 128:
+ channel_mult = (1, 1, 2, 3, 4)
+ elif image_size == 64:
+ channel_mult = (1, 2, 3, 4)
+ else:
+ raise ValueError(f"unsupported image size: {image_size}")
+ else:
+ channel_mult = tuple(
+ int(ch_mult) for ch_mult in self._model_conf.channel_mult.split(",")
+ )
+ assert 2 ** (len(channel_mult) + 2) == image_size
+
+ attention_ds = []
+ for res in self._model_conf.attention_resolutions.split(","):
+ attention_ds.append(image_size // int(res))
+
+ return PLMImUNet(
+ text_ctx=self._model_conf.text_ctx,
+ xf_width=self._model_conf.xf_width,
+ in_channels=3,
+ model_channels=self._model_conf.num_channels,
+ out_channels=6 if self._model_conf.learn_sigma else 3,
+ num_res_blocks=self._model_conf.num_res_blocks,
+ attention_resolutions=tuple(attention_ds),
+ dropout=self._model_conf.dropout,
+ channel_mult=channel_mult,
+ num_heads=self._model_conf.num_heads,
+ num_head_channels=self._model_conf.num_head_channels,
+ num_heads_upsample=self._model_conf.num_heads_upsample,
+ use_scale_shift_norm=self._model_conf.use_scale_shift_norm,
+ resblock_updown=self._model_conf.resblock_updown,
+ clip_dim=self._model_conf.clip_dim,
+ clip_emb_mult=self._model_conf.clip_emb_mult,
+ clip_emb_type=self._model_conf.clip_emb_type,
+ clip_emb_drop=self._model_conf.clip_emb_drop,
+ )
+
+ def set_cf_text_tensor(self):
+ return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
+
+ def get_sample_fn(self, timestep_respacing):
+ use_ddim = timestep_respacing.startswith(("ddim", "fast"))
+
+ diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
+ diffusion_kwargs.update(timestep_respacing=timestep_respacing)
+ diffusion = create_gaussian_diffusion(**diffusion_kwargs)
+ sample_fn = (
+ diffusion.ddim_sample_loop_progressive
+ if use_ddim
+ else diffusion.p_sample_loop_progressive
+ )
+
+ return sample_fn
+
+ def forward(
+ self,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat=None,
+ cf_guidance_scales=None,
+ timestep_respacing=None,
+ ):
+ # cfg should be enabled in inference
+ assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
+ assert img_feat is not None
+
+ bsz = txt_feat.shape[0]
+ img_sz = self._model_conf.image_size
+
+ def guided_model_fn(x_t, ts, **kwargs):
+ half = x_t[: len(x_t) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.model(combined, ts, **kwargs)
+ eps, rest = model_out[:, :3], model_out[:, 3:]
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cf_guidance_scales.view(-1, 1, 1, 1) * (
+ cond_eps - uncond_eps
+ )
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ cf_feat = self.model.cf_param.unsqueeze(0)
+ cf_feat = cf_feat.expand(bsz // 2, -1)
+ feat = torch.cat([img_feat, cf_feat.to(txt_feat.device)], dim=0)
+
+ cond = {
+ "y": feat,
+ "txt_feat": txt_feat,
+ "txt_feat_seq": txt_feat_seq,
+ "mask": mask,
+ }
+ sample_fn = self.get_sample_fn(timestep_respacing)
+ sample_outputs = sample_fn(
+ guided_model_fn,
+ (bsz, 3, img_sz, img_sz),
+ noise=None,
+ device=txt_feat.device,
+ clip_denoised=True,
+ model_kwargs=cond,
+ )
+
+ for out in sample_outputs:
+ sample = out["sample"]
+ yield sample if cf_guidance_scales is None else sample[
+ : sample.shape[0] // 2
+ ]
+
+
+class Text2ImModel(Text2ImProgressiveModel):
+ def forward(
+ self,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat=None,
+ cf_guidance_scales=None,
+ timestep_respacing=None,
+ ):
+ last_out = None
+ for out in super().forward(
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat,
+ cf_guidance_scales,
+ timestep_respacing,
+ ):
+ last_out = out
+ return last_out
diff --git a/repositories/ldm/modules/karlo/kakao/models/prior_model.py b/repositories/ldm/modules/karlo/kakao/models/prior_model.py
new file mode 100644
index 000000000..03ef230d2
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/models/prior_model.py
@@ -0,0 +1,138 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import copy
+import torch
+
+from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
+from ldm.modules.karlo.kakao.modules.xf import PriorTransformer
+
+
+class PriorDiffusionModel(torch.nn.Module):
+ """
+ A prior that generates clip image feature based on the text prompt.
+
+ :param config: yaml config to define the decoder.
+ :param tokenizer: tokenizer used in clip.
+ :param clip_mean: mean to normalize the clip image feature (zero-mean, unit variance).
+ :param clip_std: std to noramlize the clip image feature (zero-mean, unit variance).
+ """
+
+ def __init__(self, config, tokenizer, clip_mean, clip_std):
+ super().__init__()
+
+ self._conf = config
+ self._model_conf = config.model.hparams
+ self._diffusion_kwargs = dict(
+ steps=config.diffusion.steps,
+ learn_sigma=config.diffusion.learn_sigma,
+ sigma_small=config.diffusion.sigma_small,
+ noise_schedule=config.diffusion.noise_schedule,
+ use_kl=config.diffusion.use_kl,
+ predict_xstart=config.diffusion.predict_xstart,
+ rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
+ timestep_respacing=config.diffusion.timestep_respacing,
+ )
+ self._tokenizer = tokenizer
+
+ self.register_buffer("clip_mean", clip_mean[None, :], persistent=False)
+ self.register_buffer("clip_std", clip_std[None, :], persistent=False)
+
+ causal_mask = self.get_causal_mask()
+ self.register_buffer("causal_mask", causal_mask, persistent=False)
+
+ self.model = PriorTransformer(
+ text_ctx=self._model_conf.text_ctx,
+ xf_width=self._model_conf.xf_width,
+ xf_layers=self._model_conf.xf_layers,
+ xf_heads=self._model_conf.xf_heads,
+ xf_final_ln=self._model_conf.xf_final_ln,
+ clip_dim=self._model_conf.clip_dim,
+ )
+
+ cf_token, cf_mask = self.set_cf_text_tensor()
+ self.register_buffer("cf_token", cf_token, persistent=False)
+ self.register_buffer("cf_mask", cf_mask, persistent=False)
+
+ @classmethod
+ def load_from_checkpoint(
+ cls, config, tokenizer, clip_mean, clip_std, ckpt_path, strict: bool = True
+ ):
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+
+ model = cls(config, tokenizer, clip_mean, clip_std)
+ model.load_state_dict(ckpt, strict=strict)
+ return model
+
+ def set_cf_text_tensor(self):
+ return self._tokenizer.padded_tokens_and_mask([""], self.model.text_ctx)
+
+ def get_sample_fn(self, timestep_respacing):
+ use_ddim = timestep_respacing.startswith(("ddim", "fast"))
+
+ diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
+ diffusion_kwargs.update(timestep_respacing=timestep_respacing)
+ diffusion = create_gaussian_diffusion(**diffusion_kwargs)
+ sample_fn = diffusion.ddim_sample_loop if use_ddim else diffusion.p_sample_loop
+
+ return sample_fn
+
+ def get_causal_mask(self):
+ seq_len = self._model_conf.text_ctx + 4
+ mask = torch.empty(seq_len, seq_len)
+ mask.fill_(float("-inf"))
+ mask.triu_(1)
+ mask = mask[None, ...]
+ return mask
+
+ def forward(
+ self,
+ txt_feat,
+ txt_feat_seq,
+ mask,
+ cf_guidance_scales=None,
+ timestep_respacing=None,
+ denoised_fn=True,
+ ):
+ # cfg should be enabled in inference
+ assert cf_guidance_scales is not None and all(cf_guidance_scales > 0.0)
+
+ bsz_ = txt_feat.shape[0]
+ bsz = bsz_ // 2
+
+ def guided_model_fn(x_t, ts, **kwargs):
+ half = x_t[: len(x_t) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.model(combined, ts, **kwargs)
+ eps, rest = (
+ model_out[:, : int(x_t.shape[1])],
+ model_out[:, int(x_t.shape[1]) :],
+ )
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cf_guidance_scales.view(-1, 1) * (
+ cond_eps - uncond_eps
+ )
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ cond = {
+ "text_emb": txt_feat,
+ "text_enc": txt_feat_seq,
+ "mask": mask,
+ "causal_mask": self.causal_mask,
+ }
+ sample_fn = self.get_sample_fn(timestep_respacing)
+ sample = sample_fn(
+ guided_model_fn,
+ (bsz_, self.model.clip_dim),
+ noise=None,
+ device=txt_feat.device,
+ clip_denoised=False,
+ denoised_fn=lambda x: torch.clamp(x, -10, 10),
+ model_kwargs=cond,
+ )
+ sample = (sample * self.clip_std) + self.clip_mean
+
+ return sample[:bsz]
diff --git a/repositories/ldm/modules/karlo/kakao/models/sr_256_1k.py b/repositories/ldm/modules/karlo/kakao/models/sr_256_1k.py
new file mode 100644
index 000000000..1e874f6f1
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/models/sr_256_1k.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+from ldm.modules.karlo.kakao.models.sr_64_256 import SupRes64to256Progressive
+
+
+class SupRes256to1kProgressive(SupRes64to256Progressive):
+ pass # no difference currently
diff --git a/repositories/ldm/modules/karlo/kakao/models/sr_64_256.py b/repositories/ldm/modules/karlo/kakao/models/sr_64_256.py
new file mode 100644
index 000000000..32687afe3
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/models/sr_64_256.py
@@ -0,0 +1,88 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import copy
+import torch
+
+from ldm.modules.karlo.kakao.modules.unet import SuperResUNetModel
+from ldm.modules.karlo.kakao.modules import create_gaussian_diffusion
+
+
+class ImprovedSupRes64to256ProgressiveModel(torch.nn.Module):
+ """
+ ImprovedSR model fine-tunes the pretrained DDPM-based SR model by using adversarial and perceptual losses.
+ In specific, the low-resolution sample is iteratively recovered by 6 steps with the frozen pretrained SR model.
+ In the following additional one step, a seperate fine-tuned model recovers high-frequency details.
+ This approach greatly improves the fidelity of images of 256x256px, even with small number of reverse steps.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self._config = config
+ self._diffusion_kwargs = dict(
+ steps=config.diffusion.steps,
+ learn_sigma=config.diffusion.learn_sigma,
+ sigma_small=config.diffusion.sigma_small,
+ noise_schedule=config.diffusion.noise_schedule,
+ use_kl=config.diffusion.use_kl,
+ predict_xstart=config.diffusion.predict_xstart,
+ rescale_learned_sigmas=config.diffusion.rescale_learned_sigmas,
+ )
+
+ self.model_first_steps = SuperResUNetModel(
+ in_channels=3, # auto-changed to 6 inside the model
+ model_channels=config.model.hparams.channels,
+ out_channels=3,
+ num_res_blocks=config.model.hparams.depth,
+ attention_resolutions=(), # no attention
+ dropout=config.model.hparams.dropout,
+ channel_mult=config.model.hparams.channels_multiple,
+ resblock_updown=True,
+ use_middle_attention=False,
+ )
+ self.model_last_step = SuperResUNetModel(
+ in_channels=3, # auto-changed to 6 inside the model
+ model_channels=config.model.hparams.channels,
+ out_channels=3,
+ num_res_blocks=config.model.hparams.depth,
+ attention_resolutions=(), # no attention
+ dropout=config.model.hparams.dropout,
+ channel_mult=config.model.hparams.channels_multiple,
+ resblock_updown=True,
+ use_middle_attention=False,
+ )
+
+ @classmethod
+ def load_from_checkpoint(cls, config, ckpt_path, strict: bool = True):
+ ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
+
+ model = cls(config)
+ model.load_state_dict(ckpt, strict=strict)
+ return model
+
+ def get_sample_fn(self, timestep_respacing):
+ diffusion_kwargs = copy.deepcopy(self._diffusion_kwargs)
+ diffusion_kwargs.update(timestep_respacing=timestep_respacing)
+ diffusion = create_gaussian_diffusion(**diffusion_kwargs)
+ return diffusion.p_sample_loop_progressive_for_improved_sr
+
+ def forward(self, low_res, timestep_respacing="7", **kwargs):
+ assert (
+ timestep_respacing == "7"
+ ), "different respacing method may work, but no guaranteed"
+
+ sample_fn = self.get_sample_fn(timestep_respacing)
+ sample_outputs = sample_fn(
+ self.model_first_steps,
+ self.model_last_step,
+ shape=low_res.shape,
+ clip_denoised=True,
+ model_kwargs=dict(low_res=low_res),
+ **kwargs,
+ )
+ for x in sample_outputs:
+ sample = x["sample"]
+ yield sample
diff --git a/repositories/ldm/modules/karlo/kakao/modules/__init__.py b/repositories/ldm/modules/karlo/kakao/modules/__init__.py
new file mode 100644
index 000000000..11d4358a6
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/modules/__init__.py
@@ -0,0 +1,49 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+
+from .diffusion import gaussian_diffusion as gd
+from .diffusion.respace import (
+ SpacedDiffusion,
+ space_timesteps,
+)
+
+
+def create_gaussian_diffusion(
+ steps,
+ learn_sigma,
+ sigma_small,
+ noise_schedule,
+ use_kl,
+ predict_xstart,
+ rescale_learned_sigmas,
+ timestep_respacing,
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if not timestep_respacing:
+ timestep_respacing = [steps]
+
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type,
+ )
diff --git a/repositories/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py b/repositories/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py
new file mode 100644
index 000000000..6a111aa09
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/modules/diffusion/gaussian_diffusion.py
@@ -0,0 +1,828 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+import enum
+import math
+
+import numpy as np
+import torch as th
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(
+ beta_start, beta_end, warmup_time, dtype=np.float64
+ )
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start**0.5,
+ beta_end**0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
+ )
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+class GaussianDiffusion(th.nn.Module):
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type,
+ ):
+ super(GaussianDiffusion, self).__init__()
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1])
+ alphas_cumprod_next = np.append(alphas_cumprod[1:], 0.0)
+ assert alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
+ sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - alphas_cumprod)
+ log_one_minus_alphas_cumprod = np.log(1.0 - alphas_cumprod)
+ sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod)
+ sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = (
+ betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ posterior_log_variance_clipped = np.log(
+ np.append(posterior_variance[1], posterior_variance[1:])
+ )
+ posterior_mean_coef1 = (
+ betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+ )
+ posterior_mean_coef2 = (
+ (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod)
+ )
+
+ self.register_buffer("betas", th.from_numpy(betas), persistent=False)
+ self.register_buffer(
+ "alphas_cumprod", th.from_numpy(alphas_cumprod), persistent=False
+ )
+ self.register_buffer(
+ "alphas_cumprod_prev", th.from_numpy(alphas_cumprod_prev), persistent=False
+ )
+ self.register_buffer(
+ "alphas_cumprod_next", th.from_numpy(alphas_cumprod_next), persistent=False
+ )
+
+ self.register_buffer(
+ "sqrt_alphas_cumprod", th.from_numpy(sqrt_alphas_cumprod), persistent=False
+ )
+ self.register_buffer(
+ "sqrt_one_minus_alphas_cumprod",
+ th.from_numpy(sqrt_one_minus_alphas_cumprod),
+ persistent=False,
+ )
+ self.register_buffer(
+ "log_one_minus_alphas_cumprod",
+ th.from_numpy(log_one_minus_alphas_cumprod),
+ persistent=False,
+ )
+ self.register_buffer(
+ "sqrt_recip_alphas_cumprod",
+ th.from_numpy(sqrt_recip_alphas_cumprod),
+ persistent=False,
+ )
+ self.register_buffer(
+ "sqrt_recipm1_alphas_cumprod",
+ th.from_numpy(sqrt_recipm1_alphas_cumprod),
+ persistent=False,
+ )
+
+ self.register_buffer(
+ "posterior_variance", th.from_numpy(posterior_variance), persistent=False
+ )
+ self.register_buffer(
+ "posterior_log_variance_clipped",
+ th.from_numpy(posterior_log_variance_clipped),
+ persistent=False,
+ )
+ self.register_buffer(
+ "posterior_mean_coef1",
+ th.from_numpy(posterior_mean_coef1),
+ persistent=False,
+ )
+ self.register_buffer(
+ "posterior_mean_coef2",
+ th.from_numpy(posterior_mean_coef2),
+ persistent=False,
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ )
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
+ )
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
+ * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ model_kwargs=None,
+ **ignore_kwargs,
+ ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, C = x.shape[:2]
+ assert t.shape == (B,)
+ model_output = model(x, t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
+ model_output, model_var_values = th.split(model_output, C, dim=1)
+ if self.model_var_type == ModelVarType.LEARNED:
+ model_log_variance = model_var_values
+ model_variance = th.exp(model_log_variance)
+ else:
+ min_log = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x.shape
+ )
+ max_log = _extract_into_tensor(th.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ th.cat([self.posterior_variance[1][None], self.betas[1:]]),
+ th.log(th.cat([self.posterior_variance[1][None], self.betas[1:]])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
+ )
+ model_mean = model_output
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]:
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(
+ x_start=pred_xstart, x_t=x, t=t
+ )
+ else:
+ raise NotImplementedError(self.model_mean_type)
+
+ assert (
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ )
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = (
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ )
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(
+ x_start=out["pred_xstart"], x_t=x, t=t
+ )
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(
+ cond_fn, out, x, t, model_kwargs=model_kwargs
+ )
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for idx, i in enumerate(indices):
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def p_sample_loop_progressive_for_improved_sr(
+ self,
+ model,
+ model_aux,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Modified version of p_sample_loop_progressive for sampling from the improved sr model
+ """
+
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for idx, i in enumerate(indices):
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model_aux if len(indices) - 1 == idx else model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
+ + th.sqrt(1 - alpha_bar_next) * eps
+ )
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = arr.to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/repositories/ldm/modules/karlo/kakao/modules/diffusion/respace.py b/repositories/ldm/modules/karlo/kakao/modules/diffusion/respace.py
new file mode 100644
index 000000000..70c808f8b
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/modules/diffusion/respace.py
@@ -0,0 +1,112 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ elif section_counts == "fast27":
+ steps = space_timesteps(num_timesteps, "10,10,3,2,2")
+ # Help reduce DDIM artifacts from noisiest timesteps.
+ steps.remove(num_timesteps - 1)
+ steps.add(num_timesteps - 3)
+ return steps
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ timestep_map = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ timestep_map.append(i)
+ kwargs["betas"] = th.tensor(new_betas).numpy()
+ super().__init__(**kwargs)
+ self.register_buffer("timestep_map", th.tensor(timestep_map), persistent=False)
+
+ def p_mean_variance(self, model, *args, **kwargs):
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ def wrapped(x, ts, **kwargs):
+ ts_cpu = ts.detach().to("cpu")
+ return model(
+ x, self.timestep_map[ts_cpu].to(device=ts.device, dtype=ts.dtype), **kwargs
+ )
+
+ return wrapped
diff --git a/repositories/ldm/modules/karlo/kakao/modules/nn.py b/repositories/ldm/modules/karlo/kakao/modules/nn.py
new file mode 100644
index 000000000..2eef3f5a0
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/modules/nn.py
@@ -0,0 +1,114 @@
+# ------------------------------------------------------------------------------------
+# Adapted from Guided-Diffusion repo (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+import math
+
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GroupNorm32(nn.GroupNorm):
+ def __init__(self, num_groups, num_channels, swish, eps=1e-5):
+ super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps)
+ self.swish = swish
+
+ def forward(self, x):
+ y = super().forward(x.float()).to(x.dtype)
+ if self.swish == 1.0:
+ y = F.silu(y)
+ elif self.swish:
+ y = y * F.sigmoid(y * float(self.swish))
+ return y
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def normalization(channels, swish=0.0):
+ """
+ Make a standard normalization layer, with an optional swish activation.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(num_channels=channels, num_groups=32, swish=swish)
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = th.exp(
+ -math.log(max_period)
+ * th.arange(start=0, end=half, dtype=th.float32, device=timesteps.device)
+ / half
+ )
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
diff --git a/repositories/ldm/modules/karlo/kakao/modules/resample.py b/repositories/ldm/modules/karlo/kakao/modules/resample.py
new file mode 100644
index 000000000..485421aa4
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/modules/resample.py
@@ -0,0 +1,68 @@
+# ------------------------------------------------------------------------------------
+# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+from abc import abstractmethod
+
+import torch as th
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(th.nn.Module):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / th.sum(w)
+ indices = p.multinomial(batch_size, replacement=True)
+ weights = 1 / (len(p) * p[indices])
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ super(UniformSampler, self).__init__()
+ self.diffusion = diffusion
+ self.register_buffer(
+ "_weights", th.ones([diffusion.num_timesteps]), persistent=False
+ )
+
+ def weights(self):
+ return self._weights
diff --git a/repositories/ldm/modules/karlo/kakao/modules/unet.py b/repositories/ldm/modules/karlo/kakao/modules/unet.py
new file mode 100644
index 000000000..c99d0b791
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/modules/unet.py
@@ -0,0 +1,792 @@
+# ------------------------------------------------------------------------------------
+# Modified from Guided-Diffusion (https://github.com/openai/guided-diffusion)
+# ------------------------------------------------------------------------------------
+
+import math
+from abc import abstractmethod
+
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nn import (
+ avg_pool_nd,
+ conv_nd,
+ linear,
+ normalization,
+ timestep_embedding,
+ zero_module,
+)
+from .xf import LayerNorm
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, encoder_out=None, mask=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ elif isinstance(layer, AttentionBlock):
+ x = layer(x, encoder_out, mask=mask)
+ else:
+ x = layer(x)
+ return x
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
+ )
+ else:
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self, channels, use_conv, dims=2, out_channels=None):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(
+ dims, self.channels, self.out_channels, 3, stride=stride, padding=1
+ )
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+
+ self.in_layers = nn.Sequential(
+ normalization(channels, swish=1.0),
+ nn.Identity(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ linear(
+ emb_channels,
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(
+ self.out_channels, swish=0.0 if use_scale_shift_norm else 1.0
+ ),
+ nn.SiLU() if use_scale_shift_norm else nn.Identity(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = th.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class ResBlockNoTimeEmbedding(nn.Module):
+ """
+ A residual block without time embedding
+
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ """
+
+ def __init__(
+ self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_conv=False,
+ dims=2,
+ use_checkpoint=False,
+ up=False,
+ down=False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+
+ self.in_layers = nn.Sequential(
+ normalization(channels, swish=1.0),
+ nn.Identity(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels, swish=1.0),
+ nn.Dropout(p=dropout),
+ zero_module(
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
+ ),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(
+ dims, channels, self.out_channels, 3, padding=1
+ )
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
+
+ def forward(self, x, emb=None):
+ """
+ Apply the block to a Tensor, NOT conditioned on a timestep embedding.
+
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert emb is None
+
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ h = self.out_layers(h)
+ return self.skip_connection(x) + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ use_checkpoint=False,
+ encoder_channels=None,
+ ):
+ super().__init__()
+ self.channels = channels
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert (
+ channels % num_head_channels == 0
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ self.num_heads = channels // num_head_channels
+ self.use_checkpoint = use_checkpoint
+ self.norm = normalization(channels, swish=0.0)
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
+ self.attention = QKVAttention(self.num_heads)
+
+ if encoder_channels is not None:
+ self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1)
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
+
+ def forward(self, x, encoder_out=None, mask=None):
+ b, c, *spatial = x.shape
+ qkv = self.qkv(self.norm(x).view(b, c, -1))
+ if encoder_out is not None:
+ encoder_out = self.encoder_kv(encoder_out)
+ h = self.attention(qkv, encoder_out, mask=mask)
+ else:
+ h = self.attention(qkv)
+ h = self.proj_out(h)
+ return x + h.reshape(b, c, *spatial)
+
+
+class QKVAttention(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv, encoder_kv=None, mask=None):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ if encoder_kv is not None:
+ assert encoder_kv.shape[1] == self.n_heads * ch * 2
+ ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1)
+ k = th.cat([ek, k], dim=-1)
+ v = th.cat([ev, v], dim=-1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = th.einsum("bct,bcs->bts", q * scale, k * scale)
+ if mask is not None:
+ mask = F.pad(mask, (0, length), value=0.0)
+ mask = (
+ mask.unsqueeze(1)
+ .expand(-1, self.n_heads, -1)
+ .reshape(bs * self.n_heads, 1, -1)
+ )
+ weight = weight + mask
+ weight = th.softmax(weight, dim=-1)
+ a = th.einsum("bts,bcs->bct", weight, v)
+ return a.reshape(bs, -1, length)
+
+
+class UNetModel(nn.Module):
+ """
+ The full UNet model with attention and timestep embedding.
+
+ :param in_channels: channels in the input Tensor.
+ :param model_channels: base channel count for the model.
+ :param out_channels: channels in the output Tensor.
+ :param num_res_blocks: number of residual blocks per downsample.
+ :param attention_resolutions: a collection of downsample rates at which
+ attention will take place. May be a set, list, or tuple.
+ For example, if this contains 4, then at 4x downsampling, attention
+ will be used.
+ :param dropout: the dropout probability.
+ :param channel_mult: channel multiplier for each level of the UNet.
+ :param conv_resample: if True, use learned convolutions for upsampling and
+ downsampling.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param clip_dim: dimension of clip feature.
+ :param num_classes: if specified (as an int), then this model will be
+ class-conditional with `num_classes` classes.
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
+ :param num_heads: the number of attention heads in each attention layer.
+ :param num_heads_channels: if specified, ignore num_heads and instead use
+ a fixed channel width per attention head.
+ :param num_heads_upsample: works with num_heads to set a different number
+ of heads for upsampling. Deprecated.
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
+ :param resblock_updown: use residual blocks for up/downsampling.
+ :param encoder_channels: use to make the dimension of query and kv same in AttentionBlock.
+ :param use_time_embedding: use time embedding for condition.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ model_channels,
+ out_channels,
+ num_res_blocks,
+ attention_resolutions,
+ dropout=0,
+ channel_mult=(1, 2, 4, 8),
+ conv_resample=True,
+ dims=2,
+ clip_dim=None,
+ use_checkpoint=False,
+ num_heads=1,
+ num_head_channels=-1,
+ num_heads_upsample=-1,
+ use_scale_shift_norm=False,
+ use_middle_attention=True,
+ resblock_updown=False,
+ encoder_channels=None,
+ use_time_embedding=True,
+ ):
+ super().__init__()
+
+ if num_heads_upsample == -1:
+ num_heads_upsample = num_heads
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.clip_dim = clip_dim
+ self.use_checkpoint = use_checkpoint
+ self.num_heads = num_heads
+ self.num_head_channels = num_head_channels
+ self.num_heads_upsample = num_heads_upsample
+ self.use_middle_attention = use_middle_attention
+ self.use_time_embedding = use_time_embedding
+
+ if self.use_time_embedding:
+ time_embed_dim = model_channels * 4
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+
+ if self.clip_dim is not None:
+ self.clip_emb = nn.Linear(clip_dim, time_embed_dim)
+ else:
+ time_embed_dim = None
+
+ CustomResidualBlock = (
+ ResBlock if self.use_time_embedding else ResBlockNoTimeEmbedding
+ )
+ ch = input_ch = int(channel_mult[0] * model_channels)
+ self.input_blocks = nn.ModuleList(
+ [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1))]
+ )
+ self._feature_size = ch
+ input_block_chans = [ch]
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=int(mult * model_channels),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(mult * model_channels)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ )
+ )
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True,
+ )
+ if resblock_updown
+ else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch
+ )
+ )
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+ self._feature_size += ch
+
+ self.middle_block = TimestepEmbedSequential(
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ *(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ ),
+ )
+ if self.use_middle_attention
+ else tuple(), # add AttentionBlock or not
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ ),
+ )
+ self._feature_size += ch
+
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ CustomResidualBlock(
+ ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=int(model_channels * mult),
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ )
+ ]
+ ch = int(model_channels * mult)
+ if ds in attention_resolutions:
+ layers.append(
+ AttentionBlock(
+ ch,
+ use_checkpoint=use_checkpoint,
+ num_heads=num_heads_upsample,
+ num_head_channels=num_head_channels,
+ encoder_channels=encoder_channels,
+ )
+ )
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ CustomResidualBlock(
+ ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True,
+ )
+ if resblock_updown
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
+ )
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+ self._feature_size += ch
+
+ self.out = nn.Sequential(
+ normalization(ch, swish=1.0),
+ nn.Identity(),
+ zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)),
+ )
+
+ def forward(self, x, timesteps, y=None):
+ """
+ Apply the model to an input batch.
+
+ :param x: an [N x C x ...] Tensor of inputs.
+ :param timesteps: a 1-D batch of timesteps.
+ :param y: an [N] Tensor of labels, if class-conditional.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ assert (y is not None) == (
+ self.clip_dim is not None
+ ), "must specify y if and only if the model is clip-rep-conditional"
+
+ hs = []
+ if self.use_time_embedding:
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+ if self.clip_dim is not None:
+ emb = emb + self.clip_emb(y)
+ else:
+ emb = None
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb)
+ hs.append(h)
+ h = self.middle_block(h, emb)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb)
+
+ return self.out(h)
+
+
+class SuperResUNetModel(UNetModel):
+ """
+ A UNetModel that performs super-resolution.
+
+ Expects an extra kwarg `low_res` to condition on a low-resolution image.
+ Assumes that the shape of low-resolution and the input should be the same.
+ """
+
+ def __init__(self, *args, **kwargs):
+ if "in_channels" in kwargs:
+ kwargs = dict(kwargs)
+ kwargs["in_channels"] = kwargs["in_channels"] * 2
+ else:
+ # Curse you, Python. Or really, just curse positional arguments :|.
+ args = list(args)
+ args[1] = args[1] * 2
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x, timesteps, low_res=None, **kwargs):
+ _, _, new_height, new_width = x.shape
+ assert new_height == low_res.shape[2] and new_width == low_res.shape[3]
+
+ x = th.cat([x, low_res], dim=1)
+ return super().forward(x, timesteps, **kwargs)
+
+
+class PLMImUNet(UNetModel):
+ """
+ A UNetModel that conditions on text with a pretrained text encoder in CLIP.
+
+ :param text_ctx: number of text tokens to expect.
+ :param xf_width: width of the transformer.
+ :param clip_emb_mult: #extra tokens by projecting clip text feature.
+ :param clip_emb_type: type of condition (here, we fix clip image feature).
+ :param clip_emb_drop: dropout rato of clip image feature for cfg.
+ """
+
+ def __init__(
+ self,
+ text_ctx,
+ xf_width,
+ *args,
+ clip_emb_mult=None,
+ clip_emb_type="image",
+ clip_emb_drop=0.0,
+ **kwargs,
+ ):
+ self.text_ctx = text_ctx
+ self.xf_width = xf_width
+ self.clip_emb_mult = clip_emb_mult
+ self.clip_emb_type = clip_emb_type
+ self.clip_emb_drop = clip_emb_drop
+
+ if not xf_width:
+ super().__init__(*args, **kwargs, encoder_channels=None)
+ else:
+ super().__init__(*args, **kwargs, encoder_channels=xf_width)
+
+ # Project text encoded feat seq from pre-trained text encoder in CLIP
+ self.text_seq_proj = nn.Sequential(
+ nn.Linear(self.clip_dim, xf_width),
+ LayerNorm(xf_width),
+ )
+ # Project CLIP text feat
+ self.text_feat_proj = nn.Linear(self.clip_dim, self.model_channels * 4)
+
+ assert clip_emb_mult is not None
+ assert clip_emb_type == "image"
+ assert self.clip_dim is not None, "CLIP representation dim should be specified"
+
+ self.clip_tok_proj = nn.Linear(
+ self.clip_dim, self.xf_width * self.clip_emb_mult
+ )
+ if self.clip_emb_drop > 0:
+ self.cf_param = nn.Parameter(th.empty(self.clip_dim, dtype=th.float32))
+
+ def proc_clip_emb_drop(self, feat):
+ if self.clip_emb_drop > 0:
+ bsz, feat_dim = feat.shape
+ assert (
+ feat_dim == self.clip_dim
+ ), f"CLIP input dim: {feat_dim}, model CLIP dim: {self.clip_dim}"
+ drop_idx = th.rand((bsz,), device=feat.device) < self.clip_emb_drop
+ feat = th.where(
+ drop_idx[..., None], self.cf_param[None].type_as(feat), feat
+ )
+ return feat
+
+ def forward(
+ self, x, timesteps, txt_feat=None, txt_feat_seq=None, mask=None, y=None
+ ):
+ bsz = x.shape[0]
+ hs = []
+ emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
+ emb = emb + self.clip_emb(y)
+
+ xf_out = self.text_seq_proj(txt_feat_seq)
+ xf_out = xf_out.permute(0, 2, 1)
+ emb = emb + self.text_feat_proj(txt_feat)
+ xf_out = th.cat(
+ [
+ self.clip_tok_proj(y).reshape(bsz, -1, self.clip_emb_mult),
+ xf_out,
+ ],
+ dim=2,
+ )
+ mask = F.pad(mask, (self.clip_emb_mult, 0), value=True)
+ mask = th.where(mask, 0.0, float("-inf"))
+
+ h = x
+ for module in self.input_blocks:
+ h = module(h, emb, xf_out, mask=mask)
+ hs.append(h)
+ h = self.middle_block(h, emb, xf_out, mask=mask)
+ for module in self.output_blocks:
+ h = th.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, xf_out, mask=mask)
+ h = self.out(h)
+
+ return h
diff --git a/repositories/ldm/modules/karlo/kakao/modules/xf.py b/repositories/ldm/modules/karlo/kakao/modules/xf.py
new file mode 100644
index 000000000..66d7d4a2f
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/modules/xf.py
@@ -0,0 +1,231 @@
+# ------------------------------------------------------------------------------------
+# Adapted from the repos below:
+# (a) Guided-Diffusion (https://github.com/openai/guided-diffusion)
+# (b) CLIP ViT (https://github.com/openai/CLIP/)
+# ------------------------------------------------------------------------------------
+
+import math
+
+import torch as th
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nn import timestep_embedding
+
+
+def convert_module_to_f16(param):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
+ param.weight.data = param.weight.data.half()
+ if param.bias is not None:
+ param.bias.data = param.bias.data.half()
+
+
+class LayerNorm(nn.LayerNorm):
+ """
+ Implementation that supports fp16 inputs but fp32 gains/biases.
+ """
+
+ def forward(self, x: th.Tensor):
+ return super().forward(x.float()).to(x.dtype)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, n_ctx, width, heads):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.heads = heads
+ self.c_qkv = nn.Linear(width, width * 3)
+ self.c_proj = nn.Linear(width, width)
+ self.attention = QKVMultiheadAttention(heads, n_ctx)
+
+ def forward(self, x, mask=None):
+ x = self.c_qkv(x)
+ x = self.attention(x, mask=mask)
+ x = self.c_proj(x)
+ return x
+
+
+class MLP(nn.Module):
+ def __init__(self, width):
+ super().__init__()
+ self.width = width
+ self.c_fc = nn.Linear(width, width * 4)
+ self.c_proj = nn.Linear(width * 4, width)
+ self.gelu = nn.GELU()
+
+ def forward(self, x):
+ return self.c_proj(self.gelu(self.c_fc(x)))
+
+
+class QKVMultiheadAttention(nn.Module):
+ def __init__(self, n_heads: int, n_ctx: int):
+ super().__init__()
+ self.n_heads = n_heads
+ self.n_ctx = n_ctx
+
+ def forward(self, qkv, mask=None):
+ bs, n_ctx, width = qkv.shape
+ attn_ch = width // self.n_heads // 3
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
+ qkv = qkv.view(bs, n_ctx, self.n_heads, -1)
+ q, k, v = th.split(qkv, attn_ch, dim=-1)
+ weight = th.einsum("bthc,bshc->bhts", q * scale, k * scale)
+ wdtype = weight.dtype
+ if mask is not None:
+ weight = weight + mask[:, None, ...]
+ weight = th.softmax(weight, dim=-1).type(wdtype)
+ return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ n_ctx: int,
+ width: int,
+ heads: int,
+ ):
+ super().__init__()
+
+ self.attn = MultiheadAttention(
+ n_ctx,
+ width,
+ heads,
+ )
+ self.ln_1 = LayerNorm(width)
+ self.mlp = MLP(width)
+ self.ln_2 = LayerNorm(width)
+
+ def forward(self, x, mask=None):
+ x = x + self.attn(self.ln_1(x), mask=mask)
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ n_ctx: int,
+ width: int,
+ layers: int,
+ heads: int,
+ ):
+ super().__init__()
+ self.n_ctx = n_ctx
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(
+ n_ctx,
+ width,
+ heads,
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x, mask=None):
+ for block in self.resblocks:
+ x = block(x, mask=mask)
+ return x
+
+
+class PriorTransformer(nn.Module):
+ """
+ A Causal Transformer that conditions on CLIP text embedding, text.
+
+ :param text_ctx: number of text tokens to expect.
+ :param xf_width: width of the transformer.
+ :param xf_layers: depth of the transformer.
+ :param xf_heads: heads in the transformer.
+ :param xf_final_ln: use a LayerNorm after the output layer.
+ :param clip_dim: dimension of clip feature.
+ """
+
+ def __init__(
+ self,
+ text_ctx,
+ xf_width,
+ xf_layers,
+ xf_heads,
+ xf_final_ln,
+ clip_dim,
+ ):
+ super().__init__()
+
+ self.text_ctx = text_ctx
+ self.xf_width = xf_width
+ self.xf_layers = xf_layers
+ self.xf_heads = xf_heads
+ self.clip_dim = clip_dim
+ self.ext_len = 4
+
+ self.time_embed = nn.Sequential(
+ nn.Linear(xf_width, xf_width),
+ nn.SiLU(),
+ nn.Linear(xf_width, xf_width),
+ )
+ self.text_enc_proj = nn.Linear(clip_dim, xf_width)
+ self.text_emb_proj = nn.Linear(clip_dim, xf_width)
+ self.clip_img_proj = nn.Linear(clip_dim, xf_width)
+ self.out_proj = nn.Linear(xf_width, clip_dim)
+ self.transformer = Transformer(
+ text_ctx + self.ext_len,
+ xf_width,
+ xf_layers,
+ xf_heads,
+ )
+ if xf_final_ln:
+ self.final_ln = LayerNorm(xf_width)
+ else:
+ self.final_ln = None
+
+ self.positional_embedding = nn.Parameter(
+ th.empty(1, text_ctx + self.ext_len, xf_width)
+ )
+ self.prd_emb = nn.Parameter(th.randn((1, 1, xf_width)))
+
+ nn.init.normal_(self.prd_emb, std=0.01)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ text_emb=None,
+ text_enc=None,
+ mask=None,
+ causal_mask=None,
+ ):
+ bsz = x.shape[0]
+ mask = F.pad(mask, (0, self.ext_len), value=True)
+
+ t_emb = self.time_embed(timestep_embedding(timesteps, self.xf_width))
+ text_enc = self.text_enc_proj(text_enc)
+ text_emb = self.text_emb_proj(text_emb)
+ x = self.clip_img_proj(x)
+
+ input_seq = [
+ text_enc,
+ text_emb[:, None, :],
+ t_emb[:, None, :],
+ x[:, None, :],
+ self.prd_emb.to(x.dtype).expand(bsz, -1, -1),
+ ]
+ input = th.cat(input_seq, dim=1)
+ input = input + self.positional_embedding.to(input.dtype)
+
+ mask = th.where(mask, 0.0, float("-inf"))
+ mask = (mask[:, None, :] + causal_mask).to(input.dtype)
+
+ out = self.transformer(input, mask=mask)
+ if self.final_ln is not None:
+ out = self.final_ln(out)
+
+ out = self.out_proj(out[:, -1])
+
+ return out
diff --git a/repositories/ldm/modules/karlo/kakao/sampler.py b/repositories/ldm/modules/karlo/kakao/sampler.py
new file mode 100644
index 000000000..b56bf2f20
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/sampler.py
@@ -0,0 +1,272 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+
+# source: https://github.com/kakaobrain/karlo/blob/3c68a50a16d76b48a15c181d1c5a5e0879a90f85/karlo/sampler/t2i.py#L15
+# ------------------------------------------------------------------------------------
+
+from typing import Iterator
+
+import torch
+import torchvision.transforms.functional as TVF
+from torchvision.transforms import InterpolationMode
+
+from .template import BaseSampler, CKPT_PATH
+
+
+class T2ISampler(BaseSampler):
+ """
+ A sampler for text-to-image generation.
+ :param root_dir: directory for model checkpoints.
+ :param sampling_type: ["default", "fast"]
+ """
+
+ def __init__(
+ self,
+ root_dir: str,
+ sampling_type: str = "default",
+ ):
+ super().__init__(root_dir, sampling_type)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ root_dir: str,
+ clip_model_path: str,
+ clip_stat_path: str,
+ sampling_type: str = "default",
+ ):
+
+ model = cls(
+ root_dir=root_dir,
+ sampling_type=sampling_type,
+ )
+ model.load_clip(clip_model_path)
+ model.load_prior(
+ f"{CKPT_PATH['prior']}",
+ clip_stat_path=clip_stat_path,
+ prior_config="configs/karlo/prior_1B_vit_l.yaml"
+ )
+ model.load_decoder(f"{CKPT_PATH['decoder']}", decoder_config="configs/karlo/decoder_900M_vit_l.yaml")
+ model.load_sr_64_256(CKPT_PATH["sr_256"], sr_config="configs/karlo/improved_sr_64_256_1.4B.yaml")
+ return model
+
+ def preprocess(
+ self,
+ prompt: str,
+ bsz: int,
+ ):
+ """Setup prompts & cfg scales"""
+ prompts_batch = [prompt for _ in range(bsz)]
+
+ prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
+ prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
+
+ decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
+ decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
+
+ """ Get CLIP text feature """
+ clip_model = self._clip
+ tokenizer = self._tokenizer
+ max_txt_length = self._prior.model.text_ctx
+
+ tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
+ cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
+ if not (cf_token.shape == tok.shape):
+ cf_token = cf_token.expand(tok.shape[0], -1)
+ cf_mask = cf_mask.expand(tok.shape[0], -1)
+
+ tok = torch.cat([tok, cf_token], dim=0)
+ mask = torch.cat([mask, cf_mask], dim=0)
+
+ tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
+ txt_feat, txt_feat_seq = clip_model.encode_text(tok)
+
+ return (
+ prompts_batch,
+ prior_cf_scales_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ )
+
+ def __call__(
+ self,
+ prompt: str,
+ bsz: int,
+ progressive_mode=None,
+ ) -> Iterator[torch.Tensor]:
+ assert progressive_mode in ("loop", "stage", "final")
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ (
+ prompts_batch,
+ prior_cf_scales_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ ) = self.preprocess(
+ prompt,
+ bsz,
+ )
+
+ """ Transform CLIP text feature into image feature """
+ img_feat = self._prior(
+ txt_feat,
+ txt_feat_seq,
+ mask,
+ prior_cf_scales_batch,
+ timestep_respacing=self._prior_sm,
+ )
+
+ """ Generate 64x64px images """
+ images_64_outputs = self._decoder(
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ img_feat,
+ cf_guidance_scales=decoder_cf_scales_batch,
+ timestep_respacing=self._decoder_sm,
+ )
+
+ images_64 = None
+ for k, out in enumerate(images_64_outputs):
+ images_64 = out
+ if progressive_mode == "loop":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+ if progressive_mode == "stage":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+
+ images_64 = torch.clamp(images_64, -1, 1)
+
+ """ Upsample 64x64 to 256x256 """
+ images_256 = TVF.resize(
+ images_64,
+ [256, 256],
+ interpolation=InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+ images_256_outputs = self._sr_64_256(
+ images_256, timestep_respacing=self._sr_sm
+ )
+
+ for k, out in enumerate(images_256_outputs):
+ images_256 = out
+ if progressive_mode == "loop":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+ if progressive_mode == "stage":
+ yield torch.clamp(out * 0.5 + 0.5, 0.0, 1.0)
+
+ yield torch.clamp(images_256 * 0.5 + 0.5, 0.0, 1.0)
+
+
+class PriorSampler(BaseSampler):
+ """
+ A sampler for text-to-image generation, but only the prior.
+ :param root_dir: directory for model checkpoints.
+ :param sampling_type: ["default", "fast"]
+ """
+
+ def __init__(
+ self,
+ root_dir: str,
+ sampling_type: str = "default",
+ ):
+ super().__init__(root_dir, sampling_type)
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ root_dir: str,
+ clip_model_path: str,
+ clip_stat_path: str,
+ sampling_type: str = "default",
+ ):
+ model = cls(
+ root_dir=root_dir,
+ sampling_type=sampling_type,
+ )
+ model.load_clip(clip_model_path)
+ model.load_prior(
+ f"{CKPT_PATH['prior']}",
+ clip_stat_path=clip_stat_path,
+ prior_config="configs/karlo/prior_1B_vit_l.yaml"
+ )
+ return model
+
+ def preprocess(
+ self,
+ prompt: str,
+ bsz: int,
+ ):
+ """Setup prompts & cfg scales"""
+ prompts_batch = [prompt for _ in range(bsz)]
+
+ prior_cf_scales_batch = [self._prior_cf_scale] * len(prompts_batch)
+ prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device="cuda")
+
+ decoder_cf_scales_batch = [self._decoder_cf_scale] * len(prompts_batch)
+ decoder_cf_scales_batch = torch.tensor(decoder_cf_scales_batch, device="cuda")
+
+ """ Get CLIP text feature """
+ clip_model = self._clip
+ tokenizer = self._tokenizer
+ max_txt_length = self._prior.model.text_ctx
+
+ tok, mask = tokenizer.padded_tokens_and_mask(prompts_batch, max_txt_length)
+ cf_token, cf_mask = tokenizer.padded_tokens_and_mask([""], max_txt_length)
+ if not (cf_token.shape == tok.shape):
+ cf_token = cf_token.expand(tok.shape[0], -1)
+ cf_mask = cf_mask.expand(tok.shape[0], -1)
+
+ tok = torch.cat([tok, cf_token], dim=0)
+ mask = torch.cat([mask, cf_mask], dim=0)
+
+ tok, mask = tok.to(device="cuda"), mask.to(device="cuda")
+ txt_feat, txt_feat_seq = clip_model.encode_text(tok)
+
+ return (
+ prompts_batch,
+ prior_cf_scales_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ )
+
+ def __call__(
+ self,
+ prompt: str,
+ bsz: int,
+ progressive_mode=None,
+ ) -> Iterator[torch.Tensor]:
+ assert progressive_mode in ("loop", "stage", "final")
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ (
+ prompts_batch,
+ prior_cf_scales_batch,
+ decoder_cf_scales_batch,
+ txt_feat,
+ txt_feat_seq,
+ tok,
+ mask,
+ ) = self.preprocess(
+ prompt,
+ bsz,
+ )
+
+ """ Transform CLIP text feature into image feature """
+ img_feat = self._prior(
+ txt_feat,
+ txt_feat_seq,
+ mask,
+ prior_cf_scales_batch,
+ timestep_respacing=self._prior_sm,
+ )
+
+ yield img_feat
diff --git a/repositories/ldm/modules/karlo/kakao/template.py b/repositories/ldm/modules/karlo/kakao/template.py
new file mode 100644
index 000000000..949e80e67
--- /dev/null
+++ b/repositories/ldm/modules/karlo/kakao/template.py
@@ -0,0 +1,141 @@
+# ------------------------------------------------------------------------------------
+# Karlo-v1.0.alpha
+# Copyright (c) 2022 KakaoBrain. All Rights Reserved.
+# ------------------------------------------------------------------------------------
+
+import os
+import logging
+import torch
+
+from omegaconf import OmegaConf
+
+from ldm.modules.karlo.kakao.models.clip import CustomizedCLIP, CustomizedTokenizer
+from ldm.modules.karlo.kakao.models.prior_model import PriorDiffusionModel
+from ldm.modules.karlo.kakao.models.decoder_model import Text2ImProgressiveModel
+from ldm.modules.karlo.kakao.models.sr_64_256 import ImprovedSupRes64to256ProgressiveModel
+
+
+SAMPLING_CONF = {
+ "default": {
+ "prior_sm": "25",
+ "prior_n_samples": 1,
+ "prior_cf_scale": 4.0,
+ "decoder_sm": "50",
+ "decoder_cf_scale": 8.0,
+ "sr_sm": "7",
+ },
+ "fast": {
+ "prior_sm": "25",
+ "prior_n_samples": 1,
+ "prior_cf_scale": 4.0,
+ "decoder_sm": "25",
+ "decoder_cf_scale": 8.0,
+ "sr_sm": "7",
+ },
+}
+
+CKPT_PATH = {
+ "prior": "prior-ckpt-step=01000000-of-01000000.ckpt",
+ "decoder": "decoder-ckpt-step=01000000-of-01000000.ckpt",
+ "sr_256": "improved-sr-ckpt-step=1.2M.ckpt",
+}
+
+
+class BaseSampler:
+ _PRIOR_CLASS = PriorDiffusionModel
+ _DECODER_CLASS = Text2ImProgressiveModel
+ _SR256_CLASS = ImprovedSupRes64to256ProgressiveModel
+
+ def __init__(
+ self,
+ root_dir: str,
+ sampling_type: str = "fast",
+ ):
+ self._root_dir = root_dir
+
+ sampling_type = SAMPLING_CONF[sampling_type]
+ self._prior_sm = sampling_type["prior_sm"]
+ self._prior_n_samples = sampling_type["prior_n_samples"]
+ self._prior_cf_scale = sampling_type["prior_cf_scale"]
+
+ assert self._prior_n_samples == 1
+
+ self._decoder_sm = sampling_type["decoder_sm"]
+ self._decoder_cf_scale = sampling_type["decoder_cf_scale"]
+
+ self._sr_sm = sampling_type["sr_sm"]
+
+ def __repr__(self):
+ line = ""
+ line += f"Prior, sampling method: {self._prior_sm}, cf_scale: {self._prior_cf_scale}\n"
+ line += f"Decoder, sampling method: {self._decoder_sm}, cf_scale: {self._decoder_cf_scale}\n"
+ line += f"SR(64->256), sampling method: {self._sr_sm}"
+
+ return line
+
+ def load_clip(self, clip_path: str):
+ clip = CustomizedCLIP.load_from_checkpoint(
+ os.path.join(self._root_dir, clip_path)
+ )
+ clip = torch.jit.script(clip)
+ clip.cuda()
+ clip.eval()
+
+ self._clip = clip
+ self._tokenizer = CustomizedTokenizer()
+
+ def load_prior(
+ self,
+ ckpt_path: str,
+ clip_stat_path: str,
+ prior_config: str = "configs/prior_1B_vit_l.yaml"
+ ):
+ logging.info(f"Loading prior: {ckpt_path}")
+
+ config = OmegaConf.load(prior_config)
+ clip_mean, clip_std = torch.load(
+ os.path.join(self._root_dir, clip_stat_path), map_location="cpu"
+ )
+
+ prior = self._PRIOR_CLASS.load_from_checkpoint(
+ config,
+ self._tokenizer,
+ clip_mean,
+ clip_std,
+ os.path.join(self._root_dir, ckpt_path),
+ strict=True,
+ )
+ prior.cuda()
+ prior.eval()
+ logging.info("done.")
+
+ self._prior = prior
+
+ def load_decoder(self, ckpt_path: str, decoder_config: str = "configs/decoder_900M_vit_l.yaml"):
+ logging.info(f"Loading decoder: {ckpt_path}")
+
+ config = OmegaConf.load(decoder_config)
+ decoder = self._DECODER_CLASS.load_from_checkpoint(
+ config,
+ self._tokenizer,
+ os.path.join(self._root_dir, ckpt_path),
+ strict=True,
+ )
+ decoder.cuda()
+ decoder.eval()
+ logging.info("done.")
+
+ self._decoder = decoder
+
+ def load_sr_64_256(self, ckpt_path: str, sr_config: str = "configs/improved_sr_64_256_1.4B.yaml"):
+ logging.info(f"Loading SR(64->256): {ckpt_path}")
+
+ config = OmegaConf.load(sr_config)
+ sr = self._SR256_CLASS.load_from_checkpoint(
+ config, os.path.join(self._root_dir, ckpt_path), strict=True
+ )
+ sr.cuda()
+ sr.eval()
+ logging.info("done.")
+
+ self._sr_64_256 = sr
\ No newline at end of file
diff --git a/repositories/ldm/modules/midas/__init__.py b/repositories/ldm/modules/midas/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/midas/api.py b/repositories/ldm/modules/midas/api.py
new file mode 100644
index 000000000..b58ebbffd
--- /dev/null
+++ b/repositories/ldm/modules/midas/api.py
@@ -0,0 +1,170 @@
+# based on https://github.com/isl-org/MiDaS
+
+import cv2
+import torch
+import torch.nn as nn
+from torchvision.transforms import Compose
+
+from ldm.modules.midas.midas.dpt_depth import DPTDepthModel
+from ldm.modules.midas.midas.midas_net import MidasNet
+from ldm.modules.midas.midas.midas_net_custom import MidasNet_small
+from ldm.modules.midas.midas.transforms import Resize, NormalizeImage, PrepareForNet
+
+
+ISL_PATHS = {
+ "dpt_large": "midas_models/dpt_large-midas-2f21e586.pt",
+ "dpt_hybrid": "midas_models/dpt_hybrid-midas-501f0c75.pt",
+ "midas_v21": "",
+ "midas_v21_small": "",
+}
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def load_midas_transform(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load transform only
+ if model_type == "dpt_large": # DPT-Large
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ elif model_type == "midas_v21_small":
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+ else:
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return transform
+
+
+def load_model(model_type):
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
+ # load network
+ model_path = ISL_PATHS[model_type]
+ if model_type == "dpt_large": # DPT-Large
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitl16_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "dpt_hybrid": # DPT-Hybrid
+ model = DPTDepthModel(
+ path=model_path,
+ backbone="vitb_rn50_384",
+ non_negative=True,
+ )
+ net_w, net_h = 384, 384
+ resize_mode = "minimal"
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+
+ elif model_type == "midas_v21":
+ model = MidasNet(model_path, non_negative=True)
+ net_w, net_h = 384, 384
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ elif model_type == "midas_v21_small":
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
+ non_negative=True, blocks={'expand': True})
+ net_w, net_h = 256, 256
+ resize_mode = "upper_bound"
+ normalization = NormalizeImage(
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ )
+
+ else:
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
+ assert False
+
+ transform = Compose(
+ [
+ Resize(
+ net_w,
+ net_h,
+ resize_target=None,
+ keep_aspect_ratio=True,
+ ensure_multiple_of=32,
+ resize_method=resize_mode,
+ image_interpolation_method=cv2.INTER_CUBIC,
+ ),
+ normalization,
+ PrepareForNet(),
+ ]
+ )
+
+ return model.eval(), transform
+
+
+class MiDaSInference(nn.Module):
+ MODEL_TYPES_TORCH_HUB = [
+ "DPT_Large",
+ "DPT_Hybrid",
+ "MiDaS_small"
+ ]
+ MODEL_TYPES_ISL = [
+ "dpt_large",
+ "dpt_hybrid",
+ "midas_v21",
+ "midas_v21_small",
+ ]
+
+ def __init__(self, model_type):
+ super().__init__()
+ assert (model_type in self.MODEL_TYPES_ISL)
+ model, _ = load_model(model_type)
+ self.model = model
+ self.model.train = disabled_train
+
+ def forward(self, x):
+ # x in 0..1 as produced by calling self.transform on a 0..1 float64 numpy array
+ # NOTE: we expect that the correct transform has been called during dataloading.
+ with torch.no_grad():
+ prediction = self.model(x)
+ prediction = torch.nn.functional.interpolate(
+ prediction.unsqueeze(1),
+ size=x.shape[2:],
+ mode="bicubic",
+ align_corners=False,
+ )
+ assert prediction.shape == (x.shape[0], 1, x.shape[2], x.shape[3])
+ return prediction
+
diff --git a/repositories/ldm/modules/midas/midas/__init__.py b/repositories/ldm/modules/midas/midas/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/repositories/ldm/modules/midas/midas/base_model.py b/repositories/ldm/modules/midas/midas/base_model.py
new file mode 100644
index 000000000..5cf430239
--- /dev/null
+++ b/repositories/ldm/modules/midas/midas/base_model.py
@@ -0,0 +1,16 @@
+import torch
+
+
+class BaseModel(torch.nn.Module):
+ def load(self, path):
+ """Load model from file.
+
+ Args:
+ path (str): file path
+ """
+ parameters = torch.load(path, map_location=torch.device('cpu'))
+
+ if "optimizer" in parameters:
+ parameters = parameters["model"]
+
+ self.load_state_dict(parameters)
diff --git a/repositories/ldm/modules/midas/midas/blocks.py b/repositories/ldm/modules/midas/midas/blocks.py
new file mode 100644
index 000000000..2145d18fa
--- /dev/null
+++ b/repositories/ldm/modules/midas/midas/blocks.py
@@ -0,0 +1,342 @@
+import torch
+import torch.nn as nn
+
+from .vit import (
+ _make_pretrained_vitb_rn50_384,
+ _make_pretrained_vitl16_384,
+ _make_pretrained_vitb16_384,
+ forward_vit,
+)
+
+def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
+ if backbone == "vitl16_384":
+ pretrained = _make_pretrained_vitl16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb_rn50_384":
+ pretrained = _make_pretrained_vitb_rn50_384(
+ use_pretrained,
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
+ scratch = _make_scratch(
+ [256, 512, 768, 768], features, groups=groups, expand=expand
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
+ elif backbone == "vitb16_384":
+ pretrained = _make_pretrained_vitb16_384(
+ use_pretrained, hooks=hooks, use_readout=use_readout
+ )
+ scratch = _make_scratch(
+ [96, 192, 384, 768], features, groups=groups, expand=expand
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
+ elif backbone == "resnext101_wsl":
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
+ elif backbone == "efficientnet_lite3":
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
+ else:
+ print(f"Backbone '{backbone}' not implemented")
+ assert False
+
+ return pretrained, scratch
+
+
+def _make_scratch(in_shape, out_shape, groups=1, expand=False):
+ scratch = nn.Module()
+
+ out_shape1 = out_shape
+ out_shape2 = out_shape
+ out_shape3 = out_shape
+ out_shape4 = out_shape
+ if expand==True:
+ out_shape1 = out_shape
+ out_shape2 = out_shape*2
+ out_shape3 = out_shape*4
+ out_shape4 = out_shape*8
+
+ scratch.layer1_rn = nn.Conv2d(
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer2_rn = nn.Conv2d(
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer3_rn = nn.Conv2d(
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+ scratch.layer4_rn = nn.Conv2d(
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
+ )
+
+ return scratch
+
+
+def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
+ efficientnet = torch.hub.load(
+ "rwightman/gen-efficientnet-pytorch",
+ "tf_efficientnet_lite3",
+ pretrained=use_pretrained,
+ exportable=exportable
+ )
+ return _make_efficientnet_backbone(efficientnet)
+
+
+def _make_efficientnet_backbone(effnet):
+ pretrained = nn.Module()
+
+ pretrained.layer1 = nn.Sequential(
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
+ )
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
+
+ return pretrained
+
+
+def _make_resnet_backbone(resnet):
+ pretrained = nn.Module()
+ pretrained.layer1 = nn.Sequential(
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
+ )
+
+ pretrained.layer2 = resnet.layer2
+ pretrained.layer3 = resnet.layer3
+ pretrained.layer4 = resnet.layer4
+
+ return pretrained
+
+
+def _make_pretrained_resnext101_wsl(use_pretrained):
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
+ return _make_resnet_backbone(resnet)
+
+
+
+class Interpolate(nn.Module):
+ """Interpolation module.
+ """
+
+ def __init__(self, scale_factor, mode, align_corners=False):
+ """Init.
+
+ Args:
+ scale_factor (float): scaling
+ mode (str): interpolation mode
+ """
+ super(Interpolate, self).__init__()
+
+ self.interp = nn.functional.interpolate
+ self.scale_factor = scale_factor
+ self.mode = mode
+ self.align_corners = align_corners
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: interpolated data
+ """
+
+ x = self.interp(
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
+ )
+
+ return x
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock, self).__init__()
+
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ output += self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=True
+ )
+
+ return output
+
+
+
+
+class ResidualConvUnit_custom(nn.Module):
+ """Residual convolution module.
+ """
+
+ def __init__(self, features, activation, bn):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.bn = bn
+
+ self.groups=1
+
+ self.conv1 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ self.conv2 = nn.Conv2d(
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
+ )
+
+ if self.bn==True:
+ self.bn1 = nn.BatchNorm2d(features)
+ self.bn2 = nn.BatchNorm2d(features)
+
+ self.activation = activation
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input
+
+ Returns:
+ tensor: output
+ """
+
+ out = self.activation(x)
+ out = self.conv1(out)
+ if self.bn==True:
+ out = self.bn1(out)
+
+ out = self.activation(out)
+ out = self.conv2(out)
+ if self.bn==True:
+ out = self.bn2(out)
+
+ if self.groups > 1:
+ out = self.conv_merge(out)
+
+ return self.skip_add.add(out, x)
+
+ # return out + x
+
+
+class FeatureFusionBlock_custom(nn.Module):
+ """Feature fusion block.
+ """
+
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
+ """Init.
+
+ Args:
+ features (int): number of features
+ """
+ super(FeatureFusionBlock_custom, self).__init__()
+
+ self.deconv = deconv
+ self.align_corners = align_corners
+
+ self.groups=1
+
+ self.expand = expand
+ out_features = features
+ if self.expand==True:
+ out_features = features//2
+
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
+
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
+
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, *xs):
+ """Forward pass.
+
+ Returns:
+ tensor: output
+ """
+ output = xs[0]
+
+ if len(xs) == 2:
+ res = self.resConfUnit1(xs[1])
+ output = self.skip_add.add(output, res)
+ # output += res
+
+ output = self.resConfUnit2(output)
+
+ output = nn.functional.interpolate(
+ output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
+ )
+
+ output = self.out_conv(output)
+
+ return output
+
diff --git a/repositories/ldm/modules/midas/midas/dpt_depth.py b/repositories/ldm/modules/midas/midas/dpt_depth.py
new file mode 100644
index 000000000..4e9aab5d2
--- /dev/null
+++ b/repositories/ldm/modules/midas/midas/dpt_depth.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .base_model import BaseModel
+from .blocks import (
+ FeatureFusionBlock,
+ FeatureFusionBlock_custom,
+ Interpolate,
+ _make_encoder,
+ forward_vit,
+)
+
+
+def _make_fusion_block(features, use_bn):
+ return FeatureFusionBlock_custom(
+ features,
+ nn.ReLU(False),
+ deconv=False,
+ bn=use_bn,
+ expand=False,
+ align_corners=True,
+ )
+
+
+class DPT(BaseModel):
+ def __init__(
+ self,
+ head,
+ features=256,
+ backbone="vitb_rn50_384",
+ readout="project",
+ channels_last=False,
+ use_bn=False,
+ ):
+
+ super(DPT, self).__init__()
+
+ self.channels_last = channels_last
+
+ hooks = {
+ "vitb_rn50_384": [0, 1, 8, 11],
+ "vitb16_384": [2, 5, 8, 11],
+ "vitl16_384": [5, 11, 17, 23],
+ }
+
+ # Instantiate backbone and reassemble blocks
+ self.pretrained, self.scratch = _make_encoder(
+ backbone,
+ features,
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
+ groups=1,
+ expand=False,
+ exportable=False,
+ hooks=hooks[backbone],
+ use_readout=readout,
+ )
+
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
+
+ self.scratch.output_conv = head
+
+
+ def forward(self, x):
+ if self.channels_last == True:
+ x.contiguous(memory_format=torch.channels_last)
+
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return out
+
+
+class DPTDepthModel(DPT):
+ def __init__(self, path=None, non_negative=True, **kwargs):
+ features = kwargs["features"] if "features" in kwargs else 256
+
+ head = nn.Sequential(
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ super().__init__(head, **kwargs)
+
+ if path is not None:
+ self.load(path)
+
+ def forward(self, x):
+ return super().forward(x).squeeze(dim=1)
+
diff --git a/repositories/ldm/modules/midas/midas/midas_net.py b/repositories/ldm/modules/midas/midas/midas_net.py
new file mode 100644
index 000000000..8a9549778
--- /dev/null
+++ b/repositories/ldm/modules/midas/midas/midas_net.py
@@ -0,0 +1,76 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
+
+
+class MidasNet(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=256, non_negative=True):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet, self).__init__()
+
+ use_pretrained = False if path is None else True
+
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
+
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
+ nn.ReLU(True),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
diff --git a/repositories/ldm/modules/midas/midas/midas_net_custom.py b/repositories/ldm/modules/midas/midas/midas_net_custom.py
new file mode 100644
index 000000000..50e4acb5e
--- /dev/null
+++ b/repositories/ldm/modules/midas/midas/midas_net_custom.py
@@ -0,0 +1,128 @@
+"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
+This file contains code that is adapted from
+https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
+"""
+import torch
+import torch.nn as nn
+
+from .base_model import BaseModel
+from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
+
+
+class MidasNet_small(BaseModel):
+ """Network for monocular depth estimation.
+ """
+
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
+ blocks={'expand': True}):
+ """Init.
+
+ Args:
+ path (str, optional): Path to saved model. Defaults to None.
+ features (int, optional): Number of features. Defaults to 256.
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
+ """
+ print("Loading weights: ", path)
+
+ super(MidasNet_small, self).__init__()
+
+ use_pretrained = False if path else True
+
+ self.channels_last = channels_last
+ self.blocks = blocks
+ self.backbone = backbone
+
+ self.groups = 1
+
+ features1=features
+ features2=features
+ features3=features
+ features4=features
+ self.expand = False
+ if "expand" in self.blocks and self.blocks['expand'] == True:
+ self.expand = True
+ features1=features
+ features2=features*2
+ features3=features*4
+ features4=features*8
+
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
+
+ self.scratch.activation = nn.ReLU(False)
+
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
+
+
+ self.scratch.output_conv = nn.Sequential(
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
+ Interpolate(scale_factor=2, mode="bilinear"),
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
+ self.scratch.activation,
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ nn.ReLU(True) if non_negative else nn.Identity(),
+ nn.Identity(),
+ )
+
+ if path:
+ self.load(path)
+
+
+ def forward(self, x):
+ """Forward pass.
+
+ Args:
+ x (tensor): input data (image)
+
+ Returns:
+ tensor: depth
+ """
+ if self.channels_last==True:
+ print("self.channels_last = ", self.channels_last)
+ x.contiguous(memory_format=torch.channels_last)
+
+
+ layer_1 = self.pretrained.layer1(x)
+ layer_2 = self.pretrained.layer2(layer_1)
+ layer_3 = self.pretrained.layer3(layer_2)
+ layer_4 = self.pretrained.layer4(layer_3)
+
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
+
+
+ path_4 = self.scratch.refinenet4(layer_4_rn)
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
+
+ out = self.scratch.output_conv(path_1)
+
+ return torch.squeeze(out, dim=1)
+
+
+
+def fuse_model(m):
+ prev_previous_type = nn.Identity()
+ prev_previous_name = ''
+ previous_type = nn.Identity()
+ previous_name = ''
+ for name, module in m.named_modules():
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
+ # print("FUSED ", prev_previous_name, previous_name, name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
+ # print("FUSED ", prev_previous_name, previous_name)
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
+ # print("FUSED ", previous_name, name)
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
+
+ prev_previous_type = previous_type
+ prev_previous_name = previous_name
+ previous_type = type(module)
+ previous_name = name
\ No newline at end of file
diff --git a/repositories/ldm/modules/midas/midas/transforms.py b/repositories/ldm/modules/midas/midas/transforms.py
new file mode 100644
index 000000000..350cbc116
--- /dev/null
+++ b/repositories/ldm/modules/midas/midas/transforms.py
@@ -0,0 +1,234 @@
+import numpy as np
+import cv2
+import math
+
+
+def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
+
+ Args:
+ sample (dict): sample
+ size (tuple): image size
+
+ Returns:
+ tuple: new size
+ """
+ shape = list(sample["disparity"].shape)
+
+ if shape[0] >= size[0] and shape[1] >= size[1]:
+ return sample
+
+ scale = [0, 0]
+ scale[0] = size[0] / shape[0]
+ scale[1] = size[1] / shape[1]
+
+ scale = max(scale)
+
+ shape[0] = math.ceil(scale * shape[0])
+ shape[1] = math.ceil(scale * shape[1])
+
+ # resize
+ sample["image"] = cv2.resize(
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
+ )
+
+ sample["disparity"] = cv2.resize(
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
+ )
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ tuple(shape[::-1]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return tuple(shape)
+
+
+class Resize(object):
+ """Resize sample to given size (width, height).
+ """
+
+ def __init__(
+ self,
+ width,
+ height,
+ resize_target=True,
+ keep_aspect_ratio=False,
+ ensure_multiple_of=1,
+ resize_method="lower_bound",
+ image_interpolation_method=cv2.INTER_AREA,
+ ):
+ """Init.
+
+ Args:
+ width (int): desired output width
+ height (int): desired output height
+ resize_target (bool, optional):
+ True: Resize the full sample (image, mask, target).
+ False: Resize image only.
+ Defaults to True.
+ keep_aspect_ratio (bool, optional):
+ True: Keep the aspect ratio of the input sample.
+ Output sample might not have the given width and height, and
+ resize behaviour depends on the parameter 'resize_method'.
+ Defaults to False.
+ ensure_multiple_of (int, optional):
+ Output width and height is constrained to be multiple of this parameter.
+ Defaults to 1.
+ resize_method (str, optional):
+ "lower_bound": Output will be at least as large as the given size.
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
+ Defaults to "lower_bound".
+ """
+ self.__width = width
+ self.__height = height
+
+ self.__resize_target = resize_target
+ self.__keep_aspect_ratio = keep_aspect_ratio
+ self.__multiple_of = ensure_multiple_of
+ self.__resize_method = resize_method
+ self.__image_interpolation_method = image_interpolation_method
+
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if max_val is not None and y > max_val:
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ if y < min_val:
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
+
+ return y
+
+ def get_size(self, width, height):
+ # determine new height and width
+ scale_height = self.__height / height
+ scale_width = self.__width / width
+
+ if self.__keep_aspect_ratio:
+ if self.__resize_method == "lower_bound":
+ # scale such that output size is lower bound
+ if scale_width > scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "upper_bound":
+ # scale such that output size is upper bound
+ if scale_width < scale_height:
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ elif self.__resize_method == "minimal":
+ # scale as least as possbile
+ if abs(1 - scale_width) < abs(1 - scale_height):
+ # fit width
+ scale_height = scale_width
+ else:
+ # fit height
+ scale_width = scale_height
+ else:
+ raise ValueError(
+ f"resize_method {self.__resize_method} not implemented"
+ )
+
+ if self.__resize_method == "lower_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, min_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, min_val=self.__width
+ )
+ elif self.__resize_method == "upper_bound":
+ new_height = self.constrain_to_multiple_of(
+ scale_height * height, max_val=self.__height
+ )
+ new_width = self.constrain_to_multiple_of(
+ scale_width * width, max_val=self.__width
+ )
+ elif self.__resize_method == "minimal":
+ new_height = self.constrain_to_multiple_of(scale_height * height)
+ new_width = self.constrain_to_multiple_of(scale_width * width)
+ else:
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
+
+ return (new_width, new_height)
+
+ def __call__(self, sample):
+ width, height = self.get_size(
+ sample["image"].shape[1], sample["image"].shape[0]
+ )
+
+ # resize sample
+ sample["image"] = cv2.resize(
+ sample["image"],
+ (width, height),
+ interpolation=self.__image_interpolation_method,
+ )
+
+ if self.__resize_target:
+ if "disparity" in sample:
+ sample["disparity"] = cv2.resize(
+ sample["disparity"],
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ if "depth" in sample:
+ sample["depth"] = cv2.resize(
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
+ )
+
+ sample["mask"] = cv2.resize(
+ sample["mask"].astype(np.float32),
+ (width, height),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ sample["mask"] = sample["mask"].astype(bool)
+
+ return sample
+
+
+class NormalizeImage(object):
+ """Normlize image by given mean and std.
+ """
+
+ def __init__(self, mean, std):
+ self.__mean = mean
+ self.__std = std
+
+ def __call__(self, sample):
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
+
+ return sample
+
+
+class PrepareForNet(object):
+ """Prepare sample for usage as network input.
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, sample):
+ image = np.transpose(sample["image"], (2, 0, 1))
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
+
+ if "mask" in sample:
+ sample["mask"] = sample["mask"].astype(np.float32)
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
+
+ if "disparity" in sample:
+ disparity = sample["disparity"].astype(np.float32)
+ sample["disparity"] = np.ascontiguousarray(disparity)
+
+ if "depth" in sample:
+ depth = sample["depth"].astype(np.float32)
+ sample["depth"] = np.ascontiguousarray(depth)
+
+ return sample
diff --git a/repositories/ldm/modules/midas/midas/vit.py b/repositories/ldm/modules/midas/midas/vit.py
new file mode 100644
index 000000000..ea46b1be8
--- /dev/null
+++ b/repositories/ldm/modules/midas/midas/vit.py
@@ -0,0 +1,491 @@
+import torch
+import torch.nn as nn
+import timm
+import types
+import math
+import torch.nn.functional as F
+
+
+class Slice(nn.Module):
+ def __init__(self, start_index=1):
+ super(Slice, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ return x[:, self.start_index :]
+
+
+class AddReadout(nn.Module):
+ def __init__(self, start_index=1):
+ super(AddReadout, self).__init__()
+ self.start_index = start_index
+
+ def forward(self, x):
+ if self.start_index == 2:
+ readout = (x[:, 0] + x[:, 1]) / 2
+ else:
+ readout = x[:, 0]
+ return x[:, self.start_index :] + readout.unsqueeze(1)
+
+
+class ProjectReadout(nn.Module):
+ def __init__(self, in_features, start_index=1):
+ super(ProjectReadout, self).__init__()
+ self.start_index = start_index
+
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
+
+ def forward(self, x):
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
+ features = torch.cat((x[:, self.start_index :], readout), -1)
+
+ return self.project(features)
+
+
+class Transpose(nn.Module):
+ def __init__(self, dim0, dim1):
+ super(Transpose, self).__init__()
+ self.dim0 = dim0
+ self.dim1 = dim1
+
+ def forward(self, x):
+ x = x.transpose(self.dim0, self.dim1)
+ return x
+
+
+def forward_vit(pretrained, x):
+ b, c, h, w = x.shape
+
+ glob = pretrained.model.forward_flex(x)
+
+ layer_1 = pretrained.activations["1"]
+ layer_2 = pretrained.activations["2"]
+ layer_3 = pretrained.activations["3"]
+ layer_4 = pretrained.activations["4"]
+
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
+
+ unflatten = nn.Sequential(
+ nn.Unflatten(
+ 2,
+ torch.Size(
+ [
+ h // pretrained.model.patch_size[1],
+ w // pretrained.model.patch_size[0],
+ ]
+ ),
+ )
+ )
+
+ if layer_1.ndim == 3:
+ layer_1 = unflatten(layer_1)
+ if layer_2.ndim == 3:
+ layer_2 = unflatten(layer_2)
+ if layer_3.ndim == 3:
+ layer_3 = unflatten(layer_3)
+ if layer_4.ndim == 3:
+ layer_4 = unflatten(layer_4)
+
+ layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)
+ layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)
+ layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)
+ layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)
+
+ return layer_1, layer_2, layer_3, layer_4
+
+
+def _resize_pos_embed(self, posemb, gs_h, gs_w):
+ posemb_tok, posemb_grid = (
+ posemb[:, : self.start_index],
+ posemb[0, self.start_index :],
+ )
+
+ gs_old = int(math.sqrt(len(posemb_grid)))
+
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
+
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+
+ return posemb
+
+
+def forward_flex(self, x):
+ b, c, h, w = x.shape
+
+ pos_embed = self._resize_pos_embed(
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
+ )
+
+ B = x.shape[0]
+
+ if hasattr(self.patch_embed, "backbone"):
+ x = self.patch_embed.backbone(x)
+ if isinstance(x, (list, tuple)):
+ x = x[-1] # last feature if backbone outputs list/tuple of features
+
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
+
+ if getattr(self, "dist_token", None) is not None:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+ else:
+ cls_tokens = self.cls_token.expand(
+ B, -1, -1
+ ) # stole cls_tokens impl from Phil Wang, thanks
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+
+ return x
+
+
+activations = {}
+
+
+def get_activation(name):
+ def hook(model, input, output):
+ activations[name] = output
+
+ return hook
+
+
+def get_readout_oper(vit_features, features, use_readout, start_index=1):
+ if use_readout == "ignore":
+ readout_oper = [Slice(start_index)] * len(features)
+ elif use_readout == "add":
+ readout_oper = [AddReadout(start_index)] * len(features)
+ elif use_readout == "project":
+ readout_oper = [
+ ProjectReadout(vit_features, start_index) for out_feat in features
+ ]
+ else:
+ assert (
+ False
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
+
+ return readout_oper
+
+
+def _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ size=[384, 384],
+ hooks=[2, 5, 8, 11],
+ vit_features=768,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ # 32, 48, 136, 384
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
+
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[256, 512, 1024, 1024],
+ hooks=hooks,
+ vit_features=1024,
+ use_readout=use_readout,
+ )
+
+
+def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model("vit_deit_base_patch16_384", pretrained=pretrained)
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
+ )
+
+
+def _make_pretrained_deitb16_distil_384(pretrained, use_readout="ignore", hooks=None):
+ model = timm.create_model(
+ "vit_deit_base_distilled_patch16_384", pretrained=pretrained
+ )
+
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
+ return _make_vit_b16_backbone(
+ model,
+ features=[96, 192, 384, 768],
+ hooks=hooks,
+ use_readout=use_readout,
+ start_index=2,
+ )
+
+
+def _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=[0, 1, 8, 11],
+ vit_features=768,
+ use_vit_only=False,
+ use_readout="ignore",
+ start_index=1,
+):
+ pretrained = nn.Module()
+
+ pretrained.model = model
+
+ if use_vit_only == True:
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
+ else:
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
+ get_activation("1")
+ )
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
+ get_activation("2")
+ )
+
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
+
+ pretrained.activations = activations
+
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
+
+ if use_vit_only == True:
+ pretrained.act_postprocess1 = nn.Sequential(
+ readout_oper[0],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[0],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[0],
+ out_channels=features[0],
+ kernel_size=4,
+ stride=4,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+
+ pretrained.act_postprocess2 = nn.Sequential(
+ readout_oper[1],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[1],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.ConvTranspose2d(
+ in_channels=features[1],
+ out_channels=features[1],
+ kernel_size=2,
+ stride=2,
+ padding=0,
+ bias=True,
+ dilation=1,
+ groups=1,
+ ),
+ )
+ else:
+ pretrained.act_postprocess1 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+ pretrained.act_postprocess2 = nn.Sequential(
+ nn.Identity(), nn.Identity(), nn.Identity()
+ )
+
+ pretrained.act_postprocess3 = nn.Sequential(
+ readout_oper[2],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[2],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ )
+
+ pretrained.act_postprocess4 = nn.Sequential(
+ readout_oper[3],
+ Transpose(1, 2),
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
+ nn.Conv2d(
+ in_channels=vit_features,
+ out_channels=features[3],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ ),
+ nn.Conv2d(
+ in_channels=features[3],
+ out_channels=features[3],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ ),
+ )
+
+ pretrained.model.start_index = start_index
+ pretrained.model.patch_size = [16, 16]
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
+
+ # We inject this function into the VisionTransformer instances so that
+ # we can use it with interpolated position embeddings without modifying the library source.
+ pretrained.model._resize_pos_embed = types.MethodType(
+ _resize_pos_embed, pretrained.model
+ )
+
+ return pretrained
+
+
+def _make_pretrained_vitb_rn50_384(
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
+):
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
+
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
+ return _make_vit_b_rn50_backbone(
+ model,
+ features=[256, 512, 768, 768],
+ size=[384, 384],
+ hooks=hooks,
+ use_vit_only=use_vit_only,
+ use_readout=use_readout,
+ )
diff --git a/repositories/ldm/modules/midas/utils.py b/repositories/ldm/modules/midas/utils.py
new file mode 100644
index 000000000..9a9d3b5b6
--- /dev/null
+++ b/repositories/ldm/modules/midas/utils.py
@@ -0,0 +1,189 @@
+"""Utils for monoDepth."""
+import sys
+import re
+import numpy as np
+import cv2
+import torch
+
+
+def read_pfm(path):
+ """Read pfm file.
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ tuple: (data, scale)
+ """
+ with open(path, "rb") as file:
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header.decode("ascii") == "PF":
+ color = True
+ elif header.decode("ascii") == "Pf":
+ color = False
+ else:
+ raise Exception("Not a PFM file: " + path)
+
+ dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
+ if dim_match:
+ width, height = list(map(int, dim_match.groups()))
+ else:
+ raise Exception("Malformed PFM header.")
+
+ scale = float(file.readline().decode("ascii").rstrip())
+ if scale < 0:
+ # little-endian
+ endian = "<"
+ scale = -scale
+ else:
+ # big-endian
+ endian = ">"
+
+ data = np.fromfile(file, endian + "f")
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+
+ return data, scale
+
+
+def write_pfm(path, image, scale=1):
+ """Write pfm file.
+
+ Args:
+ path (str): pathto file
+ image (array): data
+ scale (int, optional): Scale. Defaults to 1.
+ """
+
+ with open(path, "wb") as file:
+ color = None
+
+ if image.dtype.name != "float32":
+ raise Exception("Image dtype must be float32.")
+
+ image = np.flipud(image)
+
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
+ color = True
+ elif (
+ len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
+ ): # greyscale
+ color = False
+ else:
+ raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
+
+ file.write("PF\n" if color else "Pf\n".encode())
+ file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
+
+ endian = image.dtype.byteorder
+
+ if endian == "<" or endian == "=" and sys.byteorder == "little":
+ scale = -scale
+
+ file.write("%f\n".encode() % scale)
+
+ image.tofile(file)
+
+
+def read_image(path):
+ """Read image and output RGB image (0-1).
+
+ Args:
+ path (str): path to file
+
+ Returns:
+ array: RGB image (0-1)
+ """
+ img = cv2.imread(path)
+
+ if img.ndim == 2:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
+
+ return img
+
+
+def resize_image(img):
+ """Resize image and make it fit for network.
+
+ Args:
+ img (array): image
+
+ Returns:
+ tensor: data ready for network
+ """
+ height_orig = img.shape[0]
+ width_orig = img.shape[1]
+
+ if width_orig > height_orig:
+ scale = width_orig / 384
+ else:
+ scale = height_orig / 384
+
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
+
+ img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
+
+ img_resized = (
+ torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
+ )
+ img_resized = img_resized.unsqueeze(0)
+
+ return img_resized
+
+
+def resize_depth(depth, width, height):
+ """Resize depth map and bring to CPU (numpy).
+
+ Args:
+ depth (tensor): depth
+ width (int): image width
+ height (int): image height
+
+ Returns:
+ array: processed depth
+ """
+ depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
+
+ depth_resized = cv2.resize(
+ depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
+ )
+
+ return depth_resized
+
+def write_depth(path, depth, bits=1):
+ """Write depth map to pfm and png file.
+
+ Args:
+ path (str): filepath without extension
+ depth (array): depth
+ """
+ write_pfm(path + ".pfm", depth.astype(np.float32))
+
+ depth_min = depth.min()
+ depth_max = depth.max()
+
+ max_val = (2**(8*bits))-1
+
+ if depth_max - depth_min > np.finfo("float").eps:
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
+ else:
+ out = np.zeros(depth.shape, dtype=depth.type)
+
+ if bits == 1:
+ cv2.imwrite(path + ".png", out.astype("uint8"))
+ elif bits == 2:
+ cv2.imwrite(path + ".png", out.astype("uint16"))
+
+ return
diff --git a/repositories/ldm/util.py b/repositories/ldm/util.py
new file mode 100644
index 000000000..9ede259d5
--- /dev/null
+++ b/repositories/ldm/util.py
@@ -0,0 +1,207 @@
+import importlib
+
+import torch
+from torch import optim
+import numpy as np
+
+from inspect import isfunction
+from PIL import Image, ImageDraw, ImageFont
+
+
+def autocast(f):
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(enabled=True,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled()):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def log_txt_as_img(wh, xc, size=10):
+ # wh a tuple of (width, height)
+ # xc a list of captions to plot
+ b = len(xc)
+ txts = list()
+ for bi in range(b):
+ txt = Image.new("RGB", wh, color="white")
+ draw = ImageDraw.Draw(txt)
+ font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
+ nc = int(40 * (wh[0] / 256))
+ lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
+
+ try:
+ draw.text((0, 0), lines, fill="black", font=font)
+ except UnicodeEncodeError:
+ print("Cant encode string for logging. Skipping.")
+
+ txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
+ txts.append(txt)
+ txts = np.stack(txts)
+ txts = torch.tensor(txts)
+ return txts
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x,torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def mean_flat(tensor):
+ """
+ https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
+ return total_params
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+class AdamWwithEMAandWings(optim.Optimizer):
+ # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
+ def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
+ weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
+ ema_power=1., param_names=()):
+ """AdamW that saves EMA versions of the parameters."""
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0.0 <= ema_decay <= 1.0:
+ raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
+ ema_power=ema_power, param_names=param_names)
+ super().__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super().__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ params_with_grad = []
+ grads = []
+ exp_avgs = []
+ exp_avg_sqs = []
+ ema_params_with_grad = []
+ state_sums = []
+ max_exp_avg_sqs = []
+ state_steps = []
+ amsgrad = group['amsgrad']
+ beta1, beta2 = group['betas']
+ ema_decay = group['ema_decay']
+ ema_power = group['ema_power']
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ params_with_grad.append(p)
+ if p.grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ grads.append(p.grad)
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of parameter values
+ state['param_exp_avg'] = p.detach().float().clone()
+
+ exp_avgs.append(state['exp_avg'])
+ exp_avg_sqs.append(state['exp_avg_sq'])
+ ema_params_with_grad.append(state['param_exp_avg'])
+
+ if amsgrad:
+ max_exp_avg_sqs.append(state['max_exp_avg_sq'])
+
+ # update the steps for each param group update
+ state['step'] += 1
+ # record the step after step update
+ state_steps.append(state['step'])
+
+ optim._functional.adamw(params_with_grad,
+ grads,
+ exp_avgs,
+ exp_avg_sqs,
+ max_exp_avg_sqs,
+ state_steps,
+ amsgrad=amsgrad,
+ beta1=beta1,
+ beta2=beta2,
+ lr=group['lr'],
+ weight_decay=group['weight_decay'],
+ eps=group['eps'],
+ maximize=False)
+
+ cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
+ for param, ema_param in zip(params_with_grad, ema_params_with_grad):
+ ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
+
+ return loss
\ No newline at end of file
diff --git a/repositories/taming/README.md b/repositories/taming/README.md
new file mode 100644
index 000000000..d295fbf75
--- /dev/null
+++ b/repositories/taming/README.md
@@ -0,0 +1,410 @@
+# Taming Transformers for High-Resolution Image Synthesis
+##### CVPR 2021 (Oral)
+![teaser](assets/mountain.jpeg)
+
+[**Taming Transformers for High-Resolution Image Synthesis**](https://compvis.github.io/taming-transformers/)
+[Patrick Esser](https://github.com/pesser)\*,
+[Robin Rombach](https://github.com/rromb)\*,
+[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
+\* equal contribution
+
+**tl;dr** We combine the efficiancy of convolutional approaches with the expressivity of transformers by introducing a convolutional VQGAN, which learns a codebook of context-rich visual parts, whose composition is modeled with an autoregressive transformer.
+
+![teaser](assets/teaser.png)
+[arXiv](https://arxiv.org/abs/2012.09841) | [BibTeX](#bibtex) | [Project Page](https://compvis.github.io/taming-transformers/)
+
+
+### News
+#### 2022
+- More pretrained VQGANs (e.g. a f8-model with only 256 codebook entries) are available in our new work on [Latent Diffusion Models](https://github.com/CompVis/latent-diffusion).
+- Added scene synthesis models as proposed in the paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458), see [this section](#scene-image-synthesis).
+#### 2021
+- Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
+- Included a bugfix for the quantizer. For backward compatibility it is
+ disabled by default (which corresponds to always training with `beta=1.0`).
+ Use `legacy=False` in the quantizer config to enable it.
+ Thanks [richcmwang](https://github.com/richcmwang) and [wcshin-git](https://github.com/wcshin-git)!
+- Our paper received an update: See https://arxiv.org/abs/2012.09841v3 and the corresponding changelog.
+- Added a pretrained, [1.4B transformer model](https://k00.fr/s511rwcv) trained for class-conditional ImageNet synthesis, which obtains state-of-the-art FID scores among autoregressive approaches and outperforms BigGAN.
+- Added pretrained, unconditional models on [FFHQ](https://k00.fr/yndvfu95) and [CelebA-HQ](https://k00.fr/2xkmielf).
+- Added accelerated sampling via caching of keys/values in the self-attention operation, used in `scripts/sample_fast.py`.
+- Added a checkpoint of a [VQGAN](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) trained with f8 compression and Gumbel-Quantization.
+ See also our updated [reconstruction notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
+- We added a [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb) which compares two VQGANs and OpenAI's [DALL-E](https://github.com/openai/DALL-E). See also [this section](#more-resources).
+- We now include an overview of pretrained models in [Tab.1](#overview-of-pretrained-models). We added models for [COCO](#coco) and [ADE20k](#ade20k).
+- The streamlit demo now supports image completions.
+- We now include a couple of examples from the D-RIN dataset so you can run the
+ [D-RIN demo](#d-rin) without preparing the dataset first.
+- You can now jump right into sampling with our [Colab quickstart notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb).
+
+## Requirements
+A suitable [conda](https://conda.io/) environment named `taming` can be created
+and activated with:
+
+```
+conda env create -f environment.yaml
+conda activate taming
+```
+## Overview of pretrained models
+The following table provides an overview of all models that are currently available.
+FID scores were evaluated using [torch-fidelity](https://github.com/toshas/torch-fidelity).
+For reference, we also include a link to the recently released autoencoder of the [DALL-E](https://github.com/openai/DALL-E) model.
+See the corresponding [colab
+notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb)
+for a comparison and discussion of reconstruction capabilities.
+
+| Dataset | FID vs train | FID vs val | Link | Samples (256x256) | Comments
+| ------------- | ------------- | ------------- |------------- | ------------- |------------- |
+| FFHQ (f=16) | 9.6 | -- | [ffhq_transformer](https://k00.fr/yndvfu95) | [ffhq_samples](https://k00.fr/j626x093) |
+| CelebA-HQ (f=16) | 10.2 | -- | [celebahq_transformer](https://k00.fr/2xkmielf) | [celebahq_samples](https://k00.fr/j626x093) |
+| ADE20K (f=16) | -- | 35.5 | [ade20k_transformer](https://k00.fr/ot46cksa) | [ade20k_samples.zip](https://heibox.uni-heidelberg.de/f/70bb78cbaf844501b8fb/) [2k] | evaluated on val split (2k images)
+| COCO-Stuff (f=16) | -- | 20.4 | [coco_transformer](https://k00.fr/2zz6i2ce) | [coco_samples.zip](https://heibox.uni-heidelberg.de/f/a395a9be612f4a7a8054/) [5k] | evaluated on val split (5k images)
+| ImageNet (cIN) (f=16) | 15.98/15.78/6.59/5.88/5.20 | -- | [cin_transformer](https://k00.fr/s511rwcv) | [cin_samples](https://k00.fr/j626x093) | different decoding hyperparameters |
+| | | | || |
+| FacesHQ (f=16) | -- | -- | [faceshq_transformer](https://k00.fr/qqfl2do8)
+| S-FLCKR (f=16) | -- | -- | [sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
+| D-RIN (f=16) | -- | -- | [drin_transformer](https://k00.fr/39jcugc5)
+| | | | | || |
+| VQGAN ImageNet (f=16), 1024 | 10.54 | 7.94 | [vqgan_imagenet_f16_1024](https://heibox.uni-heidelberg.de/d/8088892a516d4e3baf92/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
+| VQGAN ImageNet (f=16), 16384 | 7.41 | 4.98 |[vqgan_imagenet_f16_16384](https://heibox.uni-heidelberg.de/d/a7530b09fed84f80a887/) | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
+| VQGAN OpenImages (f=8), 256 | -- | 1.49 |https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion).
+| VQGAN OpenImages (f=8), 16384 | -- | 1.14 |https://ommer-lab.com/files/latent-diffusion/vq-f8.zip | --- | Reconstruction-FIDs. Available via [latent diffusion](https://github.com/CompVis/latent-diffusion)
+| VQGAN OpenImages (f=8), 8192, GumbelQuantization | 3.24 | 1.49 |[vqgan_gumbel_f8](https://heibox.uni-heidelberg.de/d/2e5662443a6b4307b470/) | --- | Reconstruction-FIDs.
+| | | | | || |
+| DALL-E dVAE (f=8), 8192, GumbelQuantization | 33.88 | 32.01 | https://github.com/openai/DALL-E | [reconstructions](https://k00.fr/j626x093) | Reconstruction-FIDs.
+
+
+## Running pretrained models
+
+The commands below will start a streamlit demo which supports sampling at
+different resolutions and image completions. To run a non-interactive version
+of the sampling process, replace `streamlit run scripts/sample_conditional.py --`
+by `python scripts/make_samples.py --outdir ` and
+keep the remaining command line arguments.
+
+To sample from unconditional or class-conditional models,
+run `python scripts/sample_fast.py -r `.
+We describe below how to use this script to sample from the ImageNet, FFHQ, and CelebA-HQ models,
+respectively.
+
+### S-FLCKR
+![teaser](assets/sunset_and_ocean.jpg)
+
+You can also [run this model in a Colab
+notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/taming-transformers.ipynb),
+which includes all necessary steps to start sampling.
+
+Download the
+[2020-11-09T13-31-51_sflckr](https://heibox.uni-heidelberg.de/d/73487ab6e5314cb5adba/)
+folder and place it into `logs`. Then, run
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-09T13-31-51_sflckr/
+```
+
+### ImageNet
+![teaser](assets/imagenet.png)
+
+Download the [2021-04-03T19-39-50_cin_transformer](https://k00.fr/s511rwcv)
+folder and place it into logs. Sampling from the class-conditional ImageNet
+model does not require any data preparation. To produce 50 samples for each of
+the 1000 classes of ImageNet, with k=600 for top-k sampling, p=0.92 for nucleus
+sampling and temperature t=1.0, run
+
+```
+python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25
+```
+
+To restrict the model to certain classes, provide them via the `--classes` argument, separated by
+commas. For example, to sample 50 *ostriches*, *border collies* and *whiskey jugs*, run
+
+```
+python scripts/sample_fast.py -r logs/2021-04-03T19-39-50_cin_transformer/ -n 50 -k 600 -t 1.0 -p 0.92 --batch_size 25 --classes 9,232,901
+```
+We recommended to experiment with the autoregressive decoding parameters (top-k, top-p and temperature) for best results.
+
+### FFHQ/CelebA-HQ
+
+Download the [2021-04-23T18-19-01_ffhq_transformer](https://k00.fr/yndvfu95) and
+[2021-04-23T18-11-19_celebahq_transformer](https://k00.fr/2xkmielf)
+folders and place them into logs.
+Again, sampling from these unconditional models does not require any data preparation.
+To produce 50000 samples, with k=250 for top-k sampling,
+p=1.0 for nucleus sampling and temperature t=1.0, run
+
+```
+python scripts/sample_fast.py -r logs/2021-04-23T18-19-01_ffhq_transformer/
+```
+for FFHQ and
+
+```
+python scripts/sample_fast.py -r logs/2021-04-23T18-11-19_celebahq_transformer/
+```
+to sample from the CelebA-HQ model.
+For both models it can be advantageous to vary the top-k/top-p parameters for sampling.
+
+### FacesHQ
+![teaser](assets/faceshq.jpg)
+
+Download [2020-11-13T21-41-45_faceshq_transformer](https://k00.fr/qqfl2do8) and
+place it into `logs`. Follow the data preparation steps for
+[CelebA-HQ](#celeba-hq) and [FFHQ](#ffhq). Run
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-13T21-41-45_faceshq_transformer/
+```
+
+### D-RIN
+![teaser](assets/drin.jpg)
+
+Download [2020-11-20T12-54-32_drin_transformer](https://k00.fr/39jcugc5) and
+place it into `logs`. To run the demo on a couple of example depth maps
+included in the repository, run
+
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.imagenet.DRINExamples}}}"
+```
+
+To run the demo on the complete validation set, first follow the data preparation steps for
+[ImageNet](#imagenet) and then run
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T12-54-32_drin_transformer/
+```
+
+### COCO
+Download [2021-01-20T16-04-20_coco_transformer](https://k00.fr/2zz6i2ce) and
+place it into `logs`. To run the demo on a couple of example segmentation maps
+included in the repository, run
+
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2021-01-20T16-04-20_coco_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.coco.Examples}}}"
+```
+
+### ADE20k
+Download [2020-11-20T21-45-44_ade20k_transformer](https://k00.fr/ot46cksa) and
+place it into `logs`. To run the demo on a couple of example segmentation maps
+included in the repository, run
+
+```
+streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
+```
+
+## Scene Image Synthesis
+![teaser](assets/scene_images_samples.svg)
+Scene image generation based on bounding box conditionals as done in our CVPR2021 AI4CC workshop paper [High-Resolution Complex Scene Synthesis with Transformers](https://arxiv.org/abs/2105.06458) (see talk on [workshop page](https://visual.cs.brown.edu/workshops/aicc2021/#awards)). Supporting the datasets COCO and Open Images.
+
+### Training
+Download first-stage models [COCO-8k-VQGAN](https://heibox.uni-heidelberg.de/f/78dea9589974474c97c1/) for COCO or [COCO/Open-Images-8k-VQGAN](https://heibox.uni-heidelberg.de/f/461d9a9f4fcf48ab84f4/) for Open Images.
+Change `ckpt_path` in `data/coco_scene_images_transformer.yaml` and `data/open_images_scene_images_transformer.yaml` to point to the downloaded first-stage models.
+Download the full COCO/OI datasets and adapt `data_path` in the same files, unless working with the 100 files provided for training and validation suits your needs already.
+
+Code can be run with
+`python main.py --base configs/coco_scene_images_transformer.yaml -t True --gpus 0,`
+or
+`python main.py --base configs/open_images_scene_images_transformer.yaml -t True --gpus 0,`
+
+### Sampling
+Train a model as described above or download a pre-trained model:
+ - [Open Images 1 billion parameter model](https://drive.google.com/file/d/1FEK-Z7hyWJBvFWQF50pzSK9y1W_CJEig/view?usp=sharing) available that trained 100 epochs. On 256x256 pixels, FID 41.48±0.21, SceneFID 14.60±0.15, Inception Score 18.47±0.27. The model was trained with 2d crops of images and is thus well-prepared for the task of generating high-resolution images, e.g. 512x512.
+ - [Open Images distilled version of the above model with 125 million parameters](https://drive.google.com/file/d/1xf89g0mc78J3d8Bx5YhbK4tNRNlOoYaO) allows for sampling on smaller GPUs (4 GB is enough for sampling 256x256 px images). Model was trained for 60 epochs with 10% soft loss, 90% hard loss. On 256x256 pixels, FID 43.07±0.40, SceneFID 15.93±0.19, Inception Score 17.23±0.11.
+ - [COCO 30 epochs](https://heibox.uni-heidelberg.de/f/0d0b2594e9074c7e9a33/)
+ - [COCO 60 epochs](https://drive.google.com/file/d/1bInd49g2YulTJBjU32Awyt5qnzxxG5U9/) (find model statistics for both COCO versions in `assets/coco_scene_images_training.svg`)
+
+When downloading a pre-trained model, remember to change `ckpt_path` in `configs/*project.yaml` to point to your downloaded first-stage model (see ->Training).
+
+Scene image generation can be run with
+`python scripts/make_scene_samples.py --outdir=/some/outdir -r /path/to/pretrained/model --resolution=512,512`
+
+
+## Training on custom data
+
+Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
+Those are the steps to follow to make this work:
+1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
+1. put your .jpg files in a folder `your_folder`
+2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
+3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
+4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
+ train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
+
+## Data Preparation
+
+### ImageNet
+The code will try to download (through [Academic
+Torrents](http://academictorrents.com/)) and prepare ImageNet the first time it
+is used. However, since ImageNet is quite large, this requires a lot of disk
+space and time. If you already have ImageNet on your disk, you can speed things
+up by putting the data into
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` (which defaults to
+`~/.cache/autoencoders/data/ILSVRC2012_{split}/data/`), where `{split}` is one
+of `train`/`validation`. It should have the following structure:
+
+```
+${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
+├── n01440764
+│ ├── n01440764_10026.JPEG
+│ ├── n01440764_10027.JPEG
+│ ├── ...
+├── n01443537
+│ ├── n01443537_10007.JPEG
+│ ├── n01443537_10014.JPEG
+│ ├── ...
+├── ...
+```
+
+If you haven't extracted the data, you can also place
+`ILSVRC2012_img_train.tar`/`ILSVRC2012_img_val.tar` (or symlinks to them) into
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/` /
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/`, which will then be
+extracted into above structure without downloading it again. Note that this
+will only happen if neither a folder
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/` nor a file
+`${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready` exist. Remove them
+if you want to force running the dataset preparation again.
+
+You will then need to prepare the depth data using
+[MiDaS](https://github.com/intel-isl/MiDaS). Create a symlink
+`data/imagenet_depth` pointing to a folder with two subfolders `train` and
+`val`, each mirroring the structure of the corresponding ImageNet folder
+described above and containing a `png` file for each of ImageNet's `JPEG`
+files. The `png` encodes `float32` depth values obtained from MiDaS as RGBA
+images. We provide the script `scripts/extract_depth.py` to generate this data.
+**Please note** that this script uses [MiDaS via PyTorch
+Hub](https://pytorch.org/hub/intelisl_midas_v2/). When we prepared the data,
+the hub provided the [MiDaS
+v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2) version, but now it
+provides a v2.1 version. We haven't tested our models with depth maps obtained
+via v2.1 and if you want to make sure that things work as expected, you must
+adjust the script to make sure it explicitly uses
+[v2.0](https://github.com/intel-isl/MiDaS/releases/tag/v2)!
+
+### CelebA-HQ
+Create a symlink `data/celebahq` pointing to a folder containing the `.npy`
+files of CelebA-HQ (instructions to obtain them can be found in the [PGGAN
+repository](https://github.com/tkarras/progressive_growing_of_gans)).
+
+### FFHQ
+Create a symlink `data/ffhq` pointing to the `images1024x1024` folder obtained
+from the [FFHQ repository](https://github.com/NVlabs/ffhq-dataset).
+
+### S-FLCKR
+Unfortunately, we are not allowed to distribute the images we collected for the
+S-FLCKR dataset and can therefore only give a description how it was produced.
+There are many resources on [collecting images from the
+web](https://github.com/adrianmrit/flickrdatasets) to get started.
+We collected sufficiently large images from [flickr](https://www.flickr.com)
+(see `data/flickr_tags.txt` for a full list of tags used to find images)
+and various [subreddits](https://www.reddit.com/r/sfwpornnetwork/wiki/network)
+(see `data/subreddits.txt` for all subreddits that were used).
+Overall, we collected 107625 images, and split them randomly into 96861
+training images and 10764 validation images. We then obtained segmentation
+masks for each image using [DeepLab v2](https://arxiv.org/abs/1606.00915)
+trained on [COCO-Stuff](https://arxiv.org/abs/1612.03716). We used a [PyTorch
+reimplementation](https://github.com/kazuto1011/deeplab-pytorch) and include an
+example script for this process in `scripts/extract_segmentation.py`.
+
+### COCO
+Create a symlink `data/coco` containing the images from the 2017 split in
+`train2017` and `val2017`, and their annotations in `annotations`. Files can be
+obtained from the [COCO webpage](https://cocodataset.org/). In addition, we use
+the [Stuff+thing PNG-style annotations on COCO 2017
+trainval](http://calvin.inf.ed.ac.uk/wp-content/uploads/data/cocostuffdataset/stuffthingmaps_trainval2017.zip)
+annotations from [COCO-Stuff](https://github.com/nightrome/cocostuff), which
+should be placed under `data/cocostuffthings`.
+
+### ADE20k
+Create a symlink `data/ade20k_root` containing the contents of
+[ADEChallengeData2016.zip](http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip)
+from the [MIT Scene Parsing Benchmark](http://sceneparsing.csail.mit.edu/).
+
+## Training models
+
+### FacesHQ
+
+Train a VQGAN with
+```
+python main.py --base configs/faceshq_vqgan.yaml -t True --gpus 0,
+```
+
+Then, adjust the checkpoint path of the config key
+`model.params.first_stage_config.params.ckpt_path` in
+`configs/faceshq_transformer.yaml` (or download
+[2020-11-09T13-33-36_faceshq_vqgan](https://k00.fr/uxy5usa9) and place into `logs`, which
+corresponds to the preconfigured checkpoint path), then run
+```
+python main.py --base configs/faceshq_transformer.yaml -t True --gpus 0,
+```
+
+### D-RIN
+
+Train a VQGAN on ImageNet with
+```
+python main.py --base configs/imagenet_vqgan.yaml -t True --gpus 0,
+```
+
+or download a pretrained one from [2020-09-23T17-56-33_imagenet_vqgan](https://k00.fr/u0j2dtac)
+and place under `logs`. If you trained your own, adjust the path in the config
+key `model.params.first_stage_config.params.ckpt_path` of
+`configs/drin_transformer.yaml`.
+
+Train a VQGAN on Depth Maps of ImageNet with
+```
+python main.py --base configs/imagenetdepth_vqgan.yaml -t True --gpus 0,
+```
+
+or download a pretrained one from [2020-11-03T15-34-24_imagenetdepth_vqgan](https://k00.fr/55rlxs6i)
+and place under `logs`. If you trained your own, adjust the path in the config
+key `model.params.cond_stage_config.params.ckpt_path` of
+`configs/drin_transformer.yaml`.
+
+To train the transformer, run
+```
+python main.py --base configs/drin_transformer.yaml -t True --gpus 0,
+```
+
+## More Resources
+### Comparing Different First Stage Models
+The reconstruction and compression capabilities of different fist stage models can be analyzed in this [colab notebook](https://colab.research.google.com/github/CompVis/taming-transformers/blob/master/scripts/reconstruction_usage.ipynb).
+In particular, the notebook compares two VQGANs with a downsampling factor of f=16 for each and codebook dimensionality of 1024 and 16384,
+a VQGAN with f=8 and 8192 codebook entries and the discrete autoencoder of OpenAI's [DALL-E](https://github.com/openai/DALL-E) (which has f=8 and 8192
+codebook entries).
+![firststages1](assets/first_stage_squirrels.png)
+![firststages2](assets/first_stage_mushrooms.png)
+
+### Other
+- A [video summary](https://www.youtube.com/watch?v=o7dqGcLDf0A&feature=emb_imp_woyt) by [Two Minute Papers](https://www.youtube.com/channel/UCbfYPyITQ-7l4upoX8nvctg).
+- A [video summary](https://www.youtube.com/watch?v=-wDSDtIAyWQ) by [Gradient Dude](https://www.youtube.com/c/GradientDude/about).
+- A [weights and biases report summarizing the paper](https://wandb.ai/ayush-thakur/taming-transformer/reports/-Overview-Taming-Transformers-for-High-Resolution-Image-Synthesis---Vmlldzo0NjEyMTY)
+by [ayulockin](https://github.com/ayulockin).
+- A [video summary](https://www.youtube.com/watch?v=JfUTd8fjtX8&feature=emb_imp_woyt) by [What's AI](https://www.youtube.com/channel/UCUzGQrN-lyyc0BWTYoJM_Sg).
+- Take a look at [ak9250's notebook](https://github.com/ak9250/taming-transformers/blob/master/tamingtransformerscolab.ipynb) if you want to run the streamlit demos on Colab.
+
+### Text-to-Image Optimization via CLIP
+VQGAN has been successfully used as an image generator guided by the [CLIP](https://github.com/openai/CLIP) model, both for pure image generation
+from scratch and image-to-image translation. We recommend the following notebooks/videos/resources:
+
+ - [Advadnouns](https://twitter.com/advadnoun/status/1389316507134357506) Patreon and corresponding LatentVision notebooks: https://www.patreon.com/patronizeme
+ - The [notebook]( https://colab.research.google.com/drive/1L8oL-vLJXVcRzCFbPwOoMkPKJ8-aYdPN) of [Rivers Have Wings](https://twitter.com/RiversHaveWings).
+ - A [video](https://www.youtube.com/watch?v=90QDe6DQXF4&t=12s) explanation by [Dot CSV](https://www.youtube.com/channel/UCy5znSnfMsDwaLlROnZ7Qbg) (in Spanish, but English subtitles are available)
+
+![txt2img](assets/birddrawnbyachild.png)
+
+Text prompt: *'A bird drawn by a child'*
+
+## Shout-outs
+Thanks to everyone who makes their code and models available. In particular,
+
+- The architecture of our VQGAN is inspired by [Denoising Diffusion Probabilistic Models](https://github.com/hojonathanho/diffusion)
+- The very hackable transformer implementation [minGPT](https://github.com/karpathy/minGPT)
+- The good ol' [PatchGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) and [Learned Perceptual Similarity (LPIPS)](https://github.com/richzhang/PerceptualSimilarity)
+
+## BibTeX
+
+```
+@misc{esser2020taming,
+ title={Taming Transformers for High-Resolution Image Synthesis},
+ author={Patrick Esser and Robin Rombach and Björn Ommer},
+ year={2020},
+ eprint={2012.09841},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+```
diff --git a/repositories/taming/data/ade20k.py b/repositories/taming/data/ade20k.py
new file mode 100644
index 000000000..366dae972
--- /dev/null
+++ b/repositories/taming/data/ade20k.py
@@ -0,0 +1,124 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/ade20k_examples.txt",
+ data_root="data/ade20k_images",
+ segmentation_root="data/ade20k_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=151, shift_segmentation=False)
+
+
+# With semantic map and scene label
+class ADE20kBase(Dataset):
+ def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None):
+ self.split = self.get_split()
+ self.n_labels = 151 # unknown + 150
+ self.data_csv = {"train": "data/ade20k_train.txt",
+ "validation": "data/ade20k_test.txt"}[self.split]
+ self.data_root = "data/ade20k_root"
+ with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f:
+ self.scene_categories = f.read().splitlines()
+ self.scene_categories = dict(line.split() for line in self.scene_categories)
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, "images", l)
+ for l in self.image_paths],
+ "relative_segmentation_path_": [l.replace(".jpg", ".png")
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.data_root, "annotations",
+ l.replace(".jpg", ".png"))
+ for l in self.image_paths],
+ "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")]
+ for l in self.image_paths],
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size if size is not None else None
+ else:
+ self.crop_size = crop_size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+
+ if crop_size is not None:
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image, mask=segmentation)
+ else:
+ processed = {"image": image, "mask": segmentation}
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class ADE20kTrain(ADE20kBase):
+ # default to random_crop=True
+ def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None):
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ interpolation=interpolation, crop_size=crop_size)
+
+ def get_split(self):
+ return "train"
+
+
+class ADE20kValidation(ADE20kBase):
+ def get_split(self):
+ return "validation"
+
+
+if __name__ == "__main__":
+ dset = ADE20kValidation()
+ ex = dset[0]
+ for k in ["image", "scene_category", "segmentation"]:
+ print(type(ex[k]))
+ try:
+ print(ex[k].shape)
+ except:
+ print(ex[k])
diff --git a/repositories/taming/data/annotated_objects_coco.py b/repositories/taming/data/annotated_objects_coco.py
new file mode 100644
index 000000000..af000ecd9
--- /dev/null
+++ b/repositories/taming/data/annotated_objects_coco.py
@@ -0,0 +1,139 @@
+import json
+from itertools import chain
+from pathlib import Path
+from typing import Iterable, Dict, List, Callable, Any
+from collections import defaultdict
+
+from tqdm import tqdm
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, ImageDescription, Category
+
+COCO_PATH_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_train2017.json',
+ 'stuff_annotations': 'annotations/stuff_train2017.json',
+ 'files': 'train2017'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'instances_annotations': 'annotations/instances_val2017.json',
+ 'stuff_annotations': 'annotations/stuff_val2017.json',
+ 'files': 'val2017'
+ }
+}
+
+
+def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]:
+ return {
+ str(img['id']): ImageDescription(
+ id=img['id'],
+ license=img.get('license'),
+ file_name=img['file_name'],
+ coco_url=img['coco_url'],
+ original_size=(img['width'], img['height']),
+ date_captured=img.get('date_captured'),
+ flickr_url=img.get('flickr_url')
+ )
+ for img in description_json
+ }
+
+
+def load_categories(category_json: Iterable) -> Dict[str, Category]:
+ return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name'])
+ for cat in category_json if cat['name'] != 'other'}
+
+
+def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription],
+ category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]:
+ annotations = defaultdict(list)
+ total = sum(len(a) for a in annotations_json)
+ for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total):
+ image_id = str(ann['image_id'])
+ if image_id not in image_descriptions:
+ raise ValueError(f'image_id [{image_id}] has no image description.')
+ category_id = ann['category_id']
+ try:
+ category_no = category_no_for_id(str(category_id))
+ except KeyError:
+ continue
+
+ width, height = image_descriptions[image_id].original_size
+ bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height)
+
+ annotations[image_id].append(
+ Annotation(
+ id=ann['id'],
+ area=bbox[2]*bbox[3], # use bbox area
+ is_group_of=ann['iscrowd'],
+ image_id=ann['image_id'],
+ bbox=bbox,
+ category_id=str(category_id),
+ category_no=category_no
+ )
+ )
+ return dict(annotations)
+
+
+class AnnotatedObjectsCoco(AnnotatedObjectsDataset):
+ def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ coco/
+ ├── annotations
+ │ ├── instances_train2017.json
+ │ ├── instances_val2017.json
+ │ ├── stuff_train2017.json
+ │ └── stuff_val2017.json
+ ├── train2017
+ │ ├── 000000000009.jpg
+ │ ├── 000000000025.jpg
+ │ └── ...
+ ├── val2017
+ │ ├── 000000000139.jpg
+ │ ├── 000000000285.jpg
+ │ └── ...
+ @param: split: one of 'train' or 'validation'
+ @param: desired image size (give square images)
+ """
+ super().__init__(**kwargs)
+ self.use_things = use_things
+ self.use_stuff = use_stuff
+
+ with open(self.paths['instances_annotations']) as f:
+ inst_data_json = json.load(f)
+ with open(self.paths['stuff_annotations']) as f:
+ stuff_data_json = json.load(f)
+
+ category_jsons = []
+ annotation_jsons = []
+ if self.use_things:
+ category_jsons.append(inst_data_json['categories'])
+ annotation_jsons.append(inst_data_json['annotations'])
+ if self.use_stuff:
+ category_jsons.append(stuff_data_json['categories'])
+ annotation_jsons.append(stuff_data_json['annotations'])
+
+ self.categories = load_categories(chain(*category_jsons))
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = load_image_descriptions(inst_data_json['images'])
+ annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area,
+ self.min_objects_per_image, self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in COCO_PATH_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for COCO data.]')
+ return COCO_PATH_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name)
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ # noinspection PyProtectedMember
+ return self.image_descriptions[image_id]._asdict()
diff --git a/repositories/taming/data/annotated_objects_dataset.py b/repositories/taming/data/annotated_objects_dataset.py
new file mode 100644
index 000000000..53cc346a1
--- /dev/null
+++ b/repositories/taming/data/annotated_objects_dataset.py
@@ -0,0 +1,218 @@
+from pathlib import Path
+from typing import Optional, List, Callable, Dict, Any, Union
+import warnings
+
+import PIL.Image as pil_image
+from torch import Tensor
+from torch.utils.data import Dataset
+from torchvision import transforms
+
+from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import load_object_from_string
+from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
+from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
+ Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
+
+
+class AnnotatedObjectsDataset(Dataset):
+ def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
+ min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
+ crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
+ encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
+ no_object_classes: Optional[int] = None):
+ self.data_path = data_path
+ self.split = split
+ self.keys = keys
+ self.target_image_size = target_image_size
+ self.min_object_area = min_object_area
+ self.min_objects_per_image = min_objects_per_image
+ self.max_objects_per_image = max_objects_per_image
+ self.crop_method = crop_method
+ self.random_flip = random_flip
+ self.no_tokens = no_tokens
+ self.use_group_parameter = use_group_parameter
+ self.encode_crop = encode_crop
+
+ self.annotations = None
+ self.image_descriptions = None
+ self.categories = None
+ self.category_ids = None
+ self.category_number = None
+ self.image_ids = None
+ self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
+ self.paths = self.build_paths(self.data_path)
+ self._conditional_builders = None
+ self.category_allow_list = None
+ if category_allow_list_target:
+ allow_list = load_object_from_string(category_allow_list_target)
+ self.category_allow_list = {name for name, _ in allow_list}
+ self.category_mapping = {}
+ if category_mapping_target:
+ self.category_mapping = load_object_from_string(category_mapping_target)
+ self.no_object_classes = no_object_classes
+
+ def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
+ top_level = Path(top_level)
+ sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
+ for path in sub_paths.values():
+ if not path.exists():
+ raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
+ return sub_paths
+
+ @staticmethod
+ def load_image_from_disk(path: Path) -> Image:
+ return pil_image.open(path).convert('RGB')
+
+ @staticmethod
+ def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
+ transform_functions = []
+ if crop_method == 'none':
+ transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
+ elif crop_method == 'center':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ CenterCropReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-1d':
+ transform_functions.extend([
+ transforms.Resize(target_image_size),
+ RandomCrop1dReturnCoordinates(target_image_size)
+ ])
+ elif crop_method == 'random-2d':
+ transform_functions.extend([
+ Random2dCropReturnCoordinates(target_image_size),
+ transforms.Resize(target_image_size)
+ ])
+ elif crop_method is None:
+ return None
+ else:
+ raise ValueError(f'Received invalid crop method [{crop_method}].')
+ if random_flip:
+ transform_functions.append(RandomHorizontalFlipReturn())
+ transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
+ return transform_functions
+
+ def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
+ crop_bbox = None
+ flipped = None
+ for t in self.transform_functions:
+ if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
+ crop_bbox, x = t(x)
+ elif isinstance(t, RandomHorizontalFlipReturn):
+ flipped, x = t(x)
+ else:
+ x = t(x)
+ return crop_bbox, flipped, x
+
+ @property
+ def no_classes(self) -> int:
+ return self.no_object_classes if self.no_object_classes else len(self.categories)
+
+ @property
+ def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
+ # cannot set this up in init because no_classes is only known after loading data in init of superclass
+ if self._conditional_builders is None:
+ self._conditional_builders = {
+ 'objects_center_points': ObjectsCenterPointsConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ ),
+ 'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
+ self.no_classes,
+ self.max_objects_per_image,
+ self.no_tokens,
+ self.encode_crop,
+ self.use_group_parameter,
+ getattr(self, 'use_additional_parameters', False)
+ )
+ }
+ return self._conditional_builders
+
+ def filter_categories(self) -> None:
+ if self.category_allow_list:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
+ if self.category_mapping:
+ self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
+
+ def setup_category_id_and_number(self) -> None:
+ self.category_ids = list(self.categories.keys())
+ self.category_ids.sort()
+ if '/m/01s55n' in self.category_ids:
+ self.category_ids.remove('/m/01s55n')
+ self.category_ids.append('/m/01s55n')
+ self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
+ if self.category_allow_list is not None and self.category_mapping is None \
+ and len(self.category_ids) != len(self.category_allow_list):
+ warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
+ 'Make sure all names in category_allow_list exist.')
+
+ def clean_up_annotations_and_image_descriptions(self) -> None:
+ image_id_set = set(self.image_ids)
+ self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
+ self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
+
+ @staticmethod
+ def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
+ min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
+ filtered = {}
+ for image_id, annotations in all_annotations.items():
+ annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
+ if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
+ filtered[image_id] = annotations_with_min_area
+ return filtered
+
+ def __len__(self):
+ return len(self.image_ids)
+
+ def __getitem__(self, n: int) -> Dict[str, Any]:
+ image_id = self.get_image_id(n)
+ sample = self.get_image_description(image_id)
+ sample['annotations'] = self.get_annotation(image_id)
+
+ if 'image' in self.keys:
+ sample['image_path'] = str(self.get_image_path(image_id))
+ sample['image'] = self.load_image_from_disk(sample['image_path'])
+ sample['image'] = convert_pil_to_tensor(sample['image'])
+ sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
+ sample['image'] = sample['image'].permute(1, 2, 0)
+
+ for conditional, builder in self.conditional_builders.items():
+ if conditional in self.keys:
+ sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
+
+ if self.keys:
+ # only return specified keys
+ sample = {key: sample[key] for key in self.keys}
+ return sample
+
+ def get_image_id(self, no: int) -> str:
+ return self.image_ids[no]
+
+ def get_annotation(self, image_id: str) -> str:
+ return self.annotations[image_id]
+
+ def get_textual_label_for_category_id(self, category_id: str) -> str:
+ return self.categories[category_id].name
+
+ def get_textual_label_for_category_no(self, category_no: int) -> str:
+ return self.categories[self.get_category_id(category_no)].name
+
+ def get_category_number(self, category_id: str) -> int:
+ return self.category_number[category_id]
+
+ def get_category_id(self, category_no: int) -> str:
+ return self.category_ids[category_no]
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ raise NotImplementedError()
+
+ def get_path_structure(self):
+ raise NotImplementedError
+
+ def get_image_path(self, image_id: str) -> Path:
+ raise NotImplementedError
diff --git a/repositories/taming/data/annotated_objects_open_images.py b/repositories/taming/data/annotated_objects_open_images.py
new file mode 100644
index 000000000..aede6803d
--- /dev/null
+++ b/repositories/taming/data/annotated_objects_open_images.py
@@ -0,0 +1,137 @@
+from collections import defaultdict
+from csv import DictReader, reader as TupleReader
+from pathlib import Path
+from typing import Dict, List, Any
+import warnings
+
+from taming.data.annotated_objects_dataset import AnnotatedObjectsDataset
+from taming.data.helper_types import Annotation, Category
+from tqdm import tqdm
+
+OPEN_IMAGES_STRUCTURE = {
+ 'train': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'oidv6-train-annotations-bbox.csv',
+ 'file_list': 'train-images-boxable.csv',
+ 'files': 'train'
+ },
+ 'validation': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'validation-annotations-bbox.csv',
+ 'file_list': 'validation-images.csv',
+ 'files': 'validation'
+ },
+ 'test': {
+ 'top_level': '',
+ 'class_descriptions': 'class-descriptions-boxable.csv',
+ 'annotations': 'test-annotations-bbox.csv',
+ 'file_list': 'test-images.csv',
+ 'files': 'test'
+ }
+}
+
+
+def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str],
+ category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]:
+ annotations: Dict[str, List[Annotation]] = defaultdict(list)
+ with open(descriptor_path) as file:
+ reader = DictReader(file)
+ for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'):
+ width = float(row['XMax']) - float(row['XMin'])
+ height = float(row['YMax']) - float(row['YMin'])
+ area = width * height
+ category_id = row['LabelName']
+ if category_id in category_mapping:
+ category_id = category_mapping[category_id]
+ if area >= min_object_area and category_id in category_no_for_id:
+ annotations[row['ImageID']].append(
+ Annotation(
+ id=i,
+ image_id=row['ImageID'],
+ source=row['Source'],
+ category_id=category_id,
+ category_no=category_no_for_id[category_id],
+ confidence=float(row['Confidence']),
+ bbox=(float(row['XMin']), float(row['YMin']), width, height),
+ area=area,
+ is_occluded=bool(int(row['IsOccluded'])),
+ is_truncated=bool(int(row['IsTruncated'])),
+ is_group_of=bool(int(row['IsGroupOf'])),
+ is_depiction=bool(int(row['IsDepiction'])),
+ is_inside=bool(int(row['IsInside']))
+ )
+ )
+ if 'train' in str(descriptor_path) and i < 14000000:
+ warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].')
+ return dict(annotations)
+
+
+def load_image_ids(csv_path: Path) -> List[str]:
+ with open(csv_path) as file:
+ reader = DictReader(file)
+ return [row['image_name'] for row in reader]
+
+
+def load_categories(csv_path: Path) -> Dict[str, Category]:
+ with open(csv_path) as file:
+ reader = TupleReader(file)
+ return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader}
+
+
+class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset):
+ def __init__(self, use_additional_parameters: bool, **kwargs):
+ """
+ @param data_path: is the path to the following folder structure:
+ open_images/
+ │ oidv6-train-annotations-bbox.csv
+ ├── class-descriptions-boxable.csv
+ ├── oidv6-train-annotations-bbox.csv
+ ├── test
+ │ ├── 000026e7ee790996.jpg
+ │ ├── 000062a39995e348.jpg
+ │ └── ...
+ ├── test-annotations-bbox.csv
+ ├── test-images.csv
+ ├── train
+ │ ├── 000002b66c9c498e.jpg
+ │ ├── 000002b97e5471a0.jpg
+ │ └── ...
+ ├── train-images-boxable.csv
+ ├── validation
+ │ ├── 0001eeaf4aed83f9.jpg
+ │ ├── 0004886b7d043cfd.jpg
+ │ └── ...
+ ├── validation-annotations-bbox.csv
+ └── validation-images.csv
+ @param: split: one of 'train', 'validation' or 'test'
+ @param: desired image size (returns square images)
+ """
+
+ super().__init__(**kwargs)
+ self.use_additional_parameters = use_additional_parameters
+
+ self.categories = load_categories(self.paths['class_descriptions'])
+ self.filter_categories()
+ self.setup_category_id_and_number()
+
+ self.image_descriptions = {}
+ annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping,
+ self.category_number)
+ self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image,
+ self.max_objects_per_image)
+ self.image_ids = list(self.annotations.keys())
+ self.clean_up_annotations_and_image_descriptions()
+
+ def get_path_structure(self) -> Dict[str, str]:
+ if self.split not in OPEN_IMAGES_STRUCTURE:
+ raise ValueError(f'Split [{self.split} does not exist for Open Images data.]')
+ return OPEN_IMAGES_STRUCTURE[self.split]
+
+ def get_image_path(self, image_id: str) -> Path:
+ return self.paths['files'].joinpath(f'{image_id:0>16}.jpg')
+
+ def get_image_description(self, image_id: str) -> Dict[str, Any]:
+ image_path = self.get_image_path(image_id)
+ return {'file_path': str(image_path), 'file_name': image_path.name}
diff --git a/repositories/taming/data/base.py b/repositories/taming/data/base.py
new file mode 100644
index 000000000..e21667df4
--- /dev/null
+++ b/repositories/taming/data/base.py
@@ -0,0 +1,70 @@
+import bisect
+import numpy as np
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset, ConcatDataset
+
+
+class ConcatDatasetWithIndex(ConcatDataset):
+ """Modified from original pytorch code to return dataset idx"""
+ def __getitem__(self, idx):
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError("absolute value of index should not exceed dataset length")
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx][sample_idx], dataset_idx
+
+
+class ImagePaths(Dataset):
+ def __init__(self, paths, size=None, random_crop=False, labels=None):
+ self.size = size
+ self.random_crop = random_crop
+
+ self.labels = dict() if labels is None else labels
+ self.labels["file_path_"] = paths
+ self._length = len(paths)
+
+ if self.size is not None and self.size > 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ if not self.random_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
+ self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return self._length
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def __getitem__(self, i):
+ example = dict()
+ example["image"] = self.preprocess_image(self.labels["file_path_"][i])
+ for k in self.labels:
+ example[k] = self.labels[k][i]
+ return example
+
+
+class NumpyPaths(ImagePaths):
+ def preprocess_image(self, image_path):
+ image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
+ image = np.transpose(image, (1,2,0))
+ image = Image.fromarray(image, mode="RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
diff --git a/repositories/taming/data/coco.py b/repositories/taming/data/coco.py
new file mode 100644
index 000000000..2b2f78384
--- /dev/null
+++ b/repositories/taming/data/coco.py
@@ -0,0 +1,176 @@
+import os
+import json
+import albumentations
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from torch.utils.data import Dataset
+
+from taming.data.sflckr import SegmentationBase # for examples included in repo
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=256, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/coco_examples.txt",
+ data_root="data/coco_images",
+ segmentation_root="data/coco_segmentations",
+ size=size, random_crop=random_crop,
+ interpolation=interpolation,
+ n_labels=183, shift_segmentation=True)
+
+
+class CocoBase(Dataset):
+ """needed for (image, caption, segmentation) pairs"""
+ def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False,
+ crop_size=None, force_no_crop=False, given_files=None):
+ self.split = self.get_split()
+ self.size = size
+ if crop_size is None:
+ self.crop_size = size
+ else:
+ self.crop_size = crop_size
+
+ self.onehot = onehot_segmentation # return segmentation as rgb or one hot
+ self.stuffthing = use_stuffthing # include thing in segmentation
+ if self.onehot and not self.stuffthing:
+ raise NotImplemented("One hot mode is only supported for the "
+ "stuffthings version because labels are stored "
+ "a bit different.")
+
+ data_json = datajson
+ with open(data_json) as json_file:
+ self.json_data = json.load(json_file)
+ self.img_id_to_captions = dict()
+ self.img_id_to_filepath = dict()
+ self.img_id_to_segmentation_filepath = dict()
+
+ assert data_json.split("/")[-1] in ["captions_train2017.json",
+ "captions_val2017.json"]
+ if self.stuffthing:
+ self.segmentation_prefix = (
+ "data/cocostuffthings/val2017" if
+ data_json.endswith("captions_val2017.json") else
+ "data/cocostuffthings/train2017")
+ else:
+ self.segmentation_prefix = (
+ "data/coco/annotations/stuff_val2017_pixelmaps" if
+ data_json.endswith("captions_val2017.json") else
+ "data/coco/annotations/stuff_train2017_pixelmaps")
+
+ imagedirs = self.json_data["images"]
+ self.labels = {"image_ids": list()}
+ for imgdir in tqdm(imagedirs, desc="ImgToPath"):
+ self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"])
+ self.img_id_to_captions[imgdir["id"]] = list()
+ pngfilename = imgdir["file_name"].replace("jpg", "png")
+ self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join(
+ self.segmentation_prefix, pngfilename)
+ if given_files is not None:
+ if pngfilename in given_files:
+ self.labels["image_ids"].append(imgdir["id"])
+ else:
+ self.labels["image_ids"].append(imgdir["id"])
+
+ capdirs = self.json_data["annotations"]
+ for capdir in tqdm(capdirs, desc="ImgToCaptions"):
+ # there are in average 5 captions per image
+ self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]]))
+
+ self.rescaler = albumentations.SmallestMaxSize(max_size=self.size)
+ if self.split=="validation":
+ self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler, self.cropper],
+ additional_targets={"segmentation": "image"})
+ if force_no_crop:
+ self.rescaler = albumentations.Resize(height=self.size, width=self.size)
+ self.preprocessor = albumentations.Compose(
+ [self.rescaler],
+ additional_targets={"segmentation": "image"})
+
+ def __len__(self):
+ return len(self.labels["image_ids"])
+
+ def preprocess_image(self, image_path, segmentation_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+
+ segmentation = Image.open(segmentation_path)
+ if not self.onehot and not segmentation.mode == "RGB":
+ segmentation = segmentation.convert("RGB")
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.onehot:
+ assert self.stuffthing
+ # stored in caffe format: unlabeled==255. stuff and thing from
+ # 0-181. to be compatible with the labels in
+ # https://github.com/nightrome/cocostuff/blob/master/labels.txt
+ # we shift stuffthing one to the right and put unlabeled in zero
+ # as long as segmentation is uint8 shifting to right handles the
+ # latter too
+ assert segmentation.dtype == np.uint8
+ segmentation = segmentation + 1
+
+ processed = self.preprocessor(image=image, segmentation=segmentation)
+ image, segmentation = processed["image"], processed["segmentation"]
+ image = (image / 127.5 - 1.0).astype(np.float32)
+
+ if self.onehot:
+ assert segmentation.dtype == np.uint8
+ # make it one hot
+ n_labels = 183
+ flatseg = np.ravel(segmentation)
+ onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool)
+ onehot[np.arange(flatseg.size), flatseg] = True
+ onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int)
+ segmentation = onehot
+ else:
+ segmentation = (segmentation / 127.5 - 1.0).astype(np.float32)
+ return image, segmentation
+
+ def __getitem__(self, i):
+ img_path = self.img_id_to_filepath[self.labels["image_ids"][i]]
+ seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]]
+ image, segmentation = self.preprocess_image(img_path, seg_path)
+ captions = self.img_id_to_captions[self.labels["image_ids"][i]]
+ # randomly draw one of all available captions per image
+ caption = captions[np.random.randint(0, len(captions))]
+ example = {"image": image,
+ "caption": [str(caption[0])],
+ "segmentation": segmentation,
+ "img_path": img_path,
+ "seg_path": seg_path,
+ "filename_": img_path.split(os.sep)[-1]
+ }
+ return example
+
+
+class CocoImagesAndCaptionsTrain(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False):
+ super().__init__(size=size,
+ dataroot="data/coco/train2017",
+ datajson="data/coco/annotations/captions_train2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop)
+
+ def get_split(self):
+ return "train"
+
+
+class CocoImagesAndCaptionsValidation(CocoBase):
+ """returns a pair of (image, caption)"""
+ def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,
+ given_files=None):
+ super().__init__(size=size,
+ dataroot="data/coco/val2017",
+ datajson="data/coco/annotations/captions_val2017.json",
+ onehot_segmentation=onehot_segmentation,
+ use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop,
+ given_files=given_files)
+
+ def get_split(self):
+ return "validation"
diff --git a/repositories/taming/data/conditional_builder/objects_bbox.py b/repositories/taming/data/conditional_builder/objects_bbox.py
new file mode 100644
index 000000000..15881e76b
--- /dev/null
+++ b/repositories/taming/data/conditional_builder/objects_bbox.py
@@ -0,0 +1,60 @@
+from itertools import cycle
+from typing import List, Tuple, Callable, Optional
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
+ pad_list, get_plot_font_size, absolute_bbox
+
+
+class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
+ @property
+ def object_descriptor_length(self) -> int:
+ return 3
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_triples = [
+ (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
+ for ann in annotations
+ ]
+ empty_triple = (self.none, self.none, self.none)
+ object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
+ return object_triples
+
+ def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ object_triples = grouper(conditional_list, 3)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
+ for object_triple in object_triples if object_triple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ font = ImageFont.truetype(
+ "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
+ size=get_plot_font_size(font_size, figure_size)
+ )
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
+ annotation = self.representation_to_annotation(representation)
+ class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
+ bbox = absolute_bbox(bbox, width, height)
+ draw.rectangle(bbox, outline=color, width=line_width)
+ draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
diff --git a/repositories/taming/data/conditional_builder/objects_center_points.py b/repositories/taming/data/conditional_builder/objects_center_points.py
new file mode 100644
index 000000000..9a480329c
--- /dev/null
+++ b/repositories/taming/data/conditional_builder/objects_center_points.py
@@ -0,0 +1,168 @@
+import math
+import random
+import warnings
+from itertools import cycle
+from typing import List, Optional, Tuple, Callable
+
+from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
+from more_itertools.recipes import grouper
+from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \
+ additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \
+ absolute_bbox, rescale_annotations
+from taming.data.helper_types import BoundingBox, Annotation
+from taming.data.image_transforms import convert_pil_to_tensor
+from torch import LongTensor, Tensor
+
+
+class ObjectsCenterPointsConditionalBuilder:
+ def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool,
+ use_group_parameter: bool, use_additional_parameters: bool):
+ self.no_object_classes = no_object_classes
+ self.no_max_objects = no_max_objects
+ self.no_tokens = no_tokens
+ self.encode_crop = encode_crop
+ self.no_sections = int(math.sqrt(self.no_tokens))
+ self.use_group_parameter = use_group_parameter
+ self.use_additional_parameters = use_additional_parameters
+
+ @property
+ def none(self) -> int:
+ return self.no_tokens - 1
+
+ @property
+ def object_descriptor_length(self) -> int:
+ return 2
+
+ @property
+ def embedding_dim(self) -> int:
+ extra_length = 2 if self.encode_crop else 0
+ return self.no_max_objects * self.object_descriptor_length + extra_length
+
+ def tokenize_coordinates(self, x: float, y: float) -> int:
+ """
+ Express 2d coordinates with one number.
+ Example: assume self.no_tokens = 16, then no_sections = 4:
+ 0 0 0 0
+ 0 0 # 0
+ 0 0 0 0
+ 0 0 0 x
+ Then the # position corresponds to token 6, the x position to token 15.
+ @param x: float in [0, 1]
+ @param y: float in [0, 1]
+ @return: discrete tokenized coordinate
+ """
+ x_discrete = int(round(x * (self.no_sections - 1)))
+ y_discrete = int(round(y * (self.no_sections - 1)))
+ return y_discrete * self.no_sections + x_discrete
+
+ def coordinates_from_token(self, token: int) -> (float, float):
+ x = token % self.no_sections
+ y = token // self.no_sections
+ return x / (self.no_sections - 1), y / (self.no_sections - 1)
+
+ def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox:
+ x0, y0 = self.coordinates_from_token(token1)
+ x1, y1 = self.coordinates_from_token(token2)
+ return x0, y0, x1 - x0, y1 - y0
+
+ def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]:
+ return self.tokenize_coordinates(bbox[0], bbox[1]), \
+ self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3])
+
+ def inverse_build(self, conditional: LongTensor) \
+ -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]:
+ conditional_list = conditional.tolist()
+ crop_coordinates = None
+ if self.encode_crop:
+ crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
+ conditional_list = conditional_list[:-2]
+ table_of_content = grouper(conditional_list, self.object_descriptor_length)
+ assert conditional.shape[0] == self.embedding_dim
+ return [
+ (object_tuple[0], self.coordinates_from_token(object_tuple[1]))
+ for object_tuple in table_of_content if object_tuple[0] != self.none
+ ], crop_coordinates
+
+ def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
+ line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
+ plot = pil_image.new('RGB', figure_size, WHITE)
+ draw = pil_img_draw.Draw(plot)
+ circle_size = get_circle_size(figure_size)
+ font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf',
+ size=get_plot_font_size(font_size, figure_size))
+ width, height = plot.size
+ description, crop_coordinates = self.inverse_build(conditional)
+ for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)):
+ x_abs, y_abs = x * width, y * height
+ ann = self.representation_to_annotation(representation)
+ label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann)
+ ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size]
+ draw.ellipse(ellipse_bbox, fill=color, width=0)
+ draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font)
+ if crop_coordinates is not None:
+ draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
+ return convert_pil_to_tensor(plot) / 127.5 - 1.
+
+ def object_representation(self, annotation: Annotation) -> int:
+ modifier = 0
+ if self.use_group_parameter:
+ modifier |= 1 * (annotation.is_group_of is True)
+ if self.use_additional_parameters:
+ modifier |= 2 * (annotation.is_occluded is True)
+ modifier |= 4 * (annotation.is_depiction is True)
+ modifier |= 8 * (annotation.is_inside is True)
+ return annotation.category_no + self.no_object_classes * modifier
+
+ def representation_to_annotation(self, representation: int) -> Annotation:
+ category_no = representation % self.no_object_classes
+ modifier = representation // self.no_object_classes
+ # noinspection PyTypeChecker
+ return Annotation(
+ area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None,
+ category_no=category_no,
+ is_group_of=bool((modifier & 1) * self.use_group_parameter),
+ is_occluded=bool((modifier & 2) * self.use_additional_parameters),
+ is_depiction=bool((modifier & 4) * self.use_additional_parameters),
+ is_inside=bool((modifier & 8) * self.use_additional_parameters)
+ )
+
+ def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]:
+ return list(self.token_pair_from_bbox(crop_coordinates))
+
+ def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
+ object_tuples = [
+ (self.object_representation(a),
+ self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2))
+ for a in annotations
+ ]
+ empty_tuple = (self.none, self.none)
+ object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects)
+ return object_tuples
+
+ def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \
+ -> LongTensor:
+ if len(annotations) == 0:
+ warnings.warn('Did not receive any annotations.')
+ if len(annotations) > self.no_max_objects:
+ warnings.warn('Received more annotations than allowed.')
+ annotations = annotations[:self.no_max_objects]
+
+ if not crop_coordinates:
+ crop_coordinates = FULL_CROP
+
+ random.shuffle(annotations)
+ annotations = filter_annotations(annotations, crop_coordinates)
+ if self.encode_crop:
+ annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip)
+ if horizontal_flip:
+ crop_coordinates = horizontally_flip_bbox(crop_coordinates)
+ extra = self._crop_encoder(crop_coordinates)
+ else:
+ annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip)
+ extra = []
+
+ object_tuples = self._make_object_descriptors(annotations)
+ flattened = [token for tuple_ in object_tuples for token in tuple_] + extra
+ assert len(flattened) == self.embedding_dim
+ assert all(0 <= value < self.no_tokens for value in flattened)
+ return LongTensor(flattened)
diff --git a/repositories/taming/data/conditional_builder/utils.py b/repositories/taming/data/conditional_builder/utils.py
new file mode 100644
index 000000000..d0ee175f2
--- /dev/null
+++ b/repositories/taming/data/conditional_builder/utils.py
@@ -0,0 +1,105 @@
+import importlib
+from typing import List, Any, Tuple, Optional
+
+from taming.data.helper_types import BoundingBox, Annotation
+
+# source: seaborn, color palette tab10
+COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
+ (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
+BLACK = (0, 0, 0)
+GRAY_75 = (63, 63, 63)
+GRAY_50 = (127, 127, 127)
+GRAY_25 = (191, 191, 191)
+WHITE = (255, 255, 255)
+FULL_CROP = (0., 0., 1., 1.)
+
+
+def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
+ """
+ Give intersection area of two rectangles.
+ @param rectangle1: (x0, y0, w, h) of first rectangle
+ @param rectangle2: (x0, y0, w, h) of second rectangle
+ """
+ rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
+ rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
+ x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
+ y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
+ return x_overlap * y_overlap
+
+
+def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
+ return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
+
+
+def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
+ bbox = relative_bbox
+ bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
+ return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
+
+
+def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
+ return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
+
+
+def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
+ List[Annotation]:
+ def clamp(x: float):
+ return max(min(x, 1.), 0.)
+
+ def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
+ x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
+ y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
+ w = min(bbox[2] / crop_coordinates[2], 1 - x0)
+ h = min(bbox[3] / crop_coordinates[3], 1 - y0)
+ if flip:
+ x0 = 1 - (x0 + w)
+ return x0, y0, w, h
+
+ return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
+
+
+def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
+ return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
+
+
+def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
+ sl = slice(1) if short else slice(None)
+ string = ''
+ if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
+ return string
+ if annotation.is_group_of:
+ string += 'group'[sl] + ','
+ if annotation.is_occluded:
+ string += 'occluded'[sl] + ','
+ if annotation.is_depiction:
+ string += 'depiction'[sl] + ','
+ if annotation.is_inside:
+ string += 'inside'[sl]
+ return '(' + string.strip(",") + ')'
+
+
+def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
+ if font_size is None:
+ font_size = 10
+ if max(figure_size) >= 256:
+ font_size = 12
+ if max(figure_size) >= 512:
+ font_size = 15
+ return font_size
+
+
+def get_circle_size(figure_size: Tuple[int, int]) -> int:
+ circle_size = 2
+ if max(figure_size) >= 256:
+ circle_size = 3
+ if max(figure_size) >= 512:
+ circle_size = 4
+ return circle_size
+
+
+def load_object_from_string(object_string: str) -> Any:
+ """
+ Source: https://stackoverflow.com/a/10773699
+ """
+ module_name, class_name = object_string.rsplit(".", 1)
+ return getattr(importlib.import_module(module_name), class_name)
diff --git a/repositories/taming/data/custom.py b/repositories/taming/data/custom.py
new file mode 100644
index 000000000..33f302a4b
--- /dev/null
+++ b/repositories/taming/data/custom.py
@@ -0,0 +1,38 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class CustomBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ return example
+
+
+
+class CustomTrain(CustomBase):
+ def __init__(self, size, training_images_list_file):
+ super().__init__()
+ with open(training_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
+class CustomTest(CustomBase):
+ def __init__(self, size, test_images_list_file):
+ super().__init__()
+ with open(test_images_list_file, "r") as f:
+ paths = f.read().splitlines()
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+
+
diff --git a/repositories/taming/data/faceshq.py b/repositories/taming/data/faceshq.py
new file mode 100644
index 000000000..6912d04b6
--- /dev/null
+++ b/repositories/taming/data/faceshq.py
@@ -0,0 +1,134 @@
+import os
+import numpy as np
+import albumentations
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
+
+
+class FacesBase(Dataset):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.data = None
+ self.keys = None
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ example = self.data[i]
+ ex = {}
+ if self.keys is not None:
+ for k in self.keys:
+ ex[k] = example[k]
+ else:
+ ex = example
+ return ex
+
+
+class CelebAHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class CelebAHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/celebahq"
+ with open("data/celebahqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = NumpyPaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQTrain(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqtrain.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FFHQValidation(FacesBase):
+ def __init__(self, size, keys=None):
+ super().__init__()
+ root = "data/ffhq"
+ with open("data/ffhqvalidation.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ paths = [os.path.join(root, relpath) for relpath in relpaths]
+ self.data = ImagePaths(paths=paths, size=size, random_crop=False)
+ self.keys = keys
+
+
+class FacesHQTrain(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQTrain(size=size, keys=keys)
+ d2 = FFHQTrain(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
+
+
+class FacesHQValidation(Dataset):
+ # CelebAHQ [0] + FFHQ [1]
+ def __init__(self, size, keys=None, crop_size=None, coord=False):
+ d1 = CelebAHQValidation(size=size, keys=keys)
+ d2 = FFHQValidation(size=size, keys=keys)
+ self.data = ConcatDatasetWithIndex([d1, d2])
+ self.coord = coord
+ if crop_size is not None:
+ self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ if self.coord:
+ self.cropper = albumentations.Compose([self.cropper],
+ additional_targets={"coord": "image"})
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ ex, y = self.data[i]
+ if hasattr(self, "cropper"):
+ if not self.coord:
+ out = self.cropper(image=ex["image"])
+ ex["image"] = out["image"]
+ else:
+ h,w,_ = ex["image"].shape
+ coord = np.arange(h*w).reshape(h,w,1)/(h*w)
+ out = self.cropper(image=ex["image"], coord=coord)
+ ex["image"] = out["image"]
+ ex["coord"] = out["coord"]
+ ex["class"] = y
+ return ex
diff --git a/repositories/taming/data/helper_types.py b/repositories/taming/data/helper_types.py
new file mode 100644
index 000000000..fb51e301d
--- /dev/null
+++ b/repositories/taming/data/helper_types.py
@@ -0,0 +1,49 @@
+from typing import Dict, Tuple, Optional, NamedTuple, Union
+from PIL.Image import Image as pil_image
+from torch import Tensor
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+Image = Union[Tensor, pil_image]
+BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
+CropMethodType = Literal['none', 'random', 'center', 'random-2d']
+SplitType = Literal['train', 'validation', 'test']
+
+
+class ImageDescription(NamedTuple):
+ id: int
+ file_name: str
+ original_size: Tuple[int, int] # w, h
+ url: Optional[str] = None
+ license: Optional[int] = None
+ coco_url: Optional[str] = None
+ date_captured: Optional[str] = None
+ flickr_url: Optional[str] = None
+ flickr_id: Optional[str] = None
+ coco_id: Optional[str] = None
+
+
+class Category(NamedTuple):
+ id: str
+ super_category: Optional[str]
+ name: str
+
+
+class Annotation(NamedTuple):
+ area: float
+ image_id: str
+ bbox: BoundingBox
+ category_no: int
+ category_id: str
+ id: Optional[int] = None
+ source: Optional[str] = None
+ confidence: Optional[float] = None
+ is_group_of: Optional[bool] = None
+ is_truncated: Optional[bool] = None
+ is_occluded: Optional[bool] = None
+ is_depiction: Optional[bool] = None
+ is_inside: Optional[bool] = None
+ segmentation: Optional[Dict] = None
diff --git a/repositories/taming/data/image_transforms.py b/repositories/taming/data/image_transforms.py
new file mode 100644
index 000000000..657ac3321
--- /dev/null
+++ b/repositories/taming/data/image_transforms.py
@@ -0,0 +1,132 @@
+import random
+import warnings
+from typing import Union
+
+import torch
+from torch import Tensor
+from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor
+from torchvision.transforms.functional import _get_image_size as get_image_size
+
+from taming.data.helper_types import BoundingBox, Image
+
+pil_to_tensor = PILToTensor()
+
+
+def convert_pil_to_tensor(image: Image) -> Tensor:
+ with warnings.catch_warnings():
+ # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194
+ warnings.simplefilter("ignore")
+ return pil_to_tensor(image)
+
+
+class RandomCrop1dReturnCoordinates(RandomCrop):
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+ if self.padding is not None:
+ img = F.pad(img, self.padding, self.fill, self.padding_mode)
+
+ width, height = get_image_size(img)
+ # pad the width if needed
+ if self.pad_if_needed and width < self.size[1]:
+ padding = [self.size[1] - width, 0]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+ # pad the height if needed
+ if self.pad_if_needed and height < self.size[0]:
+ padding = [0, self.size[0] - height]
+ img = F.pad(img, padding, self.fill, self.padding_mode)
+
+ i, j, h, w = self.get_params(img, self.size)
+ bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h
+ return bbox, F.crop(img, i, j, h, w)
+
+
+class Random2dCropReturnCoordinates(torch.nn.Module):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+
+ Based on:
+ torchvision.transforms.RandomCrop, torchvision 1.7.0
+ """
+
+ def __init__(self, min_size: int):
+ super().__init__()
+ self.min_size = min_size
+
+ def forward(self, img: Image) -> (BoundingBox, Image):
+ width, height = get_image_size(img)
+ max_size = min(width, height)
+ if max_size <= self.min_size:
+ size = max_size
+ else:
+ size = random.randint(self.min_size, max_size)
+ top = random.randint(0, height - size)
+ left = random.randint(0, width - size)
+ bbox = left / width, top / height, size / width, size / height
+ return bbox, F.crop(img, top, left, size, size)
+
+
+class CenterCropReturnCoordinates(CenterCrop):
+ @staticmethod
+ def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox:
+ if width > height:
+ w = height / width
+ h = 1.0
+ x0 = 0.5 - w / 2
+ y0 = 0.
+ else:
+ w = 1.0
+ h = width / height
+ x0 = 0.
+ y0 = 0.5 - h / 2
+ return x0, y0, w, h
+
+ def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]):
+ """
+ Additionally to cropping, returns the relative coordinates of the crop bounding box.
+ Args:
+ img (PIL Image or Tensor): Image to be cropped.
+
+ Returns:
+ Bounding box: x0, y0, w, h
+ PIL Image or Tensor: Cropped image.
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ width, height = get_image_size(img)
+ return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size)
+
+
+class RandomHorizontalFlipReturn(RandomHorizontalFlip):
+ def forward(self, img: Image) -> (bool, Image):
+ """
+ Additionally to flipping, returns a boolean whether it was flipped or not.
+ Args:
+ img (PIL Image or Tensor): Image to be flipped.
+
+ Returns:
+ flipped: whether the image was flipped or not
+ PIL Image or Tensor: Randomly flipped image.
+
+ Based on:
+ torchvision.transforms.RandomHorizontalFlip (version 1.7.0)
+ """
+ if torch.rand(1) < self.p:
+ return True, F.hflip(img)
+ return False, img
diff --git a/repositories/taming/data/imagenet.py b/repositories/taming/data/imagenet.py
new file mode 100644
index 000000000..9a02ec44b
--- /dev/null
+++ b/repositories/taming/data/imagenet.py
@@ -0,0 +1,558 @@
+import os, tarfile, glob, shutil
+import yaml
+import numpy as np
+from tqdm import tqdm
+from PIL import Image
+import albumentations
+from omegaconf import OmegaConf
+from torch.utils.data import Dataset
+
+from taming.data.base import ImagePaths
+from taming.util import download, retrieve
+import taming.data.utils as bdu
+
+
+def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
+ synsets = []
+ with open(path_to_yaml) as f:
+ di2s = yaml.load(f)
+ for idx in indices:
+ synsets.append(str(di2s[idx]))
+ print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
+ return synsets
+
+
+def str_to_indices(string):
+ """Expects a string in the format '32-123, 256, 280-321'"""
+ assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
+ subs = string.split(",")
+ indices = []
+ for sub in subs:
+ subsubs = sub.split("-")
+ assert len(subsubs) > 0
+ if len(subsubs) == 1:
+ indices.append(int(subsubs[0]))
+ else:
+ rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
+ indices.extend(rang)
+ return sorted(indices)
+
+
+class ImageNetBase(Dataset):
+ def __init__(self, config=None):
+ self.config = config or OmegaConf.create()
+ if not type(self.config)==dict:
+ self.config = OmegaConf.to_container(self.config)
+ self._prepare()
+ self._prepare_synset_to_human()
+ self._prepare_idx_to_synset()
+ self._load()
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, i):
+ return self.data[i]
+
+ def _prepare(self):
+ raise NotImplementedError()
+
+ def _filter_relpaths(self, relpaths):
+ ignore = set([
+ "n06596364_9591.JPEG",
+ ])
+ relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
+ if "sub_indices" in self.config:
+ indices = str_to_indices(self.config["sub_indices"])
+ synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
+ files = []
+ for rpath in relpaths:
+ syn = rpath.split("/")[0]
+ if syn in synsets:
+ files.append(rpath)
+ return files
+ else:
+ return relpaths
+
+ def _prepare_synset_to_human(self):
+ SIZE = 2655750
+ URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
+ self.human_dict = os.path.join(self.root, "synset_human.txt")
+ if (not os.path.exists(self.human_dict) or
+ not os.path.getsize(self.human_dict)==SIZE):
+ download(URL, self.human_dict)
+
+ def _prepare_idx_to_synset(self):
+ URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
+ self.idx2syn = os.path.join(self.root, "index_synset.yaml")
+ if (not os.path.exists(self.idx2syn)):
+ download(URL, self.idx2syn)
+
+ def _load(self):
+ with open(self.txt_filelist, "r") as f:
+ self.relpaths = f.read().splitlines()
+ l1 = len(self.relpaths)
+ self.relpaths = self._filter_relpaths(self.relpaths)
+ print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
+
+ self.synsets = [p.split("/")[0] for p in self.relpaths]
+ self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
+
+ unique_synsets = np.unique(self.synsets)
+ class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
+ self.class_labels = [class_dict[s] for s in self.synsets]
+
+ with open(self.human_dict, "r") as f:
+ human_dict = f.read().splitlines()
+ human_dict = dict(line.split(maxsplit=1) for line in human_dict)
+
+ self.human_labels = [human_dict[s] for s in self.synsets]
+
+ labels = {
+ "relpath": np.array(self.relpaths),
+ "synsets": np.array(self.synsets),
+ "class_label": np.array(self.class_labels),
+ "human_label": np.array(self.human_labels),
+ }
+ self.data = ImagePaths(self.abspaths,
+ labels=labels,
+ size=retrieve(self.config, "size", default=0),
+ random_crop=self.random_crop)
+
+
+class ImageNetTrain(ImageNetBase):
+ NAME = "ILSVRC2012_train"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
+ FILES = [
+ "ILSVRC2012_img_train.tar",
+ ]
+ SIZES = [
+ 147897477120,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
+ default=True)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 1281167
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ print("Extracting sub-tars.")
+ subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
+ for subpath in tqdm(subpaths):
+ subdir = subpath[:-len(".tar")]
+ os.makedirs(subdir, exist_ok=True)
+ with tarfile.open(subpath, "r:") as tar:
+ tar.extractall(path=subdir)
+
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+class ImageNetValidation(ImageNetBase):
+ NAME = "ILSVRC2012_validation"
+ URL = "http://www.image-net.org/challenges/LSVRC/2012/"
+ AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
+ VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
+ FILES = [
+ "ILSVRC2012_img_val.tar",
+ "validation_synset.txt",
+ ]
+ SIZES = [
+ 6744924160,
+ 1950000,
+ ]
+
+ def _prepare(self):
+ self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
+ default=False)
+ cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
+ self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
+ self.datadir = os.path.join(self.root, "data")
+ self.txt_filelist = os.path.join(self.root, "filelist.txt")
+ self.expected_length = 50000
+ if not bdu.is_prepared(self.root):
+ # prep
+ print("Preparing dataset {} in {}".format(self.NAME, self.root))
+
+ datadir = self.datadir
+ if not os.path.exists(datadir):
+ path = os.path.join(self.root, self.FILES[0])
+ if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
+ import academictorrents as at
+ atpath = at.get(self.AT_HASH, datastore=self.root)
+ assert atpath == path
+
+ print("Extracting {} to {}".format(path, datadir))
+ os.makedirs(datadir, exist_ok=True)
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=datadir)
+
+ vspath = os.path.join(self.root, self.FILES[1])
+ if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
+ download(self.VS_URL, vspath)
+
+ with open(vspath, "r") as f:
+ synset_dict = f.read().splitlines()
+ synset_dict = dict(line.split() for line in synset_dict)
+
+ print("Reorganizing into synset folders")
+ synsets = np.unique(list(synset_dict.values()))
+ for s in synsets:
+ os.makedirs(os.path.join(datadir, s), exist_ok=True)
+ for k, v in synset_dict.items():
+ src = os.path.join(datadir, k)
+ dst = os.path.join(datadir, v)
+ shutil.move(src, dst)
+
+ filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
+ filelist = [os.path.relpath(p, start=datadir) for p in filelist]
+ filelist = sorted(filelist)
+ filelist = "\n".join(filelist)+"\n"
+ with open(self.txt_filelist, "w") as f:
+ f.write(filelist)
+
+ bdu.mark_prepared(self.root)
+
+
+def get_preprocessor(size=None, random_crop=False, additional_targets=None,
+ crop_size=None):
+ if size is not None and size > 0:
+ transforms = list()
+ rescaler = albumentations.SmallestMaxSize(max_size = size)
+ transforms.append(rescaler)
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=size,width=size)
+ transforms.append(cropper)
+ else:
+ cropper = albumentations.RandomCrop(height=size,width=size)
+ transforms.append(cropper)
+ flipper = albumentations.HorizontalFlip()
+ transforms.append(flipper)
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ elif crop_size is not None and crop_size > 0:
+ if not random_crop:
+ cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
+ transforms = [cropper]
+ preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ preprocessor = lambda **kwargs: kwargs
+ return preprocessor
+
+
+def rgba_to_depth(x):
+ assert x.dtype == np.uint8
+ assert len(x.shape) == 3 and x.shape[2] == 4
+ y = x.copy()
+ y.dtype = np.float32
+ y = y.reshape(x.shape[:2])
+ return np.ascontiguousarray(y)
+
+
+class BaseWithDepth(Dataset):
+ DEFAULT_DEPTH_ROOT="data/imagenet_depth"
+
+ def __init__(self, config=None, size=None, random_crop=False,
+ crop_size=None, root=None):
+ self.config = config
+ self.base_dset = self.get_base_dset()
+ self.preprocessor = get_preprocessor(
+ size=size,
+ crop_size=crop_size,
+ random_crop=random_crop,
+ additional_targets={"depth": "image"})
+ self.crop_size = crop_size
+ if self.crop_size is not None:
+ self.rescaler = albumentations.Compose(
+ [albumentations.SmallestMaxSize(max_size = self.crop_size)],
+ additional_targets={"depth": "image"})
+ if root is not None:
+ self.DEFAULT_DEPTH_ROOT = root
+
+ def __len__(self):
+ return len(self.base_dset)
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = self.base_dset[i]
+ e["depth"] = self.preprocess_depth(self.get_depth_path(e))
+ # up if necessary
+ h,w,c = e["image"].shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ out = self.rescaler(image=e["image"], depth=e["depth"])
+ e["image"] = out["image"]
+ e["depth"] = out["depth"]
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+class ImageNetTrainWithDepth(BaseWithDepth):
+ # default to random_crop=True
+ def __init__(self, random_crop=True, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetTrain()
+ else:
+ return ImageNetTrain({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
+ return fid
+
+
+class ImageNetValidationWithDepth(BaseWithDepth):
+ def __init__(self, sub_indices=None, **kwargs):
+ self.sub_indices = sub_indices
+ super().__init__(**kwargs)
+
+ def get_base_dset(self):
+ if self.sub_indices is None:
+ return ImageNetValidation()
+ else:
+ return ImageNetValidation({"sub_indices": self.sub_indices})
+
+ def get_depth_path(self, e):
+ fid = os.path.splitext(e["relpath"])[0]+".png"
+ fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
+ return fid
+
+
+class RINTrainWithDepth(ImageNetTrainWithDepth):
+ def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class RINValidationWithDepth(ImageNetValidationWithDepth):
+ def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
+ sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
+ super().__init__(config=config, size=size, random_crop=random_crop,
+ sub_indices=sub_indices, crop_size=crop_size)
+
+
+class DRINExamples(Dataset):
+ def __init__(self):
+ self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
+ with open("data/drin_examples.txt", "r") as f:
+ relpaths = f.read().splitlines()
+ self.image_paths = [os.path.join("data/drin_images",
+ relpath) for relpath in relpaths]
+ self.depth_paths = [os.path.join("data/drin_depth",
+ relpath.replace(".JPEG", ".png")) for relpath in relpaths]
+
+ def __len__(self):
+ return len(self.image_paths)
+
+ def preprocess_image(self, image_path):
+ image = Image.open(image_path)
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ image = self.preprocessor(image=image)["image"]
+ image = (image/127.5 - 1.0).astype(np.float32)
+ return image
+
+ def preprocess_depth(self, path):
+ rgba = np.array(Image.open(path))
+ depth = rgba_to_depth(rgba)
+ depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
+ depth = 2.0*depth-1.0
+ return depth
+
+ def __getitem__(self, i):
+ e = dict()
+ e["image"] = self.preprocess_image(self.image_paths[i])
+ e["depth"] = self.preprocess_depth(self.depth_paths[i])
+ transformed = self.preprocessor(image=e["image"], depth=e["depth"])
+ e["image"] = transformed["image"]
+ e["depth"] = transformed["depth"]
+ return e
+
+
+def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
+ if factor is None or factor==1:
+ return x
+
+ dtype = x.dtype
+ assert dtype in [np.float32, np.float64]
+ assert x.min() >= -1
+ assert x.max() <= 1
+
+ keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
+ "bicubic": Image.BICUBIC}[keepmode]
+
+ lr = (x+1.0)*127.5
+ lr = lr.clip(0,255).astype(np.uint8)
+ lr = Image.fromarray(lr)
+
+ h, w, _ = x.shape
+ nh = h//factor
+ nw = w//factor
+ assert nh > 0 and nw > 0, (nh, nw)
+
+ lr = lr.resize((nw,nh), Image.BICUBIC)
+ if keepshapes:
+ lr = lr.resize((w,h), keepmode)
+ lr = np.array(lr)/127.5-1.0
+ lr = lr.astype(dtype)
+
+ return lr
+
+
+class ImageNetScale(Dataset):
+ def __init__(self, size=None, crop_size=None, random_crop=False,
+ up_factor=None, hr_factor=None, keep_mode="bicubic"):
+ self.base = self.get_base()
+
+ self.size = size
+ self.crop_size = crop_size if crop_size is not None else self.size
+ self.random_crop = random_crop
+ self.up_factor = up_factor
+ self.hr_factor = hr_factor
+ self.keep_mode = keep_mode
+
+ transforms = list()
+
+ if self.size is not None and self.size > 0:
+ rescaler = albumentations.SmallestMaxSize(max_size = self.size)
+ self.rescaler = rescaler
+ transforms.append(rescaler)
+
+ if self.crop_size is not None and self.crop_size > 0:
+ if len(transforms) == 0:
+ self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
+
+ if not self.random_crop:
+ cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
+ else:
+ cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
+ transforms.append(cropper)
+
+ if len(transforms) > 0:
+ if self.up_factor is not None:
+ additional_targets = {"lr": "image"}
+ else:
+ additional_targets = None
+ self.preprocessor = albumentations.Compose(transforms,
+ additional_targets=additional_targets)
+ else:
+ self.preprocessor = lambda **kwargs: kwargs
+
+ def __len__(self):
+ return len(self.base)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ # adjust resolution
+ image = imscale(image, self.hr_factor, keepshapes=False)
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+ if self.up_factor is None:
+ image = self.preprocessor(image=image)["image"]
+ example["image"] = image
+ else:
+ lr = imscale(image, self.up_factor, keepshapes=True,
+ keepmode=self.keep_mode)
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+class ImageNetScaleTrain(ImageNetScale):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetScaleValidation(ImageNetScale):
+ def get_base(self):
+ return ImageNetValidation()
+
+
+from skimage.feature import canny
+from skimage.color import rgb2gray
+
+
+class ImageNetEdges(ImageNetScale):
+ def __init__(self, up_factor=1, **kwargs):
+ super().__init__(up_factor=1, **kwargs)
+
+ def __getitem__(self, i):
+ example = self.base[i]
+ image = example["image"]
+ h,w,c = image.shape
+ if self.crop_size and min(h,w) < self.crop_size:
+ # have to upscale to be able to crop - this just uses bilinear
+ image = self.rescaler(image=image)["image"]
+
+ lr = canny(rgb2gray(image), sigma=2)
+ lr = lr.astype(np.float32)
+ lr = lr[:,:,None][:,:,[0,0,0]]
+
+ out = self.preprocessor(image=image, lr=lr)
+ example["image"] = out["image"]
+ example["lr"] = out["lr"]
+
+ return example
+
+
+class ImageNetEdgesTrain(ImageNetEdges):
+ def __init__(self, random_crop=True, **kwargs):
+ super().__init__(random_crop=random_crop, **kwargs)
+
+ def get_base(self):
+ return ImageNetTrain()
+
+class ImageNetEdgesValidation(ImageNetEdges):
+ def get_base(self):
+ return ImageNetValidation()
diff --git a/repositories/taming/data/open_images_helper.py b/repositories/taming/data/open_images_helper.py
new file mode 100644
index 000000000..8feb7c6e7
--- /dev/null
+++ b/repositories/taming/data/open_images_helper.py
@@ -0,0 +1,379 @@
+open_images_unify_categories_for_coco = {
+ '/m/03bt1vf': '/m/01g317',
+ '/m/04yx4': '/m/01g317',
+ '/m/05r655': '/m/01g317',
+ '/m/01bl7v': '/m/01g317',
+ '/m/0cnyhnx': '/m/01xq0k1',
+ '/m/01226z': '/m/018xm',
+ '/m/05ctyq': '/m/018xm',
+ '/m/058qzx': '/m/04ctx',
+ '/m/06pcq': '/m/0l515',
+ '/m/03m3pdh': '/m/02crq1',
+ '/m/046dlr': '/m/01x3z',
+ '/m/0h8mzrc': '/m/01x3z',
+}
+
+
+top_300_classes_plus_coco_compatibility = [
+ ('Man', 1060962),
+ ('Clothing', 986610),
+ ('Tree', 748162),
+ ('Woman', 611896),
+ ('Person', 610294),
+ ('Human face', 442948),
+ ('Girl', 175399),
+ ('Building', 162147),
+ ('Car', 159135),
+ ('Plant', 155704),
+ ('Human body', 137073),
+ ('Flower', 133128),
+ ('Window', 127485),
+ ('Human arm', 118380),
+ ('House', 114365),
+ ('Wheel', 111684),
+ ('Suit', 99054),
+ ('Human hair', 98089),
+ ('Human head', 92763),
+ ('Chair', 88624),
+ ('Boy', 79849),
+ ('Table', 73699),
+ ('Jeans', 57200),
+ ('Tire', 55725),
+ ('Skyscraper', 53321),
+ ('Food', 52400),
+ ('Footwear', 50335),
+ ('Dress', 50236),
+ ('Human leg', 47124),
+ ('Toy', 46636),
+ ('Tower', 45605),
+ ('Boat', 43486),
+ ('Land vehicle', 40541),
+ ('Bicycle wheel', 34646),
+ ('Palm tree', 33729),
+ ('Fashion accessory', 32914),
+ ('Glasses', 31940),
+ ('Bicycle', 31409),
+ ('Furniture', 30656),
+ ('Sculpture', 29643),
+ ('Bottle', 27558),
+ ('Dog', 26980),
+ ('Snack', 26796),
+ ('Human hand', 26664),
+ ('Bird', 25791),
+ ('Book', 25415),
+ ('Guitar', 24386),
+ ('Jacket', 23998),
+ ('Poster', 22192),
+ ('Dessert', 21284),
+ ('Baked goods', 20657),
+ ('Drink', 19754),
+ ('Flag', 18588),
+ ('Houseplant', 18205),
+ ('Tableware', 17613),
+ ('Airplane', 17218),
+ ('Door', 17195),
+ ('Sports uniform', 17068),
+ ('Shelf', 16865),
+ ('Drum', 16612),
+ ('Vehicle', 16542),
+ ('Microphone', 15269),
+ ('Street light', 14957),
+ ('Cat', 14879),
+ ('Fruit', 13684),
+ ('Fast food', 13536),
+ ('Animal', 12932),
+ ('Vegetable', 12534),
+ ('Train', 12358),
+ ('Horse', 11948),
+ ('Flowerpot', 11728),
+ ('Motorcycle', 11621),
+ ('Fish', 11517),
+ ('Desk', 11405),
+ ('Helmet', 10996),
+ ('Truck', 10915),
+ ('Bus', 10695),
+ ('Hat', 10532),
+ ('Auto part', 10488),
+ ('Musical instrument', 10303),
+ ('Sunglasses', 10207),
+ ('Picture frame', 10096),
+ ('Sports equipment', 10015),
+ ('Shorts', 9999),
+ ('Wine glass', 9632),
+ ('Duck', 9242),
+ ('Wine', 9032),
+ ('Rose', 8781),
+ ('Tie', 8693),
+ ('Butterfly', 8436),
+ ('Beer', 7978),
+ ('Cabinetry', 7956),
+ ('Laptop', 7907),
+ ('Insect', 7497),
+ ('Goggles', 7363),
+ ('Shirt', 7098),
+ ('Dairy Product', 7021),
+ ('Marine invertebrates', 7014),
+ ('Cattle', 7006),
+ ('Trousers', 6903),
+ ('Van', 6843),
+ ('Billboard', 6777),
+ ('Balloon', 6367),
+ ('Human nose', 6103),
+ ('Tent', 6073),
+ ('Camera', 6014),
+ ('Doll', 6002),
+ ('Coat', 5951),
+ ('Mobile phone', 5758),
+ ('Swimwear', 5729),
+ ('Strawberry', 5691),
+ ('Stairs', 5643),
+ ('Goose', 5599),
+ ('Umbrella', 5536),
+ ('Cake', 5508),
+ ('Sun hat', 5475),
+ ('Bench', 5310),
+ ('Bookcase', 5163),
+ ('Bee', 5140),
+ ('Computer monitor', 5078),
+ ('Hiking equipment', 4983),
+ ('Office building', 4981),
+ ('Coffee cup', 4748),
+ ('Curtain', 4685),
+ ('Plate', 4651),
+ ('Box', 4621),
+ ('Tomato', 4595),
+ ('Coffee table', 4529),
+ ('Office supplies', 4473),
+ ('Maple', 4416),
+ ('Muffin', 4365),
+ ('Cocktail', 4234),
+ ('Castle', 4197),
+ ('Couch', 4134),
+ ('Pumpkin', 3983),
+ ('Computer keyboard', 3960),
+ ('Human mouth', 3926),
+ ('Christmas tree', 3893),
+ ('Mushroom', 3883),
+ ('Swimming pool', 3809),
+ ('Pastry', 3799),
+ ('Lavender (Plant)', 3769),
+ ('Football helmet', 3732),
+ ('Bread', 3648),
+ ('Traffic sign', 3628),
+ ('Common sunflower', 3597),
+ ('Television', 3550),
+ ('Bed', 3525),
+ ('Cookie', 3485),
+ ('Fountain', 3484),
+ ('Paddle', 3447),
+ ('Bicycle helmet', 3429),
+ ('Porch', 3420),
+ ('Deer', 3387),
+ ('Fedora', 3339),
+ ('Canoe', 3338),
+ ('Carnivore', 3266),
+ ('Bowl', 3202),
+ ('Human eye', 3166),
+ ('Ball', 3118),
+ ('Pillow', 3077),
+ ('Salad', 3061),
+ ('Beetle', 3060),
+ ('Orange', 3050),
+ ('Drawer', 2958),
+ ('Platter', 2937),
+ ('Elephant', 2921),
+ ('Seafood', 2921),
+ ('Monkey', 2915),
+ ('Countertop', 2879),
+ ('Watercraft', 2831),
+ ('Helicopter', 2805),
+ ('Kitchen appliance', 2797),
+ ('Personal flotation device', 2781),
+ ('Swan', 2739),
+ ('Lamp', 2711),
+ ('Boot', 2695),
+ ('Bronze sculpture', 2693),
+ ('Chicken', 2677),
+ ('Taxi', 2643),
+ ('Juice', 2615),
+ ('Cowboy hat', 2604),
+ ('Apple', 2600),
+ ('Tin can', 2590),
+ ('Necklace', 2564),
+ ('Ice cream', 2560),
+ ('Human beard', 2539),
+ ('Coin', 2536),
+ ('Candle', 2515),
+ ('Cart', 2512),
+ ('High heels', 2441),
+ ('Weapon', 2433),
+ ('Handbag', 2406),
+ ('Penguin', 2396),
+ ('Rifle', 2352),
+ ('Violin', 2336),
+ ('Skull', 2304),
+ ('Lantern', 2285),
+ ('Scarf', 2269),
+ ('Saucer', 2225),
+ ('Sheep', 2215),
+ ('Vase', 2189),
+ ('Lily', 2180),
+ ('Mug', 2154),
+ ('Parrot', 2140),
+ ('Human ear', 2137),
+ ('Sandal', 2115),
+ ('Lizard', 2100),
+ ('Kitchen & dining room table', 2063),
+ ('Spider', 1977),
+ ('Coffee', 1974),
+ ('Goat', 1926),
+ ('Squirrel', 1922),
+ ('Cello', 1913),
+ ('Sushi', 1881),
+ ('Tortoise', 1876),
+ ('Pizza', 1870),
+ ('Studio couch', 1864),
+ ('Barrel', 1862),
+ ('Cosmetics', 1841),
+ ('Moths and butterflies', 1841),
+ ('Convenience store', 1817),
+ ('Watch', 1792),
+ ('Home appliance', 1786),
+ ('Harbor seal', 1780),
+ ('Luggage and bags', 1756),
+ ('Vehicle registration plate', 1754),
+ ('Shrimp', 1751),
+ ('Jellyfish', 1730),
+ ('French fries', 1723),
+ ('Egg (Food)', 1698),
+ ('Football', 1697),
+ ('Musical keyboard', 1683),
+ ('Falcon', 1674),
+ ('Candy', 1660),
+ ('Medical equipment', 1654),
+ ('Eagle', 1651),
+ ('Dinosaur', 1634),
+ ('Surfboard', 1630),
+ ('Tank', 1628),
+ ('Grape', 1624),
+ ('Lion', 1624),
+ ('Owl', 1622),
+ ('Ski', 1613),
+ ('Waste container', 1606),
+ ('Frog', 1591),
+ ('Sparrow', 1585),
+ ('Rabbit', 1581),
+ ('Pen', 1546),
+ ('Sea lion', 1537),
+ ('Spoon', 1521),
+ ('Sink', 1512),
+ ('Teddy bear', 1507),
+ ('Bull', 1495),
+ ('Sofa bed', 1490),
+ ('Dragonfly', 1479),
+ ('Brassiere', 1478),
+ ('Chest of drawers', 1472),
+ ('Aircraft', 1466),
+ ('Human foot', 1463),
+ ('Pig', 1455),
+ ('Fork', 1454),
+ ('Antelope', 1438),
+ ('Tripod', 1427),
+ ('Tool', 1424),
+ ('Cheese', 1422),
+ ('Lemon', 1397),
+ ('Hamburger', 1393),
+ ('Dolphin', 1390),
+ ('Mirror', 1390),
+ ('Marine mammal', 1387),
+ ('Giraffe', 1385),
+ ('Snake', 1368),
+ ('Gondola', 1364),
+ ('Wheelchair', 1360),
+ ('Piano', 1358),
+ ('Cupboard', 1348),
+ ('Banana', 1345),
+ ('Trumpet', 1335),
+ ('Lighthouse', 1333),
+ ('Invertebrate', 1317),
+ ('Carrot', 1268),
+ ('Sock', 1260),
+ ('Tiger', 1241),
+ ('Camel', 1224),
+ ('Parachute', 1224),
+ ('Bathroom accessory', 1223),
+ ('Earrings', 1221),
+ ('Headphones', 1218),
+ ('Skirt', 1198),
+ ('Skateboard', 1190),
+ ('Sandwich', 1148),
+ ('Saxophone', 1141),
+ ('Goldfish', 1136),
+ ('Stool', 1104),
+ ('Traffic light', 1097),
+ ('Shellfish', 1081),
+ ('Backpack', 1079),
+ ('Sea turtle', 1078),
+ ('Cucumber', 1075),
+ ('Tea', 1051),
+ ('Toilet', 1047),
+ ('Roller skates', 1040),
+ ('Mule', 1039),
+ ('Bust', 1031),
+ ('Broccoli', 1030),
+ ('Crab', 1020),
+ ('Oyster', 1019),
+ ('Cannon', 1012),
+ ('Zebra', 1012),
+ ('French horn', 1008),
+ ('Grapefruit', 998),
+ ('Whiteboard', 997),
+ ('Zucchini', 997),
+ ('Crocodile', 992),
+
+ ('Clock', 960),
+ ('Wall clock', 958),
+
+ ('Doughnut', 869),
+ ('Snail', 868),
+
+ ('Baseball glove', 859),
+
+ ('Panda', 830),
+ ('Tennis racket', 830),
+
+ ('Pear', 652),
+
+ ('Bagel', 617),
+ ('Oven', 616),
+ ('Ladybug', 615),
+ ('Shark', 615),
+ ('Polar bear', 614),
+ ('Ostrich', 609),
+
+ ('Hot dog', 473),
+ ('Microwave oven', 467),
+ ('Fire hydrant', 20),
+ ('Stop sign', 20),
+ ('Parking meter', 20),
+ ('Bear', 20),
+ ('Flying disc', 20),
+ ('Snowboard', 20),
+ ('Tennis ball', 20),
+ ('Kite', 20),
+ ('Baseball bat', 20),
+ ('Kitchen knife', 20),
+ ('Knife', 20),
+ ('Submarine sandwich', 20),
+ ('Computer mouse', 20),
+ ('Remote control', 20),
+ ('Toaster', 20),
+ ('Sink', 20),
+ ('Refrigerator', 20),
+ ('Alarm clock', 20),
+ ('Wall clock', 20),
+ ('Scissors', 20),
+ ('Hair dryer', 20),
+ ('Toothbrush', 20),
+ ('Suitcase', 20)
+]
diff --git a/repositories/taming/data/sflckr.py b/repositories/taming/data/sflckr.py
new file mode 100644
index 000000000..91101be59
--- /dev/null
+++ b/repositories/taming/data/sflckr.py
@@ -0,0 +1,91 @@
+import os
+import numpy as np
+import cv2
+import albumentations
+from PIL import Image
+from torch.utils.data import Dataset
+
+
+class SegmentationBase(Dataset):
+ def __init__(self,
+ data_csv, data_root, segmentation_root,
+ size=None, random_crop=False, interpolation="bicubic",
+ n_labels=182, shift_segmentation=False,
+ ):
+ self.n_labels = n_labels
+ self.shift_segmentation = shift_segmentation
+ self.data_csv = data_csv
+ self.data_root = data_root
+ self.segmentation_root = segmentation_root
+ with open(self.data_csv, "r") as f:
+ self.image_paths = f.read().splitlines()
+ self._length = len(self.image_paths)
+ self.labels = {
+ "relative_file_path_": [l for l in self.image_paths],
+ "file_path_": [os.path.join(self.data_root, l)
+ for l in self.image_paths],
+ "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
+ for l in self.image_paths]
+ }
+
+ size = None if size is not None and size<=0 else size
+ self.size = size
+ if self.size is not None:
+ self.interpolation = interpolation
+ self.interpolation = {
+ "nearest": cv2.INTER_NEAREST,
+ "bilinear": cv2.INTER_LINEAR,
+ "bicubic": cv2.INTER_CUBIC,
+ "area": cv2.INTER_AREA,
+ "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
+ self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=self.interpolation)
+ self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
+ interpolation=cv2.INTER_NEAREST)
+ self.center_crop = not random_crop
+ if self.center_crop:
+ self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
+ else:
+ self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
+ self.preprocessor = self.cropper
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, i):
+ example = dict((k, self.labels[k][i]) for k in self.labels)
+ image = Image.open(example["file_path_"])
+ if not image.mode == "RGB":
+ image = image.convert("RGB")
+ image = np.array(image).astype(np.uint8)
+ if self.size is not None:
+ image = self.image_rescaler(image=image)["image"]
+ segmentation = Image.open(example["segmentation_path_"])
+ assert segmentation.mode == "L", segmentation.mode
+ segmentation = np.array(segmentation).astype(np.uint8)
+ if self.shift_segmentation:
+ # used to support segmentations containing unlabeled==255 label
+ segmentation = segmentation+1
+ if self.size is not None:
+ segmentation = self.segmentation_rescaler(image=segmentation)["image"]
+ if self.size is not None:
+ processed = self.preprocessor(image=image,
+ mask=segmentation
+ )
+ else:
+ processed = {"image": image,
+ "mask": segmentation
+ }
+ example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
+ segmentation = processed["mask"]
+ onehot = np.eye(self.n_labels)[segmentation]
+ example["segmentation"] = onehot
+ return example
+
+
+class Examples(SegmentationBase):
+ def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
+ super().__init__(data_csv="data/sflckr_examples.txt",
+ data_root="data/sflckr_images",
+ segmentation_root="data/sflckr_segmentations",
+ size=size, random_crop=random_crop, interpolation=interpolation)
diff --git a/repositories/taming/data/utils.py b/repositories/taming/data/utils.py
new file mode 100644
index 000000000..2b3c3d53c
--- /dev/null
+++ b/repositories/taming/data/utils.py
@@ -0,0 +1,169 @@
+import collections
+import os
+import tarfile
+import urllib
+import zipfile
+from pathlib import Path
+
+import numpy as np
+import torch
+from taming.data.helper_types import Annotation
+from torch._six import string_classes
+from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
+from tqdm import tqdm
+
+
+def unpack(path):
+ if path.endswith("tar.gz"):
+ with tarfile.open(path, "r:gz") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("tar"):
+ with tarfile.open(path, "r:") as tar:
+ tar.extractall(path=os.path.split(path)[0])
+ elif path.endswith("zip"):
+ with zipfile.ZipFile(path, "r") as f:
+ f.extractall(path=os.path.split(path)[0])
+ else:
+ raise NotImplementedError(
+ "Unknown file extension: {}".format(os.path.splitext(path)[1])
+ )
+
+
+def reporthook(bar):
+ """tqdm progress bar for downloads."""
+
+ def hook(b=1, bsize=1, tsize=None):
+ if tsize is not None:
+ bar.total = tsize
+ bar.update(b * bsize - bar.n)
+
+ return hook
+
+
+def get_root(name):
+ base = "data/"
+ root = os.path.join(base, name)
+ os.makedirs(root, exist_ok=True)
+ return root
+
+
+def is_prepared(root):
+ return Path(root).joinpath(".ready").exists()
+
+
+def mark_prepared(root):
+ Path(root).joinpath(".ready").touch()
+
+
+def prompt_download(file_, source, target_dir, content_dir=None):
+ targetpath = os.path.join(target_dir, file_)
+ while not os.path.exists(targetpath):
+ if content_dir is not None and os.path.exists(
+ os.path.join(target_dir, content_dir)
+ ):
+ break
+ print(
+ "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath)
+ )
+ if content_dir is not None:
+ print(
+ "Or place its content into '{}'.".format(
+ os.path.join(target_dir, content_dir)
+ )
+ )
+ input("Press Enter when done...")
+ return targetpath
+
+
+def download_url(file_, url, target_dir):
+ targetpath = os.path.join(target_dir, file_)
+ os.makedirs(target_dir, exist_ok=True)
+ with tqdm(
+ unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_
+ ) as bar:
+ urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar))
+ return targetpath
+
+
+def download_urls(urls, target_dir):
+ paths = dict()
+ for fname, url in urls.items():
+ outpath = download_url(fname, url, target_dir)
+ paths[fname] = outpath
+ return paths
+
+
+def quadratic_crop(x, bbox, alpha=1.0):
+ """bbox is xmin, ymin, xmax, ymax"""
+ im_h, im_w = x.shape[:2]
+ bbox = np.array(bbox, dtype=np.float32)
+ bbox = np.clip(bbox, 0, max(im_h, im_w))
+ center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3])
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ l = int(alpha * max(w, h))
+ l = max(l, 2)
+
+ required_padding = -1 * min(
+ center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l)
+ )
+ required_padding = int(np.ceil(required_padding))
+ if required_padding > 0:
+ padding = [
+ [required_padding, required_padding],
+ [required_padding, required_padding],
+ ]
+ padding += [[0, 0]] * (len(x.shape) - 2)
+ x = np.pad(x, padding, "reflect")
+ center = center[0] + required_padding, center[1] + required_padding
+ xmin = int(center[0] - l / 2)
+ ymin = int(center[1] - l / 2)
+ return np.array(x[ymin : ymin + l, xmin : xmin + l, ...])
+
+
+def custom_collate(batch):
+ r"""source: pytorch 1.9.0, only one modification to original code """
+
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return custom_collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: custom_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(custom_collate(samples) for samples in zip(*batch)))
+ if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added
+ return batch # added
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError('each element in list of batch should be of equal size')
+ transposed = zip(*batch)
+ return [custom_collate(samples) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
diff --git a/repositories/taming/lr_scheduler.py b/repositories/taming/lr_scheduler.py
new file mode 100644
index 000000000..e598ed120
--- /dev/null
+++ b/repositories/taming/lr_scheduler.py
@@ -0,0 +1,34 @@
+import numpy as np
+
+
+class LambdaWarmUpCosineScheduler:
+ """
+ note: use with a base_lr of 1.0
+ """
+ def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
+ self.lr_warm_up_steps = warm_up_steps
+ self.lr_start = lr_start
+ self.lr_min = lr_min
+ self.lr_max = lr_max
+ self.lr_max_decay_steps = max_decay_steps
+ self.last_lr = 0.
+ self.verbosity_interval = verbosity_interval
+
+ def schedule(self, n):
+ if self.verbosity_interval > 0:
+ if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
+ if n < self.lr_warm_up_steps:
+ lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
+ self.last_lr = lr
+ return lr
+ else:
+ t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
+ t = min(t, 1.0)
+ lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
+ 1 + np.cos(t * np.pi))
+ self.last_lr = lr
+ return lr
+
+ def __call__(self, n):
+ return self.schedule(n)
+
diff --git a/repositories/taming/models/cond_transformer.py b/repositories/taming/models/cond_transformer.py
new file mode 100644
index 000000000..e4c63730f
--- /dev/null
+++ b/repositories/taming/models/cond_transformer.py
@@ -0,0 +1,352 @@
+import os, math
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+from taming.modules.util import SOSProvider
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class Net2NetTransformer(pl.LightningModule):
+ def __init__(self,
+ transformer_config,
+ first_stage_config,
+ cond_stage_config,
+ permuter_config=None,
+ ckpt_path=None,
+ ignore_keys=[],
+ first_stage_key="image",
+ cond_stage_key="depth",
+ downsample_cond_size=-1,
+ pkeep=1.0,
+ sos_token=0,
+ unconditional=False,
+ ):
+ super().__init__()
+ self.be_unconditional = unconditional
+ self.sos_token = sos_token
+ self.first_stage_key = first_stage_key
+ self.cond_stage_key = cond_stage_key
+ self.init_first_stage_from_ckpt(first_stage_config)
+ self.init_cond_stage_from_ckpt(cond_stage_config)
+ if permuter_config is None:
+ permuter_config = {"target": "taming.modules.transformer.permuter.Identity"}
+ self.permuter = instantiate_from_config(config=permuter_config)
+ self.transformer = instantiate_from_config(config=transformer_config)
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.downsample_cond_size = downsample_cond_size
+ self.pkeep = pkeep
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ for k in sd.keys():
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ self.print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def init_first_stage_from_ckpt(self, config):
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.first_stage_model = model
+
+ def init_cond_stage_from_ckpt(self, config):
+ if config == "__is_first_stage__":
+ print("Using first stage also as cond stage.")
+ self.cond_stage_model = self.first_stage_model
+ elif config == "__is_unconditional__" or self.be_unconditional:
+ print(f"Using no cond stage. Assuming the training is intended to be unconditional. "
+ f"Prepending {self.sos_token} as a sos token.")
+ self.be_unconditional = True
+ self.cond_stage_key = self.first_stage_key
+ self.cond_stage_model = SOSProvider(self.sos_token)
+ else:
+ model = instantiate_from_config(config)
+ model = model.eval()
+ model.train = disabled_train
+ self.cond_stage_model = model
+
+ def forward(self, x, c):
+ # one step to produce the logits
+ _, z_indices = self.encode_to_z(x)
+ _, c_indices = self.encode_to_c(c)
+
+ if self.training and self.pkeep < 1.0:
+ mask = torch.bernoulli(self.pkeep*torch.ones(z_indices.shape,
+ device=z_indices.device))
+ mask = mask.round().to(dtype=torch.int64)
+ r_indices = torch.randint_like(z_indices, self.transformer.config.vocab_size)
+ a_indices = mask*z_indices+(1-mask)*r_indices
+ else:
+ a_indices = z_indices
+
+ cz_indices = torch.cat((c_indices, a_indices), dim=1)
+
+ # target includes all sequence elements (no need to handle first one
+ # differently because we are conditioning)
+ target = z_indices
+ # make the prediction
+ logits, _ = self.transformer(cz_indices[:, :-1])
+ # cut off conditioning outputs - output i corresponds to p(z_i | z_{ -1:
+ c = F.interpolate(c, size=(self.downsample_cond_size, self.downsample_cond_size))
+ quant_c, _, [_,_,indices] = self.cond_stage_model.encode(c)
+ if len(indices.shape) > 2:
+ indices = indices.view(c.shape[0], -1)
+ return quant_c, indices
+
+ @torch.no_grad()
+ def decode_to_img(self, index, zshape):
+ index = self.permuter(index, reverse=True)
+ bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
+ quant_z = self.first_stage_model.quantize.get_codebook_entry(
+ index.reshape(-1), shape=bhwc)
+ x = self.first_stage_model.decode(quant_z)
+ return x
+
+ @torch.no_grad()
+ def log_images(self, batch, temperature=None, top_k=None, callback=None, lr_interface=False, **kwargs):
+ log = dict()
+
+ N = 4
+ if lr_interface:
+ x, c = self.get_xc(batch, N, diffuse=False, upsample_factor=8)
+ else:
+ x, c = self.get_xc(batch, N)
+ x = x.to(device=self.device)
+ c = c.to(device=self.device)
+
+ quant_z, z_indices = self.encode_to_z(x)
+ quant_c, c_indices = self.encode_to_c(c)
+
+ # create a "half"" sample
+ z_start_indices = z_indices[:,:z_indices.shape[1]//2]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1]-z_start_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample = self.decode_to_img(index_sample, quant_z.shape)
+
+ # sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ temperature=temperature if temperature is not None else 1.0,
+ sample=True,
+ top_k=top_k if top_k is not None else 100,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_nopix = self.decode_to_img(index_sample, quant_z.shape)
+
+ # det sample
+ z_start_indices = z_indices[:, :0]
+ index_sample = self.sample(z_start_indices, c_indices,
+ steps=z_indices.shape[1],
+ sample=False,
+ callback=callback if callback is not None else lambda k: None)
+ x_sample_det = self.decode_to_img(index_sample, quant_z.shape)
+
+ # reconstruction
+ x_rec = self.decode_to_img(z_indices, quant_z.shape)
+
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+
+ if self.cond_stage_key in ["objects_bbox", "objects_center_points"]:
+ figure_size = (x_rec.shape[2], x_rec.shape[3])
+ dataset = kwargs["pl_module"].trainer.datamodule.datasets["validation"]
+ label_for_category_no = dataset.get_textual_label_for_category_no
+ plotter = dataset.conditional_builders[self.cond_stage_key].plot
+ log["conditioning"] = torch.zeros_like(log["reconstructions"])
+ for i in range(quant_c.shape[0]):
+ log["conditioning"][i] = plotter(quant_c[i], label_for_category_no, figure_size)
+ log["conditioning_rec"] = log["conditioning"]
+ elif self.cond_stage_key != "image":
+ cond_rec = self.cond_stage_model.decode(quant_c)
+ if self.cond_stage_key == "segmentation":
+ # get image from segmentation mask
+ num_classes = cond_rec.shape[1]
+
+ c = torch.argmax(c, dim=1, keepdim=True)
+ c = F.one_hot(c, num_classes=num_classes)
+ c = c.squeeze(1).permute(0, 3, 1, 2).float()
+ c = self.cond_stage_model.to_rgb(c)
+
+ cond_rec = torch.argmax(cond_rec, dim=1, keepdim=True)
+ cond_rec = F.one_hot(cond_rec, num_classes=num_classes)
+ cond_rec = cond_rec.squeeze(1).permute(0, 3, 1, 2).float()
+ cond_rec = self.cond_stage_model.to_rgb(cond_rec)
+ log["conditioning_rec"] = cond_rec
+ log["conditioning"] = c
+
+ log["samples_half"] = x_sample
+ log["samples_nopix"] = x_sample_nopix
+ log["samples_det"] = x_sample_det
+ return log
+
+ def get_input(self, key, batch):
+ x = batch[key]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ if len(x.shape) == 4:
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ if x.dtype == torch.double:
+ x = x.float()
+ return x
+
+ def get_xc(self, batch, N=None):
+ x = self.get_input(self.first_stage_key, batch)
+ c = self.get_input(self.cond_stage_key, batch)
+ if N is not None:
+ x = x[:N]
+ c = c[:N]
+ return x, c
+
+ def shared_step(self, batch, batch_idx):
+ x, c = self.get_xc(batch)
+ logits, target = self(x, c)
+ loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1))
+ return loss
+
+ def training_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("train/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ loss = self.shared_step(batch, batch_idx)
+ self.log("val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def configure_optimizers(self):
+ """
+ Following minGPT:
+ This long function is unfortunately doing something very simple and is being very defensive:
+ We are separating out all parameters of the model into two buckets: those that will experience
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+ We are then returning the PyTorch optimizer object.
+ """
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear, )
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in self.transformer.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # special case the position embedding parameter in the root GPT module as not decayed
+ no_decay.add('pos_emb')
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in self.transformer.named_parameters()}
+ inter_params = decay & no_decay
+ union_params = decay | no_decay
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+ assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+ % (str(param_dict.keys() - union_params), )
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.01},
+ {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
+ ]
+ optimizer = torch.optim.AdamW(optim_groups, lr=self.learning_rate, betas=(0.9, 0.95))
+ return optimizer
diff --git a/repositories/taming/models/dummy_cond_stage.py b/repositories/taming/models/dummy_cond_stage.py
new file mode 100644
index 000000000..6e1993807
--- /dev/null
+++ b/repositories/taming/models/dummy_cond_stage.py
@@ -0,0 +1,22 @@
+from torch import Tensor
+
+
+class DummyCondStage:
+ def __init__(self, conditional_key):
+ self.conditional_key = conditional_key
+ self.train = None
+
+ def eval(self):
+ return self
+
+ @staticmethod
+ def encode(c: Tensor):
+ return c, None, (None, None, c)
+
+ @staticmethod
+ def decode(c: Tensor):
+ return c
+
+ @staticmethod
+ def to_rgb(c: Tensor):
+ return c
diff --git a/repositories/taming/models/vqgan.py b/repositories/taming/models/vqgan.py
new file mode 100644
index 000000000..a6950baa5
--- /dev/null
+++ b/repositories/taming/models/vqgan.py
@@ -0,0 +1,404 @@
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from main import instantiate_from_config
+
+from taming.modules.diffusionmodules.model import Encoder, Decoder
+from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
+from taming.modules.vqvae.quantize import GumbelQuantize
+from taming.modules.vqvae.quantize import EMAVectorQuantizer
+
+class VQModel(pl.LightningModule):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
+ remap=remap, sane_index_shape=sane_index_shape)
+ self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ self.image_key = image_key
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels)==int
+ self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, emb_loss, info = self.quantize(h)
+ return quant, emb_loss, info
+
+ def decode(self, quant):
+ quant = self.post_quant_conv(quant)
+ dec = self.decoder(quant)
+ return dec
+
+ def decode_code(self, code_b):
+ quant_b = self.quantize.embed_code(code_b)
+ dec = self.decode(quant_b)
+ return dec
+
+ def forward(self, input):
+ quant, diff, _ = self.encode(input)
+ dec = self.decode(quant)
+ return dec, diff
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
+ return x.float()
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log("train/aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log("train/discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
+ return x
+
+
+class VQSegmentationModel(VQModel):
+ def __init__(self, n_labels, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.register_buffer("colorize", torch.randn(3, n_labels, 1, 1))
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ return opt_ae
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="train")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, split="val")
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ total_loss = log_dict_ae["val/total_loss"]
+ self.log("val/total_loss", total_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
+ return aeloss
+
+ @torch.no_grad()
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ xrec, _ = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ # convert logits to indices
+ xrec = torch.argmax(xrec, dim=1, keepdim=True)
+ xrec = F.one_hot(xrec, num_classes=x.shape[1])
+ xrec = xrec.squeeze(1).permute(0, 3, 1, 2).float()
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["inputs"] = x
+ log["reconstructions"] = xrec
+ return log
+
+
+class VQNoDiscModel(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None
+ ):
+ super().__init__(ddconfig=ddconfig, lossconfig=lossconfig, n_embed=n_embed, embed_dim=embed_dim,
+ ckpt_path=ckpt_path, ignore_keys=ignore_keys, image_key=image_key,
+ colorize_nlabels=colorize_nlabels)
+
+ def training_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
+ output = pl.TrainResult(minimize=aeloss)
+ output.log("train/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return output
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ output = pl.EvalResult(checkpoint_on=rec_loss)
+ output.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=True, on_epoch=True)
+ output.log_dict(log_dict_ae)
+
+ return output
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quantize.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=self.learning_rate, betas=(0.5, 0.9))
+ return optimizer
+
+
+class GumbelVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ temperature_scheduler_config,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ kl_weight=1e-8,
+ remap=None,
+ ):
+
+ z_channels = ddconfig["z_channels"]
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+
+ self.loss.n_classes = n_embed
+ self.vocab_size = n_embed
+
+ self.quantize = GumbelQuantize(z_channels, embed_dim,
+ n_embed=n_embed,
+ kl_weight=kl_weight, temp_init=1.0,
+ remap=remap)
+
+ self.temperature_scheduler = instantiate_from_config(temperature_scheduler_config) # annealing of temp
+
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+
+ def temperature_scheduling(self):
+ self.quantize.temperature = self.temperature_scheduler(self.global_step)
+
+ def encode_to_prequant(self, x):
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ return h
+
+ def decode_code(self, code_b):
+ raise NotImplementedError
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ self.temperature_scheduling()
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x)
+
+ if optimizer_idx == 0:
+ # autoencode
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+
+ self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ self.log("temperature", self.quantize.temperature, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # discriminator
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
+ last_layer=self.get_last_layer(), split="train")
+ self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ x = self.get_input(batch, self.image_key)
+ xrec, qloss = self(x, return_pred_indices=True)
+ aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+
+ discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, self.global_step,
+ last_layer=self.get_last_layer(), split="val")
+ rec_loss = log_dict_ae["val/rec_loss"]
+ self.log("val/rec_loss", rec_loss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log("val/aeloss", aeloss,
+ prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def log_images(self, batch, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ # encode
+ h = self.encoder(x)
+ h = self.quant_conv(h)
+ quant, _, _ = self.quantize(h)
+ # decode
+ x_rec = self.decode(quant)
+ log["inputs"] = x
+ log["reconstructions"] = x_rec
+ return log
+
+
+class EMAVQ(VQModel):
+ def __init__(self,
+ ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ remap=None,
+ sane_index_shape=False, # tell vector quantizer to return indices as bhw
+ ):
+ super().__init__(ddconfig,
+ lossconfig,
+ n_embed,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=ignore_keys,
+ image_key=image_key,
+ colorize_nlabels=colorize_nlabels,
+ monitor=monitor,
+ )
+ self.quantize = EMAVectorQuantizer(n_embed=n_embed,
+ embedding_dim=embed_dim,
+ beta=0.25,
+ remap=remap)
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ #Remove self.quantize from parameter list since it is updated via EMA
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
+ list(self.decoder.parameters())+
+ list(self.quant_conv.parameters())+
+ list(self.post_quant_conv.parameters()),
+ lr=lr, betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr, betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
\ No newline at end of file
diff --git a/repositories/taming/modules/diffusionmodules/model.py b/repositories/taming/modules/diffusionmodules/model.py
new file mode 100644
index 000000000..d3a5db6aa
--- /dev/null
+++ b/repositories/taming/modules/diffusionmodules/model.py
@@ -0,0 +1,776 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0,1,0,0))
+ return emb
+
+
+def nonlinearity(x):
+ # swish
+ return x*torch.sigmoid(x)
+
+
+def Normalize(in_channels):
+ return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0,1,0,1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class ResnetBlock(nn.Module):
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
+ dropout, temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels,
+ out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x+h
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b,c,h,w = q.shape
+ q = q.reshape(b,c,h*w)
+ q = q.permute(0,2,1) # b,hw,c
+ k = k.reshape(b,c,h*w) # b,c,hw
+ w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b,c,h*w)
+ w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b,c,h,w)
+
+ h_ = self.proj_out(h_)
+
+ return x+h_
+
+
+class Model(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, use_timestep=True):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, t=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Encoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, double_z=True, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2*z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x):
+ #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
+
+ # timestep embedding
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class Decoder(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
+ resolution, z_channels, give_pre_end=False, **ignorekwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1,)+tuple(ch_mult)
+ block_in = ch*ch_mult[self.num_resolutions-1]
+ curr_res = resolution // 2**(self.num_resolutions-1)
+ self.z_shape = (1,z_channels,curr_res,curr_res)
+ print("Working with z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class VUNet(nn.Module):
+ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
+ attn_resolutions, dropout=0.0, resamp_with_conv=True,
+ in_channels, c_channels,
+ resolution, z_channels, use_timestep=False, **ignore_kwargs):
+ super().__init__()
+ self.ch = ch
+ self.temb_ch = self.ch*4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch,
+ self.temb_ch),
+ torch.nn.Linear(self.temb_ch,
+ self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(c_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1,)+tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch*in_ch_mult[i_level]
+ block_out = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions-1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ self.z_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = AttnBlock(block_in)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch*ch_mult[i_level]
+ skip_in = ch*ch_mult[i_level]
+ for i_block in range(self.num_res_blocks+1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch*in_ch_mult[i_level]
+ block.append(ResnetBlock(in_channels=block_in+skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(AttnBlock(block_in))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+
+ def forward(self, x, z):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions-1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ z = self.z_in(z)
+ h = torch.cat((h,z),dim=1)
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h = self.up[i_level].block[i_block](
+ torch.cat([h, hs.pop()], dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0, dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0, dropout=0.0),
+ nn.Conv2d(2*in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1,2,3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+ def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
+ ch_mult=(2,2), dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
diff --git a/repositories/taming/modules/discriminator/model.py b/repositories/taming/modules/discriminator/model.py
new file mode 100644
index 000000000..2aaa3110d
--- /dev/null
+++ b/repositories/taming/modules/discriminator/model.py
@@ -0,0 +1,67 @@
+import functools
+import torch.nn as nn
+
+
+from taming.modules.util import ActNorm
+
+
+def weights_init(m):
+ classname = m.__class__.__name__
+ if classname.find('Conv') != -1:
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
+ elif classname.find('BatchNorm') != -1:
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
+ nn.init.constant_(m.bias.data, 0)
+
+
+class NLayerDiscriminator(nn.Module):
+ """Defines a PatchGAN discriminator as in Pix2Pix
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
+ """
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
+ """Construct a PatchGAN discriminator
+ Parameters:
+ input_nc (int) -- the number of channels in input images
+ ndf (int) -- the number of filters in the last conv layer
+ n_layers (int) -- the number of conv layers in the discriminator
+ norm_layer -- normalization layer
+ """
+ super(NLayerDiscriminator, self).__init__()
+ if not use_actnorm:
+ norm_layer = nn.BatchNorm2d
+ else:
+ norm_layer = ActNorm
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
+ use_bias = norm_layer.func != nn.BatchNorm2d
+ else:
+ use_bias = norm_layer != nn.BatchNorm2d
+
+ kw = 4
+ padw = 1
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers): # gradually increase the number of filters
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True)
+ ]
+
+ sequence += [
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
+ self.main = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ """Standard forward."""
+ return self.main(input)
diff --git a/repositories/taming/modules/losses/__init__.py b/repositories/taming/modules/losses/__init__.py
new file mode 100644
index 000000000..d09caf9eb
--- /dev/null
+++ b/repositories/taming/modules/losses/__init__.py
@@ -0,0 +1,2 @@
+from taming.modules.losses.vqperceptual import DummyLoss
+
diff --git a/repositories/taming/modules/losses/lpips.py b/repositories/taming/modules/losses/lpips.py
new file mode 100644
index 000000000..a72804476
--- /dev/null
+++ b/repositories/taming/modules/losses/lpips.py
@@ -0,0 +1,123 @@
+"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
+
+import torch
+import torch.nn as nn
+from torchvision import models
+from collections import namedtuple
+
+from taming.util import get_ckpt_path
+
+
+class LPIPS(nn.Module):
+ # Learned perceptual metric
+ def __init__(self, use_dropout=True):
+ super().__init__()
+ self.scaling_layer = ScalingLayer()
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
+ self.net = vgg16(pretrained=True, requires_grad=False)
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
+ self.load_from_pretrained()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def load_from_pretrained(self, name="vgg_lpips"):
+ ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
+
+ @classmethod
+ def from_pretrained(cls, name="vgg_lpips"):
+ if name != "vgg_lpips":
+ raise NotImplementedError
+ model = cls()
+ ckpt = get_ckpt_path(name)
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
+ return model
+
+ def forward(self, input, target):
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
+ feats0, feats1, diffs = {}, {}, {}
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
+ for kk in range(len(self.chns)):
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
+
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
+ val = res[0]
+ for l in range(1, len(self.chns)):
+ val += res[l]
+ return val
+
+
+class ScalingLayer(nn.Module):
+ def __init__(self):
+ super(ScalingLayer, self).__init__()
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
+
+ def forward(self, inp):
+ return (inp - self.shift) / self.scale
+
+
+class NetLinLayer(nn.Module):
+ """ A single linear layer which does a 1x1 conv """
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
+ super(NetLinLayer, self).__init__()
+ layers = [nn.Dropout(), ] if (use_dropout) else []
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
+ self.model = nn.Sequential(*layers)
+
+
+class vgg16(torch.nn.Module):
+ def __init__(self, requires_grad=False, pretrained=True):
+ super(vgg16, self).__init__()
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
+ self.slice1 = torch.nn.Sequential()
+ self.slice2 = torch.nn.Sequential()
+ self.slice3 = torch.nn.Sequential()
+ self.slice4 = torch.nn.Sequential()
+ self.slice5 = torch.nn.Sequential()
+ self.N_slices = 5
+ for x in range(4):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(4, 9):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(9, 16):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(16, 23):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(23, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h = self.slice1(X)
+ h_relu1_2 = h
+ h = self.slice2(h)
+ h_relu2_2 = h
+ h = self.slice3(h)
+ h_relu3_3 = h
+ h = self.slice4(h)
+ h_relu4_3 = h
+ h = self.slice5(h)
+ h_relu5_3 = h
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
+ return out
+
+
+def normalize_tensor(x,eps=1e-10):
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
+ return x/(norm_factor+eps)
+
+
+def spatial_average(x, keepdim=True):
+ return x.mean([2,3],keepdim=keepdim)
+
diff --git a/repositories/taming/modules/losses/segmentation.py b/repositories/taming/modules/losses/segmentation.py
new file mode 100644
index 000000000..4ba77deb5
--- /dev/null
+++ b/repositories/taming/modules/losses/segmentation.py
@@ -0,0 +1,22 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class BCELoss(nn.Module):
+ def forward(self, prediction, target):
+ loss = F.binary_cross_entropy_with_logits(prediction,target)
+ return loss, {}
+
+
+class BCELossWithQuant(nn.Module):
+ def __init__(self, codebook_weight=1.):
+ super().__init__()
+ self.codebook_weight = codebook_weight
+
+ def forward(self, qloss, target, prediction, split):
+ bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
+ loss = bce_loss + self.codebook_weight*qloss
+ return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/bce_loss".format(split): bce_loss.detach().mean(),
+ "{}/quant_loss".format(split): qloss.detach().mean()
+ }
diff --git a/repositories/taming/modules/losses/vqperceptual.py b/repositories/taming/modules/losses/vqperceptual.py
new file mode 100644
index 000000000..c2febd445
--- /dev/null
+++ b/repositories/taming/modules/losses/vqperceptual.py
@@ -0,0 +1,136 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from taming.modules.losses.lpips import LPIPS
+from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
+
+
+class DummyLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+
+def adopt_weight(weight, global_step, threshold=0, value=0.):
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
+def hinge_d_loss(logits_real, logits_fake):
+ loss_real = torch.mean(F.relu(1. - logits_real))
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
+ d_loss = 0.5 * (loss_real + loss_fake)
+ return d_loss
+
+
+def vanilla_d_loss(logits_real, logits_fake):
+ d_loss = 0.5 * (
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
+ return d_loss
+
+
+class VQLPIPSWithDiscriminator(nn.Module):
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
+ disc_ndf=64, disc_loss="hinge"):
+ super().__init__()
+ assert disc_loss in ["hinge", "vanilla"]
+ self.codebook_weight = codebook_weight
+ self.pixel_weight = pixelloss_weight
+ self.perceptual_loss = LPIPS().eval()
+ self.perceptual_weight = perceptual_weight
+
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
+ n_layers=disc_num_layers,
+ use_actnorm=use_actnorm,
+ ndf=disc_ndf
+ ).apply(weights_init)
+ self.discriminator_iter_start = disc_start
+ if disc_loss == "hinge":
+ self.disc_loss = hinge_d_loss
+ elif disc_loss == "vanilla":
+ self.disc_loss = vanilla_d_loss
+ else:
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
+ self.disc_factor = disc_factor
+ self.discriminator_weight = disc_weight
+ self.disc_conditional = disc_conditional
+
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
+ if last_layer is not None:
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
+ else:
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
+
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
+ d_weight = d_weight * self.discriminator_weight
+ return d_weight
+
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
+ global_step, last_layer=None, cond=None, split="train"):
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
+ if self.perceptual_weight > 0:
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
+ else:
+ p_loss = torch.tensor([0.0])
+
+ nll_loss = rec_loss
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
+ nll_loss = torch.mean(nll_loss)
+
+ # now the GAN part
+ if optimizer_idx == 0:
+ # generator update
+ if cond is None:
+ assert not self.disc_conditional
+ logits_fake = self.discriminator(reconstructions.contiguous())
+ else:
+ assert self.disc_conditional
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
+ g_loss = -torch.mean(logits_fake)
+
+ try:
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
+ except RuntimeError:
+ assert not self.training
+ d_weight = torch.tensor(0.0)
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
+
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
+ "{}/p_loss".format(split): p_loss.detach().mean(),
+ "{}/d_weight".format(split): d_weight.detach(),
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
+ "{}/g_loss".format(split): g_loss.detach().mean(),
+ }
+ return loss, log
+
+ if optimizer_idx == 1:
+ # second pass for discriminator update
+ if cond is None:
+ logits_real = self.discriminator(inputs.contiguous().detach())
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
+ else:
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
+
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
+
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
+ "{}/logits_real".format(split): logits_real.detach().mean(),
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
+ }
+ return d_loss, log
diff --git a/repositories/taming/modules/misc/coord.py b/repositories/taming/modules/misc/coord.py
new file mode 100644
index 000000000..ee69b0c89
--- /dev/null
+++ b/repositories/taming/modules/misc/coord.py
@@ -0,0 +1,31 @@
+import torch
+
+class CoordStage(object):
+ def __init__(self, n_embed, down_factor):
+ self.n_embed = n_embed
+ self.down_factor = down_factor
+
+ def eval(self):
+ return self
+
+ def encode(self, c):
+ """fake vqmodel interface"""
+ assert 0.0 <= c.min() and c.max() <= 1.0
+ b,ch,h,w = c.shape
+ assert ch == 1
+
+ c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
+ mode="area")
+ c = c.clamp(0.0, 1.0)
+ c = self.n_embed*c
+ c_quant = c.round()
+ c_ind = c_quant.to(dtype=torch.long)
+
+ info = None, None, c_ind
+ return c_quant, None, info
+
+ def decode(self, c):
+ c = c/self.n_embed
+ c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
+ mode="nearest")
+ return c
diff --git a/repositories/taming/modules/transformer/mingpt.py b/repositories/taming/modules/transformer/mingpt.py
new file mode 100644
index 000000000..d14b7b681
--- /dev/null
+++ b/repositories/taming/modules/transformer/mingpt.py
@@ -0,0 +1,415 @@
+"""
+taken from: https://github.com/karpathy/minGPT/
+GPT model:
+- the initial stem consists of a combination of token encoding and a positional encoding
+- the meat of it is a uniform sequence of Transformer blocks
+ - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
+ - all blocks feed into a central residual pathway similar to resnets
+- the final decoder is a linear projection into a vanilla Softmax classifier
+"""
+
+import math
+import logging
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from transformers import top_k_top_p_filtering
+
+logger = logging.getLogger(__name__)
+
+
+class GPTConfig:
+ """ base GPT config, params common to all GPT versions """
+ embd_pdrop = 0.1
+ resid_pdrop = 0.1
+ attn_pdrop = 0.1
+
+ def __init__(self, vocab_size, block_size, **kwargs):
+ self.vocab_size = vocab_size
+ self.block_size = block_size
+ for k,v in kwargs.items():
+ setattr(self, k, v)
+
+
+class GPT1Config(GPTConfig):
+ """ GPT-1 like network roughly 125M params """
+ n_layer = 12
+ n_head = 12
+ n_embd = 768
+
+
+class CausalSelfAttention(nn.Module):
+ """
+ A vanilla multi-head masked self-attention layer with a projection at the end.
+ It is possible to use torch.nn.MultiheadAttention here but I am including an
+ explicit implementation here to show that there is nothing too scary here.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ assert config.n_embd % config.n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(config.n_embd, config.n_embd)
+ self.query = nn.Linear(config.n_embd, config.n_embd)
+ self.value = nn.Linear(config.n_embd, config.n_embd)
+ # regularization
+ self.attn_drop = nn.Dropout(config.attn_pdrop)
+ self.resid_drop = nn.Dropout(config.resid_pdrop)
+ # output projection
+ self.proj = nn.Linear(config.n_embd, config.n_embd)
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ mask = torch.tril(torch.ones(config.block_size,
+ config.block_size))
+ if hasattr(config, "n_unmasked"):
+ mask[:config.n_unmasked, :config.n_unmasked] = 1
+ self.register_buffer("mask", mask.view(1, 1, config.block_size, config.block_size))
+ self.n_head = config.n_head
+
+ def forward(self, x, layer_past=None):
+ B, T, C = x.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+ v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
+
+ present = torch.stack((k, v))
+ if layer_past is not None:
+ past_key, past_value = layer_past
+ k = torch.cat((past_key, k), dim=-2)
+ v = torch.cat((past_value, v), dim=-2)
+
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
+ if layer_past is None:
+ att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
+
+ att = F.softmax(att, dim=-1)
+ att = self.attn_drop(att)
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
+
+ # output projection
+ y = self.resid_drop(self.proj(y))
+ return y, present # TODO: check that this does not break anything
+
+
+class Block(nn.Module):
+ """ an unassuming Transformer block """
+ def __init__(self, config):
+ super().__init__()
+ self.ln1 = nn.LayerNorm(config.n_embd)
+ self.ln2 = nn.LayerNorm(config.n_embd)
+ self.attn = CausalSelfAttention(config)
+ self.mlp = nn.Sequential(
+ nn.Linear(config.n_embd, 4 * config.n_embd),
+ nn.GELU(), # nice
+ nn.Linear(4 * config.n_embd, config.n_embd),
+ nn.Dropout(config.resid_pdrop),
+ )
+
+ def forward(self, x, layer_past=None, return_present=False):
+ # TODO: check that training still works
+ if return_present: assert not self.training
+ # layer past: tuple of length two with B, nh, T, hs
+ attn, present = self.attn(self.ln1(x), layer_past=layer_past)
+
+ x = x + attn
+ x = x + self.mlp(self.ln2(x))
+ if layer_past is not None or return_present:
+ return x, present
+ return x
+
+
+class GPT(nn.Module):
+ """ the full GPT language model, with a context size of block_size """
+ def __init__(self, vocab_size, block_size, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.ln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+ def forward_with_past(self, idx, embeddings=None, targets=None, past=None, past_length=None):
+ # inference only
+ assert not self.training
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ if past is not None:
+ assert past_length is not None
+ past = torch.cat(past, dim=-2) # n_layer, 2, b, nh, len_past, dim_head
+ past_shape = list(past.shape)
+ expected_shape = [self.config.n_layer, 2, idx.shape[0], self.config.n_head, past_length, self.config.n_embd//self.config.n_head]
+ assert past_shape == expected_shape, f"{past_shape} =/= {expected_shape}"
+ position_embeddings = self.pos_emb[:, past_length, :] # each position maps to a (learnable) vector
+ else:
+ position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]
+
+ x = self.drop(token_embeddings + position_embeddings)
+ presents = [] # accumulate over layers
+ for i, block in enumerate(self.blocks):
+ x, present = block(x, layer_past=past[i, ...] if past is not None else None, return_present=True)
+ presents.append(present)
+
+ x = self.ln_f(x)
+ logits = self.head(x)
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss, torch.stack(presents) # _, _, n_layer, 2, b, nh, 1, dim_head
+
+
+class DummyGPT(nn.Module):
+ # for debugging
+ def __init__(self, add_value=1):
+ super().__init__()
+ self.add_value = add_value
+
+ def forward(self, idx):
+ return idx + self.add_value, None
+
+
+class CodeGPT(nn.Module):
+ """Takes in semi-embeddings"""
+ def __init__(self, vocab_size, block_size, in_channels, n_layer=12, n_head=8, n_embd=256,
+ embd_pdrop=0., resid_pdrop=0., attn_pdrop=0., n_unmasked=0):
+ super().__init__()
+ config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
+ embd_pdrop=embd_pdrop, resid_pdrop=resid_pdrop, attn_pdrop=attn_pdrop,
+ n_layer=n_layer, n_head=n_head, n_embd=n_embd,
+ n_unmasked=n_unmasked)
+ # input embedding stem
+ self.tok_emb = nn.Linear(in_channels, config.n_embd)
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
+ self.drop = nn.Dropout(config.embd_pdrop)
+ # transformer
+ self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
+ # decoder head
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+ self.block_size = config.block_size
+ self.apply(self._init_weights)
+ self.config = config
+ logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
+
+ def get_block_size(self):
+ return self.block_size
+
+ def _init_weights(self, module):
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def forward(self, idx, embeddings=None, targets=None):
+ # forward the GPT model
+ token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
+
+ if embeddings is not None: # prepend explicit embeddings
+ token_embeddings = torch.cat((embeddings, token_embeddings), dim=1)
+
+ t = token_embeddings.shape[1]
+ assert t <= self.block_size, "Cannot forward, model block size is exhausted."
+ position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
+ x = self.drop(token_embeddings + position_embeddings)
+ x = self.blocks(x)
+ x = self.taming_cinln_f(x)
+ logits = self.head(x)
+
+ # if we are given some desired targets also calculate the loss
+ loss = None
+ if targets is not None:
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
+
+ return logits, loss
+
+
+
+#### sampling utils
+
+def top_k_logits(logits, k):
+ v, ix = torch.topk(logits, k)
+ out = logits.clone()
+ out[out < v[:, [-1]]] = -float('Inf')
+ return out
+
+@torch.no_grad()
+def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
+ """
+ take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
+ the sequence, feeding the predictions back into the model each time. Clearly the sampling
+ has quadratic complexity unlike an RNN that is only linear, and has a finite context window
+ of block_size, unlike an RNN that has an infinite context window.
+ """
+ block_size = model.get_block_size()
+ model.eval()
+ for k in range(steps):
+ x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
+ logits, _ = model(x_cond)
+ # pluck the logits at the final step and scale by temperature
+ logits = logits[:, -1, :] / temperature
+ # optionally crop probabilities to only the top k options
+ if top_k is not None:
+ logits = top_k_logits(logits, top_k)
+ # apply softmax to convert to probabilities
+ probs = F.softmax(logits, dim=-1)
+ # sample from the distribution or take the most likely
+ if sample:
+ ix = torch.multinomial(probs, num_samples=1)
+ else:
+ _, ix = torch.topk(probs, k=1, dim=-1)
+ # append to the sequence and continue
+ x = torch.cat((x, ix), dim=1)
+
+ return x
+
+
+@torch.no_grad()
+def sample_with_past(x, model, steps, temperature=1., sample_logits=True,
+ top_k=None, top_p=None, callback=None):
+ # x is conditioning
+ sample = x
+ cond_len = x.shape[1]
+ past = None
+ for n in range(steps):
+ if callback is not None:
+ callback(n)
+ logits, _, present = model.forward_with_past(x, past=past, past_length=(n+cond_len-1))
+ if past is None:
+ past = [present]
+ else:
+ past.append(present)
+ logits = logits[:, -1, :] / temperature
+ if top_k is not None:
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
+
+ probs = F.softmax(logits, dim=-1)
+ if not sample_logits:
+ _, x = torch.topk(probs, k=1, dim=-1)
+ else:
+ x = torch.multinomial(probs, num_samples=1)
+ # append to the sequence and continue
+ sample = torch.cat((sample, x), dim=1)
+ del past
+ sample = sample[:, cond_len:] # cut conditioning off
+ return sample
+
+
+#### clustering utils
+
+class KMeans(nn.Module):
+ def __init__(self, ncluster=512, nc=3, niter=10):
+ super().__init__()
+ self.ncluster = ncluster
+ self.nc = nc
+ self.niter = niter
+ self.shape = (3,32,32)
+ self.register_buffer("C", torch.zeros(self.ncluster,nc))
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def is_initialized(self):
+ return self.initialized.item() == 1
+
+ @torch.no_grad()
+ def initialize(self, x):
+ N, D = x.shape
+ assert D == self.nc, D
+ c = x[torch.randperm(N)[:self.ncluster]] # init clusters at random
+ for i in range(self.niter):
+ # assign all pixels to the closest codebook element
+ a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
+ # move each codebook element to be the mean of the pixels that assigned to it
+ c = torch.stack([x[a==k].mean(0) for k in range(self.ncluster)])
+ # re-assign any poorly positioned codebook elements
+ nanix = torch.any(torch.isnan(c), dim=1)
+ ndead = nanix.sum().item()
+ print('done step %d/%d, re-initialized %d dead clusters' % (i+1, self.niter, ndead))
+ c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
+
+ self.C.copy_(c)
+ self.initialized.fill_(1)
+
+
+ def forward(self, x, reverse=False, shape=None):
+ if not reverse:
+ # flatten
+ bs,c,h,w = x.shape
+ assert c == self.nc
+ x = x.reshape(bs,c,h*w,1)
+ C = self.C.permute(1,0)
+ C = C.reshape(1,c,1,self.ncluster)
+ a = ((x-C)**2).sum(1).argmin(-1) # bs, h*w indices
+ return a
+ else:
+ # flatten
+ bs, HW = x.shape
+ """
+ c = self.C.reshape( 1, self.nc, 1, self.ncluster)
+ c = c[bs*[0],:,:,:]
+ c = c[:,:,HW*[0],:]
+ x = x.reshape(bs, 1, HW, 1)
+ x = x[:,3*[0],:,:]
+ x = torch.gather(c, dim=3, index=x)
+ """
+ x = self.C[x]
+ x = x.permute(0,2,1)
+ shape = shape if shape is not None else self.shape
+ x = x.reshape(bs, *shape)
+
+ return x
diff --git a/repositories/taming/modules/transformer/permuter.py b/repositories/taming/modules/transformer/permuter.py
new file mode 100644
index 000000000..0d43bb135
--- /dev/null
+++ b/repositories/taming/modules/transformer/permuter.py
@@ -0,0 +1,248 @@
+import torch
+import torch.nn as nn
+import numpy as np
+
+
+class AbstractPermuter(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ def forward(self, x, reverse=False):
+ raise NotImplementedError
+
+
+class Identity(AbstractPermuter):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, reverse=False):
+ return x
+
+
+class Subsample(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ C = 1
+ indices = np.arange(H*W).reshape(C,H,W)
+ while min(H, W) > 1:
+ indices = indices.reshape(C,H//2,2,W//2,2)
+ indices = indices.transpose(0,2,4,1,3)
+ indices = indices.reshape(C*4,H//2, W//2)
+ H = H//2
+ W = W//2
+ C = C*4
+ assert H == W == 1
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx',
+ nn.Parameter(idx, requires_grad=False))
+ self.register_buffer('backward_shuffle_idx',
+ nn.Parameter(torch.argsort(idx), requires_grad=False))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+def mortonify(i, j):
+ """(i,j) index to linear morton code"""
+ i = np.uint64(i)
+ j = np.uint64(j)
+
+ z = np.uint(0)
+
+ for pos in range(32):
+ z = (z |
+ ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) |
+ ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1))
+ )
+ return z
+
+
+class ZCurve(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)]
+ idx = np.argsort(reverseidx)
+ idx = torch.tensor(idx)
+ reverseidx = torch.tensor(reverseidx)
+ self.register_buffer('forward_shuffle_idx',
+ idx)
+ self.register_buffer('backward_shuffle_idx',
+ reverseidx)
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralOut(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class SpiralIn(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ assert H == W
+ size = W
+ indices = np.arange(size*size).reshape(size,size)
+
+ i0 = size//2
+ j0 = size//2-1
+
+ i = i0
+ j = j0
+
+ idx = [indices[i0, j0]]
+ step_mult = 0
+ for c in range(1, size//2+1):
+ step_mult += 1
+ # steps left
+ for k in range(step_mult):
+ i = i - 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step down
+ for k in range(step_mult):
+ i = i
+ j = j + 1
+ idx.append(indices[i, j])
+
+ step_mult += 1
+ if c < size//2:
+ # step right
+ for k in range(step_mult):
+ i = i + 1
+ j = j
+ idx.append(indices[i, j])
+
+ # step up
+ for k in range(step_mult):
+ i = i
+ j = j - 1
+ idx.append(indices[i, j])
+ else:
+ # end reached
+ for k in range(step_mult-1):
+ i = i + 1
+ idx.append(indices[i, j])
+
+ assert len(idx) == size*size
+ idx = idx[::-1]
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class Random(nn.Module):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.random.RandomState(1).permutation(H*W)
+ idx = torch.tensor(indices.ravel())
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+class AlternateParsing(AbstractPermuter):
+ def __init__(self, H, W):
+ super().__init__()
+ indices = np.arange(W*H).reshape(H,W)
+ for i in range(1, H, 2):
+ indices[i, :] = indices[i, ::-1]
+ idx = indices.flatten()
+ assert len(idx) == H*W
+ idx = torch.tensor(idx)
+ self.register_buffer('forward_shuffle_idx', idx)
+ self.register_buffer('backward_shuffle_idx', torch.argsort(idx))
+
+ def forward(self, x, reverse=False):
+ if not reverse:
+ return x[:, self.forward_shuffle_idx]
+ else:
+ return x[:, self.backward_shuffle_idx]
+
+
+if __name__ == "__main__":
+ p0 = AlternateParsing(16, 16)
+ print(p0.forward_shuffle_idx)
+ print(p0.backward_shuffle_idx)
+
+ x = torch.randint(0, 768, size=(11, 256))
+ y = p0(x)
+ xre = p0(y, reverse=True)
+ assert torch.equal(x, xre)
+
+ p1 = SpiralOut(2, 2)
+ print(p1.forward_shuffle_idx)
+ print(p1.backward_shuffle_idx)
diff --git a/repositories/taming/modules/util.py b/repositories/taming/modules/util.py
new file mode 100644
index 000000000..9ee16385d
--- /dev/null
+++ b/repositories/taming/modules/util.py
@@ -0,0 +1,130 @@
+import torch
+import torch.nn as nn
+
+
+def count_params(model):
+ total_params = sum(p.numel() for p in model.parameters())
+ return total_params
+
+
+class ActNorm(nn.Module):
+ def __init__(self, num_features, logdet=False, affine=True,
+ allow_reverse_init=False):
+ assert affine
+ super().__init__()
+ self.logdet = logdet
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
+ self.allow_reverse_init = allow_reverse_init
+
+ self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
+
+ def initialize(self, input):
+ with torch.no_grad():
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
+ mean = (
+ flatten.mean(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+ std = (
+ flatten.std(1)
+ .unsqueeze(1)
+ .unsqueeze(2)
+ .unsqueeze(3)
+ .permute(1, 0, 2, 3)
+ )
+
+ self.loc.data.copy_(-mean)
+ self.scale.data.copy_(1 / (std + 1e-6))
+
+ def forward(self, input, reverse=False):
+ if reverse:
+ return self.reverse(input)
+ if len(input.shape) == 2:
+ input = input[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ _, _, height, width = input.shape
+
+ if self.training and self.initialized.item() == 0:
+ self.initialize(input)
+ self.initialized.fill_(1)
+
+ h = self.scale * (input + self.loc)
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+
+ if self.logdet:
+ log_abs = torch.log(torch.abs(self.scale))
+ logdet = height*width*torch.sum(log_abs)
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
+ return h, logdet
+
+ return h
+
+ def reverse(self, output):
+ if self.training and self.initialized.item() == 0:
+ if not self.allow_reverse_init:
+ raise RuntimeError(
+ "Initializing ActNorm in reverse direction is "
+ "disabled by default. Use allow_reverse_init=True to enable."
+ )
+ else:
+ self.initialize(output)
+ self.initialized.fill_(1)
+
+ if len(output.shape) == 2:
+ output = output[:,:,None,None]
+ squeeze = True
+ else:
+ squeeze = False
+
+ h = output / self.scale - self.loc
+
+ if squeeze:
+ h = h.squeeze(-1).squeeze(-1)
+ return h
+
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class Labelator(AbstractEncoder):
+ """Net2Net Interface for Class-Conditional Model"""
+ def __init__(self, n_classes, quantize_interface=True):
+ super().__init__()
+ self.n_classes = n_classes
+ self.quantize_interface = quantize_interface
+
+ def encode(self, c):
+ c = c[:,None]
+ if self.quantize_interface:
+ return c, None, [None, None, c.long()]
+ return c
+
+
+class SOSProvider(AbstractEncoder):
+ # for unconditional training
+ def __init__(self, sos_token, quantize_interface=True):
+ super().__init__()
+ self.sos_token = sos_token
+ self.quantize_interface = quantize_interface
+
+ def encode(self, x):
+ # get batch size from data and replicate sos_token
+ c = torch.ones(x.shape[0], 1)*self.sos_token
+ c = c.long().to(x.device)
+ if self.quantize_interface:
+ return c, None, [None, None, c]
+ return c
diff --git a/repositories/taming/modules/vqvae/quantize.py b/repositories/taming/modules/vqvae/quantize.py
new file mode 100644
index 000000000..d75544e41
--- /dev/null
+++ b/repositories/taming/modules/vqvae/quantize.py
@@ -0,0 +1,445 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from torch import einsum
+from einops import rearrange
+
+
+class VectorQuantizer(nn.Module):
+ """
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
+ ____________________________________________
+ Discretization bottleneck part of the VQ-VAE.
+ Inputs:
+ - n_e : number of embeddings
+ - e_dim : dimension of embedding
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
+ _____________________________________________
+ """
+
+ # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for
+ # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be
+ # used wherever VectorQuantizer has been used before and is additionally
+ # more efficient.
+ def __init__(self, n_e, e_dim, beta):
+ super(VectorQuantizer, self).__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ def forward(self, z):
+ """
+ Inputs the output of the encoder network z and maps it to a discrete
+ one-hot vector that is the index of the closest embedding vector e_j
+ z (continuous) -> z_q (discrete)
+ z.shape = (batch, channel, height, width)
+ quantization pipeline:
+ 1. get encoder input (B,C,H,W)
+ 2. flatten input to (B*H*W,C)
+ """
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = z.permute(0, 2, 3, 1).contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.matmul(z_flattened, self.embedding.weight.t())
+
+ ## could possible replace this here
+ # #\start...
+ # find closest encodings
+ min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
+
+ min_encodings = torch.zeros(
+ min_encoding_indices.shape[0], self.n_e).to(z)
+ min_encodings.scatter_(1, min_encoding_indices, 1)
+
+ # dtype min encodings: torch.float32
+ # min_encodings shape: torch.Size([2048, 512])
+ # min_encoding_indices.shape: torch.Size([2048, 1])
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
+ #.........\end
+
+ # with:
+ # .........\start
+ #min_encoding_indices = torch.argmin(d, dim=1)
+ #z_q = self.embedding(min_encoding_indices)
+ # ......\end......... (TODO)
+
+ # compute loss for embedding
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # perplexity
+ e_mean = torch.mean(min_encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ # TODO: check for more easy handling with nn.Embedding
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
+ min_encodings.scatter_(1, indices[:,None], 1)
+
+ # get quantized latent vectors
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+
+class GumbelQuantize(nn.Module):
+ """
+ credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!)
+ Gumbel Softmax trick quantizer
+ Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016
+ https://arxiv.org/abs/1611.01144
+ """
+ def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True,
+ kl_weight=5e-4, temp_init=1.0, use_vqinterface=True,
+ remap=None, unknown_index="random"):
+ super().__init__()
+
+ self.embedding_dim = embedding_dim
+ self.n_embed = n_embed
+
+ self.straight_through = straight_through
+ self.temperature = temp_init
+ self.kl_weight = kl_weight
+
+ self.proj = nn.Conv2d(num_hiddens, n_embed, 1)
+ self.embed = nn.Embedding(n_embed, embedding_dim)
+
+ self.use_vqinterface = use_vqinterface
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, return_logits=False):
+ # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work
+ hard = self.straight_through if self.training else True
+ temp = self.temperature if temp is None else temp
+
+ logits = self.proj(z)
+ if self.remap is not None:
+ # continue only with used logits
+ full_zeros = torch.zeros_like(logits)
+ logits = logits[:,self.used,...]
+
+ soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard)
+ if self.remap is not None:
+ # go back to all entries but unused set to zero
+ full_zeros[:,self.used,...] = soft_one_hot
+ soft_one_hot = full_zeros
+ z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight)
+
+ # + kl divergence to the prior loss
+ qy = F.softmax(logits, dim=1)
+ diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean()
+
+ ind = soft_one_hot.argmax(dim=1)
+ if self.remap is not None:
+ ind = self.remap_to_used(ind)
+ if self.use_vqinterface:
+ if return_logits:
+ return z_q, diff, (None, None, ind), logits
+ return z_q, diff, (None, None, ind)
+ return z_q, diff, ind
+
+ def get_codebook_entry(self, indices, shape):
+ b, h, w, c = shape
+ assert b*h*w == indices.shape[0]
+ indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w)
+ if self.remap is not None:
+ indices = self.unmap_to_all(indices)
+ one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float()
+ z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight)
+ return z_q
+
+
+class VectorQuantizer2(nn.Module):
+ """
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
+ """
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
+ # backwards compatibility we use the buggy version by default, but you can
+ # specify legacy=False to fix it.
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
+ sane_index_shape=False, legacy=True):
+ super().__init__()
+ self.n_e = n_e
+ self.e_dim = e_dim
+ self.beta = beta
+ self.legacy = legacy
+
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_e
+
+ self.sane_index_shape = sane_index_shape
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
+ assert return_logits==False, "Only for interface compatible with Gumbel"
+ # reshape z -> (batch, height, width, channel) and flatten
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
+ z_flattened = z.view(-1, self.e_dim)
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
+
+ min_encoding_indices = torch.argmin(d, dim=1)
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
+ perplexity = None
+ min_encodings = None
+
+ # compute loss for embedding
+ if not self.legacy:
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
+ torch.mean((z_q - z.detach()) ** 2)
+ else:
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
+ torch.mean((z_q - z.detach()) ** 2)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
+
+ if self.remap is not None:
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
+
+ if self.sane_index_shape:
+ min_encoding_indices = min_encoding_indices.reshape(
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
+
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
+
+ def get_codebook_entry(self, indices, shape):
+ # shape specifying (batch, height, width, channel)
+ if self.remap is not None:
+ indices = indices.reshape(shape[0],-1) # add batch axis
+ indices = self.unmap_to_all(indices)
+ indices = indices.reshape(-1) # flatten again
+
+ # get quantized latent vectors
+ z_q = self.embedding(indices)
+
+ if shape is not None:
+ z_q = z_q.view(shape)
+ # reshape back to match original input shape
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
+
+ return z_q
+
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5):
+ super().__init__()
+ self.decay = decay
+ self.eps = eps
+ weight = torch.randn(num_tokens, codebook_dim)
+ self.weight = nn.Parameter(weight, requires_grad = False)
+ self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False)
+ self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False)
+ self.update = True
+
+ def forward(self, embed_id):
+ return F.embedding(embed_id, self.weight)
+
+ def cluster_size_ema_update(self, new_cluster_size):
+ self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay)
+
+ def embed_avg_ema_update(self, new_embed_avg):
+ self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay)
+
+ def weight_update(self, num_tokens):
+ n = self.cluster_size.sum()
+ smoothed_cluster_size = (
+ (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n
+ )
+ #normalize embedding average with smoothed cluster size
+ embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1)
+ self.weight.data.copy_(embed_normalized)
+
+
+class EMAVectorQuantizer(nn.Module):
+ def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5,
+ remap=None, unknown_index="random"):
+ super().__init__()
+ self.codebook_dim = codebook_dim
+ self.num_tokens = num_tokens
+ self.beta = beta
+ self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps)
+
+ self.remap = remap
+ if self.remap is not None:
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
+ self.re_embed = self.used.shape[0]
+ self.unknown_index = unknown_index # "random" or "extra" or integer
+ if self.unknown_index == "extra":
+ self.unknown_index = self.re_embed
+ self.re_embed = self.re_embed+1
+ print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. "
+ f"Using {self.unknown_index} for unknown indices.")
+ else:
+ self.re_embed = n_embed
+
+ def remap_to_used(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ match = (inds[:,:,None]==used[None,None,...]).long()
+ new = match.argmax(-1)
+ unknown = match.sum(2)<1
+ if self.unknown_index == "random":
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
+ else:
+ new[unknown] = self.unknown_index
+ return new.reshape(ishape)
+
+ def unmap_to_all(self, inds):
+ ishape = inds.shape
+ assert len(ishape)>1
+ inds = inds.reshape(ishape[0],-1)
+ used = self.used.to(inds)
+ if self.re_embed > self.used.shape[0]: # extra token
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
+ return back.reshape(ishape)
+
+ def forward(self, z):
+ # reshape z -> (batch, height, width, channel) and flatten
+ #z, 'b c h w -> b h w c'
+ z = rearrange(z, 'b c h w -> b h w c')
+ z_flattened = z.reshape(-1, self.codebook_dim)
+
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
+ d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \
+ self.embedding.weight.pow(2).sum(dim=1) - 2 * \
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n'
+
+
+ encoding_indices = torch.argmin(d, dim=1)
+
+ z_q = self.embedding(encoding_indices).view(z.shape)
+ encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype)
+ avg_probs = torch.mean(encodings, dim=0)
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ if self.training and self.embedding.update:
+ #EMA cluster size
+ encodings_sum = encodings.sum(0)
+ self.embedding.cluster_size_ema_update(encodings_sum)
+ #EMA embedding average
+ embed_sum = encodings.transpose(0,1) @ z_flattened
+ self.embedding.embed_avg_ema_update(embed_sum)
+ #normalize embed_avg and update weight
+ self.embedding.weight_update(self.num_tokens)
+
+ # compute loss for embedding
+ loss = self.beta * F.mse_loss(z_q.detach(), z)
+
+ # preserve gradients
+ z_q = z + (z_q - z).detach()
+
+ # reshape back to match original input shape
+ #z_q, 'b h w c -> b c h w'
+ z_q = rearrange(z_q, 'b h w c -> b c h w')
+ return z_q, loss, (perplexity, encodings, encoding_indices)
diff --git a/repositories/taming/util.py b/repositories/taming/util.py
new file mode 100644
index 000000000..06053e5de
--- /dev/null
+++ b/repositories/taming/util.py
@@ -0,0 +1,157 @@
+import os, hashlib
+import requests
+from tqdm import tqdm
+
+URL_MAP = {
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
+}
+
+CKPT_MAP = {
+ "vgg_lpips": "vgg.pth"
+}
+
+MD5_MAP = {
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
+}
+
+
+def download(url, local_path, chunk_size=1024):
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
+ with requests.get(url, stream=True) as r:
+ total_size = int(r.headers.get("content-length", 0))
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
+ with open(local_path, "wb") as f:
+ for data in r.iter_content(chunk_size=chunk_size):
+ if data:
+ f.write(data)
+ pbar.update(chunk_size)
+
+
+def md5_hash(path):
+ with open(path, "rb") as f:
+ content = f.read()
+ return hashlib.md5(content).hexdigest()
+
+
+def get_ckpt_path(name, root, check=False):
+ assert name in URL_MAP
+ path = os.path.join(root, CKPT_MAP[name])
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
+ download(URL_MAP[name], path)
+ md5 = md5_hash(path)
+ assert md5 == MD5_MAP[name], md5
+ return path
+
+
+class KeyNotFoundError(Exception):
+ def __init__(self, cause, keys=None, visited=None):
+ self.cause = cause
+ self.keys = keys
+ self.visited = visited
+ messages = list()
+ if keys is not None:
+ messages.append("Key not found: {}".format(keys))
+ if visited is not None:
+ messages.append("Visited: {}".format(visited))
+ messages.append("Cause:\n{}".format(cause))
+ message = "\n".join(messages)
+ super().__init__(message)
+
+
+def retrieve(
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
+):
+ """Given a nested list or dict return the desired value at key expanding
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
+ is done in-place.
+
+ Parameters
+ ----------
+ list_or_dict : list or dict
+ Possibly nested list or dictionary.
+ key : str
+ key/to/value, path like string describing all keys necessary to
+ consider to get to the desired value. List indices can also be
+ passed here.
+ splitval : str
+ String that defines the delimiter between keys of the
+ different depth levels in `key`.
+ default : obj
+ Value returned if :attr:`key` is not found.
+ expand : bool
+ Whether to expand callable nodes on the path or not.
+
+ Returns
+ -------
+ The desired value or if :attr:`default` is not ``None`` and the
+ :attr:`key` is not found returns ``default``.
+
+ Raises
+ ------
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
+ ``None``.
+ """
+
+ keys = key.split(splitval)
+
+ success = True
+ try:
+ visited = []
+ parent = None
+ last_key = None
+ for key in keys:
+ if callable(list_or_dict):
+ if not expand:
+ raise KeyNotFoundError(
+ ValueError(
+ "Trying to get past callable node with expand=False."
+ ),
+ keys=keys,
+ visited=visited,
+ )
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+
+ last_key = key
+ parent = list_or_dict
+
+ try:
+ if isinstance(list_or_dict, dict):
+ list_or_dict = list_or_dict[key]
+ else:
+ list_or_dict = list_or_dict[int(key)]
+ except (KeyError, IndexError, ValueError) as e:
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
+
+ visited += [key]
+ # final expansion of retrieved value
+ if expand and callable(list_or_dict):
+ list_or_dict = list_or_dict()
+ parent[last_key] = list_or_dict
+ except KeyNotFoundError as e:
+ if default is None:
+ raise e
+ else:
+ list_or_dict = default
+ success = False
+
+ if not pass_success:
+ return list_or_dict
+ else:
+ return list_or_dict, success
+
+
+if __name__ == "__main__":
+ config = {"keya": "a",
+ "keyb": "b",
+ "keyc":
+ {"cc1": 1,
+ "cc2": 2,
+ }
+ }
+ from omegaconf import OmegaConf
+ config = OmegaConf.create(config)
+ print(config)
+ retrieve(config, "keya")
+