Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flash attention puzzle #31

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
803 changes: 803 additions & 0 deletions Flash_attention_puzzle.ipynb

Large diffs are not rendered by default.

Binary file added flash_attn_forward_algo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
179 changes: 134 additions & 45 deletions lib.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
from dataclasses import dataclass
import numpy as np
from chalk import *
from chalk import (
concat, rectangle, text, hcat, vcat, circle, arc_between, empty, place_at, vstrut, hstrut, image
)
from chalk.transform import unit_y, P2, V2
from chalk.core import set_svg_height
from colour import Color
import chalk
from dataclasses import dataclass
from typing import List, Any
from collections import Counter
from numba import cuda
import numba
import random
from typing import Tuple
from unittest import mock
from functools import reduce

@dataclass
class ScalarHistory:
Expand All @@ -23,26 +28,77 @@ def __add__(self, b):
return self
if isinstance(b, Scalar):
return ScalarHistory(self.last_fn, self.inputs + [b])
return ScalarHistory(self.last_fn, self.inputs + b.inputs)

def __rsub__(self, b):
return self - b

def __sub__(self, b):
if isinstance(b, (float, int)):
return self
if isinstance(b, Scalar):
return ScalarHistory(self.last_fn, self.inputs + [b])

return ScalarHistory(self.last_fn, self.inputs + b.inputs)

def __rmul__(self, b):
return self * b

def __mul__(self, b):
if isinstance(b, (float, int)):
return self
if isinstance(b, Scalar):
return ScalarHistory(self.last_fn, self.inputs + [b])
return ScalarHistory(self.last_fn, self.inputs + b.inputs)

def __rtruediv__(self, b):
return self / b

def __truediv__(self, b):
if isinstance(b, (float, int)):
return self
if isinstance(b, Scalar):
return ScalarHistory(self.last_fn, self.inputs + [b])
return ScalarHistory(self.last_fn, self.inputs + b.inputs)


class Scalar:
def __init__(self, location):
self.location = location

def __rmul__(self, b):
return self * b

def __mul__(self, b):
if isinstance(b, (float, int)):
return ScalarHistory("id", [self])
return ScalarHistory("*", [self, b])


def __radd__(self, b):
return self + b

def __add__(self, b):
if isinstance(b, (float, int)):
return ScalarHistory("id", [self])
return ScalarHistory("+", [self, b])

def __rsub__(self, b):
return self - b

def __sub__(self, b):
if isinstance(b, (float, int)):
return ScalarHistory("id", [self])
return ScalarHistory("-", [self, b])

def __gt__(self, b):
if isinstance(b, (float, int)):
return ScalarHistory("id", [self])
return ScalarHistory(">", [self, b])

def __lt__(self, b):
if isinstance(b, (float, int)):
return ScalarHistory("id", [self])
return ScalarHistory("<", [self, b])

def __iadd__(self, other):
assert False, "Instead of `out[] +=` use a local variable `acc + =`"
Expand All @@ -53,14 +109,14 @@ def __init__(self, name, array):
self.incoming = []
self.array = array

self.size = array.shape
self.shape = array.shape

def __getitem__(self, index):
self.array[index]
if isinstance(index, int):
index = (index,)
assert len(index) == len(self.size), "Wrong number of indices"
if index[0] >= self.size[0]:
assert len(index) == len(self.shape), "Wrong number of indices"
if index[0] >= self.shape[0]:
assert False, "bad size"

return Scalar((self.name,) + index)
Expand All @@ -69,8 +125,8 @@ def __setitem__(self, index, val):
self.array[index]
if isinstance(index, int):
index = (index,)
assert len(index) == len(self.size), "Wrong number of indices"
if index[0] >= self.size[0]:
assert len(index) == len(self.shape), "Wrong number of indices"
if index[0] >= self.shape[0]:
assert False, "bad size"
if isinstance(val, Scalar):
val = ScalarHistory("id", [val])
Expand All @@ -95,8 +151,8 @@ def tuple(self):


