-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflash-backward.cu
145 lines (113 loc) · 4.53 KB
/
flash-backward.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <torch/types.h>
#include <cuda.h>
#include <cuda_runtime.h>
__global__
void backward_kernel(const float* Q, const float* K, const float* V, const float* Mask ,const int N, const int d,
const int Tc, const int Tr, const int Bc, const int Br, const float softmax_scale,
float* l, float *m, float* ){
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y;
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d);
int lm_offset = (bx * gridDim.y * N) + (by * N);
extern__shared__float sram[];
int tile_size =Bc * d;
float* Qi =sram;
float* Kj =&sram[tile_size];
float* Vj =&sram[tile_size * 2];
float* S = &sram[tile_size * 3];
float* dKj =&sram[tile_size * 4];
float* dVj =&sram[tile_size * 5];
float* dOi =&sram[tile_size * 6];
float* dQi = &sram[tile_size * 7];
float* dOi = &sram[tile_size * 8];
float* Pi = &sram[];
float* Di = &sram[];
for(int j=0;j<Tr;j++){
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
dKj[(tx * d) + x] = 0;
dVj[(tx * d) + x] = 0;
}
__syncthreads();
for(int i=0;i<Tc;i++){
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
dQi[(tx * d) + x] = dQ[qkv_offset + (tile_size * i) + (tx * d) + x];
Oi[(tx * d) + x] = O[qkv_offset + (tile_size * i) + (tx * d) + x];
}
//compute Si,Pi
for (int y = 0; y < Br; y++) {
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
S[(Br * tx) + y] = sum;
P[(Br * tx) + y] = __expf(S[(Br * tx) + y]-Li[tx]);
}
//compute dVj Br*Bc*Bc*d
//default Bc = Br or pay attention on tx and Br
for(int y = 0; y <d ;y++){
float sum = 0;
for(int x = 0; x < Bc;x++){
sum +=Pi[(tx * Bc) + x] * dOi[( x * d) + y];
}
dVj[(tx * d) + y] = dVj[(tx * d) + y] + sum;
}
//compute dPi
for(int y =0; y<Br;y++){
float sum =0;
for(int x =0; x<d;x++){
sum +=dOi[(tx + d) +x] * Vj[(y * d) + x];
}
dPi[(Br * tx) + y] = sum;
}
//compute dSi
for(int y =0 ; y<Br ; y++){
dSi[(tx * Br) + y] = P[(tx * Br) + y] * ( dPi[(Br * tx) + y] - Di[tx]);
}
//compute dQi Bc*Br*Br*d
for(int y = 0; y <d ;y++){
float sum = 0;
for(int x = 0; x < Br;x++){
sum +=dSi[(tx * Br) + x] * Kj[(x * d) + y];
}
dQ[qkv_offset + (tile_size * i) + (tx * d) + y] = dQi[(tx * d) + y] + sum;
}
//compute dKj Br*Bc*Bc*d
for(int y = 0; y <d ;y++){
float sum = 0;
for(int x = 0; x < Bc;x++){
sum +=dSi[(tx * Bc) + x] * Qj[(x * d) + y];
}
dKj[(tx * d) + y] = dKj[(tx * d) + y] + sum;
}
}
for (int x = 0; x < d; x++) {
dK[qkv_offset + (tile_size * j) + (tx * d) + x] = Kj[(tx * d) + x];
dV[qkv_offset + (tile_size * j) + (tx * d) + x] = vj[(tx * d) + x];
}
}
}
torch::Tensor backward(torch::Tensor Q,torch::Tensor K,torch::Tensor V,torch::Tensor Mask,torch:::Tensor dO, ){
const int Bc = 32;
const int Br = 32;
const int B = Q.size(0); const int nh = Q.size(1);
const int N = Q.size(2); const int d = Q.size(3);
const int Tc = ceil((float) N / Bc); const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
auto dQ =torch::zeros_like(Q);
auto dK =torch::zeros_like(K);
auto dV =torch::zeros_like(V);
//alloc block memory
const int sram_size = ;
dim3 grid_dim(B,nh);
dim3 block_dim(Bc);
backward_kernel<<<grid_dim,block_dim,sram_size>>>(
Q.data_ptr<float>(), K.data_ptr<float>(), V.data_ptr<float>(),Mask.data_ptr<float>(),
N, d, Tc, Tr, Bc, Br, softmax_scale,
l.data_ptr<float>(), m.data_ptr<float>(), O.data_ptr<float>()
)
return dQ,dK,dV;
}