Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark llama #23

Merged
merged 1 commit into from
Dec 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 106 additions & 12 deletions test/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax.lax
import enzyme_ad.jax as enzyme_jax
import numpy as np
import timeit


def rmsnorm(x, weight):
Expand Down Expand Up @@ -289,13 +290,40 @@ def jfunc(x, weights, key_cache, value_cache):
def efunc(x, weights, key_cache, value_cache):
return func(x, weights, key_cache, value_cache)

# eres = efunc(x, weights, key_cache, value_cache)
# print("Enzyme primal", eres)
# res = func(x, weights, key_cache, value_cache)
# print("Jax primal", res)
# print (" max error", jnp.max(jnp.abs(eres-res)))
# assert (jnp.abs(eres - res) < 1e-3).all()

eres = efunc(x, weights, key_cache, value_cache)
print("Enzyme primal", eres)
res = jfunc(x, weights, key_cache, value_cache)
print("Jax primal", res)
print(" max error", jnp.max(jnp.abs(eres - res)))
assert (jnp.abs(eres - res) < 1e-3).all()

number = 1000
print(
"Enzyme primal",
timeit.Timer(
"efunc(x, weights, key_cache, value_cache)",
globals={
"efunc": efunc,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)
print(
"JaX primal",
timeit.Timer(
"jfunc(x, weights, key_cache, value_cache)",
globals={
"jfunc": jfunc,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)
# jfunc = jax.jit(partial(forward, config))
# mlir = jax.jit(partial(forward, config)).lower(1, weights, key_cache, value_cache).compiler_ir(dialect="mhlo")

Expand All @@ -307,11 +335,44 @@ def jfwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
def efwd(x, dx, weights, dweights, kc, dkc, vc, dvc):
return jax.jvp(efunc, (x, weights, kc, vc), (x, weights, dkc, dvc))

# print("pre fwd diff")
# eres = efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Enzyme fwd", eres)
# jres = jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)
# print("Jax fwd", jres)
eres = efwd(
x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache
)
print("Enzyme fwd", eres)
jres = jfwd(
x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache
)
print("Jax fwd", jres)
print(
"Enzyme fwd",
timeit.Timer(
"efwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)",
globals={
"efwd": efwd,
"x": x,
"dx": dx,
"weights": weights,
"dweights": dweights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)
print(
"JaX fwd",
timeit.Timer(
"jfwd(x, dx, weights, dweights, key_cache, key_cache, value_cache, value_cache)",
globals={
"jfwd": jfwd,
"x": x,
"dx": dx,
"weights": weights,
"dweights": dweights,
"key_cache": key_cache,
"value_cache": value_cache,
},
).timeit(number),
)

@jax.jit
def jrev(x, weights, kc, vc, dx, dkc, dvc):
Expand All @@ -328,6 +389,39 @@ def erev(x, weights, kc, vc, dx, dkc, dvc):
jres = jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)
print("Jax rev", jres)

print(
"Enzyme rev",
timeit.Timer(
"erev(x, weights, key_cache, value_cache, dx, dkc, dvc)",
globals={
"erev": erev,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
"dx": dx,
"dkc": dkc,
"dvc": dvc,
},
).timeit(number),
)
print(
"JaX rev",
timeit.Timer(
"jrev(x, weights, key_cache, value_cache, dx, dkc, dvc)",
globals={
"jrev": jrev,
"x": x,
"weights": weights,
"key_cache": key_cache,
"value_cache": value_cache,
"dx": dx,
"dkc": dkc,
"dvc": dvc,
},
).timeit(number),
)


if __name__ == "__main__":
absltest.main()
Loading