class RefList:
def __init__(self):
self.refs = []
def __init__(self, refs=None):
self.refs = refs or []

def __getitem__(self, index):
return self.refs[-1][index]
Expand All @@ -112,13 +168,15 @@ def __init__(self, cuda):
def array(self, size, ig):
if isinstance(size, int):
size = (size,)
s = np.zeros(size)
cache = Table("S" + str(len(self.cuda.caches)), s)
table = Table(
name="S" + str(len(self.cuda.caches)),
array=np.zeros(size)
)
# self.caches.append(cache)
self.cuda.caches.append(RefList())
self.cuda.caches[-1].refs = [cache]
reflist = RefList(refs=[table])
self.cuda.caches.append(reflist)
self.cuda.saved.append([])
return self.cuda.caches[-1]
return reflist


class Cuda:
Expand All @@ -144,9 +202,7 @@ def syncthreads(self):
temp = old_cache.incoming
old_cache.incoming = self.saved[i]
self.saved[i] = temp
cache = Table(old_cache.name + "'", old_cache.array)

c.refs.append(cache)
c.refs.append(Table(old_cache.name + "'", old_cache.array))

def finish(self):
for i, c in enumerate(self.caches):
Expand Down Expand Up @@ -207,10 +263,10 @@ def myconnect(diagram, loc, color, con, name1, name2):

def draw_table(tab):
t = text(tab.name, 0.5).fill_color(black).line_width(0.0)
if len(tab.size) == 1:
tab = table(tab.name, 0, *tab.size)
if len(tab.shape) == 1:
tab = table(tab.name, 0, *tab.shape)
else:
tab = table(tab.name, *tab.size)
tab = table(tab.name, *tab.shape)
tab = tab.line_width(0.05)
return tab.beside((t + vstrut(0.5)), -unit_y)

Expand All @@ -221,6 +277,7 @@ def draw_connect(tab, dia, loc2, color, con):
myconnect(dia, loc2, color, con, (tab.name,) + loc, inp.location)
for (loc, val) in tab.incoming
for inp in val.inputs
if not isinstance(inp, ScalarHistory)
]
)

Expand All @@ -235,7 +292,10 @@ def draw_base(_, a, c, out):
return hcat([inputs, shareds, outputs], 2.0)


