From b4a286a128de4657d67a629ed68be5db3a9aecdb Mon Sep 17 00:00:00 2001 From: Xiaodong Wang Date: Tue, 15 Oct 2024 18:58:49 -0700 Subject: [PATCH] [AMD] hipify torchaudio Differential Revision: D64184710 Pull Request resolved: https://github.com/pytorch/audio/pull/3840 --- src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh | 8 ++++++++ src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh | 16 ++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh index c13dd1ef71..f4ad3add2b 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernel_utils.cuh @@ -39,7 +39,11 @@ __global__ void ReduceMax2D( CAST_DTYPE shf; for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { +#ifndef USE_ROCM shf = __shfl_down_sync(0xFFFFFFFF, val, stride); +#else + shf = __shfl_down(val, stride); +#endif if (threadIdx.x < stride && threadIdx.x + stride < dim) { if (shf > val) { val = shf; @@ -81,7 +85,11 @@ __global__ void ReduceLogSumExpGivenMax2D( CAST_DTYPE shf; for (int stride = (WARP_SIZE >> 1); stride > 0; stride >>= 1) { +#ifndef USE_ROCM shf = __shfl_down_sync(0xFFFFFFFF, val, stride); +#else + shf = __shfl_down(val, stride); +#endif if (threadIdx.x < stride && threadIdx.x + stride < dim) { val = val + shf; } diff --git a/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh b/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh index 4f2737891e..136e6844f2 100644 --- a/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh +++ b/src/libtorchaudio/rnnt/gpu/gpu_kernels.cuh @@ -126,7 +126,11 @@ __device__ void ComputeAlphas( #pragma unroll for (int i = 1; i < warpSize; i <<= 1) { +#ifndef USE_ROCM val = __shfl_up_sync(0xffffffff, skip_prob, i); +#else + val = __shfl_up(skip_prob, i); +#endif if (i <= threadIdx.x) { skip_prob = skip_prob + val; } @@ -150,7 +154,11 @@ __device__ void ComputeAlphas( CAST_DTYPE out = val; for (int i = 1; i < warpSize; ++i) { +#ifndef USE_ROCM val = __shfl_up_sync(0xffffffff, val, 1); +#else + val = __shfl_up(val, 1); +#endif if (i == threadIdx.x) { val = math::lse(val + skip_prob, emit); out = val; @@ -225,7 +233,11 @@ __device__ void ComputeBetasCosts( #pragma unroll for (int i = 1; i < warpSize; i <<= 1) { +#ifndef USE_ROCM val = __shfl_up_sync(0xffffffff, skip_prob, i); +#else + val = __shfl_up(skip_prob, i); +#endif if (i <= threadIdx.x) { skip_prob = skip_prob + val; } @@ -248,7 +260,11 @@ __device__ void ComputeBetasCosts( CAST_DTYPE out = val; for (int i = 1; i < warpSize; ++i) { +#ifndef USE_ROCM val = __shfl_up_sync(0xffffffff, val, 1); +#else + val = __shfl_up(val, 1); +#endif if (i == threadIdx.x) { val = math::lse(val + skip_prob, emit); out = val;