Skip to content

Commit

Permalink
rename gelu_backward to gelu_backward_inplace, and made it obvious in…
Browse files Browse the repository at this point in the history
… the code that this is an inplace operation
  • Loading branch information
ngc92 committed Jun 1, 2024
1 parent cc4bed0 commit 761fb2c
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -942,12 +942,12 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) {
store128(out + idx, packed_out);
}

__global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floatX* dout) {
__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

x128 packed_dinp;
x128 packed_inp = load128cs(inp + idx);
x128 packed_dout = load128cs(dout + idx);
x128 packed_dout = load128(d_in_out + idx);
for (int k = 0; k < packed_inp.size; ++k) {
float x = (float)packed_inp[k];
float cube = 0.044715f * x * x * x;
Expand All @@ -958,7 +958,7 @@ __global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floa
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]);
}
store128(dinp + idx, packed_dinp);
store128(d_in_out + idx, packed_dinp);
}

template<typename OutFloat, bool UseAuxBuffer>
Expand Down Expand Up @@ -1747,12 +1747,12 @@ void gelu_forward(floatX* out, const floatX* inp, int N) {
cudaCheck(cudaGetLastError());
}

void gelu_backward(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N) {
NVTX_RANGE_FN();
const int block_size = 128;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout);
gelu_backward_inplace_kernel<<<grid_size, block_size>>>(d_in_out, inp);
cudaCheck(cudaGetLastError());
}

Expand Down Expand Up @@ -2660,7 +2660,7 @@ void gpt2_backward(GPT2 *model, int* inputs) {
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
}
matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C);
gelu_backward(dl_bt4c, l_fch, dl_bt4c, B*T*4*C);
gelu_backward_inplace(dl_bt4c, l_fch, B*T*4*C);
if(model->recompute >= 2) {
// same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
Expand Down

0 comments on commit 761fb2c

Please sign in to comment.