Skip to content

Commit

Permalink
refactor for multiframework (#62)
Browse files Browse the repository at this point in the history
Make it easier to support other frameworks by moving all framework
details into one location, integrate and simplify relevant code

See e.g. #50

- [x] Move framework specific code into one location
- [x] Make tests use all backends
- [x] Remove backend specific marks, instead use '-k' filtering
- [x] Add numpy to relevant tests & CI (and debug any problems that come
up)
- [ ] Add guide for adding additional framework (to docs/contributing,
to somewhere else?)

---------

Co-authored-by: AndPotap <[email protected]>
  • Loading branch information
mfinzi and AndPotap authored Sep 19, 2023
1 parent e3db6b7 commit 17dbae6
Show file tree
Hide file tree
Showing 45 changed files with 470 additions and 599 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/run_test_suite.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@ jobs:
jax: "stable"
codecov: "upload" # Upload to codecov only if we're testing both jax and torch
exclude:
- torch: "none"
jax: "none"
- torch: "latest"
jax: "stable"
#- torch: "stable"
Expand Down Expand Up @@ -64,7 +62,7 @@ jobs:
else
CODECOV_ARGS="";
fi
cmd="pytest -m '${MARK}'${CODECOV_ARGS} tests/"
cmd="pytest -m '${MARK}' ${CODECOV_ARGS} tests/"
echo $cmd
eval $cmd
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<!-- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/wilson-labs/cola/blob/master/docs/notebooks/colabs/all.ipynb) -->

CoLA is a framework for scalable linear algebra, automatically exploiting the structure often found in machine learning problems and beyond.
CoLA supports both PyTorch and JAX.
CoLA natively supports PyTorch, Jax, as well as (limited) Numpy if Jax is not installed.