def draw_coins(tpbx, tpby):
def draw_coins(tpbx, tpby, colors=None):
colors = colors or list(
Color("red").range_to(Color("blue"), sum(1 for _ in Coord(tpbx, tpby).enumerate()))
)
return concat(
[
(circle(0.5).fill_color(colors[tt]).fill_opacity(0.7) + im).translate(
Expand All @@ -253,14 +313,19 @@ def label(dia, content):



def draw_results(results, name, tpbx, tpby, sparse=False):
def draw_results(
results, name, tpbx, tpby, sparse=False, svg_height=500, svg_height_factor=50, colors=None
):
full = empty()
blocks = []
locations = []
base = draw_base(*results[Coord(0, 0)][Coord(0, 0)])
for block, inner in results.items():
dia = base
for pos, (tt, a, c, out) in inner.items():
colors = colors or list(
Color("red").range_to(Color("blue"), len(inner))
)
for pos, (tt, input_tables, cuda_obj, out) in inner.items():
loc = (
pos.x / tpbx + (1 / (2 * tpbx)),
(pos.y / tpby)
Expand All @@ -274,13 +339,16 @@ def draw_results(results, name, tpbx, tpby, sparse=False):
pos.x == (tpbx - 1)
and pos.y == (tpby - 1)
)
all_tabs = (
a + [c2.refs[i] for i in range(1, c.rounds()) for c2 in c.caches] + [out]
all_tables = (
input_tables
+ [
cache.refs[i] for i in range(1, cuda_obj.rounds()) for cache in cuda_obj.caches
]
+ [out]
)
dia = dia + concat(
draw_connect(t, dia, loc, color, lines) for t in all_tabs
draw_connect(tab, dia, loc, color, lines) for tab in all_tables
)
height = dia.get_envelope().height

# Label block and surround
dia = hstrut(1) | (label(dia, f"Block {block.x} {block.y}")) | hstrut(1)
Expand Down Expand Up @@ -311,10 +379,10 @@ def draw_results(results, name, tpbx, tpby, sparse=False):
)
full = full.pad(1.1).center_xy()
env = full.get_envelope()
set_svg_height(50 * env.height)
set_svg_height(svg_height_factor * env.height)


chalk.core.set_svg_output_height(500)
chalk.core.set_svg_output_height(svg_height)
return rectangle(env.width, env.height).fill_color(white) + full


Expand All @@ -330,6 +398,7 @@ class CudaProblem:
blockspergrid: Coord = Coord(1, 1)
threadsperblock: Coord = Coord(1, 1)
spec: Any = None
input_names: List[str] = ("a", "b", "c", "d")

def run_cuda(self):
fn = self.fn
Expand All @@ -340,38 +409,44 @@ def run_cuda(self):
)
return self.out

@mock.patch("math.exp", lambda x: x)
def run_python(self):
results = {}
fn = self.fn
for _, block in self.blockspergrid.enumerate():
results[block] = {}
for tt, pos in self.threadsperblock.enumerate():
a = []
args = ["a", "b", "c", "d"]
input_tables = []
for i, inp in enumerate(self.inputs):
a.append(Table(args[i], inp))
input_tables.append(Table(self.input_names[i], inp))
out = Table("out", self.out)

c = Cuda(block, self.threadsperblock, pos)
fn(c)(out, *a, *self.args)
c.finish()
results[block][pos] = (tt, a, c, out)
cuda_obj = Cuda(block, self.threadsperblock, pos)
fn(cuda_obj)(out, *input_tables, *self.args)
cuda_obj.finish()
results[block][pos] = (tt, input_tables, cuda_obj, out)
return results

def score(self, results):

total = 0
full = Counter()
for pos, (tt, a, c, out) in results[Coord(0, 0)].items():
for pos, (tt, a, cuda_obj, out) in results[Coord(0, 0)].items():
total += 1
count = Counter()
for out, tab in [(False, c2.refs[i]) for i in range(1, c.rounds()) for c2 in c.caches] + [(True, out)]:
for out, tab in [
(False, cache.refs[i])
for i in range(1, cuda_obj.rounds())
for cache in cuda_obj.caches
] + [(True, out)]:
for inc in tab.incoming:
if out:
count["out_writes"] += 1
else:
count["shared_writes"] += 1
for ins in inc[1].inputs:
if isinstance(ins, ScalarHistory):
continue
if ins.location[0].startswith("S"):
count["shared_reads"] += 1
else:
Expand All @@ -386,17 +461,31 @@ def score(self, results):
| {full['in_reads']:>13} | {full['out_writes']:>13} | {full['shared_reads']:>13} | {full['shared_writes']:>13} |
""")

def show(self, sparse=False):
def show(self, sparse=False, svg_height_factor=50):
results = self.run_python()
self.score(results)
return draw_results(results, self.name,
self.threadsperblock.x, self.threadsperblock.y, sparse)
colors = [
*Color("red").range_to(
Color("blue"),
reduce(int.__mul__, self.threadsperblock.tuple())
)
]
return draw_results(
results,
self.name,
self.threadsperblock.x,
self.threadsperblock.y,
sparse,
svg_height=500,
svg_height_factor=svg_height_factor,
colors=colors,
)

def check(self):
def check(self, atol=0, rtol=1e-7):
x = self.run_cuda()
y = self.spec(*self.inputs)
try:
np.testing.assert_allclose(x, y)
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
print("Passed Tests!")
from IPython.display import HTML
pups = [
Expand Down