Skip to content

jarbus/FlashAttention.jl

 
 

Repository files navigation

FlashAttention

Stable Dev Build Status Coverage

This is a Julia implementation of the Flash Attention algorithm.

Usage

using FlashAttention, CUDA

Q = CUDA.randn(Float16, 64, 1024, 48, 3);
K = CUDA.randn(Float16, 64, 1024, 48, 3);
V = CUDA.randn(Float16, 64, 1024, 48, 3);

flash_attention(Q,K,V)

Profiling

Please refer to the file flash_attention.ncu-rep. This is not the fastest implementation for

  1. we do not use tensor cores as in the C++ implmentation,
  2. CUDA.jl doese not yet support asynchronous copy from global memory to shared memory, and
  3. this kernel's theoretical occupancy (12.5%) is limited by the required amount of shared memory.

Future work

I plan to implement it in the future using MoYe.jl to achieve competitive performance.

About

Julia implementation of the Flash Attention algorithm

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Julia 100.0%