Skip to content

Commit

Permalink
Add generation experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
daemon committed Nov 12, 2022
1 parent 651a5c7 commit 84a9a33
Show file tree
Hide file tree
Showing 71 changed files with 446 additions and 15 deletions.
2 changes: 2 additions & 0 deletions daam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .hook import *
from .trace import *
from .utils import *
from .experiment import *
53 changes: 47 additions & 6 deletions daam/experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any
from dataclasses import dataclass
import json

import PIL.Image
import numpy as np
Expand All @@ -25,6 +26,9 @@
]


UNUSED_LABELS: List[str] = [f'__unused_{i}__' for i in range(1, 200)]


COCOSTUFF27_LABELS: List[str] = [
'electronic', 'appliance', 'food', 'furniture', 'indoor', 'kitchen', 'accessory', 'animal', 'outdoor', 'person',
'sports', 'vehicle', 'ceiling', 'floor', 'food', 'furniture', 'rawmaterial', 'textile', 'wall', 'window',
Expand Down Expand Up @@ -77,6 +81,7 @@ class GenerationExperiment:
path: Optional[Path] = None
truth_masks: Optional[Dict[str, torch.Tensor]] = None
prediction_masks: Optional[Dict[str, torch.Tensor]] = None
annotations: Optional[Dict[str, Any]] = None

def save(self, path: str = None):
if path is None:
Expand All @@ -99,6 +104,16 @@ def save(self, path: str = None):
im = PIL.Image.fromarray((mask * 255).unsqueeze(-1).expand(-1, -1, 4).byte().numpy())
im.save(path / f'{name.lower()}.gt.png')

self.save_annotations()

def save_annotations(self, path: Path = None):
if path is None:
path = self.path

if self.annotations is not None:
with (path / 'annotations.json').open('w') as f:
json.dump(self.annotations, f)

def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]:
masks = {}

Expand All @@ -110,11 +125,11 @@ def _load_truth_masks(self, simplify80: bool = False) -> Dict[str, torch.Tensor]
return masks

def _load_pred_masks(self, pred_prefix, composite=False, simplify80=False, vocab=None):
# type: (str, bool, bool, List[str]) -> Dict[str, torch.Tensor]
# type: (str, bool, bool, List[str] | None) -> Dict[str, torch.Tensor]
masks = {}

if vocab is None:
vocab = COCOSTUFF27_LABELS
vocab = UNUSED_LABELS

if composite:
im = PIL.Image.open(self.path / f'composite.{pred_prefix}.pred.png')
Expand All @@ -136,16 +151,42 @@ def save_prediction_mask(self, mask: torch.Tensor, word: str, name: str):
im.save(self.path / f'{word.lower()}.{name}.pred.png')

@staticmethod
def contains_truth_mask(path: str) -> bool:
return any(Path(path).glob('*.gt.png'))
def contains_truth_mask(path: str | Path, prompt_id: str = None) -> bool:
if prompt_id is None:
return any(Path(path).glob('*.gt.png'))
else:
return any((Path(path) / prompt_id).glob('*.gt.png'))

@staticmethod
def has_annotations(path: str | Path) -> bool:
return Path(path).joinpath('annotations.json').exists()

@staticmethod
def has_experiment(path: str | Path, prompt_id: str) -> bool:
return (Path(path) / prompt_id / 'generation.pt').exists()

def _try_load_annotations(self):
if not (self.path / 'annotations.json').exists():
return None

return json.load((self.path / 'annotations.json').open())

def annotate(self, key: str, value: Any) -> 'GenerationExperiment':
if self.annotations is None:
self.annotations = {}

self.annotations[key] = value

return self

@classmethod
def load(cls, path, pred_prefix='daam', composite=False, simplify80=False, vocab=None):
# type: (str, str, bool, bool, List[str]) -> GenerationExperiment
# type: (str, str, bool, bool, List[str] | None) -> GenerationExperiment
path = Path(path)
exp = torch.load(path / 'generation.pt')
exp.path = path
exp.truth_masks = exp._load_truth_masks(simplify80=simplify80)
exp.prediction_masks = exp._load_pred_masks(pred_prefix, composite=composite, simplify80=simplify80, vocab=vocab)
exp.annotations = exp._try_load_annotations()

return exp
47 changes: 47 additions & 0 deletions daam/run/annotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from pathlib import Path
import argparse

from diffusers import StableDiffusionPipeline
from matplotlib import pyplot as plt

from daam import GenerationExperiment, HeatMap, plot_overlay_heat_map, expand_image


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input-folder', '-i', type=str, required=True)
parser.add_argument('--pred-prefix', '-p', type=str, default='daam')
parser.add_argument('--model', type=str, default='CompVis/stable-diffusion-v1-4')
args = parser.parse_args()

input_folder = Path(args.input_folder)

pipe = StableDiffusionPipeline.from_pretrained(args.model, use_auth_token=True)
tokenizer = pipe.tokenizer
del pipe

for path in input_folder.iterdir():
if not path.is_dir():
continue

exp = GenerationExperiment.load(str(path), args.pred_prefix)

if exp.annotations is not None and 'num_objects' in exp.annotations:
continue

plt.clf()
print(exp.prompt)
exp.image.show()

heat_map = HeatMap(tokenizer, exp.prompt, exp.global_heat_map)
plt.clf()
plot_overlay_heat_map(exp.image, expand_image(heat_map.compute_word_heat_map(exp.prompt.split()[0])))
plt.show()

