Skip to content

Commit

Permalink
Add CPU support (#21)
Browse files Browse the repository at this point in the history
* Default to CPU when CUDA deosn't exist

* Bump version
  • Loading branch information
daemon authored Dec 9, 2022
1 parent 378df6a commit cb5d2d2
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 16 deletions.
2 changes: 1 addition & 1 deletion daam/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.10'
__version__ = '0.0.11'
3 changes: 2 additions & 1 deletion daam/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch

from .utils import auto_autocast
from .evaluate import load_mask


Expand Down Expand Up @@ -233,7 +234,7 @@ def save_heat_map(
if tokenizer is None:
tokenizer = self.tokenizer

with torch.cuda.amp.autocast(dtype=torch.float32):
with auto_autocast(dtype=torch.float32):
path = self.path / self.subtype / f'{output_prefix}{word.lower()}.heat_map.png'
heat_map = GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map)
heat_map.compute_word_heat_map(word).expand_as(self.image, color_normalize=not absolute, out_file=path, plot=True)
Expand Down
7 changes: 3 additions & 4 deletions daam/heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
import torch.nn.functional as F

from .evaluate import compute_ioa
from .utils import compute_token_merge_indices, cached_nlp

from .utils import compute_token_merge_indices, cached_nlp, auto_autocast

__all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMap', 'SyntacticHeatMapPair']

Expand All @@ -27,7 +26,7 @@ def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, col
else:
plt_ = ax

with torch.cuda.amp.autocast(dtype=torch.float32):
with auto_autocast(dtype=torch.float32):
im = np.array(im)

if crop is not None:
Expand Down Expand Up @@ -152,7 +151,7 @@ def __init__(self):
self.ids_to_num_maps: Dict[RawHeatMapKey, int] = defaultdict(lambda: 0)

def update(self, factor: int, layer_idx: int, head_idx: int, heatmap: torch.Tensor):
with torch.cuda.amp.autocast(dtype=torch.float32):
with auto_autocast(dtype=torch.float32):
key = (factor, layer_idx, head_idx)
self.ids_to_heatmaps[key] = self.ids_to_heatmaps[key] + heatmap

Expand Down
4 changes: 2 additions & 2 deletions daam/run/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from spacy import displacy

from daam import trace
from daam.utils import set_seed, cached_nlp
from daam.utils import set_seed, cached_nlp, auto_autocast


def dependency(text):
Expand Down Expand Up @@ -83,7 +83,7 @@ def plot(prompt, choice, replaced_word, inf_steps, is_random_seed):
new_prompt = ' '.join(new_prompt)

merge_idxs, words = get_tokenizing_mapping(prompt, pipe.tokenizer)
with torch.cuda.amp.autocast(dtype=torch.float16), lock:
with auto_autocast(dtype=torch.float16), lock:
try:
plt.close('all')
plt.clf()
Expand Down
6 changes: 3 additions & 3 deletions daam/run/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from daam import trace
from daam.experiment import GenerationExperiment, build_word_list_coco80
from daam.utils import set_seed, cached_nlp
from daam.utils import set_seed, cached_nlp, auto_device, auto_autocast


def main():
Expand Down Expand Up @@ -191,9 +191,9 @@ def main():

prompts = prompts[:args.gen_limit]
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True)
pipe = pipe.to('cuda')
pipe = auto_device(pipe)

with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad():
with auto_autocast(dtype=torch.float16), torch.no_grad():
for gen_idx, (prompt_id, prompt) in enumerate(tqdm(prompts)):
seed = int(time.time()) if args.random_seed else args.seed
prompt = prompt.replace(',', ' ,').replace('.', ' .').strip()
Expand Down
6 changes: 3 additions & 3 deletions daam/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn.functional as F

from . import cache_dir
from .utils import cache_dir, auto_autocast
from .experiment import GenerationExperiment
from .heatmap import RawHeatMapCollection, GlobalHeatMap
from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator
Expand Down Expand Up @@ -105,7 +105,7 @@ def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, laye
all_merges = []
x = int(np.sqrt(self.latent_hw))

with torch.cuda.amp.autocast(dtype=torch.float32):
with auto_autocast(dtype=torch.float32):
for (factor, layer, head), heat_map in heat_maps:
if factor in factors and (head_idx is None or head_idx == head) and (layer_idx is None or layer_idx == layer):
heat_map = heat_map.unsqueeze(1)
Expand Down Expand Up @@ -209,7 +209,7 @@ def _unravel_attn(self, x):
maps = []
x = x.permute(2, 0, 1)

with torch.cuda.amp.autocast(dtype=torch.float32):
with auto_autocast(dtype=torch.float32):
for map_ in x:
map_ = map_.view(map_.size(0), h, w)
map_ = map_[map_.size(0) // 2:] # Filter out unconditional
Expand Down
25 changes: 23 additions & 2 deletions daam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
import random
from typing import TypeVar

import PIL.Image
import matplotlib.pyplot as plt
Expand All @@ -13,7 +14,27 @@
import torch.nn.functional as F


__all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir']
__all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir', 'auto_device', 'auto_autocast']


T = TypeVar('T')


def auto_device(obj: T = torch.device('cpu')) -> T:
if isinstance(obj, torch.device):
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
return obj.to('cuda')

return obj


def auto_autocast(*args, **kwargs):
if not torch.cuda.is_available():
kwargs['enabled'] = False

return torch.cuda.amp.autocast(*args, **kwargs)


def plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4):
Expand All @@ -29,7 +50,7 @@ def set_seed(seed: int) -> torch.Generator:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

gen = torch.Generator(device='cuda')
gen = torch.Generator(device=auto_device())
gen.manual_seed(seed)

return gen
Expand Down

0 comments on commit cb5d2d2

Please sign in to comment.