Skip to content

Commit

Permalink
Replace util.get_attr with ModelPatcher.get_model_object (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei authored Jan 5, 2025
1 parent 8c745b7 commit 2c204e7
Showing 1 changed file with 6 additions and 39 deletions.
45 changes: 6 additions & 39 deletions lib_layerdiffusion/attention_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import torch
import einops

from comfy import model_management, utils
from comfy import model_management
from comfy.ldm.modules.attention import optimized_attention
from comfy.model_patcher import ModelPatcher


module_mapping_sd15 = {
Expand Down Expand Up @@ -324,53 +325,19 @@ def __init__(self, layer_list):
self.layers = torch.nn.ModuleList(layer_list)


def unload_model_clones(model, unload_weights_only=True, force_unload=True):
current_loaded_models = model_management.current_loaded_models

to_unload = []
for i, m in enumerate(current_loaded_models):
if model.is_clone(m.model):
to_unload = [i] + to_unload

if len(to_unload) == 0:
return True

same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1

if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True

if not force_unload:
if unload_weights_only and unload_weight is False:
return None
else:
unload_weight = True

for i in to_unload:
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)

return unload_weight


class AttentionSharingPatcher(torch.nn.Module):
def __init__(self, unet, frames=2, use_control=True, rank=256):
def __init__(self, unet: ModelPatcher, frames=2, use_control=True, rank=256):
super().__init__()
unload_model_clones(unet)

units = []
for i in range(32):
real_key = module_mapping_sd15[i]
attn_module = utils.get_attr(unet.model.diffusion_model, real_key)
key = "diffusion_model." + module_mapping_sd15[i]
attn_module = unet.get_model_object(key)
u = AttentionSharingUnit(
attn_module, frames=frames, use_control=use_control, rank=rank
)
units.append(u)
unet.add_object_patch("diffusion_model." + real_key, u)
unet.add_object_patch(key, u)

self.hookers = HookerLayers(units)

Expand Down

0 comments on commit 2c204e7

Please sign in to comment.