Skip to content

Commit

Permalink
Use flash attention when available.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Feb 24, 2024
1 parent 37e6258 commit d13ef08
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 15 deletions.
31 changes: 23 additions & 8 deletions clip_jaxtorch/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
import einops
from collections import OrderedDict
from functools import partial
import math

try:
from flash_attn_jax import flash_mha
use_flash_attention = True
except ImportError:
use_flash_attention = False

class QuickGELU(nn.Module):
def forward(self, cx, x):
Expand Down Expand Up @@ -44,13 +51,18 @@ def __init__(self, d_model, n_head, attn_mask=None):
def forward(self, cx, x):
# x : n k c
qkv = jnp.einsum('nkc,ic->nki', x, cx[self.in_proj_weight]) + cx[self.in_proj_bias]
q, k, v = qkv.rearrange('n k (p h c) -> p n h k c', p=3, h=self.n_head)
qk = jnp.einsum('nhqc,nhkc->nhqk', q, k) / jnp.sqrt(q.shape[-1])
if self.attn_mask is not None:
qk = self.attn_mask(cx, qk)
qk = jax.nn.softmax(qk, axis=-1)
out = jnp.einsum('nhqk,nhkc->nhqc', qk, v)
out = out.rearrange('n h k c -> n k (h c)')
if use_flash_attention and x.dtype in [jnp.float16, jnp.bfloat16]:
q,k,v = qkv.rearrange('n k (p h c) -> p n k h c', p=3, h=self.n_head)
out = flash_mha(q,k,v, softmax_scale = 1 / math.sqrt(q.shape[-1]), is_causal = self.attn_mask=='causal')
out = out.rearrange('n k h c -> n k (h c)')
else:
q, k, v = qkv.rearrange('n k (p h c) -> p n h k c', p=3, h=self.n_head)
qk = jnp.einsum('nhqc,nhkc->nhqk', q, k) / jnp.sqrt(q.shape[-1])
if self.attn_mask == 'causal':
qk = causal(cx, qk)
qk = jax.nn.softmax(qk, axis=-1)
out = jnp.einsum('nhqk,nhkc->nhqc', qk, v)
out = out.rearrange('n h k c -> n k (h c)')
out = self.out_proj(cx, out)
return out

Expand Down Expand Up @@ -125,7 +137,7 @@ def gather_bi(x, i):
class CLIPText(nn.Module):
def __init__(self, n_dim=512, n_layers=12, n_heads=8, d_out=512):
super().__init__()
self.transformer = Transformer(n_dim, n_layers, heads=n_heads, attn_mask=causal)
self.transformer = Transformer(n_dim, n_layers, heads=n_heads, attn_mask='causal')
self.token_embedding = nn.Embedding(49408, n_dim)
self.ln_final = nn.LayerNorm(n_dim)
self.positional_embedding = init.normal(77, n_dim)
Expand All @@ -145,6 +157,9 @@ def encode_text(self, cx, text):

return x

def encode_image(self, cx, image):
return self.visual(cx, image)

class VITB32(CLIPText):
def __init__(self):
super().__init__(512, 12, 8, 512)
Expand Down
77 changes: 70 additions & 7 deletions test/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,90 @@ def norm1(x):
return x / x.square().sum(axis=-1,keepdims=True).sqrt()

@pytest.mark.parametrize('model_name', ['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px'])
@pytest.mark.parametrize('flash', [False, True])
@torch.no_grad()
def test_vit(model_name):
def test_vit(model_name, flash):
vit.use_flash_attention = flash

rng = PRNG(jax.random.PRNGKey(1))

clip_model, _ = torch_clip.load(model_name, device='cpu')
jax_model, params = jax_clip.load(model_name)
params = tree_map(lambda x: x.astype(jnp.float32), params)

if flash:
dtype = jnp.float16
else:
dtype = jnp.float32

params = tree_map(lambda x: x.astype(dtype), params)

size = jax_model.visual.input_resolution

if flash:
# Error with float16 is a little higher
tol = dict(atol=0.007, rtol=0.003)
else:
tol = dict(atol=0.002, rtol=0.001)

image = jax.random.normal(rng.split(), [1, 3, size, size])
out_jax = jax_model.visual(Context(params, None), image)
out_jax = jax_model.visual(Context(params, None), image.astype(dtype))
out_torch = fromtorch(clip_model.visual(fromjax(image)))
assert jnp.allclose(out_torch, out_jax, atol=0.002, rtol=0.001)
assert jnp.allclose(out_torch, out_jax, **tol), ((out_torch-out_jax).abs().max(), out_torch.abs().max())

image = jnp.zeros([1, 3, size, size])
out_jax = jax_model.visual(Context(params, None), image)
out_jax = jax_model.visual(Context(params, None), image.astype(dtype))
out_torch = fromtorch(clip_model.visual(fromjax(image)))
assert jnp.allclose(out_torch, out_jax, atol=0.002, rtol=0.001)
assert jnp.allclose(out_torch, out_jax, **tol), ((out_torch-out_jax).abs().max(), out_torch.abs().max())

text = fromtorch(torch_clip.tokenize("hello world"))
out_jax = jax_model.encode_text(Context(params, None), text)
out_torch = fromtorch(clip_model.encode_text(fromjax(text)))
assert jnp.allclose(out_torch, out_jax, atol=0.002, rtol=0.001)
assert jnp.allclose(out_torch, out_jax, **tol), ((out_torch-out_jax).abs().max(), ((out_torch-out_jax).abs().max()/out_torch.abs().max()))

if __name__ == '__main__':
key = jax.random.PRNGKey(1)
model_name = 'ViT-B/32'
# clip_model, _ = torch_clip.load(model_name, device='cpu')
jax_model, params_f32 = jax_clip.load(model_name)
params_f32 = tree_map(lambda x: x.astype(jnp.float32), params_f32)
params_f16 = tree_map(lambda x: x.astype(jnp.float16), params_f32)

from functools import partial, wraps

def capture(name, fwd):
@wraps(fwd)
def forward(cx, *args, **kwargs):
out = fwd(cx, *args, **kwargs)
cx.tmp['out'][name] = out
return out
return forward

for mod in jax_model.modules():
mod.forward = capture(mod.name, mod.forward)

vit.use_flash_attention = True

@jax.jit
def compare(params_f32, params_f16, text):
out32 = {}
out16 = {}

cx = Context(params_f32, None)
cx.tmp['out'] = out32
jax_model.encode_text(cx, text)#image.astype(jnp.float32))
cx = Context(params_f16, None)
cx.tmp['out'] = out16
jax_model.encode_text(cx, text)#image.astype(jnp.float16))

rtol = {}
for name in out32.keys():
rtol[name] = (out16[name] - out32[name]).max()/out32[name].max()
return rtol

size = jax_model.visual.input_resolution
text = fromtorch(torch_clip.tokenize("hello world"))
# image = jax.random.normal(key, [1, 3, size, size])
rtol = compare(params_f32, params_f16, text)
for mod in jax_model.modules():
if mod.name in rtol.keys():
print(mod.name, rtol[mod.name])

0 comments on commit d13ef08

Please sign in to comment.