Skip to content

Commit

Permalink
Use regular numbers for rope in lumina model.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Feb 5, 2025
1 parent a57d635 commit 6065300
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions comfy/ldm/lumina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, RMSNorm
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.ldm.flux.layers import EmbedND


def modulate(x, scale):
Expand Down Expand Up @@ -92,10 +93,9 @@ def apply_rotary_emb(
and key tensor with rotary embeddings.
"""

x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in)
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float()
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x_in.shape).type_as(x_in)

def forward(
self,
Expand Down Expand Up @@ -130,6 +130,7 @@ def forward(

xq = self.q_norm(xq)
xk = self.k_norm(xk)

xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)

Expand Down Expand Up @@ -480,7 +481,8 @@ def __init__(
assert (dim // n_heads) == sum(axes_dims)
self.axes_dims = axes_dims
self.axes_lens = axes_lens
self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
# self.rope_embedder = RopeEmbedder(axes_dims=axes_dims, axes_lens=axes_lens)
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
self.dim = dim
self.n_heads = n_heads

Expand Down Expand Up @@ -550,7 +552,7 @@ def patchify_and_embed(
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

freqs_cis = self.rope_embedder(position_ids)
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2)

# build freqs_cis for cap and image individually
cap_freqs_cis_shape = list(freqs_cis.shape)
Expand Down

0 comments on commit 6065300

Please sign in to comment.