diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 442a814c3f1..ec4119722d3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -93,9 +93,9 @@ def apply_rotary_emb( and key tensor with rotary embeddings. """ - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2).float() + t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape).type_as(x_in) + return t_out.reshape(*x_in.shape) def forward( self, @@ -552,7 +552,7 @@ def patchify_and_embed( position_ids[i, cap_len:cap_len+img_len, 1] = row_ids position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2) + freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) # build freqs_cis for cap and image individually cap_freqs_cis_shape = list(freqs_cis.shape)