-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated split implementations to include optimization/compilation code.
- Loading branch information
1 parent
f9d0b41
commit f63d59a
Showing
4 changed files
with
79 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,34 +1,42 @@ | ||
import dace | ||
import numpy as np | ||
|
||
from dace.transformation.auto import auto_optimize | ||
|
||
N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj')) | ||
|
||
def get_flash_attention_dace_cpu(): | ||
|
||
@dace.program | ||
def flash_attention_dace_4(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]): | ||
N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj')) | ||
|
||
for ti in dace.map[0:N:Ti]: | ||
@dace.program | ||
def flash_attention_dace_4(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]): | ||
|
||
m = np.full([Ti], -np.inf, Q.dtype) | ||
l = np.zeros([Ti], Q.dtype) | ||
S = np.empty([Ti, Tj], Q.dtype) | ||
Oi = np.zeros([Ti, d], Q.dtype) | ||
for ti in dace.map[0:N:Ti]: | ||
|
||
Qi = Q[ti:ti+Ti, :] | ||
m = np.full([Ti], -np.inf, Q.dtype) | ||
l = np.zeros([Ti], Q.dtype) | ||
S = np.empty([Ti, Tj], Q.dtype) | ||
Oi = np.zeros([Ti, d], Q.dtype) | ||
|
||
for tj in range(0, N, Tj): | ||
Qi = Q[ti:ti+Ti, :] | ||
|
||
S[:] = Qi @ np.transpose(K[tj:tj+Tj, :]) | ||
for tj in range(0, N, Tj): | ||
|
||
max_row = np.max(S, axis=1) | ||
m_new = np.maximum(m, max_row) | ||
p_tilde = np.exp(S - m_new[:, np.newaxis]) | ||
sum_row = np.sum(p_tilde, axis=1) | ||
l_tmp = l * np.exp(m - m_new) | ||
l_new = l_tmp + sum_row | ||
Oi[:] = (Oi * l_tmp[:, np.newaxis] + p_tilde @ V[tj:tj+Tj, :]) / l_new[:, np.newaxis] | ||
m[:] = m_new | ||
l[:] = l_new | ||
|
||
O[ti:ti+Ti, :] = Oi | ||
S[:] = Qi @ np.transpose(K[tj:tj+Tj, :]) | ||
|
||
max_row = np.max(S, axis=1) | ||
m_new = np.maximum(m, max_row) | ||
p_tilde = np.exp(S - m_new[:, np.newaxis]) | ||
sum_row = np.sum(p_tilde, axis=1) | ||
l_tmp = l * np.exp(m - m_new) | ||
l_new = l_tmp + sum_row | ||
Oi[:] = (Oi * l_tmp[:, np.newaxis] + p_tilde @ V[tj:tj+Tj, :]) / l_new[:, np.newaxis] | ||
m[:] = m_new | ||
l[:] = l_new | ||
|
||
O[ti:ti+Ti, :] = Oi | ||
|
||
fa_sdfg = flash_attention_dace_4.to_sdfg(simplify=False) | ||
fa_sdfg.simplify() | ||
auto_optimize.auto_optimize(fa_sdfg, dace.DeviceType.CPU) | ||
return fa_sdfg.compile() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,33 @@ | ||
import dace | ||
import numpy as np | ||
|
||
from dace.transformation.auto import auto_optimize | ||
|
||
N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj')) | ||
|
||
def get_flash_attention_dace_gpu(): | ||
|
||
@dace.program | ||
def custom_attention_dace(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]): | ||
N, d, Ti, Tj = (dace.symbol(s) for s in ('N', 'd', 'Ti', 'Tj')) | ||
|
||
for ti in dace.map[0:N:Ti]: | ||
@dace.program | ||
def custom_attention_dace(Q: dace.float32[N, d], K: dace.float32[N, d], V: dace.float32[N, d], O: dace.float32[N, d]): | ||
|
||
S = Q[ti:ti+Ti, :] @ np.transpose(K) | ||
m = np.max(S, axis=1) | ||
p_tilde = np.exp(S - m[:, np.newaxis]) | ||
l = np.sum(p_tilde, axis=1) | ||
O[ti:ti+Ti, :] = (p_tilde @ V) / l[:, np.newaxis] | ||
for ti in dace.map[0:N:Ti]: | ||
|
||
S = Q[ti:ti+Ti, :] @ np.transpose(K) | ||
m = np.max(S, axis=1) | ||
p_tilde = np.exp(S - m[:, np.newaxis]) | ||
l = np.sum(p_tilde, axis=1) | ||
O[ti:ti+Ti, :] = (p_tilde @ V) / l[:, np.newaxis] | ||
|
||
sdfg = custom_attention_dace.to_sdfg(simplify=False) | ||
auto_optimize.apply_gpu_storage(sdfg) | ||
for sd in sdfg.all_sdfgs_recursive(): | ||
if sd.parent_sdfg is not None and sd.parent_sdfg is sdfg: | ||
sd.simplify() | ||
auto_optimize.auto_optimize(sd, dace.DeviceType.GPU, use_gpu_storage=True) | ||
for state in sdfg.states(): | ||
for node in state.nodes(): | ||
if isinstance(node, dace.nodes.MapEntry): | ||
node.schedule = dace.ScheduleType.Sequential | ||
sdfg.simplify() | ||
return sdfg.compile() |