Skip to content

Commit

Permalink
Cleaner handling of attention mask in ltxv model code.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jan 9, 2025
1 parent 2307ff6 commit ff83865
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions comfy/ldm/lightricks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,8 @@ def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_l
x = self.patchify_proj(x)
timestep = timestep * 1000.0

attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max

pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)

Expand Down

0 comments on commit ff83865

Please sign in to comment.