Skip to content

Commit

Permalink
Merge branch 'main' of github.com:jwohlwend/boltz
Browse files Browse the repository at this point in the history
  • Loading branch information
jwohlwend committed Nov 28, 2024
2 parents f2901a9 + 3ad510a commit 2ae8d3a
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions src/boltz/model/layers/outer_product_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def forward(self, m: Tensor, mask: Tensor, chunk_size: int = None) -> Tensor:
z_out = z_out + z.to(m) @ sliced_weight_proj_o.T
return z_out
else:
mask = mask[:, :, None, :] * mask[:, :, :, None]
num_mask = mask.sum(1).clamp(min=1)
z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float())
z = z.reshape(*z.shape[:3], -1)
Expand Down

0 comments on commit 2ae8d3a

Please sign in to comment.