-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
revert spmv to original jax implementation
- Loading branch information
1 parent
1290fc6
commit 5595a87
Showing
1 changed file
with
22 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Sparse Matrix-Vector Multiplication (SpMV) | ||
import jax.numpy as jnp | ||
import jax | ||
from jax import lax | ||
|
||
# Matrix-Vector Multiplication with the matrix given in Compressed Sparse Row | ||
# (CSR) format | ||
@jax.jit | ||
def spmv(A_row, A_col, A_val, x): | ||
y = jnp.empty(A_row.size - 1, dtype=A_val.dtype) | ||
|
||
def row_update(i, y): | ||
|
||
mask = (jnp.arange(A_col.size) >= A_row[i]) & (jnp.arange(A_col.size) < A_row[i + 1]) | ||
cols = jnp.where(mask, A_col, 0) | ||
vals = jnp.where(mask, A_val, 0) | ||
y = y.at[i].set(vals @ x[cols]) | ||
|
||
return y | ||
|
||
y = lax.fori_loop(0, A_row.size - 1, row_update, y) | ||
return y |