num_objects = int(input('Number of objects: '))
exp.annotate('num_objects', num_objects)
exp.save_annotations()


if __name__ == '__main__':
main()
23 changes: 22 additions & 1 deletion daam/run/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,32 @@ def main():

evaluator = MeanEvaluator() if args.eval_type != 'hungarian' else UnsupervisedEvaluator()
simplify80 = False
vocab = None
vocab = []

if args.restrict_set == 'coco27':
simplify80 = True
vocab = COCOSTUFF27_LABELS
elif args.restrict_set == 'coco80':
vocab = COCO80_LABELS

if not vocab:
for path in tqdm(Path(args.input_folder).glob('*')):
if not path.is_dir() or not GenerationExperiment.contains_truth_mask(path):
continue

exp = GenerationExperiment.load(
path,
args.pred_prefix,
composite=args.mask_type == 'composite',
simplify80=simplify80
)

vocab.extend(exp.truth_masks)
vocab.extend(exp.prediction_masks)

vocab = list(set(vocab))
vocab.sort()

for path in tqdm(Path(args.input_folder).glob('*')):
if not path.is_dir() or not GenerationExperiment.contains_truth_mask(path):
continue
Expand All @@ -40,6 +58,9 @@ def main():

if args.eval_type == 'labeled':
for word, mask in exp.truth_masks.items():
if word not in vocab and args.restrict_set != 'none':
continue

try:
evaluator.log_iou(exp.prediction_masks[word], mask)
evaluator.log_intensity(exp.prediction_masks[word])
Expand Down
80 changes: 75 additions & 5 deletions daam/run/generate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from collections import defaultdict
from pathlib import Path
import argparse
import json
import pandas as pd
import random

from diffusers import StableDiffusionPipeline
from tqdm import tqdm
import inflect
import torch

from daam import trace
Expand All @@ -14,21 +17,77 @@

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--action', type=str, default='prompt', choices=['prompt', 'coco'])
parser.add_argument('--action', type=str, default='prompt', choices=['prompt', 'coco', 'template', 'cconj'])
parser.add_argument('--output-folder', '-o', type=str, default='output')
parser.add_argument('--input-folder', '-i', type=str, default='input')
parser.add_argument('--seed', '-s', type=int, default=0)
parser.add_argument('--gen-limit', type=int, default=1000)
parser.add_argument('--template', type=str, default='{numeral} {noun}')
parser.add_argument('--template-data-file', '-tdf', type=str, default='template.tsv')
parser.add_argument('--regenerate', action='store_true')
args = parser.parse_args()

gen = set_seed(args.seed)
eng = inflect.engine()

if args.action == 'coco':
with (Path(args.input_folder) / 'captions_val2014.json').open() as f:
captions = json.load(f)['annotations'][:args.gen_limit]
captions = json.load(f)['annotations']

random.shuffle(captions)
captions = captions[:args.gen_limit]
prompts = [(caption['id'], caption['caption']) for caption in captions]
elif args.action == 'template':
template_df = pd.read_csv(args.template_data_file, sep='\t', quoting=3)
sample_dict = defaultdict(list)

for name, df in template_df.groupby('pos'):
sample_dict[name].extend(df['word'].tolist())

prompts = []
template_words = args.template.split()
plural_numerals = {'0', '2', '3', '4', '5', '6', '7', '8', '9', 'zero', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine'}

for prompt_id in range(args.gen_limit):
words = []
pluralize = False

for word in template_words:
if word.startswith('{'):
pos = word[1:-1]
word = random.choice(sample_dict[pos])

if pos == 'noun' and pluralize:
word = eng.plural(word)

words.append(word)
pluralize = word in plural_numerals

prompt_id = str(prompt_id)
prompts.append((prompt_id, ' '.join(words)))
tqdm.write(str(prompts[-1]))
elif args.action == 'cconj':
template_df = pd.read_csv(args.template_data_file, sep='\t', quoting=3)
sample_dict = defaultdict(list)

for name, df in template_df.groupby('pos'):
sample_dict[name].extend(df['word'].tolist())

prompts = []
prompt_id = 0

for _ in range(args.gen_limit):

for word1 in tqdm(sample_dict['noun']):
for word2 in sample_dict['noun']:
if word1 == word2:
continue

prompt = f'a {word1} and a {word2}'
print(prompt)

prompts.append((str(prompt_id), prompt))
prompt_id += 1
else:
prompts = [('prompt', input('> '))]

Expand All @@ -37,15 +96,26 @@ def main():

pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
pipe = pipe.to(device)
seed = args.seed

with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
for prompt_id, prompt in tqdm(prompts):
if args.action == 'template' or args.action == 'cconj':
gen = set_seed(int(prompt_id))
seed = int(prompt_id)

prompt_id = str(prompt_id)

if args.regenerate and not GenerationExperiment.contains_truth_mask(args.output_folder, prompt_id):
print(f'Skipping {prompt_id}')
continue

with trace(pipe, weighted=True) as tc:
out = pipe(prompt, num_inference_steps=30, generator=gen)
out = pipe(prompt, num_inference_steps=20, generator=gen)
exp = GenerationExperiment(
id=str(prompt_id),
id=prompt_id,
global_heat_map=tc.compute_global_heat_map(prompt).heat_maps,
seed=args.seed,
seed=seed,
prompt=prompt,
image=out.images[0]
)
Expand Down
Loading

0 comments on commit 84a9a33

Please sign in to comment.