diff --git a/tedeous/optimizers/ngd.py b/tedeous/optimizers/ngd.py index 7de51907..86389e92 100644 --- a/tedeous/optimizers/ngd.py +++ b/tedeous/optimizers/ngd.py @@ -73,6 +73,27 @@ def jacobian() -> torch.Tensor: J = jacobian() return 1.0 / len(residuals) * J.T @ J + + def torch_cuda_lstsq(self, A: torch.Tensor, B: torch.Tensor, tol: float = None) -> torch.Tensor: + """ Find lstsq (least-squares solution) for torch.tensor cuda. + + Args: + A (torch.Tensor): lhs tensor of shape (*, m, n) where * is zero or more batch dimensions. + B (torch.Tensor): rhs tensor of shape (*, m, k) where * is zero or more batch dimensions. + tol (float): used to determine the effective rank of A. By default set to the machine precision of the dtype of A. + + Returns: + torch.Tensor: solution for A and B. + """ + tol = torch.finfo(A.dtype).eps if tol is None else tol + U, S, Vh = torch.linalg.svd(A, full_matrices=False) + Spinv = torch.zeros_like(S) + Spinv[S>tol] = 1/S[S>tol] + UhB = U.adjoint() @ B + if Spinv.ndim!=UhB.ndim: + Spinv = Spinv.unsqueeze(-1) + SpinvUhB = Spinv * UhB + return Vh.adjoint() @ SpinvUhB def step(self, closure=None) -> torch.Tensor: """ It runs ONE step on the natural gradient descent. @@ -98,10 +119,7 @@ def step(self, closure=None) -> torch.Tensor: G = torch.min(torch.tensor([loss, 0.0])) * Id + G # compute natural gradient - G = np.array(G.detach().cpu().numpy(), dtype=np.float32) - f_grads = np.array(f_grads.detach().cpu().numpy(), dtype=np.float32) - f_nat_grad = lstsq(G, f_grads)[0] - f_nat_grad = torch.from_numpy(np.array(f_nat_grad)).to(torch.float32).to('cuda') + f_nat_grad = self.torch_cuda_lstsq(G, f_grads) # one step of NGD self.grid_line_search_update(loss_function, f_nat_grad)