-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththunderkittens.cpp
218 lines (182 loc) · 7.67 KB
/
thunderkittens.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <vector>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
/*
HOW TO REGISTER YOUR OWN, CUSTOM SET OF KERNELS:
1. Decide on the identifier which will go in config.py. For example, "attn_inference" is the identifier for the first set below.
2. Add the identifier to the dict of sources in config.py.
3. Add the identifier to the list of kernels you want compiled.
4. The macro defined here, when that kernel is compiled, will be "TK_COMPILE_{IDENTIFIER_IN_ALL_CAPS}." You need to add two chunks to this file.
4a. the extern declaration at the top.
4b. the registration of the function into the module.
m.def("attention_inference_forward", attention_inference_forward);
*/
#ifdef TK_COMPILE_TTT_LINEAR_FORWARD
extern torch::Tensor ttt_linear_forward(
const torch::Tensor ttt_norm_weight,
const torch::Tensor ttt_norm_bias,
const torch::Tensor W1_init,
const torch::Tensor b1_init,
const torch::Tensor XQ_batch,
const torch::Tensor XV_batch,
const torch::Tensor XK_batch,
const torch::Tensor eta_batch,
const torch::Tensor make_last_b_matrix,
const torch::Tensor make_last_coeff_1_matrix,
const torch::Tensor output
);
#endif
#ifdef TK_COMPILE_TTT_MLP_FORWARD
extern torch::Tensor ttt_mlp_forward(
const torch::Tensor ttt_norm_weight,
const torch::Tensor ttt_norm_bias,
const torch::Tensor W1_init,
const torch::Tensor b1_init,
const torch::Tensor W2_init,
const torch::Tensor b2_init,
const torch::Tensor XQ_batch,
const torch::Tensor XV_batch,
const torch::Tensor XK_batch,
const torch::Tensor eta_batch
);
#endif
////////////////////////////////
//// ThunderKittens Premade ////
////////////////////////////////
#ifdef TK_COMPILE_ATTN
extern std::vector<torch::Tensor> attention_forward(
torch::Tensor q, torch::Tensor k, torch::Tensor v, bool causal
);
extern std::vector<torch::Tensor> attention_backward(
torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor o,
torch::Tensor l_vec, torch::Tensor og,
bool causal
);
#endif
#ifdef TK_COMPILE_HEDGEHOG
extern std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> hedgehog(
torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor q_map, torch::Tensor k_map,
torch::Tensor alphas, torch::Tensor betas
);
#endif
#ifdef TK_COMPILE_BASED
extern std::tuple<torch::Tensor, torch::Tensor> based(
const torch::Tensor q,
const torch::Tensor k,
const torch::Tensor v
);
#endif
#ifdef TK_COMPILE_CYLON
extern void cylon(
torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor o, torch::Tensor kv_state,
torch::Tensor q_map, torch::Tensor k_map
);
extern void cylon_bwd(
torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor q_map, torch::Tensor k_map,
torch::Tensor o_grad, torch::Tensor kv_state,
torch::Tensor q_grad, torch::Tensor k_grad, torch::Tensor v_grad,
torch::Tensor q_map_grad, torch::Tensor k_map_grad
);
#endif
#ifdef TK_COMPILE_FLUX
extern torch::Tensor fused_flux_linear_gate(
const torch::Tensor x,
const torch::Tensor weight,
const torch::Tensor bias,
const torch::Tensor gate,
const torch::Tensor y
);
extern torch::Tensor fused_flux_linear_gelu(
const torch::Tensor x,
const torch::Tensor weight,
const torch::Tensor bias
);
#endif
#ifdef TK_COMPILE_FFTCONV
extern torch::Tensor fftconv(
const torch::Tensor u_real,
const torch::Tensor kf_real,
const torch::Tensor kf_imag,
const torch::Tensor f_real,
const torch::Tensor f_imag,
const torch::Tensor finv_real,
const torch::Tensor finv_imag,
const torch::Tensor tw_real,
const torch::Tensor tw_imag,
const torch::Tensor twinv_real,
const torch::Tensor twinv_imag,
int B,
int H,
int N,
int N1
);
#endif
#ifdef TK_COMPILE_FUSED_ROTARY
extern torch::Tensor fused_rotary(
const torch::Tensor x,
const torch::Tensor cos_in,
const torch::Tensor sin_in
);
#endif
#ifdef TK_COMPILE_FUSED_LAYERNORM
extern std::tuple<torch::Tensor, torch::Tensor> fused_layernorm(
const torch::Tensor x,
const torch::Tensor residual,
const torch::Tensor norm_weight,
const torch::Tensor norm_bias,
float dropout_p
);
#endif
#ifdef TK_COMPILE_MAMBA2
extern torch::Tensor mamba2(
const torch::Tensor q,
const torch::Tensor k,
const torch::Tensor v,
const torch::Tensor a
);
#endif
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "ThunderKittens Kernels"; // optional module docstring
#ifdef TK_COMPILE_TTT_LINEAR_FORWARD
m.def("ttt_linear_forward", &ttt_linear_forward, "TTT-Linear Forward.");
#endif
#ifdef TK_COMPILE_TTT_MLP_FORWARD
m.def("ttt_mlp_forward", &ttt_mlp_forward, "TTT-MLP Forward.");
#endif
#ifdef TK_COMPILE_ATTN
m.def("mha_forward", torch::wrap_pybind_function(attention_forward), "Bidirectional forward MHA. Takes Q,K,V,O in (B,H,N,D) where D must be 64 or 128, and N must be a multiple of 64. Additionally writes out norm vector L of shape (B,H,N), used in backward pass.");
m.def("mha_backward", torch::wrap_pybind_function(attention_backward), "Bidirectional backward MHA. Takes Q,K,V,O,Og,Qg,Kg,Vg in (B,H,N,D) where D must be 64 or 128, and N must be a multiple of 64. Additionally requres norm vec l_vec, and (TODO) d_vec memory.");
#endif
#ifdef TK_COMPILE_HEDGEHOG
m.def("hedgehog", hedgehog, "Hedgehog forward. Takes tensors (q, k, v, q_map, k_map, alphas, betas). q, k, v are bf16 (B,H,N,64), q_map and k_map are bf16 (H,E,64,64), alphas and betas are fp32 (H,E). Returns (B,H,N,64) in bf16.");
#endif
#ifdef TK_COMPILE_BASED
m.def("based", based, "Based forward. Takes tensors (q, k, v). q, k, v are bf16 (B,H,N,64). Returns (B,H,N,64) in bf16.");
#endif
#ifdef TK_COMPILE_CYLON
m.def("cylon", cylon, """Cylon forward. Takes tensors (q, k, v, o, kv_state, q_map, k_map). q, k, v, o are bf16 (B,H,N,64), kv_state is fp32 (B,H,E,64,64), q_map and k_map are bf16 (H,E,64,64).""");
m.def("cylon_bwd", cylon_bwd, "Cylon backward. Takes tensors (q, k, v, q_map, k_map, o_grad, kv_state, q_grad, k_grad, v_grad, q_map_grad, k_map_grad). q, k, v, o_grad are bf16 (B,H,N,64), q_map and k_map are bf16 (H,E,64,64), kv_state is fp32 (B,H,E,64,64). Outputs q_grad, k_grad, v_grad are fp32 (B,H,N,64) and q_map_grad, k_map_grad are fp32 (H,E,64,64).");
#endif
#ifdef TK_COMPILE_FLUX
m.def("tk_flux_linear_gate", fused_flux_linear_gate, "Flux linear gate. Takes tensors (x, weight, bias, gate, y). x is (B, H1), weight is (H2, H1), bias and gate are (H2), y is (B, H2). x, weight, bias, gate, y are bf16. Returns (B, H2) in bf16.");
m.def("tk_flux_linear_gelu", fused_flux_linear_gelu, "Flux linear gelu. Takes tensors (x, weight, bias). x is (B, H1), weight is (H2, H1), bias is (H2). x, weight, bias are bf16. Returns (B, H2) in bf16.");
#endif
#ifdef TK_COMPILE_FFTCONV
m.def("fftconv", fftconv, "FFTConv TK. Takes tensors (u_real, kf_real, kf_imag, f_real, f_imag, finv_real, finv_imag, tw_real, tw_imag, twinv_real, twinv_imag, B, H, N, N1). All tensors are bf16 except B, H, N, N1 which are ints. Returns (B, H, N, N1) in bf16.");
#endif
#ifdef TK_COMPILE_FUSED_ROTARY
m.def("fused_rotary", fused_rotary, "Rotary TK. Takes tensors (x, cos_in, sin_in). All tensors are bf16. Returns (B, H, N, 128) in bf16.");
#endif
#ifdef TK_COMPILE_FUSED_LAYERNORM
m.def("fused_layernorm", fused_layernorm, "LayerNorm TK. Takes tensors (x, residual, norm_weight, norm_bias, dropout_p). x, residual, norm_weight, norm_bias are bf16. dropout_p is float. Returns (B, H, N, 128) in bf16.");
#endif
#ifdef TK_COMPILE_MAMBA2
m.def("mamba2", mamba2, "Mamba2 TK. Takes tensors (q, k, v, a). q, k, v tensors are bf16 and a is float.");
#endif
}