Skip to content

Commit

Permalink
fix: floating point zeros for old torch.where
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris authored Jun 28, 2024
1 parent 6ed632a commit cbe46c5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions torchcomp/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], ..
x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1)
)
if ctx.needs_input_grad[2]:
grad_at = torch.where(at_mask, grad_combined, 0).sum(1)
grad_at = torch.where(at_mask, grad_combined, 0.0).sum(1)
if ctx.needs_input_grad[3]:
grad_rt = torch.where(~at_mask, grad_combined, 0).sum(1)
grad_rt = torch.where(~at_mask, grad_combined, 0.0).sum(1)

return grad_x, grad_zi, grad_at, grad_rt

Expand All @@ -173,8 +173,8 @@ def jvp(
else:
grad_beta = torch.where(
at_mask,
0 if grad_at is None else grad_at.unsqueeze(1),
0 if grad_rt is None else grad_rt.unsqueeze(1),
0.0 if grad_at is None else grad_at.unsqueeze(1),
0.0 if grad_rt is None else grad_rt.unsqueeze(1),
)
fwd_combined = fwd_x + grad_beta * (
x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1)
Expand Down

0 comments on commit cbe46c5

Please sign in to comment.