Skip to content

Commit

Permalink
Updated split implementations to include optimization/compilation code.
Browse files Browse the repository at this point in the history
  • Loading branch information
alexnick83 committed Nov 14, 2024
1 parent f9d0b41 commit f63d59a
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 51 deletions.
14 changes: 8 additions & 6 deletions FlashAttention/dace/bench_flash_attention_dace_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
warnings.filterwarnings("ignore")

from timeit import repeat
from flash_attention_dace_cpu import flash_attention_dace_4
# from flash_attention_dace_cpu import flash_attention_dace_4
from flash_attention_dace_cpu import get_flash_attention_dace_cpu
from dace.transformation.auto.auto_optimize import auto_optimize


Expand All @@ -16,11 +17,12 @@

if __name__ == "__main__":

# Flash Attention
fa_sdfg = flash_attention_dace_4.to_sdfg(simplify=False)
fa_sdfg.simplify()
auto_optimize(fa_sdfg, dace.DeviceType.CPU)
fa_func = fa_sdfg.compile()
# # Flash Attention
# fa_sdfg = flash_attention_dace_4.to_sdfg(simplify=False)
# fa_sdfg.simplify()
# auto_optimize(fa_sdfg, dace.DeviceType.CPU)
# fa_func = fa_sdfg.compile()
fa_func = get_flash_attention_dace_cpu()

rng = np.random.default_rng(42)

Expand Down
30 changes: 16 additions & 14 deletions FlashAttention/dace/bench_flash_attention_dace_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
warnings.filterwarnings("ignore")

from timeit import repeat
from flash_attention_dace_gpu import custom_attention_dace
# from flash_attention_dace_gpu import custom_attention_dace
from flash_attention_dace_gpu import get_flash_attention_dace_gpu
from dace.transformation.auto.auto_optimize import auto_optimize, apply_gpu_storage


Expand All @@ -17,19 +18,20 @@

if __name__ == "__main__":

# Flash Attention
sdfg = custom_attention_dace.to_sdfg(simplify=False)
apply_gpu_storage(sdfg)
for sd in sdfg.all_sdfgs_recursive():
if sd.parent_sdfg is not None and sd.parent_sdfg is sdfg:
sd.simplify()
auto_optimize(sd, dace.DeviceType.GPU, use_gpu_storage=True)
for state in sdfg.states():
for node in state.nodes():
if isinstance(node, dace.nodes.MapEntry):
node.schedule = dace.ScheduleType.Sequential
sdfg.simplify()
fa_func = sdfg.compile()
# # Flash Attention
# sdfg = custom_attention_dace.to_sdfg(simplify=False)
# apply_gpu_storage(sdfg)
# for sd in sdfg.all_sdfgs_recursive():
# if sd.parent_sdfg is not None and sd.parent_sdfg is sdfg:
# sd.simplify()
# auto_optimize(sd, dace.DeviceType.GPU, use_gpu_storage=True)
# for state in sdfg.states():
# for node in state.nodes():
# if isinstance(node, dace.nodes.MapEntry):
# node.schedule = dace.ScheduleType.Sequential
# sdfg.simplify()
# fa_func = sdfg.compile()
fa_func = get_flash_attention_dace_gpu()

rng = np.random.default_rng(42)

Expand Down
52 changes: 30 additions & 22 deletions FlashAttention/dace/flash_attention_dace_cpu.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,42 @@
import dace
import numpy as np

from dace.transformation.auto import auto_optimize

N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj'))

def get_flash_attention_dace_cpu():

@dace.program
def flash_attention_dace_4(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]):
N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj'))

for ti in dace.map[0:N:Ti]:
@dace.program
def flash_attention_dace_4(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]):

m = np.full([Ti], -np.inf, Q.dtype)
l = np.zeros([Ti], Q.dtype)
S = np.empty([Ti, Tj], Q.dtype)
Oi = np.zeros([Ti, d], Q.dtype)
for ti in dace.map[0:N:Ti]:

Qi = Q[ti:ti+Ti, :]
m = np.full([Ti], -np.inf, Q.dtype)
l = np.zeros([Ti], Q.dtype)
S = np.empty([Ti, Tj], Q.dtype)
Oi = np.zeros([Ti, d], Q.dtype)

for tj in range(0, N, Tj):
Qi = Q[ti:ti+Ti, :]

S[:] = Qi @ np.transpose(K[tj:tj+Tj, :])
for tj in range(0, N, Tj):

max_row = np.max(S, axis=1)
m_new = np.maximum(m, max_row)
p_tilde = np.exp(S - m_new[:, np.newaxis])
sum_row = np.sum(p_tilde, axis=1)
l_tmp = l * np.exp(m - m_new)
l_new = l_tmp + sum_row
Oi[:] = (Oi * l_tmp[:, np.newaxis] + p_tilde @ V[tj:tj+Tj, :]) / l_new[:, np.newaxis]
m[:] = m_new
l[:] = l_new

O[ti:ti+Ti, :] = Oi
S[:] = Qi @ np.transpose(K[tj:tj+Tj, :])

max_row = np.max(S, axis=1)
m_new = np.maximum(m, max_row)
p_tilde = np.exp(S - m_new[:, np.newaxis])
sum_row = np.sum(p_tilde, axis=1)
l_tmp = l * np.exp(m - m_new)
l_new = l_tmp + sum_row
Oi[:] = (Oi * l_tmp[:, np.newaxis] + p_tilde @ V[tj:tj+Tj, :]) / l_new[:, np.newaxis]
m[:] = m_new
l[:] = l_new

O[ti:ti+Ti, :] = Oi

fa_sdfg = flash_attention_dace_4.to_sdfg(simplify=False)
fa_sdfg.simplify()
auto_optimize.auto_optimize(fa_sdfg, dace.DeviceType.CPU)
return fa_sdfg.compile()
34 changes: 25 additions & 9 deletions FlashAttention/dace/flash_attention_dace_gpu.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import dace
import numpy as np

from dace.transformation.auto import auto_optimize

N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj'))

def get_flash_attention_dace_gpu():

@dace.program
def custom_attention_dace(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]):
N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj'))

for ti in dace.map[0:N:Ti]:
@dace.program
def custom_attention_dace(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]):

S = Q[ti:ti+Ti, :] @ np.transpose(K)
m = np.max(S, axis=1)
p_tilde = np.exp(S - m[:, np.newaxis])
l = np.sum(p_tilde, axis=1)
O[ti:ti+Ti, :] = (p_tilde @ V) / l[:, np.newaxis]
for ti in dace.map[0:N:Ti]:

S = Q[ti:ti+Ti, :] @ np.transpose(K)
m = np.max(S, axis=1)
p_tilde = np.exp(S - m[:, np.newaxis])
l = np.sum(p_tilde, axis=1)
O[ti:ti+Ti, :] = (p_tilde @ V) / l[:, np.newaxis]

sdfg = custom_attention_dace.to_sdfg(simplify=False)
auto_optimize.apply_gpu_storage(sdfg)
for sd in sdfg.all_sdfgs_recursive():
if sd.parent_sdfg is not None and sd.parent_sdfg is sdfg:
sd.simplify()
auto_optimize.auto_optimize(sd, dace.DeviceType.GPU, use_gpu_storage=True)
for state in sdfg.states():
for node in state.nodes():
if isinstance(node, dace.nodes.MapEntry):
node.schedule = dace.ScheduleType.Sequential
sdfg.simplify()
return sdfg.compile()

0 comments on commit f63d59a

Please sign in to comment.