From a83ca6502737165fdfdf8e58287f67618cf5666a Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 22:26:27 -0700 Subject: [PATCH 1/8] cannibalized dango cutouts --- src/pytti/Perceptor/cutouts/dango.py | 98 ++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 src/pytti/Perceptor/cutouts/dango.py diff --git a/src/pytti/Perceptor/cutouts/dango.py b/src/pytti/Perceptor/cutouts/dango.py new file mode 100644 index 0000000..c2501bd --- /dev/null +++ b/src/pytti/Perceptor/cutouts/dango.py @@ -0,0 +1,98 @@ +# via https://github.com/multimodalart/majesty-diffusion/blob/main/latent.ipynb + + +class MakeCutouts(nn.Module): + def __init__( + self, + cut_size, + Overview=4, + WholeCrop=0, + WC_Allowance=10, + WC_Grey_P=0.2, + InnerCrop=0, + IC_Size_Pow=0.5, + IC_Grey_P=0.2, + ): + super().__init__() + self.cut_size = cut_size + self.Overview = Overview + self.WholeCrop = WholeCrop + self.WC_Allowance = WC_Allowance + self.WC_Grey_P = WC_Grey_P + self.InnerCrop = InnerCrop + self.IC_Size_Pow = IC_Size_Pow + self.IC_Grey_P = IC_Grey_P + self.augs = T.Compose( + [ + # T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine( + degrees=0, + translate=(0.05, 0.05), + # scale=(0.9,0.95), + fill=-1, + interpolation=T.InterpolationMode.BILINEAR, + ), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.RandomPerspective(p=1, interpolation = T.InterpolationMode.BILINEAR, fill=-1,distortion_scale=0.2), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05), + ] + ) + + def forward(self, input): + gray = transforms.Grayscale(3) + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + l_size = max(sideX, sideY) + output_shape = [input.shape[0], 3, self.cut_size, self.cut_size] + output_shape_2 = [input.shape[0], 3, self.cut_size + 2, self.cut_size + 2] + pad_input = F.pad( + input, + ( + (sideY - max_size) // 2 + round(max_size * 0.055), + (sideY - max_size) // 2 + round(max_size * 0.055), + (sideX - max_size) // 2 + round(max_size * 0.055), + (sideX - max_size) // 2 + round(max_size * 0.055), + ), + **padargs + ) + cutouts_list = [] + + if self.Overview > 0: + cutouts = [] + cutout = resize(pad_input, out_shape=output_shape, antialiasing=True) + output_shape_all = list(output_shape) + output_shape_all[0] = self.Overview * input.shape[0] + pad_input = pad_input.repeat(input.shape[0], 1, 1, 1) + cutout = resize(pad_input, out_shape=output_shape_all) + if aug: + cutout = self.augs(cutout) + cutouts_list.append(cutout) + + if self.InnerCrop > 0: + cutouts = [] + for i in range(self.InnerCrop): + size = int( + torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size) + + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + if i <= int(self.IC_Grey_P * self.InnerCrop): + cutout = gray(cutout) + cutout = resize(cutout, out_shape=output_shape) + cutouts.append(cutout) + if cutout_debug: + TF.to_pil_image(cutouts[-1].add(1).div(2).clamp(0, 1).squeeze(0)).save( + "content/diff/cutouts/cutout_InnerCrop.jpg", quality=99 + ) + cutouts_tensor = torch.cat(cutouts) + cutouts = [] + cutouts_list.append(cutouts_tensor) + cutouts = torch.cat(cutouts_list) + return cutouts From d29239d62b7db79a5caaeb40c19841a16012055b Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 22:33:22 -0700 Subject: [PATCH 2/8] dango imports --- src/pytti/Perceptor/cutouts/dango.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/pytti/Perceptor/cutouts/dango.py b/src/pytti/Perceptor/cutouts/dango.py index c2501bd..26f7c9d 100644 --- a/src/pytti/Perceptor/cutouts/dango.py +++ b/src/pytti/Perceptor/cutouts/dango.py @@ -1,6 +1,14 @@ # via https://github.com/multimodalart/majesty-diffusion/blob/main/latent.ipynb +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms +from torchvision import transforms as T +from torchvision.transforms import functional as TF + + class MakeCutouts(nn.Module): def __init__( self, From 6fdcc6d498a85bc91fd2f9b060c0d8f9ea82cee6 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 22:51:12 -0700 Subject: [PATCH 3/8] dango passes acceptance tests --- src/pytti/Perceptor/cutouts/__init__.py | 1 + src/pytti/Perceptor/cutouts/dango.py | 24 +++++------------------- 2 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/pytti/Perceptor/cutouts/__init__.py b/src/pytti/Perceptor/cutouts/__init__.py index e69de29..715e8db 100644 --- a/src/pytti/Perceptor/cutouts/__init__.py +++ b/src/pytti/Perceptor/cutouts/__init__.py @@ -0,0 +1 @@ +from .dango import MakeCutouts diff --git a/src/pytti/Perceptor/cutouts/dango.py b/src/pytti/Perceptor/cutouts/dango.py index 26f7c9d..c4f0f44 100644 --- a/src/pytti/Perceptor/cutouts/dango.py +++ b/src/pytti/Perceptor/cutouts/dango.py @@ -8,6 +8,10 @@ from torchvision import transforms as T from torchvision.transforms import functional as TF +from . import augs as cutouts_augs + +padargs = {"mode": "constant", "value": -1} + class MakeCutouts(nn.Module): def __init__( @@ -30,25 +34,7 @@ def __init__( self.InnerCrop = InnerCrop self.IC_Size_Pow = IC_Size_Pow self.IC_Grey_P = IC_Grey_P - self.augs = T.Compose( - [ - # T.RandomHorizontalFlip(p=0.5), - T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), - T.RandomAffine( - degrees=0, - translate=(0.05, 0.05), - # scale=(0.9,0.95), - fill=-1, - interpolation=T.InterpolationMode.BILINEAR, - ), - T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), - # T.RandomPerspective(p=1, interpolation = T.InterpolationMode.BILINEAR, fill=-1,distortion_scale=0.2), - T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), - T.RandomGrayscale(p=0.1), - T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), - T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05), - ] - ) + self.augs = cutouts_augs.dango def forward(self, input): gray = transforms.Grayscale(3) From 28fa33a622071852f1a6576bcbac9b6df32d73e7 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 23:51:46 -0700 Subject: [PATCH 4/8] dango passes preliminary integration tests --- src/pytti/Perceptor/cutouts/__init__.py | 2 ++ src/pytti/Perceptor/cutouts/augs.py | 24 ++++++++++++++++++++++++ src/pytti/Perceptor/cutouts/dango.py | 13 +++++++++++-- 3 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/pytti/Perceptor/cutouts/__init__.py b/src/pytti/Perceptor/cutouts/__init__.py index 715e8db..94b5630 100644 --- a/src/pytti/Perceptor/cutouts/__init__.py +++ b/src/pytti/Perceptor/cutouts/__init__.py @@ -1 +1,3 @@ from .dango import MakeCutouts + +test = MakeCutouts(1) diff --git a/src/pytti/Perceptor/cutouts/augs.py b/src/pytti/Perceptor/cutouts/augs.py index 32873dd..101b675 100644 --- a/src/pytti/Perceptor/cutouts/augs.py +++ b/src/pytti/Perceptor/cutouts/augs.py @@ -1,5 +1,7 @@ import kornia.augmentation as K +import torch from torch import nn +from torchvision import transforms as T def pytti_classic(): @@ -16,3 +18,25 @@ def pytti_classic(): ), nn.Identity(), ) + + +def dango(): + return T.Compose( + [ + # T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine( + degrees=0, + translate=(0.05, 0.05), + # scale=(0.9,0.95), + fill=-1, + interpolation=T.InterpolationMode.BILINEAR, + ), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.RandomPerspective(p=1, interpolation = T.InterpolationMode.BILINEAR, fill=-1,distortion_scale=0.2), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05), + ] + ) diff --git a/src/pytti/Perceptor/cutouts/dango.py b/src/pytti/Perceptor/cutouts/dango.py index c4f0f44..3ef0392 100644 --- a/src/pytti/Perceptor/cutouts/dango.py +++ b/src/pytti/Perceptor/cutouts/dango.py @@ -1,5 +1,8 @@ # via https://github.com/multimodalart/majesty-diffusion/blob/main/latent.ipynb +# !pip install resize-right +# TO DO: add resize-right to setup instructions and notebook +from resize_right import resize import torch from torch import nn @@ -24,6 +27,8 @@ def __init__( InnerCrop=0, IC_Size_Pow=0.5, IC_Grey_P=0.2, + aug=True, + cutout_debug=False, ): super().__init__() self.cut_size = cut_size @@ -35,9 +40,13 @@ def __init__( self.IC_Size_Pow = IC_Size_Pow self.IC_Grey_P = IC_Grey_P self.augs = cutouts_augs.dango + self._aug = aug + self.cutout_debug = cutout_debug def forward(self, input): - gray = transforms.Grayscale(3) + gray = transforms.Grayscale( + 3 + ) # this is possibly a performance improvement? 1 channel instead of 3. but also means we can't use color augs... sideY, sideX = input.shape[2:4] max_size = min(sideX, sideY) min_size = min(sideX, sideY, self.cut_size) @@ -63,7 +72,7 @@ def forward(self, input): output_shape_all[0] = self.Overview * input.shape[0] pad_input = pad_input.repeat(input.shape[0], 1, 1, 1) cutout = resize(pad_input, out_shape=output_shape_all) - if aug: + if self._aug: cutout = self.augs(cutout) cutouts_list.append(cutout) From 3f88e5b0780ab88731e9414d244faac62253aa78 Mon Sep 17 00:00:00 2001 From: David Marx Date: Thu, 26 May 2022 23:54:09 -0700 Subject: [PATCH 5/8] fixed integration bug --- src/pytti/Perceptor/cutouts/dango.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytti/Perceptor/cutouts/dango.py b/src/pytti/Perceptor/cutouts/dango.py index 3ef0392..ae94077 100644 --- a/src/pytti/Perceptor/cutouts/dango.py +++ b/src/pytti/Perceptor/cutouts/dango.py @@ -90,7 +90,7 @@ def forward(self, input): cutout = gray(cutout) cutout = resize(cutout, out_shape=output_shape) cutouts.append(cutout) - if cutout_debug: + if self.cutout_debug: TF.to_pil_image(cutouts[-1].add(1).div(2).clamp(0, 1).squeeze(0)).save( "content/diff/cutouts/cutout_InnerCrop.jpg", quality=99 ) From 4f90270002244d6796ee364887e3fa8794451da9 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 00:49:09 -0700 Subject: [PATCH 6/8] deprecated old named_rearrange --- src/pytti/tensor_tools.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/pytti/tensor_tools.py b/src/pytti/tensor_tools.py index 6019c11..e3764ba 100644 --- a/src/pytti/tensor_tools.py +++ b/src/pytti/tensor_tools.py @@ -3,12 +3,16 @@ from torchvision import transforms from PIL import Image as PIL_Image +from loguru import logger +from einops import rearrange + normalize = transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] ) # Stuff like this could probably be replaced with einops -def named_rearrange(tensor, axes, new_positions) -> torch.tensor: +# ....sheeeeesh... yeah, let's squash this ugliness +def named_rearrange__OLD(tensor, axes, new_positions) -> torch.tensor: """ Permute and unsqueeze tensor to match target dimensional arrangement tensor: (Tensor) input @@ -42,8 +46,26 @@ def named_rearrange(tensor, axes, new_positions) -> torch.tensor: return tensor.permute(*permutation) +def named_rearrange(tensor, axes, new_positions) -> torch.tensor: + return format_input(tensor, source, dest) + + def format_input(tensor, source, dest) -> torch.tensor: - return named_rearrange(tensor, source.output_axes, dest.input_axes) + """ + Takes a tensor and two layers, and returns the tensor in the format that the second layer expects + + :param tensor: the tensor to be formatted + :param source: the source model + :param dest: the destination tensor + :return: A tensor with the same data as the input tensor, but with the axes reordered. + """ + # logger.debug(f"Formatting {tensor.shape} from {source} to {dest}") + # logger.debug(f"source.output_axes: {source.output_axes}") + # logger.debug(f"dest.input_axes: {dest.input_axes}") + einstein_notation = f"{' '.join(source.output_axes)} -> {' '.join(dest.input_axes)}" + # logger.debug(einstein_notation) + # return named_rearrange(tensor, source.output_axes, dest.input_axes) + return rearrange(tensor, einstein_notation) def pad_tensor(tensor, target_len) -> torch.tensor: @@ -59,6 +81,14 @@ def cat_with_pad(tensors): def format_module(module, dest, *args, **kwargs) -> torch.tensor: + """ + Takes a module, a destination, and any number of arguments and keyword arguments, and returns the + output of the module, formatted for the destination + + :param module: the module to be formatted + :param dest: the destination of the output. This is a tuple of the form (module, index) + :return: The output of the module, formatted for the destination. + """ output = module(*args, **kwargs) if isinstance(output, tuple): output = output[0] From b3352efa6d82ed60b702004c8b893f93c351aee6 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 01:08:10 -0700 Subject: [PATCH 7/8] deprecated format_input: redundant function --- src/pytti/ImageGuide.py | 16 ++++++++-------- src/pytti/Perceptor/Embedder.py | 4 ++-- src/pytti/Perceptor/Prompt.py | 14 +++++++------- src/pytti/__init__.py | 2 -- src/pytti/tensor_tools.py | 8 ++------ 5 files changed, 19 insertions(+), 25 deletions(-) diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 36122ed..bbc8f4a 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -13,7 +13,7 @@ from collections import Counter from pytti import ( - format_input, + # named_rearrange, set_t, print_vram_usage, freeze_vram_usage, @@ -279,7 +279,7 @@ def train( losses = [] aug_losses = { - aug: aug(format_input(z, self.image_rep, aug), self.image_rep) + aug: aug(named_rearrange(z, self.image_rep, aug), self.image_rep) for aug in loss_augs } @@ -300,9 +300,9 @@ def train( t = i / interp_steps interp_losses = [ prompt( - format_input(image_embeds, self.embedder, prompt), - format_input(offsets, self.embedder, prompt), - format_input(sizes, self.embedder, prompt), + named_rearrange(image_embeds, self.embedder, prompt), + named_rearrange(offsets, self.embedder, prompt), + named_rearrange(sizes, self.embedder, prompt), )[0] * (1 - t) for prompt in interp_prompts @@ -310,9 +310,9 @@ def train( prompt_losses = { prompt: prompt( - format_input(image_embeds, self.embedder, prompt), - format_input(offsets, self.embedder, prompt), - format_input(sizes, self.embedder, prompt), + named_rearrange(image_embeds, self.embedder, prompt), + named_rearrange(offsets, self.embedder, prompt), + named_rearrange(sizes, self.embedder, prompt), ) for prompt in prompts } diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index b44b91a..a71b238 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -1,7 +1,7 @@ from typing import Tuple import pytti -from pytti import DEVICE, format_input, cat_with_pad, format_module, normalize +from pytti import DEVICE, named_rearrange, cat_with_pad, format_module, normalize # from pytti.ImageGuide import DirectImageGuide from pytti.image_models import DifferentiableImage @@ -106,7 +106,7 @@ def forward( device=device, memory_format=torch.channels_last ) else: - input = format_input(input, diff_image, self).to( + input = named_rearrange(input, diff_image, self).to( device=device, memory_format=torch.channels_last ) max_size = min(side_x, side_y) diff --git a/src/pytti/Perceptor/Prompt.py b/src/pytti/Perceptor/Prompt.py index f8ff836..f529f5d 100644 --- a/src/pytti/Perceptor/Prompt.py +++ b/src/pytti/Perceptor/Prompt.py @@ -18,7 +18,7 @@ import pytti from pytti import ( DEVICE, - format_input, + named_rearrange, cat_with_pad, replace_grad, fetch, @@ -310,7 +310,7 @@ def __init__( ): self.input_axes = ("c", "n", "i") super().__init__( - format_input(embeds, embedder, self), + named_rearrange(embeds, embedder, self), weight, stop, text + " (semantic)", @@ -318,8 +318,8 @@ def __init__( mask=mask, ) self.input_axes = ("c", "n", "i") - self.register_buffer("positions", format_input(positions, embedder, self)) - self.register_buffer("sizes", format_input(sizes, embedder, self)) + self.register_buffer("positions", named_rearrange(positions, embedder, self)) + self.register_buffer("sizes", named_rearrange(sizes, embedder, self)) @torch.no_grad() @vram_usage_mode("Image Prompts") @@ -335,9 +335,9 @@ def set_image(self, embedder, pil_image): img.encode_image(pil_image) embeds, positions, sizes = embedder(img) embeds = embeds.clone() - self.positions.set_(format_input(positions, embedder, self)) - self.sizes.set_(format_input(sizes, embedder, self)) - self.embeds.set_(format_input(embeds, embedder, self)) + self.positions.set_(named_rearrange(positions, embedder, self)) + self.sizes.set_(named_rearrange(sizes, embedder, self)) + self.embeds.set_(named_rearrange(embeds, embedder, self)) def minimize_average_distance(tensor_a, tensor_b, device=DEVICE): diff --git a/src/pytti/__init__.py b/src/pytti/__init__.py index 4efc1c5..a59f33e 100644 --- a/src/pytti/__init__.py +++ b/src/pytti/__init__.py @@ -13,7 +13,6 @@ from pytti.tensor_tools import ( named_rearrange, - format_input, pad_tensor, cat_with_pad, format_module, @@ -29,7 +28,6 @@ __all__ = [ "DEVICE", "named_rearrange", - "format_input", "pad_tensor", "cat_with_pad", "format_module", diff --git a/src/pytti/tensor_tools.py b/src/pytti/tensor_tools.py index e3764ba..0b90cf9 100644 --- a/src/pytti/tensor_tools.py +++ b/src/pytti/tensor_tools.py @@ -46,11 +46,7 @@ def named_rearrange__OLD(tensor, axes, new_positions) -> torch.tensor: return tensor.permute(*permutation) -def named_rearrange(tensor, axes, new_positions) -> torch.tensor: - return format_input(tensor, source, dest) - - -def format_input(tensor, source, dest) -> torch.tensor: +def named_rearrange(tensor, source, dest) -> torch.tensor: """ Takes a tensor and two layers, and returns the tensor in the format that the second layer expects @@ -92,7 +88,7 @@ def format_module(module, dest, *args, **kwargs) -> torch.tensor: output = module(*args, **kwargs) if isinstance(output, tuple): output = output[0] - return format_input(output, module, dest) + return named_rearrange(output, module, dest) class ReplaceGrad(torch.autograd.Function): From 0275705524fae9bfaa5458634fb47f1ce5ace337 Mon Sep 17 00:00:00 2001 From: David Marx Date: Fri, 27 May 2022 01:10:55 -0700 Subject: [PATCH 8/8] fixed accidentally introduced bug --- src/pytti/ImageGuide.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index bbc8f4a..7d8c06e 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -13,7 +13,7 @@ from collections import Counter from pytti import ( - # named_rearrange, + named_rearrange, set_t, print_vram_usage, freeze_vram_usage,