Skip to content

Commit

Permalink
revert spmv to original jax implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
hardik01shah committed Feb 3, 2025
1 parent 1290fc6 commit 5595a87
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions npbench/benchmarks/spmv/spmv_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Sparse Matrix-Vector Multiplication (SpMV)
import jax.numpy as jnp
import jax
from jax import lax

# Matrix-Vector Multiplication with the matrix given in Compressed Sparse Row
# (CSR) format
@jax.jit
def spmv(A_row, A_col, A_val, x):
y = jnp.empty(A_row.size - 1, dtype=A_val.dtype)

def row_update(i, y):

mask = (jnp.arange(A_col.size) >= A_row[i]) & (jnp.arange(A_col.size) < A_row[i + 1])
cols = jnp.where(mask, A_col, 0)
vals = jnp.where(mask, A_val, 0)
y = y.at[i].set(vals @ x[cols])

return y

y = lax.fori_loop(0, A_row.size - 1, row_update, y)
return y

0 comments on commit 5595a87

Please sign in to comment.