diff --git a/torchcomp/core.py b/torchcomp/core.py index f46c48b..d1e5404 100644 --- a/torchcomp/core.py +++ b/torchcomp/core.py @@ -149,9 +149,9 @@ def backward(ctx: Any, grad_y: torch.Tensor) -> Tuple[Optional[torch.Tensor], .. x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1) ) if ctx.needs_input_grad[2]: - grad_at = torch.where(at_mask, grad_combined, 0).sum(1) + grad_at = torch.where(at_mask, grad_combined, 0.0).sum(1) if ctx.needs_input_grad[3]: - grad_rt = torch.where(~at_mask, grad_combined, 0).sum(1) + grad_rt = torch.where(~at_mask, grad_combined, 0.0).sum(1) return grad_x, grad_zi, grad_at, grad_rt @@ -173,8 +173,8 @@ def jvp( else: grad_beta = torch.where( at_mask, - 0 if grad_at is None else grad_at.unsqueeze(1), - 0 if grad_rt is None else grad_rt.unsqueeze(1), + 0.0 if grad_at is None else grad_at.unsqueeze(1), + 0.0 if grad_rt is None else grad_rt.unsqueeze(1), ) fwd_combined = fwd_x + grad_beta * ( x - torch.cat([zi.unsqueeze(1), y[:, :-1]], dim=1)