Skip to content

Commit

Permalink
Merge pull request karpathy#509 from ngc92/inplace
Browse files Browse the repository at this point in the history
Inplace gelu backward
  • Loading branch information
karpathy authored Jun 1, 2024
2 parents bd0f036 + 761fb2c commit 9cf8c2f
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -942,12 +942,12 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) {
store128(out + idx, packed_out);
}

__global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floatX* dout) {
__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;

x128 packed_dinp;
x128 packed_inp = load128cs(inp + idx);
x128 packed_dout = load128cs(dout + idx);
x128 packed_dout = load128(d_in_out + idx);
for (int k = 0; k < packed_inp.size; ++k) {
float x = (float)packed_inp[k];
float cube = 0.044715f * x * x * x;
Expand All @@ -958,7 +958,7 @@ __global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floa
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]);
}
store128(dinp + idx, packed_dinp);
store128(d_in_out + idx, packed_dinp);
}

template<typename OutFloat, bool UseAuxBuffer>
Expand Down Expand Up @@ -1747,12 +1747,12 @@ void gelu_forward(floatX* out, const floatX* inp, int N) {
cudaCheck(cudaGetLastError());
}

void gelu_backward(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N) {
NVTX_RANGE_FN();
const int block_size = 128;
assert(N % block_size == 0);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout);
gelu_backward_inplace_kernel<<<grid_size, block_size>>>(d_in_out, inp);
cudaCheck(cudaGetLastError());
}

Expand Down Expand Up @@ -2660,7 +2660,7 @@ void gpt2_backward(GPT2 *model, int* inputs) {
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
}
matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C);
gelu_backward(dl_bt4c, l_fch, dl_bt4c, B*T*4*C);
gelu_backward_inplace(dl_bt4c, l_fch, B*T*4*C);
if(model->recompute >= 2) {
// same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
Expand Down Expand Up @@ -2885,8 +2885,8 @@ void gpt2_free(GPT2 *model) {
cudaCheck(cudaFree(model->grads_acts_memory));
cudaCheck(cudaFree(model->inputs));
cudaCheck(cudaFree(model->targets));
cudaFreeHost(model->cpu_losses);
cudaFreeHost(model->cpu_losses_fp32);
cudaCheck(cudaFreeHost(model->cpu_losses));
cudaCheck(cudaFreeHost(model->cpu_losses_fp32));
free(model->workload_indices);
free(model->bucket_info);
}
Expand Down

0 comments on commit 9cf8c2f

Please sign in to comment.