Skip to content

Commit

Permalink
improved rope - experimental
Browse files Browse the repository at this point in the history
  • Loading branch information
jfacevedo-google committed Mar 5, 2025
1 parent 296e956 commit 8c5c0b5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
19 changes: 13 additions & 6 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,13 +494,21 @@ def setup(self):
)

def apply_rope(self, xq: Array, xk: Array, freqs_cis: Array) -> tuple[Array, Array]:
xq_ = xq.reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.reshape(*xk.shape[:-1], -1, 1, 2)
dtype = xq.dtype
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)

xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])

xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
freqs_cis = freqs_cis[None, None, ...]
xq_out_complex = xq_ * freqs_cis
xk_out_complex = xk_ * freqs_cis

return xq_out.reshape(*xq.shape).astype(xq.dtype), xk_out.reshape(*xk.shape).astype(xk.dtype)
xq_out = jnp.stack([jnp.real(xq_out_complex), jnp.imag(xq_out_complex)], axis=-1).reshape(xq.shape).astype(dtype)
xk_out = jnp.stack([jnp.real(xk_out_complex), jnp.imag(xk_out_complex)], axis=-1).reshape(xk.shape).astype(dtype)

return xq_out, xk_out

def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None):

Expand Down Expand Up @@ -534,7 +542,6 @@ def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=Non
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)

image_rotary_emb = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)
query_proj, key_proj = self.apply_rope(query_proj, key_proj, image_rotary_emb)

query_proj = query_proj.transpose(0, 2, 1, 3).reshape(query_proj.shape[0], query_proj.shape[2], -1)
Expand Down
5 changes: 2 additions & 3 deletions src/maxdiffusion/models/embeddings_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ def get_1d_rotary_pos_embed(
freqs = jnp.outer(pos, freqs)
freqs_cos = jnp.cos(freqs)
freqs_sin = jnp.sin(freqs)
out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1)

return out
freq_cis = jnp.complex64(freqs_cos + 1j * freqs_sin)
return freq_cis


class PixArtAlphaTextProjection(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
if image_rotary_emb is not None:
# since this function returns image_rotary_emb and passes it between layers,
# we do not want to modify it
image_rotary_emb_reordered = rearrange(image_rotary_emb, "n d (i j) -> n d i j", i=2, j=2)
q, k = self.attn.apply_rope(q, k, image_rotary_emb_reordered)
q, k = self.attn.apply_rope(q, k, image_rotary_emb)

q = q.transpose(0, 2, 1, 3).reshape(q.shape[0], q.shape[2], -1)
k = k.transpose(0, 2, 1, 3).reshape(k.shape[0], k.shape[2], -1)
Expand All @@ -147,7 +146,7 @@ def __call__(self, hidden_states, temb, image_rotary_emb=None):
if hidden_states.dtype == jnp.float16:
hidden_states = jnp.clip(hidden_states, -65504, 65504)

return hidden_states
return hidden_states, temb, image_rotary_emb


class FluxTransformerBlock(nn.Module):
Expand Down Expand Up @@ -296,7 +295,7 @@ def __call__(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
if encoder_hidden_states.dtype == jnp.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
return hidden_states, encoder_hidden_states, temb, image_rotary_emb


@flax_register_to_config
Expand Down Expand Up @@ -504,7 +503,7 @@ def __call__(
image_rotary_emb = nn.with_logical_constraint(image_rotary_emb, ("activation_batch", "activation_embed"))

for double_block in self.double_blocks:
hidden_states, encoder_hidden_states = double_block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb = double_block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
Expand All @@ -513,7 +512,9 @@ def __call__(
hidden_states = jnp.concatenate([encoder_hidden_states, hidden_states], axis=1)
hidden_states = nn.with_logical_constraint(hidden_states, ("activation_batch", "activation_length", "activation_embed"))
for single_block in self.single_blocks:
hidden_states = single_block(hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb)
hidden_states, temb, image_rotary_emb = single_block(
hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]

hidden_states = self.norm_out(hidden_states, temb)
Expand Down

0 comments on commit 8c5c0b5

Please sign in to comment.