Skip to content

Commit

Permalink
add functuion to compute lstsq on cuda torch.tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
florentiner authored and SuperSashka committed Aug 26, 2024
1 parent c2dd5fb commit 3701592
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions tedeous/optimizers/ngd.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,27 @@ def jacobian() -> torch.Tensor:

J = jacobian()
return 1.0 / len(residuals) * J.T @ J

def torch_cuda_lstsq(self, A: torch.Tensor, B: torch.Tensor, tol: float = None) -> torch.Tensor:
""" Find lstsq (least-squares solution) for torch.tensor cuda.
Args:
A (torch.Tensor): lhs tensor of shape (*, m, n) where * is zero or more batch dimensions.
B (torch.Tensor): rhs tensor of shape (*, m, k) where * is zero or more batch dimensions.
tol (float): used to determine the effective rank of A. By default set to the machine precision of the dtype of A.
Returns:
torch.Tensor: solution for A and B.
"""
tol = torch.finfo(A.dtype).eps if tol is None else tol
U, S, Vh = torch.linalg.svd(A, full_matrices=False)
Spinv = torch.zeros_like(S)
Spinv[S>tol] = 1/S[S>tol]
UhB = U.adjoint() @ B
if Spinv.ndim!=UhB.ndim:
Spinv = Spinv.unsqueeze(-1)
SpinvUhB = Spinv * UhB
return Vh.adjoint() @ SpinvUhB

def step(self, closure=None) -> torch.Tensor:
""" It runs ONE step on the natural gradient descent.
Expand All @@ -98,10 +119,7 @@ def step(self, closure=None) -> torch.Tensor:
G = torch.min(torch.tensor([loss, 0.0])) * Id + G

# compute natural gradient
G = np.array(G.detach().cpu().numpy(), dtype=np.float32)
f_grads = np.array(f_grads.detach().cpu().numpy(), dtype=np.float32)
f_nat_grad = lstsq(G, f_grads)[0]
f_nat_grad = torch.from_numpy(np.array(f_nat_grad)).to(torch.float32).to('cuda')
f_nat_grad = self.torch_cuda_lstsq(G, f_grads)

# one step of NGD
self.grid_line_search_update(loss_function, f_nat_grad)
Expand Down

0 comments on commit 3701592

Please sign in to comment.