Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use einops #183

Open
wants to merge 8 commits into
base: test
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/pytti/ImageGuide.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from collections import Counter

from pytti import (
format_input,
named_rearrange,
set_t,
print_vram_usage,
freeze_vram_usage,
Expand Down Expand Up @@ -279,7 +279,7 @@ def train(
losses = []

aug_losses = {
aug: aug(format_input(z, self.image_rep, aug), self.image_rep)
aug: aug(named_rearrange(z, self.image_rep, aug), self.image_rep)
for aug in loss_augs
}

Expand All @@ -300,19 +300,19 @@ def train(
t = i / interp_steps
interp_losses = [
prompt(
format_input(image_embeds, self.embedder, prompt),
format_input(offsets, self.embedder, prompt),
format_input(sizes, self.embedder, prompt),
named_rearrange(image_embeds, self.embedder, prompt),
named_rearrange(offsets, self.embedder, prompt),
named_rearrange(sizes, self.embedder, prompt),
)[0]
* (1 - t)
for prompt in interp_prompts
]

prompt_losses = {
prompt: prompt(
format_input(image_embeds, self.embedder, prompt),
format_input(offsets, self.embedder, prompt),
format_input(sizes, self.embedder, prompt),
named_rearrange(image_embeds, self.embedder, prompt),
named_rearrange(offsets, self.embedder, prompt),
named_rearrange(sizes, self.embedder, prompt),
)
for prompt in prompts
}
Expand Down
4 changes: 2 additions & 2 deletions src/pytti/Perceptor/Embedder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Tuple

import pytti
from pytti import DEVICE, format_input, cat_with_pad, format_module, normalize
from pytti import DEVICE, named_rearrange, cat_with_pad, format_module, normalize

# from pytti.ImageGuide import DirectImageGuide
from pytti.image_models import DifferentiableImage
Expand Down Expand Up @@ -106,7 +106,7 @@ def forward(
device=device, memory_format=torch.channels_last
)
else:
input = format_input(input, diff_image, self).to(
input = named_rearrange(input, diff_image, self).to(
device=device, memory_format=torch.channels_last
)
max_size = min(side_x, side_y)
Expand Down
14 changes: 7 additions & 7 deletions src/pytti/Perceptor/Prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import pytti
from pytti import (
DEVICE,
format_input,
named_rearrange,
cat_with_pad,
replace_grad,
fetch,
Expand Down Expand Up @@ -310,16 +310,16 @@ def __init__(
):
self.input_axes = ("c", "n", "i")
super().__init__(
format_input(embeds, embedder, self),
named_rearrange(embeds, embedder, self),
weight,
stop,
text + " (semantic)",
prompt_string,
mask=mask,
)
self.input_axes = ("c", "n", "i")
self.register_buffer("positions", format_input(positions, embedder, self))
self.register_buffer("sizes", format_input(sizes, embedder, self))
self.register_buffer("positions", named_rearrange(positions, embedder, self))
self.register_buffer("sizes", named_rearrange(sizes, embedder, self))

@torch.no_grad()
@vram_usage_mode("Image Prompts")
Expand All @@ -335,9 +335,9 @@ def set_image(self, embedder, pil_image):
img.encode_image(pil_image)
embeds, positions, sizes = embedder(img)
embeds = embeds.clone()
self.positions.set_(format_input(positions, embedder, self))
self.sizes.set_(format_input(sizes, embedder, self))
self.embeds.set_(format_input(embeds, embedder, self))
self.positions.set_(named_rearrange(positions, embedder, self))
self.sizes.set_(named_rearrange(sizes, embedder, self))
self.embeds.set_(named_rearrange(embeds, embedder, self))


def minimize_average_distance(tensor_a, tensor_b, device=DEVICE):
Expand Down
3 changes: 3 additions & 0 deletions src/pytti/Perceptor/cutouts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .dango import MakeCutouts

test = MakeCutouts(1)
24 changes: 24 additions & 0 deletions src/pytti/Perceptor/cutouts/augs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import kornia.augmentation as K
import torch
from torch import nn
from torchvision import transforms as T


def pytti_classic():
Expand All @@ -16,3 +18,25 @@ def pytti_classic():
),
nn.Identity(),
)


def dango():
return T.Compose(
[
# T.RandomHorizontalFlip(p=0.5),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomAffine(
degrees=0,
translate=(0.05, 0.05),
# scale=(0.9,0.95),
fill=-1,
interpolation=T.InterpolationMode.BILINEAR,
),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
# T.RandomPerspective(p=1, interpolation = T.InterpolationMode.BILINEAR, fill=-1,distortion_scale=0.2),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.RandomGrayscale(p=0.1),
T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
T.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05),
]
)
101 changes: 101 additions & 0 deletions src/pytti/Perceptor/cutouts/dango.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# via https://github.com/multimodalart/majesty-diffusion/blob/main/latent.ipynb

# !pip install resize-right
# TO DO: add resize-right to setup instructions and notebook
from resize_right import resize

import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision import transforms as T
from torchvision.transforms import functional as TF

from . import augs as cutouts_augs

padargs = {"mode": "constant", "value": -1}


class MakeCutouts(nn.Module):
def __init__(
self,
cut_size,
Overview=4,
WholeCrop=0,
WC_Allowance=10,
WC_Grey_P=0.2,
InnerCrop=0,
IC_Size_Pow=0.5,
IC_Grey_P=0.2,
aug=True,
cutout_debug=False,
):
super().__init__()
self.cut_size = cut_size
self.Overview = Overview
self.WholeCrop = WholeCrop
self.WC_Allowance = WC_Allowance
self.WC_Grey_P = WC_Grey_P
self.InnerCrop = InnerCrop
self.IC_Size_Pow = IC_Size_Pow
self.IC_Grey_P = IC_Grey_P
self.augs = cutouts_augs.dango
self._aug = aug
self.cutout_debug = cutout_debug

def forward(self, input):
gray = transforms.Grayscale(
3
) # this is possibly a performance improvement? 1 channel instead of 3. but also means we can't use color augs...
sideY, sideX = input.shape[2:4]
max_size = min(sideX, sideY)
min_size = min(sideX, sideY, self.cut_size)
l_size = max(sideX, sideY)
output_shape = [input.shape[0], 3, self.cut_size, self.cut_size]
output_shape_2 = [input.shape[0], 3, self.cut_size + 2, self.cut_size + 2]
pad_input = F.pad(
input,
(
(sideY - max_size) // 2 + round(max_size * 0.055),
(sideY - max_size) // 2 + round(max_size * 0.055),
(sideX - max_size) // 2 + round(max_size * 0.055),
(sideX - max_size) // 2 + round(max_size * 0.055),
),
**padargs
)
cutouts_list = []

if self.Overview > 0:
cutouts = []
cutout = resize(pad_input, out_shape=output_shape, antialiasing=True)
output_shape_all = list(output_shape)
output_shape_all[0] = self.Overview * input.shape[0]
pad_input = pad_input.repeat(input.shape[0], 1, 1, 1)
cutout = resize(pad_input, out_shape=output_shape_all)
if self._aug:
cutout = self.augs(cutout)
cutouts_list.append(cutout)

if self.InnerCrop > 0:
cutouts = []
for i in range(self.InnerCrop):
size = int(
torch.rand([]) ** self.IC_Size_Pow * (max_size - min_size)
+ min_size
)
offsetx = torch.randint(0, sideX - size + 1, ())
offsety = torch.randint(0, sideY - size + 1, ())
cutout = input[:, :, offsety : offsety + size, offsetx : offsetx + size]
if i <= int(self.IC_Grey_P * self.InnerCrop):
cutout = gray(cutout)
cutout = resize(cutout, out_shape=output_shape)
cutouts.append(cutout)
if self.cutout_debug:
TF.to_pil_image(cutouts[-1].add(1).div(2).clamp(0, 1).squeeze(0)).save(
"content/diff/cutouts/cutout_InnerCrop.jpg", quality=99
)
cutouts_tensor = torch.cat(cutouts)
cutouts = []
cutouts_list.append(cutouts_tensor)
cutouts = torch.cat(cutouts_list)
return cutouts
2 changes: 0 additions & 2 deletions src/pytti/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from pytti.tensor_tools import (
named_rearrange,
format_input,
pad_tensor,
cat_with_pad,
format_module,
Expand All @@ -29,7 +28,6 @@
__all__ = [
"DEVICE",
"named_rearrange",
"format_input",
"pad_tensor",
"cat_with_pad",
"format_module",
Expand Down
34 changes: 30 additions & 4 deletions src/pytti/tensor_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
from torchvision import transforms
from PIL import Image as PIL_Image

from loguru import logger
from einops import rearrange

normalize = transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
)

# Stuff like this could probably be replaced with einops
def named_rearrange(tensor, axes, new_positions) -> torch.tensor:
# ....sheeeeesh... yeah, let's squash this ugliness
def named_rearrange__OLD(tensor, axes, new_positions) -> torch.tensor:
"""
Permute and unsqueeze tensor to match target dimensional arrangement
tensor: (Tensor) input
Expand Down Expand Up @@ -42,8 +46,22 @@ def named_rearrange(tensor, axes, new_positions) -> torch.tensor:
return tensor.permute(*permutation)


def format_input(tensor, source, dest) -> torch.tensor:
return named_rearrange(tensor, source.output_axes, dest.input_axes)
def named_rearrange(tensor, source, dest) -> torch.tensor:
"""
Takes a tensor and two layers, and returns the tensor in the format that the second layer expects

:param tensor: the tensor to be formatted
:param source: the source model
:param dest: the destination tensor
:return: A tensor with the same data as the input tensor, but with the axes reordered.
"""
# logger.debug(f"Formatting {tensor.shape} from {source} to {dest}")
# logger.debug(f"source.output_axes: {source.output_axes}")
# logger.debug(f"dest.input_axes: {dest.input_axes}")
einstein_notation = f"{' '.join(source.output_axes)} -> {' '.join(dest.input_axes)}"
# logger.debug(einstein_notation)
# return named_rearrange(tensor, source.output_axes, dest.input_axes)
return rearrange(tensor, einstein_notation)


def pad_tensor(tensor, target_len) -> torch.tensor:
Expand All @@ -59,10 +77,18 @@ def cat_with_pad(tensors):


def format_module(module, dest, *args, **kwargs) -> torch.tensor:
"""
Takes a module, a destination, and any number of arguments and keyword arguments, and returns the
output of the module, formatted for the destination

:param module: the module to be formatted
:param dest: the destination of the output. This is a tuple of the form (module, index)
:return: The output of the module, formatted for the destination.
"""
output = module(*args, **kwargs)
if isinstance(output, tuple):
output = output[0]
return format_input(output, module, dest)
return named_rearrange(output, module, dest)


class ReplaceGrad(torch.autograd.Function):
Expand Down