-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #31 from hardik01shah/main
Add JAX Support to NPBench and Implement JAX Benchmarks
- Loading branch information
Showing
64 changed files
with
2,105 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,6 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# dace | ||
.dacecache/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
{ | ||
"framework": { | ||
"simple_name": "jax", | ||
"full_name": "Jax", | ||
"prefix": "jax", | ||
"postfix": "jax", | ||
"class": "JaxFramework", | ||
"arch": "cpu" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2014 Jérôme Kieffer et al. | ||
# This is an open-access article distributed under the terms of the | ||
# Creative Commons Attribution License, which permits unrestricted use, | ||
# distribution, and reproduction in any medium, provided the original author | ||
# and source are credited. | ||
# http://creativecommons.org/licenses/by/3.0/ | ||
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for | ||
# high performance azimuthal integration on gpu, 2014. In Proceedings of the | ||
# 7th European Conference on Python in Science (EuroSciPy 2014). | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from functools import partial | ||
|
||
@partial(jax.jit, static_argnums=(2,)) | ||
def azimint_hist(data: jax.Array, radius: jax.Array, npt): | ||
histu = jnp.histogram(radius, npt)[0] | ||
histw = jnp.histogram(radius, npt, weights=data)[0] | ||
return histw / histu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright 2014 Jérôme Kieffer et al. | ||
# This is an open-access article distributed under the terms of the | ||
# Creative Commons Attribution License, which permits unrestricted use, | ||
# distribution, and reproduction in any medium, provided the original author | ||
# and source are credited. | ||
# http://creativecommons.org/licenses/by/3.0/ | ||
# Jérôme Kieffer and Giannis Ashiotis. Pyfai: a python library for | ||
# high performance azimuthal integration on gpu, 2014. In Proceedings of the | ||
# 7th European Conference on Python in Science (EuroSciPy 2014). | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import lax | ||
from functools import partial | ||
|
||
|
||
@partial(jax.jit, static_argnums=(2,)) | ||
def azimint_naive(data, radius, npt): | ||
rmax = radius.max() | ||
res = jnp.zeros(npt, dtype=jnp.float64) | ||
|
||
def loop_body(i, res): | ||
r1 = rmax * i / npt | ||
r2 = rmax * (i + 1) / npt | ||
mask_r12 = jnp.logical_and((r1 <= radius), (radius < r2)) | ||
mean = jnp.where(mask_r12, data, 0).mean(where=mask_r12) | ||
res = res.at[i].set(mean) | ||
return res | ||
|
||
res = lax.fori_loop(0, npt, loop_body, res) | ||
|
||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Barba, Lorena A., and Forsyth, Gilbert F. (2018). | ||
# CFD Python: the 12 steps to Navier-Stokes equations. | ||
# Journal of Open Source Education, 1(9), 21, | ||
# https://doi.org/10.21105/jose.00021 | ||
# TODO: License | ||
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth. | ||
# All content is under Creative Commons Attribution CC-BY 4.0, | ||
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018). | ||
|
||
import jax.numpy as jnp | ||
import jax | ||
from jax import lax | ||
from functools import partial | ||
|
||
|
||
@partial(jax.jit, static_argnums=(1,)) | ||
def build_up_b(b, rho, dt, u, v, dx, dy): | ||
|
||
b = b.at[1:-1, | ||
1:-1].set(rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + | ||
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - | ||
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * | ||
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * | ||
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - | ||
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2)) | ||
|
||
return b | ||
|
||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def pressure_poisson(nit, p, dx, dy, b): | ||
def body_func(p, _): | ||
pn = p.copy() | ||
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + | ||
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / | ||
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 / | ||
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) | ||
|
||
p = p.at[:, -1].set(p[:, -2]) # dp/dx = 0 at x = 2 | ||
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 | ||
p = p.at[:, 0].set(p[:, 1]) # dp/dx = 0 at x = 0 | ||
p = p.at[-1, :].set(0) # p = 0 at y = 2 | ||
|
||
return p, None | ||
|
||
p, _ = lax.scan(body_func, p, jnp.arange(nit)) | ||
|
||
return p | ||
|
||
|
||
@partial(jax.jit, static_argnums=(0,1,2,3,10,11,)) | ||
def cavity_flow(nx, ny, nt, nit, u, v, dt, dx, dy, p, rho, nu): | ||
b = jnp.zeros((ny, nx)) | ||
array_vals = (u, v, p, b) | ||
|
||
def body_func(array_vals, _): | ||
|
||
u, v, p, b = array_vals | ||
|
||
un = u.copy() | ||
vn = v.copy() | ||
|
||
b = build_up_b(b, rho, dt, u, v, dx, dy) | ||
p = pressure_poisson(nit, p, dx, dy, b) | ||
|
||
u = u.at[1:-1, | ||
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * | ||
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) - | ||
vn[1:-1, 1:-1] * dt / dy * | ||
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * | ||
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * | ||
(dt / dx**2 * | ||
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + | ||
dt / dy**2 * | ||
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1]))) | ||
|
||
v = v.at[1:-1, | ||
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * | ||
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - | ||
vn[1:-1, 1:-1] * dt / dy * | ||
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * | ||
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * | ||
(dt / dx**2 * | ||
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + | ||
dt / dy**2 * | ||
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) | ||
|
||
u = u.at[0, :].set(0) | ||
u = u.at[:, 0].set(0) | ||
u = u.at[:, -1].set(0) | ||
u = u.at[-1, :].set(1) # set velocity on cavity lid equal to 1 | ||
v = v.at[0, :].set(0) | ||
v = v.at[-1, :].set(0) | ||
v = v.at[:, 0].set(0) | ||
v = v.at[:, -1].set(0) | ||
|
||
return (u, v, p, b), None | ||
|
||
out_vals, _ = lax.scan(body_func, array_vals, jnp.arange(nt)) | ||
u, v, p, b = out_vals | ||
|
||
return u, v, p |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Barba, Lorena A., and Forsyth, Gilbert F. (2018). | ||
# CFD Python: the 12 steps to Navier-Stokes equations. | ||
# Journal of Open Source Education, 1(9), 21, | ||
# https://doi.org/10.21105/jose.00021 | ||
# TODO: License | ||
# (c) 2017 Lorena A. Barba, Gilbert F. Forsyth. | ||
# All content is under Creative Commons Attribution CC-BY 4.0, | ||
# and all code is under BSD-3 clause (previously under MIT, and changed on March 8, 2018). | ||
|
||
import jax.numpy as jnp | ||
import jax | ||
from jax import lax | ||
from functools import partial | ||
|
||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def build_up_b(rho, dt, dx, dy, u, v): | ||
b = jnp.zeros_like(u) | ||
b = b.at[1:-1, | ||
1:-1].set((rho * (1 / dt * ((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx) + | ||
(v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy)) - | ||
((u[1:-1, 2:] - u[1:-1, 0:-2]) / (2 * dx))**2 - 2 * | ||
((u[2:, 1:-1] - u[0:-2, 1:-1]) / (2 * dy) * | ||
(v[1:-1, 2:] - v[1:-1, 0:-2]) / (2 * dx)) - | ||
((v[2:, 1:-1] - v[0:-2, 1:-1]) / (2 * dy))**2))) | ||
|
||
# Periodic BC Pressure @ x = 2 | ||
b = b.at[1:-1, -1].set((rho * (1 / dt * ((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx) + | ||
(v[2:, -1] - v[0:-2, -1]) / (2 * dy)) - | ||
((u[1:-1, 0] - u[1:-1, -2]) / (2 * dx))**2 - 2 * | ||
((u[2:, -1] - u[0:-2, -1]) / (2 * dy) * | ||
(v[1:-1, 0] - v[1:-1, -2]) / (2 * dx)) - | ||
((v[2:, -1] - v[0:-2, -1]) / (2 * dy))**2))) | ||
|
||
# Periodic BC Pressure @ x = 0 | ||
b = b.at[1:-1, 0].set((rho * (1 / dt * ((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx) + | ||
(v[2:, 0] - v[0:-2, 0]) / (2 * dy)) - | ||
((u[1:-1, 1] - u[1:-1, -1]) / (2 * dx))**2 - 2 * | ||
((u[2:, 0] - u[0:-2, 0]) / (2 * dy) * | ||
(v[1:-1, 1] - v[1:-1, -1]) / | ||
(2 * dx)) - ((v[2:, 0] - v[0:-2, 0]) / (2 * dy))**2))) | ||
|
||
return b | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def pressure_poisson_periodic(nit, p, dx, dy, b): | ||
|
||
def body_func(p, q): | ||
pn = p.copy() | ||
p = p.at[1:-1, 1:-1].set(((pn[1:-1, 2:] + pn[1:-1, 0:-2]) * dy**2 + | ||
(pn[2:, 1:-1] + pn[0:-2, 1:-1]) * dx**2) / | ||
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 / | ||
(2 * (dx**2 + dy**2)) * b[1:-1, 1:-1]) | ||
|
||
# Periodic BC Pressure @ x = 2 | ||
p = p.at[1:-1, -1].set(((pn[1:-1, 0] + pn[1:-1, -2]) * dy**2 + | ||
(pn[2:, -1] + pn[0:-2, -1]) * dx**2) / | ||
(2 * (dx**2 + dy**2)) - dx**2 * dy**2 / | ||
(2 * (dx**2 + dy**2)) * b[1:-1, -1]) | ||
|
||
# Periodic BC Pressure @ x = 0 | ||
p = p.at[1:-1, | ||
0].set((((pn[1:-1, 1] + pn[1:-1, -1]) * dy**2 + | ||
(pn[2:, 0] + pn[0:-2, 0]) * dx**2) / (2 * (dx**2 + dy**2)) - | ||
dx**2 * dy**2 / (2 * (dx**2 + dy**2)) * b[1:-1, 0])) | ||
|
||
# Wall boundary conditions, pressure | ||
p = p.at[-1, :].set(p[-2, :]) # dp/dy = 0 at y = 2 | ||
p = p.at[0, :].set(p[1, :]) # dp/dy = 0 at y = 0 | ||
|
||
return p, None | ||
|
||
p, _ = lax.scan(body_func, p, jnp.arange(nit)) | ||
|
||
|
||
@partial(jax.jit, static_argnums=(0,7,8,9)) | ||
def channel_flow(nit, u, v, dt, dx, dy, p, rho, nu, F): | ||
udiff = 1 | ||
stepcount = 0 | ||
|
||
array_vals = (udiff, stepcount, u, v, p) | ||
|
||
def conf_func(array_vals): | ||
udiff, _, _, _ , _ = array_vals | ||
return udiff > .001 | ||
|
||
def body_func(array_vals): | ||
_, stepcount, u, v, p = array_vals | ||
|
||
un = u.copy() | ||
vn = v.copy() | ||
|
||
b = build_up_b(rho, dt, dx, dy, u, v) | ||
pressure_poisson_periodic(nit, p, dx, dy, b) | ||
|
||
u = u.at[1:-1, | ||
1:-1].set(un[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * | ||
(un[1:-1, 1:-1] - un[1:-1, 0:-2]) - | ||
vn[1:-1, 1:-1] * dt / dy * | ||
(un[1:-1, 1:-1] - un[0:-2, 1:-1]) - dt / (2 * rho * dx) * | ||
(p[1:-1, 2:] - p[1:-1, 0:-2]) + nu * | ||
(dt / dx**2 * | ||
(un[1:-1, 2:] - 2 * un[1:-1, 1:-1] + un[1:-1, 0:-2]) + | ||
dt / dy**2 * | ||
(un[2:, 1:-1] - 2 * un[1:-1, 1:-1] + un[0:-2, 1:-1])) + | ||
F * dt) | ||
|
||
v = v.at[1:-1, | ||
1:-1].set(vn[1:-1, 1:-1] - un[1:-1, 1:-1] * dt / dx * | ||
(vn[1:-1, 1:-1] - vn[1:-1, 0:-2]) - | ||
vn[1:-1, 1:-1] * dt / dy * | ||
(vn[1:-1, 1:-1] - vn[0:-2, 1:-1]) - dt / (2 * rho * dy) * | ||
(p[2:, 1:-1] - p[0:-2, 1:-1]) + nu * | ||
(dt / dx**2 * | ||
(vn[1:-1, 2:] - 2 * vn[1:-1, 1:-1] + vn[1:-1, 0:-2]) + | ||
dt / dy**2 * | ||
(vn[2:, 1:-1] - 2 * vn[1:-1, 1:-1] + vn[0:-2, 1:-1]))) | ||
|
||
# Periodic BC u @ x = 2 | ||
u = u.at[1:-1, -1].set( | ||
un[1:-1, -1] - un[1:-1, -1] * dt / dx * | ||
(un[1:-1, -1] - un[1:-1, -2]) - vn[1:-1, -1] * dt / dy * | ||
(un[1:-1, -1] - un[0:-2, -1]) - dt / (2 * rho * dx) * | ||
(p[1:-1, 0] - p[1:-1, -2]) + nu * | ||
(dt / dx**2 * | ||
(un[1:-1, 0] - 2 * un[1:-1, -1] + un[1:-1, -2]) + dt / dy**2 * | ||
(un[2:, -1] - 2 * un[1:-1, -1] + un[0:-2, -1])) + F * dt) | ||
|
||
# Periodic BC u @ x = 0 | ||
u = u.at[1:-1, | ||
0].set(un[1:-1, 0] - un[1:-1, 0] * dt / dx * | ||
(un[1:-1, 0] - un[1:-1, -1]) - vn[1:-1, 0] * dt / dy * | ||
(un[1:-1, 0] - un[0:-2, 0]) - dt / (2 * rho * dx) * | ||
(p[1:-1, 1] - p[1:-1, -1]) + nu * | ||
(dt / dx**2 * | ||
(un[1:-1, 1] - 2 * un[1:-1, 0] + un[1:-1, -1]) + dt / dy**2 * | ||
(un[2:, 0] - 2 * un[1:-1, 0] + un[0:-2, 0])) + F * dt) | ||
|
||
# Periodic BC v @ x = 2 | ||
v = v.at[1:-1, -1].set( | ||
vn[1:-1, -1] - un[1:-1, -1] * dt / dx * | ||
(vn[1:-1, -1] - vn[1:-1, -2]) - vn[1:-1, -1] * dt / dy * | ||
(vn[1:-1, -1] - vn[0:-2, -1]) - dt / (2 * rho * dy) * | ||
(p[2:, -1] - p[0:-2, -1]) + nu * | ||
(dt / dx**2 * | ||
(vn[1:-1, 0] - 2 * vn[1:-1, -1] + vn[1:-1, -2]) + dt / dy**2 * | ||
(vn[2:, -1] - 2 * vn[1:-1, -1] + vn[0:-2, -1]))) | ||
|
||
# Periodic BC v @ x = 0 | ||
v = v.at[1:-1, | ||
0].set(vn[1:-1, 0] - un[1:-1, 0] * dt / dx * | ||
(vn[1:-1, 0] - vn[1:-1, -1]) - vn[1:-1, 0] * dt / dy * | ||
(vn[1:-1, 0] - vn[0:-2, 0]) - dt / (2 * rho * dy) * | ||
(p[2:, 0] - p[0:-2, 0]) + nu * | ||
(dt / dx**2 * | ||
(vn[1:-1, 1] - 2 * vn[1:-1, 0] + vn[1:-1, -1]) + dt / dy**2 * | ||
(vn[2:, 0] - 2 * vn[1:-1, 0] + vn[0:-2, 0]))) | ||
|
||
# Wall BC: u,v = 0 @ y = 0,2 | ||
u = u.at[0, :].set(0) | ||
u = u.at[-1, :].set(0) | ||
v = v.at[0, :].set(0) | ||
v = v.at[-1, :].set(0) | ||
|
||
udiff = (jnp.sum(u) - jnp.sum(un)) / jnp.sum(u) | ||
stepcount += 1 | ||
|
||
return (udiff, stepcount, u, v, p) | ||
|
||
_, stepcount, _, _, _ = lax.while_loop(conf_func, body_func, array_vals) | ||
|
||
return stepcount |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# https://cython.readthedocs.io/en/latest/src/userguide/numpy_tutorial.html | ||
|
||
import jax.numpy as jnp | ||
import jax | ||
|
||
@jax.jit | ||
def compute(array_1, array_2, a, b, c): | ||
return jnp.clip(array_1, 2, 10) * a + array_2 * b + c |
Oops, something went wrong.