diff --git a/src/pytti/ImageGuide.py b/src/pytti/ImageGuide.py index 36122ed..7d8c06e 100644 --- a/src/pytti/ImageGuide.py +++ b/src/pytti/ImageGuide.py @@ -13,7 +13,7 @@ from collections import Counter from pytti import ( - format_input, + named_rearrange, set_t, print_vram_usage, freeze_vram_usage, @@ -279,7 +279,7 @@ def train( losses = [] aug_losses = { - aug: aug(format_input(z, self.image_rep, aug), self.image_rep) + aug: aug(named_rearrange(z, self.image_rep, aug), self.image_rep) for aug in loss_augs } @@ -300,9 +300,9 @@ def train( t = i / interp_steps interp_losses = [ prompt( - format_input(image_embeds, self.embedder, prompt), - format_input(offsets, self.embedder, prompt), - format_input(sizes, self.embedder, prompt), + named_rearrange(image_embeds, self.embedder, prompt), + named_rearrange(offsets, self.embedder, prompt), + named_rearrange(sizes, self.embedder, prompt), )[0] * (1 - t) for prompt in interp_prompts @@ -310,9 +310,9 @@ def train( prompt_losses = { prompt: prompt( - format_input(image_embeds, self.embedder, prompt), - format_input(offsets, self.embedder, prompt), - format_input(sizes, self.embedder, prompt), + named_rearrange(image_embeds, self.embedder, prompt), + named_rearrange(offsets, self.embedder, prompt), + named_rearrange(sizes, self.embedder, prompt), ) for prompt in prompts } diff --git a/src/pytti/Perceptor/Embedder.py b/src/pytti/Perceptor/Embedder.py index b44b91a..a71b238 100644 --- a/src/pytti/Perceptor/Embedder.py +++ b/src/pytti/Perceptor/Embedder.py @@ -1,7 +1,7 @@ from typing import Tuple import pytti -from pytti import DEVICE, format_input, cat_with_pad, format_module, normalize +from pytti import DEVICE, named_rearrange, cat_with_pad, format_module, normalize # from pytti.ImageGuide import DirectImageGuide from pytti.image_models import DifferentiableImage @@ -106,7 +106,7 @@ def forward( device=device, memory_format=torch.channels_last ) else: - input = format_input(input, diff_image, self).to( + input = named_rearrange(input, diff_image, self).to( device=device, memory_format=torch.channels_last ) max_size = min(side_x, side_y) diff --git a/src/pytti/Perceptor/Prompt.py b/src/pytti/Perceptor/Prompt.py index f8ff836..f529f5d 100644 --- a/src/pytti/Perceptor/Prompt.py +++ b/src/pytti/Perceptor/Prompt.py @@ -18,7 +18,7 @@ import pytti from pytti import ( DEVICE, - format_input, + named_rearrange, cat_with_pad, replace_grad, fetch, @@ -310,7 +310,7 @@ def __init__( ): self.input_axes = ("c", "n", "i") super().__init__( - format_input(embeds, embedder, self), + named_rearrange(embeds, embedder, self), weight, stop, text + " (semantic)", @@ -318,8 +318,8 @@ def __init__( mask=mask, ) self.input_axes = ("c", "n", "i") - self.register_buffer("positions", format_input(positions, embedder, self)) - self.register_buffer("sizes", format_input(sizes, embedder, self)) + self.register_buffer("positions", named_rearrange(positions, embedder, self)) + self.register_buffer("sizes", named_rearrange(sizes, embedder, self)) @torch.no_grad() @vram_usage_mode("Image Prompts") @@ -335,9 +335,9 @@ def set_image(self, embedder, pil_image): img.encode_image(pil_image) embeds, positions, sizes = embedder(img) embeds = embeds.clone() - self.positions.set_(format_input(positions, embedder, self)) - self.sizes.set_(format_input(sizes, embedder, self)) - self.embeds.set_(format_input(embeds, embedder, self)) + self.positions.set_(named_rearrange(positions, embedder, self)) + self.sizes.set_(named_rearrange(sizes, embedder, self)) + self.embeds.set_(named_rearrange(embeds, embedder, self)) def minimize_average_distance(tensor_a, tensor_b, device=DEVICE): diff --git a/src/pytti/Perceptor/cutouts/__init__.py b/src/pytti/Perceptor/cutouts/__init__.py index e69de29..94b5630 100644 --- a/src/pytti/Perceptor/cutouts/__init__.py +++ b/src/pytti/Perceptor/cutouts/__init__.py @@ -0,0 +1,3 @@ +from .dango import MakeCutouts + +test = MakeCutouts(1) diff --git a/src/pytti/Perceptor/cutouts/augs.py b/src/pytti/Perceptor/cutouts/augs.py index 32873dd..101b675 100644 --- a/src/pytti/Perceptor/cutouts/augs.py +++ b/src/pytti/Perceptor/cutouts/augs.py @@ -1,5 +1,7 @@ import kornia.augmentation as K +import torch from torch import nn +from torchvision import transforms as T def pytti_classic(): @@ -16,3 +18,25 @@ def pytti_classic(): ), nn.Identity(), ) + + +def dango(): + return T.Compose( + [ + # T.RandomHorizontalFlip(p=0.5), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomAffine( + degrees=0, + translate=(0.05, 0.05), + # scale=(0.9,0.95), + fill=-1, + interpolation=T.InterpolationMode.BILINEAR, + ), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + # T.RandomPerspective(p=1, interpolation = T.InterpolationMode.BILINEAR, fill=-1,distortion_scale=0.2), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.RandomGrayscale(p=0.1), + T.Lambda(lambda x: x + torch.randn_like(x) * 0.01), + T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05), + ] + ) diff --git a/src/pytti/Perceptor/cutouts/dango.py b/src/pytti/Perceptor/cutouts/dango.py new file mode 100644 index 0000000..ae94077 --- /dev/null +++ b/src/pytti/Perceptor/cutouts/dango.py @@ -0,0 +1,101 @@ +# via https://github.com/multimodalart/majesty-diffusion/blob/main/latent.ipynb + +# !pip install resize-right +# TO DO: add resize-right to setup instructions and notebook +from resize_right import resize + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import transforms +from torchvision import transforms as T +from torchvision.transforms import functional as TF + +from . import augs as cutouts_augs + +padargs = {"mode": "constant", "value": -1} + + +class MakeCutouts(nn.Module): + def __init__( + self, + cut_size, + Overview=4, + WholeCrop=0, + WC_Allowance=10, + WC_Grey_P=0.2, + InnerCrop=0, + IC_Size_Pow=0.5, + IC_Grey_P=0.2, + aug=True, + cutout_debug=False, + ): + super().__init__() + self.cut_size = cut_size + self.Overview = Overview + self.WholeCrop = WholeCrop + self.WC_Allowance = WC_Allowance + self.WC_Grey_P = WC_Grey_P + self.InnerCrop = InnerCrop + self.IC_Size_Pow = IC_Size_Pow + self.IC_Grey_P = IC_Grey_P + self.augs = cutouts_augs.dango + self._aug = aug + self.cutout_debug = cutout_debug + + def forward(self, input): + gray = transforms.Grayscale( + 3 + ) # this is possibly a performance improvement? 1 channel instead of 3. but also means we can't use color augs... + sideY, sideX = input.shape[2:4] + max_size = min(sideX, sideY) + min_size = min(sideX, sideY, self.cut_size) + l_size = max(sideX, sideY) + output_shape = [input.shape[0], 3, self.cut_size, self.cut_size] + output_shape_2 = [input.shape[0], 3, self.cut_size + 2, self.cut_size + 2] + pad_input = F.pad( + input, + ( + (sideY - max_size) // 2 + round(max_size * 0.055), + (sideY - max_size) // 2 + round(max_size * 0.055), + (sideX - max_size) // 2 + round(max_size * 0.055), + (sideX - max_size) // 2 + round(max_size * 0.055), + ), + **padargs + ) + cutouts_list = [] + + if self.Overview > 0: + cutouts = [] + cutout = resize(pad_input, out_shape=output_shape, antialiasing=True) + output_shape_all = list(output_shape) + output_shape_all[0] = self.Overview * input.shape[0] + pad_input = pad_input.repeat(input.shape[0], 1, 1, 1) + cutout = resize(pad_input, out_shape=output_shape_all) + if self._aug: + cutout = self.augs(cutout) + cutouts_list.append(cutout) + + if self.InnerCrop > 0: + cutouts = [] + for i in range(self.InnerCrop): + size = int( + torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size) + + min_size + ) + offsetx = torch.randint(0, sideX - size + 1, ()) + offsety = torch.randint(0, sideY - size + 1, ()) + cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size] + if i <= int(self.IC_Grey_P * self.InnerCrop): + cutout = gray(cutout) + cutout = resize(cutout, out_shape=output_shape) + cutouts.append(cutout) + if self.cutout_debug: + TF.to_pil_image(cutouts[-1].add(1).div(2).clamp(0, 1).squeeze(0)).save( + "content/diff/cutouts/cutout_InnerCrop.jpg", quality=99 + ) + cutouts_tensor = torch.cat(cutouts) + cutouts = [] + cutouts_list.append(cutouts_tensor) + cutouts = torch.cat(cutouts_list) + return cutouts diff --git a/src/pytti/__init__.py b/src/pytti/__init__.py index 4efc1c5..a59f33e 100644 --- a/src/pytti/__init__.py +++ b/src/pytti/__init__.py @@ -13,7 +13,6 @@ from pytti.tensor_tools import ( named_rearrange, - format_input, pad_tensor, cat_with_pad, format_module, @@ -29,7 +28,6 @@ __all__ = [ "DEVICE", "named_rearrange", - "format_input", "pad_tensor", "cat_with_pad", "format_module", diff --git a/src/pytti/tensor_tools.py b/src/pytti/tensor_tools.py index 6019c11..0b90cf9 100644 --- a/src/pytti/tensor_tools.py +++ b/src/pytti/tensor_tools.py @@ -3,12 +3,16 @@ from torchvision import transforms from PIL import Image as PIL_Image +from loguru import logger +from einops import rearrange + normalize = transforms.Normalize( mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] ) # Stuff like this could probably be replaced with einops -def named_rearrange(tensor, axes, new_positions) -> torch.tensor: +# ....sheeeeesh... yeah, let's squash this ugliness +def named_rearrange__OLD(tensor, axes, new_positions) -> torch.tensor: """ Permute and unsqueeze tensor to match target dimensional arrangement tensor: (Tensor) input @@ -42,8 +46,22 @@ def named_rearrange(tensor, axes, new_positions) -> torch.tensor: return tensor.permute(*permutation) -def format_input(tensor, source, dest) -> torch.tensor: - return named_rearrange(tensor, source.output_axes, dest.input_axes) +def named_rearrange(tensor, source, dest) -> torch.tensor: + """ + Takes a tensor and two layers, and returns the tensor in the format that the second layer expects + + :param tensor: the tensor to be formatted + :param source: the source model + :param dest: the destination tensor + :return: A tensor with the same data as the input tensor, but with the axes reordered. + """ + # logger.debug(f"Formatting {tensor.shape} from {source} to {dest}") + # logger.debug(f"source.output_axes: {source.output_axes}") + # logger.debug(f"dest.input_axes: {dest.input_axes}") + einstein_notation = f"{' '.join(source.output_axes)} -> {' '.join(dest.input_axes)}" + # logger.debug(einstein_notation) + # return named_rearrange(tensor, source.output_axes, dest.input_axes) + return rearrange(tensor, einstein_notation) def pad_tensor(tensor, target_len) -> torch.tensor: @@ -59,10 +77,18 @@ def cat_with_pad(tensors): def format_module(module, dest, *args, **kwargs) -> torch.tensor: + """ + Takes a module, a destination, and any number of arguments and keyword arguments, and returns the + output of the module, formatted for the destination + + :param module: the module to be formatted + :param dest: the destination of the output. This is a tuple of the form (module, index) + :return: The output of the module, formatted for the destination. + """ output = module(*args, **kwargs) if isinstance(output, tuple): output = output[0] - return format_input(output, module, dest) + return named_rearrange(output, module, dest) class ReplaceGrad(torch.autograd.Function):