From cb5d2d2b19b77eb9654bd6ac3dfde92a25a02541 Mon Sep 17 00:00:00 2001 From: Raphael Tang Date: Thu, 8 Dec 2022 22:13:02 -0500 Subject: [PATCH] Add CPU support (#21) * Default to CPU when CUDA deosn't exist * Bump version --- daam/_version.py | 2 +- daam/experiment.py | 3 ++- daam/heatmap.py | 7 +++---- daam/run/demo.py | 4 ++-- daam/run/generate.py | 6 +++--- daam/trace.py | 6 +++--- daam/utils.py | 25 +++++++++++++++++++++++-- 7 files changed, 37 insertions(+), 16 deletions(-) diff --git a/daam/_version.py b/daam/_version.py index 6820f36..ad3cf1d 100644 --- a/daam/_version.py +++ b/daam/_version.py @@ -1 +1 @@ -__version__ = '0.0.10' +__version__ = '0.0.11' diff --git a/daam/experiment.py b/daam/experiment.py index 31e03ad..4465a30 100644 --- a/daam/experiment.py +++ b/daam/experiment.py @@ -8,6 +8,7 @@ import numpy as np import torch +from .utils import auto_autocast from .evaluate import load_mask @@ -233,7 +234,7 @@ def save_heat_map( if tokenizer is None: tokenizer = self.tokenizer - with torch.cuda.amp.autocast(dtype=torch.float32): + with auto_autocast(dtype=torch.float32): path = self.path / self.subtype / f'{output_prefix}{word.lower()}.heat_map.png' heat_map = GlobalHeatMap(tokenizer, self.prompt, self.global_heat_map) heat_map.compute_word_heat_map(word).expand_as(self.image, color_normalize=not absolute, out_file=path, plot=True) diff --git a/daam/heatmap.py b/daam/heatmap.py index b3b20fe..0fef9b1 100644 --- a/daam/heatmap.py +++ b/daam/heatmap.py @@ -12,8 +12,7 @@ import torch.nn.functional as F from .evaluate import compute_ioa -from .utils import compute_token_merge_indices, cached_nlp - +from .utils import compute_token_merge_indices, cached_nlp, auto_autocast __all__ = ['GlobalHeatMap', 'RawHeatMapCollection', 'WordHeatMap', 'ParsedHeatMap', 'SyntacticHeatMapPair'] @@ -27,7 +26,7 @@ def plot_overlay_heat_map(im, heat_map, word=None, out_file=None, crop=None, col else: plt_ = ax - with torch.cuda.amp.autocast(dtype=torch.float32): + with auto_autocast(dtype=torch.float32): im = np.array(im) if crop is not None: @@ -152,7 +151,7 @@ def __init__(self): self.ids_to_num_maps: Dict[RawHeatMapKey, int] = defaultdict(lambda: 0) def update(self, factor: int, layer_idx: int, head_idx: int, heatmap: torch.Tensor): - with torch.cuda.amp.autocast(dtype=torch.float32): + with auto_autocast(dtype=torch.float32): key = (factor, layer_idx, head_idx) self.ids_to_heatmaps[key] = self.ids_to_heatmaps[key] + heatmap diff --git a/daam/run/demo.py b/daam/run/demo.py index c041ed7..effb259 100644 --- a/daam/run/demo.py +++ b/daam/run/demo.py @@ -12,7 +12,7 @@ from spacy import displacy from daam import trace -from daam.utils import set_seed, cached_nlp +from daam.utils import set_seed, cached_nlp, auto_autocast def dependency(text): @@ -83,7 +83,7 @@ def plot(prompt, choice, replaced_word, inf_steps, is_random_seed): new_prompt = ' '.join(new_prompt) merge_idxs, words = get_tokenizing_mapping(prompt, pipe.tokenizer) - with torch.cuda.amp.autocast(dtype=torch.float16), lock: + with auto_autocast(dtype=torch.float16), lock: try: plt.close('all') plt.clf() diff --git a/daam/run/generate.py b/daam/run/generate.py index 395c0ec..be7dc99 100644 --- a/daam/run/generate.py +++ b/daam/run/generate.py @@ -15,7 +15,7 @@ from daam import trace from daam.experiment import GenerationExperiment, build_word_list_coco80 -from daam.utils import set_seed, cached_nlp +from daam.utils import set_seed, cached_nlp, auto_device, auto_autocast def main(): @@ -191,9 +191,9 @@ def main(): prompts = prompts[:args.gen_limit] pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True) - pipe = pipe.to('cuda') + pipe = auto_device(pipe) - with torch.cuda.amp.autocast(dtype=torch.float16), torch.no_grad(): + with auto_autocast(dtype=torch.float16), torch.no_grad(): for gen_idx, (prompt_id, prompt) in enumerate(tqdm(prompts)): seed = int(time.time()) if args.random_seed else args.seed prompt = prompt.replace(',', ' ,').replace('.', ' .').strip() diff --git a/daam/trace.py b/daam/trace.py index 7b830b1..388c76f 100644 --- a/daam/trace.py +++ b/daam/trace.py @@ -9,7 +9,7 @@ import torch import torch.nn.functional as F -from . import cache_dir +from .utils import cache_dir, auto_autocast from .experiment import GenerationExperiment from .heatmap import RawHeatMapCollection, GlobalHeatMap from .hook import ObjectHooker, AggregateHooker, UNetCrossAttentionLocator @@ -105,7 +105,7 @@ def compute_global_heat_map(self, prompt=None, factors=None, head_idx=None, laye all_merges = [] x = int(np.sqrt(self.latent_hw)) - with torch.cuda.amp.autocast(dtype=torch.float32): + with auto_autocast(dtype=torch.float32): for (factor, layer, head), heat_map in heat_maps: if factor in factors and (head_idx is None or head_idx == head) and (layer_idx is None or layer_idx == layer): heat_map = heat_map.unsqueeze(1) @@ -209,7 +209,7 @@ def _unravel_attn(self, x): maps = [] x = x.permute(2, 0, 1) - with torch.cuda.amp.autocast(dtype=torch.float32): + with auto_autocast(dtype=torch.float32): for map_ in x: map_ = map_.view(map_.size(0), h, w) map_ = map_[map_.size(0) // 2:] # Filter out unconditional diff --git a/daam/utils.py b/daam/utils.py index 118a046..f61d739 100644 --- a/daam/utils.py +++ b/daam/utils.py @@ -4,6 +4,7 @@ import os import sys import random +from typing import TypeVar import PIL.Image import matplotlib.pyplot as plt @@ -13,7 +14,27 @@ import torch.nn.functional as F -__all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir'] +__all__ = ['set_seed', 'compute_token_merge_indices', 'plot_mask_heat_map', 'cached_nlp', 'cache_dir', 'auto_device', 'auto_autocast'] + + +T = TypeVar('T') + + +def auto_device(obj: T = torch.device('cpu')) -> T: + if isinstance(obj, torch.device): + return torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if torch.cuda.is_available(): + return obj.to('cuda') + + return obj + + +def auto_autocast(*args, **kwargs): + if not torch.cuda.is_available(): + kwargs['enabled'] = False + + return torch.cuda.amp.autocast(*args, **kwargs) def plot_mask_heat_map(im: PIL.Image.Image, heat_map: torch.Tensor, threshold: float = 0.4): @@ -29,7 +50,7 @@ def set_seed(seed: int) -> torch.Generator: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) - gen = torch.Generator(device='cuda') + gen = torch.Generator(device=auto_device()) gen.manual_seed(seed) return gen