Skip to content

Commit

Permalink
Merge pull request #31 from hardik01shah/main
Browse files Browse the repository at this point in the history
Add JAX Support to NPBench and Implement JAX Benchmarks
  • Loading branch information
alexnick83 authored Feb 5, 2025
2 parents 1a9e6f8 + 5595a87 commit bbba077
Show file tree
Hide file tree
Showing 64 changed files with 2,105 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# dace
.dacecache/
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ python plot_results.py
Currently, the following frameworks are supported (in alphabetical order):
- CuPy
- DaCe
- JAX
- Numba
- NumPy
- Pythran
Expand Down Expand Up @@ -55,6 +56,24 @@ However, you may want to install the latest version from the [GitHub repository]
To run NPBench with DaCe, you have to select as framework (see details below)
either `dace_cpu` or `dace_gpu`.

### Jax

JAX can be installed with pip:
- CPU-only (Linux/macOS/Windows)
```sh
pip install -U jax
```
- GPU (NVIDIA, CUDA 12)
```sh
pip install -U "jax[cuda12]"
```
- TPU (Google Cloud TPU VM)
```sh
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
For more installation options, please consult the JAX [installation guide](https://jax.readthedocs.io/en/latest/installation.html#installation).


### Numba

Numba can be installed with pip:
Expand Down
10 changes: 10 additions & 0 deletions framework_info/jax.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"framework": {
"simple_name": "jax",
"full_name": "Jax",
"prefix": "jax",
"postfix": "jax",
"class": "JaxFramework",
"arch": "cpu"
}
}
19 changes: 19 additions & 0 deletions npbench/benchmarks/azimint_hist/azimint_hist_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2014 Jérôme Kieffer et al.
# This is an open-access article distributed under the terms of the
# Creative Commons Attribution License, which permits unrestricted use,
# distribution, and reproduction in any medium, provided the original author
# and source are credited.
# http://creativecommons.org/licenses/by/3.0/
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for
# high performance azimuthal integration on gpu, 2014. In Proceedings of the
# 7th European Conference on Python in Science (EuroSciPy 2014).

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.jit, static_argnums=(2,))
def azimint_hist(data: jax.Array, radius: jax.Array, npt):
histu = jnp.histogram(radius, npt)[0]
histw = jnp.histogram(radius, npt, weights=data)[0]
return histw / histu
32 changes: 32 additions & 0 deletions npbench/benchmarks/azimint_naive/azimint_naive_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2014 Jérôme Kieffer et al.
# This is an open-access article distributed under the terms of the
# Creative Commons Attribution License, which permits unrestricted use,
# distribution, and reproduction in any medium, provided the original author
# and source are credited.
# http://creativecommons.org/licenses/by/3.0/
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for
# high performance azimuthal integration on gpu, 2014. In Proceedings of the
# 7th European Conference on Python in Science (EuroSciPy 2014).

import jax
import jax.numpy as jnp
from jax import lax
from functools import partial


@partial(jax.jit, static_argnums=(2,))
def azimint_naive(data, radius, npt):
rmax = radius.max()
res = jnp.zeros(npt, dtype=jnp.float64)

def loop_body(i, res):
r1 = rmax * i / npt
r2 = rmax * (i + 1) / npt
mask_r12 = jnp.logical_and((r1 <= radius), (radius < r2))
mean = jnp.where(mask_r12, data, 0).mean(where=mask_r12)
res = res.at[i].set(mean)
return res

res = lax.fori_loop(0, npt, loop_body, res)

return res
102 changes: 102 additions & 0 deletions npbench/benchmarks/cavity_flow/cavity_flow_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Barba, Lorena A., and Forsyth, Gilbert F. (2018).
# CFD Python: the 12 steps to Navier-Stokes equations.
# Journal of Open Source Education, 1(9), 21,
# https://doi.org/10.21105/jose.00021
# TODO: License
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth.
# All content is under Creative Commons Attribution CC-BY 4.0,
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018).

import jax.numpy as jnp
import jax
from jax import lax
from functools import partial


@partial(jax.jit, static_argnums=(1,))
def build_up_b(b, rho, dt, u, v, dx, dy):

b = b.at[1:-1,
1:-1].set(rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) +
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) -
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 *
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) *
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) -
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))

return b


@partial(jax.jit, static_argnums=(0,))
def pressure_poisson(nit, p, dx, dy, b):
def body_func(p, _):
pn = p.copy()
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 +
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) /
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])

p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0
p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0
p = p.at[-1, :].set(0) # p = 0 at y = 2

return p, None

p, _ = lax.scan(body_func, p, jnp.arange(nit))

return p


@partial(jax.jit, static_argnums=(0,1,2,3,10,11,))
def cavity_flow(nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu):
b = jnp.zeros((ny, nx))
array_vals = (u, v, p, b)

def body_func(array_vals, _):

u, v, p, b = array_vals

un = u.copy()
vn = v.copy()

b = build_up_b(b, rho, dt, u, v, dx, dy)
p = pressure_poisson(nit, p, dx, dy, b)

u = u.at[1:-1,
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) *
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu *
(dt / dx**2 *
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) +
dt / dy**2 *
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])))

v = v.at[1:-1,
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) *
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu *
(dt / dx**2 *
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) +
dt / dy**2 *
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))

u = u.at[0, :].set(0)
u = u.at[:, 0].set(0)
u = u.at[:, -1].set(0)
u = u.at[-1, :].set(1) # set velocity on cavity lid equal to 1
v = v.at[0, :].set(0)
v = v.at[-1, :].set(0)
v = v.at[:, 0].set(0)
v = v.at[:, -1].set(0)

return (u, v, p, b), None

out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt))
u, v, p, b = out_vals

return u, v, p
172 changes: 172 additions & 0 deletions npbench/benchmarks/channel_flow/channel_flow_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Barba, Lorena A., and Forsyth, Gilbert F. (2018).
# CFD Python: the 12 steps to Navier-Stokes equations.
# Journal of Open Source Education, 1(9), 21,
# https://doi.org/10.21105/jose.00021
# TODO: License
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth.
# All content is under Creative Commons Attribution CC-BY 4.0,
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018).

import jax.numpy as jnp
import jax
from jax import lax
from functools import partial


@partial(jax.jit, static_argnums=(0,))
def build_up_b(rho, dt, dx, dy, u, v):
b = jnp.zeros_like(u)
b = b.at[1:-1,
1:-1].set((rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) +
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) -
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 *
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) *
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) -
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)))

# Periodic BC Pressure @ x = 2
b = b.at[1:-1, -1].set((rho * (1 / dt * ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) +
(v[2:, -1] - v[0:-2, -1]) / (2 * dy)) -
((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx))**2 - 2 *
((u[2:, -1] - u[0:-2, -1]) / (2 * dy) *
(v[1:-1, 0] - v[1:-1, -2]) / (2 * dx)) -
((v[2:, -1] - v[0:-2, -1]) / (2 * dy))**2)))

# Periodic BC Pressure @ x = 0
b = b.at[1:-1, 0].set((rho * (1 / dt * ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) +
(v[2:, 0] - v[0:-2, 0]) / (2 * dy)) -
((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx))**2 - 2 *
((u[2:, 0] - u[0:-2, 0]) / (2 * dy) *
(v[1:-1, 1] - v[1:-1, -1]) /
(2 * dx)) - ((v[2:, 0] - v[0:-2, 0]) / (2 * dy))**2)))

return b

@partial(jax.jit, static_argnums=(0,))
def pressure_poisson_periodic(nit, p, dx, dy, b):

def body_func(p, q):
pn = p.copy()
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 +
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) /
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1])

# Periodic BC Pressure @ x = 2
p = p.at[1:-1, -1].set(((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 +
(pn[2:, -1] + pn[0:-2, -1]) * dx**2) /
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 /
(2 * (dx**2 + dy**2)) * b[1:-1, -1])

# Periodic BC Pressure @ x = 0
p = p.at[1:-1,
0].set((((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 +
(pn[2:, 0] + pn[0:-2, 0]) * dx**2) / (2 * (dx**2 + dy**2)) -
dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0]))

# Wall boundary conditions, pressure
p = p.at[-1, :].set(p[-2, :]) # dp/dy = 0 at y = 2
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0

return p, None

p, _ = lax.scan(body_func, p, jnp.arange(nit))


@partial(jax.jit, static_argnums=(0,7,8,9))
def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F):
udiff = 1
stepcount = 0

array_vals = (udiff, stepcount, u, v, p)

def conf_func(array_vals):
udiff, _, _, _ , _ = array_vals
return udiff > .001

def body_func(array_vals):
_, stepcount, u, v, p = array_vals

un = u.copy()
vn = v.copy()

b = build_up_b(rho, dt, dx, dy, u, v)
pressure_poisson_periodic(nit, p, dx, dy, b)

u = u.at[1:-1,
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) *
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu *
(dt / dx**2 *
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) +
dt / dy**2 *
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])) +
F * dt)

v = v.at[1:-1,
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx *
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) -
vn[1:-1, 1:-1] * dt / dy *
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) *
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu *
(dt / dx**2 *
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) +
dt / dy**2 *
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1])))

# Periodic BC u @ x = 2
u = u.at[1:-1, -1].set(
un[1:-1, -1] - un[1:-1, -1] * dt / dx *
(un[1:-1, -1] - un[1:-1, -2]) - vn[1:-1, -1] * dt / dy *
(un[1:-1, -1] - un[0:-2, -1]) - dt / (2 * rho * dx) *
(p[1:-1, 0] - p[1:-1, -2]) + nu *
(dt / dx**2 *
(un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2]) + dt / dy**2 *
(un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])) + F * dt)

# Periodic BC u @ x = 0
u = u.at[1:-1,
0].set(un[1:-1, 0] - un[1:-1, 0] * dt / dx *
(un[1:-1, 0] - un[1:-1, -1]) - vn[1:-1, 0] * dt / dy *
(un[1:-1, 0] - un[0:-2, 0]) - dt / (2 * rho * dx) *
(p[1:-1, 1] - p[1:-1, -1]) + nu *
(dt / dx**2 *
(un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1]) + dt / dy**2 *
(un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])) + F * dt)

# Periodic BC v @ x = 2
v = v.at[1:-1, -1].set(
vn[1:-1, -1] - un[1:-1, -1] * dt / dx *
(vn[1:-1, -1] - vn[1:-1, -2]) - vn[1:-1, -1] * dt / dy *
(vn[1:-1, -1] - vn[0:-2, -1]) - dt / (2 * rho * dy) *
(p[2:, -1] - p[0:-2, -1]) + nu *
(dt / dx**2 *
(vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2]) + dt / dy**2 *
(vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1])))

# Periodic BC v @ x = 0
v = v.at[1:-1,
0].set(vn[1:-1, 0] - un[1:-1, 0] * dt / dx *
(vn[1:-1, 0] - vn[1:-1, -1]) - vn[1:-1, 0] * dt / dy *
(vn[1:-1, 0] - vn[0:-2, 0]) - dt / (2 * rho * dy) *
(p[2:, 0] - p[0:-2, 0]) + nu *
(dt / dx**2 *
(vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1]) + dt / dy**2 *
(vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0])))

# Wall BC: u,v = 0 @ y = 0,2
u = u.at[0, :].set(0)
u = u.at[-1, :].set(0)
v = v.at[0, :].set(0)
v = v.at[-1, :].set(0)

udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u)
stepcount += 1

return (udiff, stepcount, u, v, p)

_, stepcount, _, _, _ = lax.while_loop(conf_func, body_func, array_vals)

return stepcount
8 changes: 8 additions & 0 deletions npbench/benchmarks/compute/compute_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html

import jax.numpy as jnp
import jax

@jax.jit
def compute(array_1, array_2, a, b, c):
return jnp.clip(array_1, 2, 10) * a + array_2 * b + c
Loading

0 comments on commit bbba077

Please sign in to comment.