-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #97 from chrhansk/feature-matrix-intertia
Add inertia control
- Loading branch information
Showing
24 changed files
with
622 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import scipy as sp | ||
|
||
from pygradflow.params import LinearSolverType | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
def linear_solver( | ||
mat: sp.sparse.spmatrix, solver_type: LinearSolverType, symmetric=False | ||
) -> LinearSolver: | ||
if solver_type == LinearSolverType.LU: | ||
from .lu_solver import LUSolver | ||
|
||
return LUSolver(mat, symmetric=symmetric) | ||
elif solver_type == LinearSolverType.MINRES: | ||
from .minres_solver import MINRESSolver | ||
|
||
return MINRESSolver(mat, symmetric=symmetric) | ||
elif solver_type == LinearSolverType.Cholesky: | ||
from .cholesky_solver import CholeskySolver | ||
|
||
return CholeskySolver(mat, symmetric=symmetric) | ||
elif solver_type == LinearSolverType.MA57: | ||
from .ma57_solver import MA57Solver | ||
|
||
return MA57Solver(mat, symmetric=symmetric) | ||
elif solver_type == LinearSolverType.MUMPS: | ||
from .mumps_solver import MUMPSSolver | ||
|
||
return MUMPSSolver(mat, symmetric=symmetric) | ||
elif solver_type == LinearSolverType.SSIDS: | ||
from .ssids_solver import SSIDSSolver | ||
|
||
return SSIDSSolver(mat, symmetric=symmetric) | ||
else: | ||
from .gmres_solver import GMRESSolver | ||
|
||
assert solver_type == LinearSolverType.GMRES | ||
return GMRESSolver(mat, symmetric=symmetric) | ||
|
||
|
||
__all__ = ["linear_solver", "LinearSolverError", "LinearSolver"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import sksparse.cholmod | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
class CholeskySolver(LinearSolver): | ||
def __init__(self, mat, symmetric=False): | ||
super().__init__(mat, symmetric=symmetric) | ||
|
||
try: | ||
self.factor = sksparse.cholmod.cholesky(mat) | ||
self.factor.L() | ||
except sksparse.cholmod.CholmodNotPositiveDefiniteError as e: | ||
raise LinearSolverError() from e | ||
|
||
def solve(self, rhs, trans=False, initial_sol=None): | ||
assert not trans | ||
|
||
return self.factor.solve_A(rhs) | ||
|
||
def num_neg_eigvals(self): | ||
return 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import numpy as np | ||
import scipy as sp | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
class GMRESSolver(LinearSolver): | ||
def __init__(self, mat: sp.sparse.spmatrix, symmetric=False) -> None: | ||
super().__init__(mat, symmetric=symmetric) | ||
self.mat = mat | ||
|
||
def solve(self, rhs, trans=False, initial_sol=None): | ||
mat = self.mat.T if trans else self.mat | ||
|
||
if initial_sol is not None: | ||
initial_sol = initial_sol() | ||
|
||
(n, _) = mat.shape | ||
|
||
atol = 1e-8 | ||
|
||
# Workaround for scipy bug | ||
if initial_sol is not None: | ||
res = rhs - mat @ initial_sol | ||
if np.linalg.norm(res, ord=np.inf) < atol: | ||
return initial_sol | ||
|
||
result = sp.sparse.linalg.gmres(mat, rhs, maxiter=n, x0=initial_sol, atol=atol) | ||
|
||
(sol, info) = result | ||
|
||
if info != 0: | ||
raise LinearSolverError("GMRES failed with error code {}".format(info)) | ||
|
||
return sol |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
import scipy as sp | ||
from numpy import ndarray | ||
|
||
|
||
class LinearSolverError(Exception): | ||
""" | ||
Error signaling that the linear solver failed, e.g. because the | ||
matrix is (near) singular. The solver attempts to recover by | ||
reducing the step size | ||
""" | ||
|
||
pass | ||
|
||
|
||
class LinearSolver(ABC): | ||
|
||
def __init__(self, matrix: sp.sparse.spmatrix, symmetric=False): | ||
self.symmetric = symmetric | ||
|
||
@abstractmethod | ||
def solve(self, rhs: ndarray, trans: bool = False, initial_sol=False) -> ndarray: | ||
raise NotImplementedError() | ||
|
||
def num_neg_eigvals(self) -> Optional[int]: | ||
return None | ||
|
||
def rcond(self) -> Optional[float]: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import scipy as sp | ||
|
||
from pygradflow.log import logger | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
class LUSolver(LinearSolver): | ||
def __init__(self, mat: sp.sparse.spmatrix, symmetric=False) -> None: | ||
super().__init__(mat, symmetric=symmetric) | ||
|
||
self.mat = mat | ||
try: | ||
self.solver = sp.sparse.linalg.splu(mat) | ||
except RuntimeError as err: | ||
logger.warn("LU decomposition failed: %s", err) | ||
raise LinearSolverError("LU decomposition failed") | ||
|
||
def solve(self, rhs, trans=False, initial_sol=None): | ||
trans_str = "T" if trans else "N" | ||
return self.solver.solve(rhs, trans=trans_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import Optional | ||
|
||
from pyomo.contrib.pynumero.linalg.base import LinearSolverStatus | ||
from pyomo.contrib.pynumero.linalg.ma57_interface import MA57 | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
class MA57Solver(LinearSolver): | ||
def __init__(self, mat, symmetric=False, report_rcond=False): | ||
super().__init__(mat, symmetric=symmetric) | ||
|
||
self.mat = mat.tocoo() | ||
self.solver = MA57() | ||
self.report_rcond = report_rcond | ||
|
||
# if report_rcond: | ||
# self.solver.set_icntl(10, 1) | ||
|
||
status = self.solver.do_symbolic_factorization(self.mat) | ||
|
||
if status.status != LinearSolverStatus.successful: | ||
raise LinearSolverError("Failed to compute symbolic factorization") | ||
|
||
status = self.solver.do_numeric_factorization(self.mat) | ||
|
||
if status.status != LinearSolverStatus.successful: | ||
raise LinearSolverError("Failed to compute numeric factorization") | ||
|
||
def solve(self, rhs, trans=False, initial_sol=None): | ||
from pyomo.contrib.pynumero.linalg.base import LinearSolverStatus | ||
|
||
assert not trans | ||
|
||
x, status = self.solver.do_back_solve(rhs) | ||
|
||
if status.status != LinearSolverStatus.successful: | ||
raise LinearSolverError("Failed to compute solution") | ||
|
||
return x | ||
|
||
def num_neg_eigvals(self): | ||
return self.solver.get_info(24) | ||
|
||
def rcond(self) -> Optional[float]: | ||
# TODO: Find out how to get rcond from MA57 | ||
# if self.report_rcond: | ||
# return self.solver.get_info(27) | ||
|
||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import scipy as sp | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
class MINRESSolver(LinearSolver): | ||
def __init__(self, mat, symmetric=False): | ||
super().__init__(mat, symmetric=symmetric) | ||
assert symmetric, "MINRES requires a symmetric matrix" | ||
self.mat = mat | ||
|
||
def solve(self, rhs, trans=False, initial_sol=None): | ||
# matrix should be symmetric anyways | ||
if initial_sol is not None: | ||
initial_sol = initial_sol() | ||
|
||
result = sp.sparse.linalg.minres(self.mat, rhs, x0=initial_sol) | ||
|
||
(sol, info) = result | ||
|
||
if info != 0: | ||
raise LinearSolverError("MINRES failed with error code {}".format(info)) | ||
|
||
return sol |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import warnings | ||
|
||
import mumps | ||
import numpy as np | ||
import scipy as sp | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
class MUMPSSolver(LinearSolver): | ||
def __init__(self, mat: sp.sparse.spmatrix, symmetric=False) -> None: | ||
super().__init__(mat, symmetric=symmetric) | ||
|
||
sym = 2 if symmetric else 0 | ||
|
||
self.ctx = mumps.DMumpsContext(sym=sym, par=1) | ||
|
||
self.ctx.set_icntl(13, 1) | ||
|
||
self.mat = mat | ||
|
||
if mat.format != "coo": | ||
warnings.warn( | ||
"Converting matrix to COO format", sp.sparse.SparseEfficiencyWarning | ||
) | ||
self.mat = mat.tocoo() | ||
|
||
rows = self.mat.row | ||
cols = self.mat.col | ||
data = self.mat.data | ||
|
||
if symmetric: | ||
filter = rows >= cols | ||
rows = rows[filter] | ||
cols = cols[filter] | ||
data = data[filter] | ||
|
||
if data.dtype != np.float64: | ||
warnings.warn( | ||
"Converting matrix data to float64", sp.sparse.SparseEfficiencyWarning | ||
) | ||
data = data.astype(np.float64) | ||
|
||
self.ctx.set_shape(self.mat.shape[0]) | ||
self.ctx.set_centralized_assembled(rows + 1, cols + 1, data) | ||
|
||
# Analysis | ||
self.ctx.run(job=1) | ||
|
||
# Factorization | ||
self.ctx.run(job=2) | ||
|
||
def solve(self, rhs, trans=False, initial_sol=None): | ||
sol = np.copy(rhs) | ||
|
||
if sol.dtype != np.float64: | ||
warnings.warn( | ||
"Converting rhs to float64", sp.sparse.SparseEfficiencyWarning | ||
) | ||
sol = sol.astype(np.float64) | ||
|
||
if not self.symmetric: | ||
if trans: | ||
self.ctx.set_icntl(9, 0) | ||
else: | ||
self.ctx.set_icntl(9, 1) | ||
|
||
self.ctx.set_rhs(sol) | ||
|
||
try: | ||
# Solution | ||
self.ctx.run(job=3) | ||
except RuntimeError as e: | ||
raise LinearSolverError from e | ||
|
||
return sol.astype(rhs.dtype) | ||
|
||
def num_neg_eigvals(self): | ||
return self.ctx.id.infog[11] | ||
|
||
def rcond(self): | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import pyspral.ssids as ssids | ||
|
||
from .linear_solver import LinearSolver, LinearSolverError | ||
|
||
|
||
class SSIDSSolver(LinearSolver): | ||
def __init__(self, mat, symmetric=False, report_rcond=False): | ||
super().__init__(mat, symmetric=symmetric) | ||
|
||
try: | ||
self.symbolic_factor = ssids.analyze(mat, check=True) | ||
self.numeric_factor = self.symbolic_factor.factor(posdef=False) | ||
except ssids.SSIDSError as e: | ||
raise LinearSolverError() from e | ||
|
||
def solve(self, rhs, trans=False, initial_sol=None): | ||
try: | ||
return self.numeric_factor.solve(rhs, inplace=False) | ||
except ssids.SSIDSError as e: | ||
raise LinearSolverError() from e | ||
|
||
def num_neg_eigvals(self): | ||
return self.numeric_factor.inform.num_neg |
Oops, something went wrong.