Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible speedup for linalg.solve #79

Open
ksnxr opened this issue Nov 18, 2023 · 2 comments
Open

Possible speedup for linalg.solve #79

ksnxr opened this issue Nov 18, 2023 · 2 comments

Comments

@ksnxr
Copy link

ksnxr commented Nov 18, 2023

Looking at the code

, linalg.solve is implemented by explicitly calculating the inverse times a vector. Is there a reason not to use something similar to scipy.linalg.solve to achieve this?

import cola
import numpy as np
import scipy as sp
import timeit

A = np.random.randn(2,2)
A = A @ A.T
v = np.random.randn(2)

cola_A = cola.ops.Dense(A)

print(timeit.timeit(lambda: cola.linalg.solve(cola_A, v), number=10000))
print(timeit.timeit(lambda: sp.linalg.solve(A, v), number=10000))

the outputs are

14.14389366703108
0.0880865000654012

I could also try to submit a pull request later. Thanks

@mfinzi
Copy link
Collaborator

mfinzi commented Nov 20, 2023

Hi @ksnxr ,

Thanks for opening the pull request.

solve(A) (with the default Auto algorithm) uses one of Cholesky, LU, CG, GMRES to solve the linear system.
See inverse/inv.py#L71-L89).
None of them form the inverse matrix explicitly, but for the 2x2 example cola will call the LU solve.

The matrix listed is very small, but the runtime differences are a bit surprising so we should take a look into it.

@ksnxr
Copy link
Author

ksnxr commented Nov 23, 2023

I played around a bit. There are expressions for "solve" for different backends e.g.

solve = jnp.linalg.solve
. Directly replacing the solve function with the below seems to make the running times comparable to scipy version.

@export
def solve(A, b, alg=Auto()):
    """ Computes Linear solve of a linear operator. Equivalent to cola.inv

    Args:
        A (LinearOperator): The linear operator to compute the inverse of.
        b (Array): The right hand side of the linear system of shape (d, ) or (d, k)
        alg (Algorithm, optional): The algorithm to use for the solves.

    Returns:
        Array: The solution of the linear system of shape (d, ) or (d, k)

    Example:
        >>> A = MyLinearOperator()
        >>> x = cola.solve(A, b, alg=Auto(max_iters=10, pbar=True))
    """
    xnp = A.xnp
    return xnp.solve(A.to_dense(), b)

Looking at the code for inv, it seems complicated to properly treat the individual cases. I might leave it for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants