Skip to content

Commit

Permalink
Skip layer guidance now works on stable audio model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 20, 2024
1 parent 8986151 commit 22535d0
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions comfy/ldm/audio/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,9 @@ def forward(
return_info = False,
**kwargs
):
patches_replace = kwargs.get("transformer_options", {}).get("patches_replace", {})
batch, seq, device = *x.shape[:2], x.device
context = kwargs["context"]

info = {
"hidden_states": [],
Expand Down Expand Up @@ -643,9 +645,19 @@ def forward(
if self.use_sinusoidal_emb or self.use_abs_pos_emb:
x = x + self.pos_emb(x)

blocks_replace = patches_replace.get("dit", {})
# Iterate over the transformer layers
for layer in self.layers:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
for i, layer in enumerate(self.layers):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = layer(args["img"], rotary_pos_emb=args["pe"], global_cond=args["vec"], context=args["txt"])
return out

out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb}, {"original_block": block_wrap})
x = out["img"]
else:
x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context)
# x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)

if return_info:
Expand Down Expand Up @@ -874,7 +886,6 @@ def forward(
mask=None,
return_info=False,
control=None,
transformer_options={},
**kwargs):
return self._forward(
x,
Expand Down

0 comments on commit 22535d0

Please sign in to comment.