diff --git a/clip_jaxtorch/vit.py b/clip_jaxtorch/vit.py index 6711ffb..6f8ff79 100644 --- a/clip_jaxtorch/vit.py +++ b/clip_jaxtorch/vit.py @@ -5,6 +5,13 @@ import einops from collections import OrderedDict from functools import partial +import math + +try: + from flash_attn_jax import flash_mha + use_flash_attention = True +except ImportError: + use_flash_attention = False class QuickGELU(nn.Module): def forward(self, cx, x): @@ -44,13 +51,18 @@ def __init__(self, d_model, n_head, attn_mask=None): def forward(self, cx, x): # x : n k c qkv = jnp.einsum('nkc,ic->nki', x, cx[self.in_proj_weight]) + cx[self.in_proj_bias] - q, k, v = qkv.rearrange('n k (p h c) -> p n h k c', p=3, h=self.n_head) - qk = jnp.einsum('nhqc,nhkc->nhqk', q, k) / jnp.sqrt(q.shape[-1]) - if self.attn_mask is not None: - qk = self.attn_mask(cx, qk) - qk = jax.nn.softmax(qk, axis=-1) - out = jnp.einsum('nhqk,nhkc->nhqc', qk, v) - out = out.rearrange('n h k c -> n k (h c)') + if use_flash_attention and x.dtype in [jnp.float16, jnp.bfloat16]: + q,k,v = qkv.rearrange('n k (p h c) -> p n k h c', p=3, h=self.n_head) + out = flash_mha(q,k,v, softmax_scale = 1 / math.sqrt(q.shape[-1]), is_causal = self.attn_mask=='causal') + out = out.rearrange('n k h c -> n k (h c)') + else: + q, k, v = qkv.rearrange('n k (p h c) -> p n h k c', p=3, h=self.n_head) + qk = jnp.einsum('nhqc,nhkc->nhqk', q, k) / jnp.sqrt(q.shape[-1]) + if self.attn_mask == 'causal': + qk = causal(cx, qk) + qk = jax.nn.softmax(qk, axis=-1) + out = jnp.einsum('nhqk,nhkc->nhqc', qk, v) + out = out.rearrange('n h k c -> n k (h c)') out = self.out_proj(cx, out) return out @@ -125,7 +137,7 @@ def gather_bi(x, i): class CLIPText(nn.Module): def __init__(self, n_dim=512, n_layers=12, n_heads=8, d_out=512): super().__init__() - self.transformer = Transformer(n_dim, n_layers, heads=n_heads, attn_mask=causal) + self.transformer = Transformer(n_dim, n_layers, heads=n_heads, attn_mask='causal') self.token_embedding = nn.Embedding(49408, n_dim) self.ln_final = nn.LayerNorm(n_dim) self.positional_embedding = init.normal(77, n_dim) @@ -145,6 +157,9 @@ def encode_text(self, cx, text): return x + def encode_image(self, cx, image): + return self.visual(cx, image) + class VITB32(CLIPText): def __init__(self): super().__init__(512, 12, 8, 512) diff --git a/test/test_vit.py b/test/test_vit.py index a488e46..cd62dd3 100644 --- a/test/test_vit.py +++ b/test/test_vit.py @@ -23,27 +23,90 @@ def norm1(x): return x / x.square().sum(axis=-1,keepdims=True).sqrt() @pytest.mark.parametrize('model_name', ['ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']) +@pytest.mark.parametrize('flash', [False, True]) @torch.no_grad() -def test_vit(model_name): +def test_vit(model_name, flash): + vit.use_flash_attention = flash + rng = PRNG(jax.random.PRNGKey(1)) clip_model, _ = torch_clip.load(model_name, device='cpu') jax_model, params = jax_clip.load(model_name) - params = tree_map(lambda x: x.astype(jnp.float32), params) + + if flash: + dtype = jnp.float16 + else: + dtype = jnp.float32 + + params = tree_map(lambda x: x.astype(dtype), params) size = jax_model.visual.input_resolution + if flash: + # Error with float16 is a little higher + tol = dict(atol=0.007, rtol=0.003) + else: + tol = dict(atol=0.002, rtol=0.001) + image = jax.random.normal(rng.split(), [1, 3, size, size]) - out_jax = jax_model.visual(Context(params, None), image) + out_jax = jax_model.visual(Context(params, None), image.astype(dtype)) out_torch = fromtorch(clip_model.visual(fromjax(image))) - assert jnp.allclose(out_torch, out_jax, atol=0.002, rtol=0.001) + assert jnp.allclose(out_torch, out_jax, **tol), ((out_torch-out_jax).abs().max(), out_torch.abs().max()) image = jnp.zeros([1, 3, size, size]) - out_jax = jax_model.visual(Context(params, None), image) + out_jax = jax_model.visual(Context(params, None), image.astype(dtype)) out_torch = fromtorch(clip_model.visual(fromjax(image))) - assert jnp.allclose(out_torch, out_jax, atol=0.002, rtol=0.001) + assert jnp.allclose(out_torch, out_jax, **tol), ((out_torch-out_jax).abs().max(), out_torch.abs().max()) text = fromtorch(torch_clip.tokenize("hello world")) out_jax = jax_model.encode_text(Context(params, None), text) out_torch = fromtorch(clip_model.encode_text(fromjax(text))) - assert jnp.allclose(out_torch, out_jax, atol=0.002, rtol=0.001) + assert jnp.allclose(out_torch, out_jax, **tol), ((out_torch-out_jax).abs().max(), ((out_torch-out_jax).abs().max()/out_torch.abs().max())) + +if __name__ == '__main__': + key = jax.random.PRNGKey(1) + model_name = 'ViT-B/32' + # clip_model, _ = torch_clip.load(model_name, device='cpu') + jax_model, params_f32 = jax_clip.load(model_name) + params_f32 = tree_map(lambda x: x.astype(jnp.float32), params_f32) + params_f16 = tree_map(lambda x: x.astype(jnp.float16), params_f32) + + from functools import partial, wraps + + def capture(name, fwd): + @wraps(fwd) + def forward(cx, *args, **kwargs): + out = fwd(cx, *args, **kwargs) + cx.tmp['out'][name] = out + return out + return forward + + for mod in jax_model.modules(): + mod.forward = capture(mod.name, mod.forward) + + vit.use_flash_attention = True + + @jax.jit + def compare(params_f32, params_f16, text): + out32 = {} + out16 = {} + + cx = Context(params_f32, None) + cx.tmp['out'] = out32 + jax_model.encode_text(cx, text)#image.astype(jnp.float32)) + cx = Context(params_f16, None) + cx.tmp['out'] = out16 + jax_model.encode_text(cx, text)#image.astype(jnp.float16)) + + rtol = {} + for name in out32.keys(): + rtol[name] = (out16[name] - out32[name]).max()/out32[name].max() + return rtol + + size = jax_model.visual.input_resolution + text = fromtorch(torch_clip.tokenize("hello world")) + # image = jax.random.normal(key, [1, 3, size, size]) + rtol = compare(params_f32, params_f16, text) + for mod in jax_model.modules(): + if mod.name in rtol.keys(): + print(mod.name, rtol[mod.name])