From cc4bed08067ef2d2413e3c3e28c54636d6e90f8e Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 1 Jun 2024 14:30:40 +0300 Subject: [PATCH 1/2] added missing checks --- train_gpt2.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index ff4d9c200..6919dbf4c 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -2885,8 +2885,8 @@ void gpt2_free(GPT2 *model) { cudaCheck(cudaFree(model->grads_acts_memory)); cudaCheck(cudaFree(model->inputs)); cudaCheck(cudaFree(model->targets)); - cudaFreeHost(model->cpu_losses); - cudaFreeHost(model->cpu_losses_fp32); + cudaCheck(cudaFreeHost(model->cpu_losses)); + cudaCheck(cudaFreeHost(model->cpu_losses_fp32)); free(model->workload_indices); free(model->bucket_info); } From 761fb2cf2b287b15f6810da4e92273463dc53851 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Sat, 1 Jun 2024 14:31:26 +0300 Subject: [PATCH 2/2] rename gelu_backward to gelu_backward_inplace, and made it obvious in the code that this is an inplace operation --- train_gpt2.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 6919dbf4c..3c2149118 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -942,12 +942,12 @@ __global__ void gelu_forward_kernel2(floatX* out, const floatX* inp) { store128(out + idx, packed_out); } -__global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floatX* dout) { +__global__ void gelu_backward_inplace_kernel(floatX* d_in_out, const floatX* inp) { int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size; x128 packed_dinp; x128 packed_inp = load128cs(inp + idx); - x128 packed_dout = load128cs(dout + idx); + x128 packed_dout = load128(d_in_out + idx); for (int k = 0; k < packed_inp.size; ++k) { float x = (float)packed_inp[k]; float cube = 0.044715f * x * x * x; @@ -958,7 +958,7 @@ __global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floa float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x); packed_dinp[k] = (floatX)(local_grad * (float)packed_dout[k]); } - store128(dinp + idx, packed_dinp); + store128(d_in_out + idx, packed_dinp); } template @@ -1747,12 +1747,12 @@ void gelu_forward(floatX* out, const floatX* inp, int N) { cudaCheck(cudaGetLastError()); } -void gelu_backward(floatX* dinp, const floatX* inp, const floatX* dout, const int N) { +void gelu_backward_inplace(floatX* d_in_out, const floatX* inp, const int N) { NVTX_RANGE_FN(); const int block_size = 128; assert(N % block_size == 0); const int grid_size = CEIL_DIV(N, block_size * x128::size); - gelu_backward_kernel<<>>(dinp, inp, dout); + gelu_backward_inplace_kernel<<>>(d_in_out, inp); cudaCheck(cudaGetLastError()); } @@ -2660,7 +2660,7 @@ void gpt2_backward(GPT2 *model, int* inputs) { gelu_forward(l_fch_gelu, l_fch, B*T*4*C); } matmul_backward(dl_bt4c, dl_fcprojw, dl_fcprojb, dresidual, l_fch_gelu, l_fcprojw, scratchF, B, T, 4*C, C); - gelu_backward(dl_bt4c, l_fch, dl_bt4c, B*T*4*C); + gelu_backward_inplace(dl_bt4c, l_fch, B*T*4*C); if(model->recompute >= 2) { // same as gelu above, l_ln1 and l_ln2 are just buffers if recompute >= 2, recompute them here on demand layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);