## Installation
```shell
Expand Down Expand Up @@ -152,6 +152,10 @@ If you use CoLA, please cite the following paper:
| **Implementation**||||


| Backends | PyTorch | Jax | Numpy |
|:----------------:|:-----------:|:---:|:-------:|
| **Implementation**|||Most operations|

## Contributing
See the contributing guidelines [docs/CONTRIBUTING.md](https://cola.readthedocs.io/en/latest/contributing.html) for information on submitting issues
and pull requests.
Expand Down
5 changes: 2 additions & 3 deletions cola/algorithms/arnoldi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from cola.ops import LinearOperator
from cola.ops import Array
from cola.ops import Householder, Product
from cola.utils.control_flow import for_loop
from cola.utils import export
import cola
from cola import Stiefel, lazify
Expand Down Expand Up @@ -111,7 +110,7 @@ def ArnoldiDecomposition(A: LinearOperator, start_vector=None, max_iters=100, to

def get_householder_vec_simple(x, idx, xnp):
indices = xnp.arange(x.shape[0])
vec = xnp.where(indices >= idx, x=x, y=0.)
vec = xnp.where(indices >= idx, x, 0.)
x_norm = xnp.norm(vec)
vec = xnp.update_array(vec, vec[idx] - x_norm, idx)
beta = xnp.nan_to_num(2. / xnp.norm(vec)**2., posinf=0., neginf=0., nan=0.)
Expand Down Expand Up @@ -170,7 +169,7 @@ def last_iter_fun(state):

init_val = initialize_householder_arnoldi(xnp, rhs, max_iters=max_iters, dtype=A.dtype)
# state = xnp.while_loop(cond_fun, body_fun, init_val)
state = for_loop(1, max_iters + 1, body_fun, init_val)
state = xnp.for_loop(1, max_iters + 1, body_fun, init_val)
state = last_iter_fun(state)
Q, H, *_ = state
infodict = {}
Expand Down
4 changes: 2 additions & 2 deletions cola/algorithms/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ def update_alpha(gamma, p, Ap, has_converged, xnp):
denom = xnp.sum(xnp.conj(p) * Ap, axis=-2, keepdims=True)
alpha = do_safe_div(gamma, denom, xnp=xnp)
device = xnp.get_device(p)
alpha = xnp.where(has_converged, x=xnp.array(0.0, dtype=p.dtype, device=device), y=alpha)
alpha = xnp.where(has_converged, xnp.array(0.0, dtype=p.dtype, device=device), alpha)
return alpha


def update_gamma_beta(r, z, gamma0, has_converged, xnp):
gamma1 = xnp.sum(xnp.conj(r) * z, axis=-2, keepdims=True)
beta = do_safe_div(gamma1, gamma0, xnp=xnp)
device = xnp.get_device(r)
beta = xnp.where(has_converged, x=xnp.array(0.0, dtype=r.dtype, device=device), y=beta)
beta = xnp.where(has_converged, xnp.array(0.0, dtype=r.dtype, device=device), beta)
return gamma1, beta


Expand Down
2 changes: 1 addition & 1 deletion cola/algorithms/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from cola.fns import lazify
from cola.ops import LinearOperator
from cola.ops import Array
from cola.ops import get_library_fns
from cola.backends import get_library_fns
from cola.utils import export
import cola

Expand Down
3 changes: 1 addition & 2 deletions cola/algorithms/svrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# from cola.linalg.eigs import eigmax
from cola.ops import Sum, Product, Dense
from cola.ops import I_like
from cola.utils.control_flow import while_loop
from cola.utils import export
# import standard Union type

Expand Down Expand Up @@ -99,7 +98,7 @@ def cond(state):
inf = {}
# while_loop, inf = xnp.while_loop_winfo(lambda s: s[-3], tol, pbar=pbar)
# while_loop = xnp.while_loop
_, _, anchor_w, _, residual, _, _ = while_loop(cond, body, state)
_, _, anchor_w, _, residual, _, _ = xnp.while_loop(cond, body, state)
return anchor_w, inf if info else anchor_w


Expand Down
8 changes: 8 additions & 0 deletions cola/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
""" CoLA Backends"""
from cola.utils import import_from_all

__all__ = []
import_from_all("backends", globals(), __all__, __name__)
all_backends = ["torch", "jax", "numpy"]
tracing_backends = ["torch", "jax"]
__all__ += ["all_backends", "tracing_backends"]
77 changes: 77 additions & 0 deletions cola/backends/backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from cola.utils import export
from types import ModuleType
import numpy as np


@export
def get_library_fns(dtype):
""" Given a dtype e.g. jnp.float32 or torch.complex64, returns the appropriate
namespace for standard array functionality (either torch_fns or jax_fns)."""
try:
from jax import numpy as jnp
if dtype in [jnp.float32, jnp.float64, jnp.complex64, jnp.complex128, jnp.int32, jnp.int64]:
from cola.backends import jax_fns as fns
return fns
except ImportError:
pass
try:
import torch
if dtype in [torch.float32, torch.float64, torch.complex64, torch.complex128, torch.int32, torch.int64]:
from cola.backends import torch_fns as fns
return fns
except ImportError:
pass

if dtype in [np.float32, np.float64, np.complex64, np.complex128, np.int32, np.int64]:
from cola.backends import np_fns as fns
return fns
raise ImportError("No supported array library found")


@export
def get_xnp(backend: str) -> ModuleType:
try:
match backend:
case "torch":
from cola.backends import torch_fns as fns
return fns
case "jax":
from cola.backends import jax_fns as fns
from jax.config import config
config.update('jax_platform_name', 'cpu') # Force tests to run tests on CPU
# do we actually want this here?
return fns
case "numpy":
from cola.backends import np_fns as fns
return fns
case _:
raise ValueError(f"Unknown backend {backend}.")
except ImportError:
raise RuntimeError(f"Could not import {backend}. It is likely not installed.")


@export
class AutoRegisteringPyTree(type):
def __init__(cls, *args, **kwargs):
super().__init__(*args, **kwargs)
cls._dynamic = cls._dynamic.copy()
import optree
optree.register_pytree_node_class(cls, namespace='cola')
try:
import jax
jax.tree_util.register_pytree_node_class(cls)
except ImportError:
pass
try:
# TODO: when pytorch migrates to optree, switch as well
import torch

def tree_flatten(self):
return self.tree_flatten()

def tree_unflatten(ctx, children):
return cls.tree_unflatten(children, ctx)

torch.utils._pytree._register_pytree_node(cls, tree_flatten, tree_unflatten)
except ImportError:
pass
10 changes: 7 additions & 3 deletions cola/jax_fns.py → cola/backends/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from jax.scipy.linalg import lu as lu_lax
from jax.scipy.linalg import solve_triangular as solvetri
from cola.utils.jax_tqdm import pbar_while, while_loop_winfo
from cola.utils.control_flow import while_loop as _while_loop_no_jit

cos = jnp.cos
sin = jnp.sin
Expand Down Expand Up @@ -67,7 +66,6 @@
qr = qr
clip = jnp.clip
while_loop = _while_loop
while_loop_no_jit = _while_loop_no_jit
for_loop = _for_loop
min = jnp.min
max = jnp.max
Expand Down Expand Up @@ -106,6 +104,13 @@ def iscomplexobj(x):
return jnp.iscomplex(x).any()


def while_loop_no_jit(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val


def get_array_device(array):
return array.device()

Expand Down Expand Up @@ -222,7 +227,6 @@ def next_key(key):
def randn(*shape, dtype, device, key=None):
del device
if key is None:
print('Non keyed randn used. To be deprecated soon.')
logging.warning('Non keyed randn used. To be deprecated soon.')
out = np.random.randn(*shape)
if dtype is not None:
Expand Down
Loading

0 comments on commit 17dbae6

Please sign in to comment.