Skip to content

Commit

Permalink
Merge pull request #6 from AlexKoff88/ak/fixes
Browse files Browse the repository at this point in the history
Fixes
  • Loading branch information
AlexKoff88 authored Jun 8, 2023
2 parents 1f4ced5 + f34b3a6 commit 124f380
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 58 deletions.
86 changes: 32 additions & 54 deletions demo.ipynb

Large diffs are not rendered by default.

8 changes: 5 additions & 3 deletions tomeov/import_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
def is_diffusers_available():
try:
import stable_diffusion
import utils
import diffusers
return True
except ImportError:
print("diffusers library is not available. Please install it to use Token Merging.")
return False

def is_openclip_available():
try:
import open_clip_torch
import open_clip
return True
except ImportError:
print("OpenCLIP library is not available. Please install it to use Token Merging.")
return False

def is_timm_available():
try:
import timm
return True
except ImportError:
print("Timm library is not available. Please install it to use Token Merging.")
return False
18 changes: 17 additions & 1 deletion tomeov/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from . import merge
from .utils import isinstance_str
from .merge import merge_wavg



def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, ...]:
Expand Down Expand Up @@ -54,6 +54,22 @@ def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tenso
return ToMeBlock


def merge_wavg(
merge: Callable, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Applies the merge function by taking a weighted average based on token size.
Returns the merged tensor and the new token sizes.
"""
size = torch.ones_like(x[..., 0, None])

x = merge(x * size, mode="sum")
size = merge(size, mode="sum")

x = x / size
return x


def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class for a diffusers model.
Expand Down

0 comments on commit 124f380

Please sign in to comment.