From 306a0ffb6e0cae27c5bd9a3b9cd378048c8e00e7 Mon Sep 17 00:00:00 2001 From: min-jean-cho Date: Mon, 24 Feb 2025 06:21:49 -0800 Subject: [PATCH] Update LayerNorm with Welford online algorithm (#1374) Welford online algorithm improves the numerical stability of variance computation compared to naive or two-pass algorithm. Simple test case: ```python import torch import torch.nn as nn B, D = 1, 5 x = torch.tensor([[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]], dtype=torch.float32).to("xpu") layernorm = nn.LayerNorm(D, elementwise_affine=False) y = layernorm(x) print("LayerNorm Output:\n", y) ``` Output now (welford): ```python tensor([[0., 0., 0., 0., 0.]], device='xpu:0') ``` Output before (two-pass): ```python tensor([[nan, nan, nan, nan, nan]], device='xpu:0') ``` --------- Co-authored-by: Yu, Guangye <106960996+guangyey@users.noreply.github.com> Co-authored-by: Yutao Xu --- src/ATen/native/xpu/LayerNorm.cpp | 11 +- src/ATen/native/xpu/sycl/LayerNormKernels.cpp | 596 +++++++++++++----- src/ATen/native/xpu/sycl/LayerNormKernels.h | 8 +- test/regressions/test_layer_norm.py | 20 + 4 files changed, 477 insertions(+), 158 deletions(-) create mode 100644 test/regressions/test_layer_norm.py diff --git a/src/ATen/native/xpu/LayerNorm.cpp b/src/ATen/native/xpu/LayerNorm.cpp index f9e032122..8ee463b84 100644 --- a/src/ATen/native/xpu/LayerNorm.cpp +++ b/src/ATen/native/xpu/LayerNorm.cpp @@ -49,18 +49,19 @@ ::std::tuple layer_norm_xpu( Tensor Y = at::native::empty_like( *X, - c10::nullopt /* dtype */, - c10::nullopt /* layout */, - c10::nullopt /* device */, - c10::nullopt /* pin_memory */, + std::nullopt /* dtype */, + std::nullopt /* layout */, + std::nullopt /* device */, + std::nullopt /* pin_memory */, LEGACY_CONTIGUOUS_MEMORY_FORMAT); + auto acc_type = at::toAccumulateType(input.scalar_type(), true); Tensor mean = at::empty({M}, X->options().dtype(acc_type)); Tensor rstd = at::empty({M}, X->options().dtype(acc_type)); native::xpu::layer_norm_kernel( - *X, *gamma, *beta, M, N, epsilon, Y, mean, rstd); + *X, *gamma, *beta, M, N, epsilon, &Y, &mean, &rstd); const auto input_shape = input.sizes(); const size_t axis = input.dim() - normalized_shape.size(); diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp index 171c54736..6982354f8 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -15,105 +16,6 @@ namespace at { namespace native { namespace xpu { -template -class LayerNormForward : public NormForward { - public: - using accscalar_t = acc_type_device; - typedef NormForward NF; - LayerNormForward() = delete; - LayerNormForward( - const scalar_t* X_data, - scalar_t* Y_data, - mean_t* mean_data, - mean_t* var_data, - const weight_t* gamma_data, - const weight_t* beta_data, - accscalar_t eps, - int64_t M, - int64_t N) - : NormForward( - X_data, - Y_data, - mean_data, - var_data, - gamma_data, - beta_data, - eps), - M(M), - N(N) { - numel = M * N; - }; - - template < - int vec_size, - typename index_t, - typename vec_t, - typename weight_vec_t, - typename nd_item_id> - void update( - nd_item_id item_id, - const NormConfig& cfg, - accscalar_t sum1 = 0, - accscalar_t sum2 = 0) const { - auto group_id = item_id.get_group(0); - auto group_id_foreach = item_id.get_group(1); - auto local_id = item_id.get_local_id(2); - - index_t group_offset = group_id * cfg.problem_size; - if (cfg.workgroup_num_foreach == 1) { - if (local_id == 0) { - NF::reduce_project(item_id, sum1, sum2, cfg); - } - item_id.barrier(sycl_global_fence); - } - - mean_t mean_val = NF::mean_data[group_id]; - mean_t var_val = NF::var_data[group_id]; - for (index_t j = local_id * vec_size; j < cfg.workgroup_work_size; - j += cfg.workgroup_size * vec_size) { - index_t plane_offset = group_id_foreach * cfg.workgroup_work_size + j; - if (plane_offset < (index_t)cfg.problem_size) { - vec_t X_val = *(reinterpret_cast( - NF::X_data + group_offset + plane_offset)); - weight_vec_t gamma_val, beta_val; - vec_t Y_val; - if (NF::gamma_data != nullptr) { - gamma_val = *(reinterpret_cast( - NF::gamma_data + plane_offset)); - } - if (NF::beta_data != nullptr) { - beta_val = *(reinterpret_cast( - NF::beta_data + plane_offset)); - } - - for (int v = 0; v < vec_size; ++v) { - if (NF::gamma_data != nullptr && NF::beta_data != nullptr) { - Y_val[v] = static_cast(gamma_val[v]) * - (var_val * static_cast(X_val[v] - mean_val)) + - static_cast(beta_val[v]); - } else if (NF::gamma_data != nullptr) { - Y_val[v] = static_cast(gamma_val[v]) * - (var_val * static_cast(X_val[v] - mean_val)); - } else if (NF::beta_data != nullptr) { - Y_val[v] = - (var_val * static_cast(X_val[v] - mean_val)) + - static_cast(beta_val[v]); - } else { - Y_val[v] = - (var_val * static_cast(X_val[v] - mean_val)); - } - } - *(reinterpret_cast(NF::Y_data + group_offset + plane_offset)) = - Y_val; - } - } - }; - - int64_t M; - int64_t N; - int64_t numel; -}; - template class LayerNormBackward : public NormBackward { public: @@ -268,45 +170,445 @@ class LayerNormBackward : public NormBackward { int64_t numel; }; -template +constexpr int vec_size = + 4; // we could make it dependent on dtype, but that would lead to different + // results between float and low-p types + +// Checks alignment of buffers for using vectorized loads / stores +template +bool can_vectorize(const T* ptr, int alignment) { + uint64_t addr = reinterpret_cast(ptr); + return addr % alignment == 0; +}; + +template +struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + using WelfordType = WelfordData; + using WelfordOp = WelfordOps>; + + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<1> item_id) const { + const int64_t i = item_id.get_group(0); + WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false}; + WelfordType val(0, 0, 0, 0); + for (int64_t j = item_id.get_local_id(0); j < N_; + j += item_id.get_local_range(0)) { + const int64_t index = i * N_ + j; + val = welford_op.reduce(val, static_cast(X_[index]), index); + } + + val = GroupReduceWithoutBroadcast( + item_id, val, welford_op, shared_); + + if (item_id.get_local_id(0) == 0) { + T_ACC m1; + T_ACC m2; + std::tie(m2, m1) = welford_op.project(val); + mean_[i] = m1; + rstd_[i] = c10::xpu::compat::rsqrt(m2 + eps_); + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + shared_ = sycl_local_acc_t(SIMD, cgh); + } + + RowwiseMomentsFunctor( + int64_t N, + T_ACC eps, + const T* X, + T_ACC* mean, + T_ACC* rstd) + : N_(N), eps_(eps), X_(X), mean_(mean), rstd_(rstd) {} + + private: + int64_t N_; + T_ACC eps_; + const T* X_; + T_ACC* mean_; + T_ACC* rstd_; + sycl_local_acc_t shared_; +}; + +template +void launch_rowwise_moments_kernel( + int64_t N, + int64_t M, + T_ACC eps, + const T* X_data, + T_ACC* mean_data, + T_ACC* rstd_data) { + RowwiseMomentsFunctor kfn(N, eps, X_data, mean_data, rstd_data); + + int64_t sg_size = SIMD; + int64_t wg_size = get_group_reduce_group_size(sg_size); + sycl::range<1> local_range{size_t(wg_size)}; + sycl::range<1> global_range{size_t(M * wg_size)}; + auto queue = getCurrentSYCLQueue(); + + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + +template +struct LayerNormForwardKernelFunctor { + void operator()(sycl::nd_item<1> item_id) const { + const int64_t i = item_id.get_group(0); + for (int64_t j = item_id.get_local_id(0); j < N_; + j += item_id.get_local_range(0)) { + const int64_t index = i * N_ + j; + const T_ACC gamma_v = + gamma_ == nullptr ? T_ACC(1) : static_cast(gamma_[j]); + const T_ACC beta_v = + beta_ == nullptr ? T_ACC(0) : static_cast(beta_[j]); + Y_[index] = + (static_cast(X_[index]) - static_cast(mean_[i])) * + static_cast(rstd_[i]) * gamma_v + + beta_v; + } + } + LayerNormForwardKernelFunctor( + int64_t N, + const T* X, + const T_ACC* mean, + const T_ACC* rstd, + const T* gamma, + const T* beta, + T* Y) + : N_(N), + X_(X), + mean_(mean), + rstd_(rstd), + gamma_(gamma), + beta_(beta), + Y_(Y) {} + + private: + int64_t N_; + const T* X_; + const T_ACC* mean_; + const T_ACC* rstd_; + const T* gamma_; + const T* beta_; + T* Y_; +}; + +template +void launch_layer_norm_forward_kernel( + int64_t N, + int64_t M, + const T* X_data, + const T_ACC* mean_data, + const T_ACC* rstd_data, + const T* gamma_data, + const T* beta_data, + T* Y_data) { + LayerNormForwardKernelFunctor kfn( + N, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); + + int64_t sg_size = SIMD; + int64_t wg_size = get_group_reduce_group_size(sg_size); + sycl::range<1> local_range{size_t(wg_size)}; + sycl::range<1> global_range(M * size_t(wg_size)); + auto queue = getCurrentSYCLQueue(); + + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + +struct WelfordDataLN { + float mean; + float sigma2; + float count; + WelfordDataLN() : mean(0.f), sigma2(0.f), count(0.f) {} + WelfordDataLN(float mean, float sigma2, float count) + : mean(mean), sigma2(sigma2), count(count) {} +}; + +template +WelfordDataLN WelfordOnlineSum(const U val, const WelfordDataLN& curr_sum) { + U delta = val - curr_sum.mean; + U new_count = curr_sum.count + 1.f; + U new_mean = curr_sum.mean + + delta * (1.f / new_count); // proper division is slow, this is less + // accurate but noticeably faster + return { + static_cast(new_mean), + static_cast(curr_sum.sigma2 + delta * (val - new_mean)), + static_cast(new_count)}; +} + +WelfordDataLN WelfordCombine( + const WelfordDataLN dataB, + const WelfordDataLN dataA) { + using U = decltype(dataB.count); + U delta = dataB.mean - dataA.mean; + U count = dataA.count + dataB.count; + U mean, sigma2; + if (count > decltype(dataB.count){0}) { + auto coef = 1.f / count; // NB we don't use --use_fast_math, but this is + // emulation, 1./count goes to intrinsic, `* coef` + // is multiplication, instead of slow fp division + auto nA = dataA.count * coef; + auto nB = dataB.count * coef; + mean = nA * dataA.mean + nB * dataB.mean; + sigma2 = dataA.sigma2 + dataB.sigma2 + delta * delta * dataA.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; +} + +template +WelfordDataLN compute_stats( + const T* RESTRICT X, + const int N, + T_ACC& buf, + sycl::nd_item<2>& item_id) { + // X points to the row to read + using vec_t = aligned_vector; + using acc_t = acc_type_device; + const vec_t* X_vec = reinterpret_cast(X); + const int numx = item_id.get_local_range(1) * item_id.get_local_range(0); + const int thrx = item_id.get_local_linear_id(); + const int n_vec_to_read = N / vec_size; + WelfordDataLN wd(0.f, 0.f, 0.f); + // no tail, we check that N is multiple of vec_size + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + wd = WelfordOnlineSum(static_cast(data.val[ii]), wd); + } + } + // intra-warp reduction + auto sg = item_id.get_sub_group(); + for (int offset = (SIMD >> 1); offset > 0; offset >>= 1) { + WelfordDataLN wdB{ + sycl::shift_group_left(sg, wd.mean, offset), + sycl::shift_group_left(sg, wd.sigma2, offset), + sycl::shift_group_left(sg, wd.count, offset)}; + wd = WelfordCombine(wd, wdB); + } + + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (item_id.get_local_range(0) > 1) { + auto addr_offset = item_id.get_local_range(0); + for (int offset = item_id.get_local_range(0) / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) >= offset && + item_id.get_local_id(0) < 2 * offset) { + const int wrt_y = item_id.get_local_id(0) - offset; + buf[2 * wrt_y] = wd.mean; + buf[2 * wrt_y + 1] = wd.sigma2; + buf[wrt_y + addr_offset] = wd.count; + } + item_id.barrier(sycl_local_fence); + + // lower half merges + if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) < offset) { + const int rd_y = item_id.get_local_id(0); + WelfordDataLN wdB{ + static_cast(buf[2 * rd_y]), + static_cast(buf[2 * rd_y + 1]), + static_cast(buf[rd_y + addr_offset])}; + wd = WelfordCombine(wd, wdB); + } + item_id.barrier(sycl_local_fence); + } + + if (item_id.get_local_id(1) == 0 && item_id.get_local_id(0) == 0) { + buf[0] = wd.mean; + buf[1] = wd.sigma2 / float(N); + } + item_id.barrier(sycl_local_fence); + return WelfordDataLN{ + static_cast(buf[0]), static_cast(buf[1]), 0.f}; + } else { + return WelfordDataLN{ + sycl::select_from_group(sg, wd.mean, 0), + sycl::select_from_group(sg, wd.sigma2, 0) / float(N), + 0.f}; + } +} + +template +struct VectorizedLayerNormKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + [[intel::reqd_sub_group_size(SIMD)]] void operator()( + sycl::nd_item<2> item_id) const { + auto i1 = item_id.get_group(1); + const T* block_row = X_ + i1 * N_; + WelfordDataLN wd = compute_stats(block_row, N_, buf_, item_id); + + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(block_row); + const vec_t* gamma_vec = + (gamma_ != nullptr) ? reinterpret_cast(gamma_) : nullptr; + const vec_t* beta_vec = + (beta_ != nullptr) ? reinterpret_cast(beta_) : nullptr; + vec_t* Y_vec = reinterpret_cast(Y_ + i1 * N_); + + const int numx = item_id.get_local_range(1) * item_id.get_local_range(0); + const int thrx = item_id.get_local_linear_id(); + const int n_vec_to_read = N_ / vec_size; + + T_ACC rstd_val = c10::xpu::compat::rsqrt(wd.sigma2 + eps_); + + // No tail, N is guaranteed to be multiple of vec size + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; + + // Computation is performed in T_ACC, X is cast to T_ACC and result is + // implicitly cast to T + if (gamma_vec != nullptr && beta_vec != nullptr) { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * + (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } + } else if (gamma_vec != nullptr) { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = static_cast(gamma_vec[i].val[ii]) * + (rstd_val * (static_cast(data.val[ii]) - wd.mean)); + } + } else if (beta_vec != nullptr) { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = + (rstd_val * (static_cast(data.val[ii]) - wd.mean)) + + static_cast(beta_vec[i].val[ii]); + } + } else { +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = rstd_val * (static_cast(data.val[ii]) - wd.mean); + } + } + Y_vec[i] = out; + } + if (thrx == 0) { + mean_[i1] = wd.mean; + rstd_[i1] = rstd_val; + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + buf_ = sycl_local_acc_t((wg_size_ / SIMD) * 2, cgh); + } + + VectorizedLayerNormKernelFunctor( + const int N, + T_ACC eps, + const T* RESTRICT X, + const T* gamma, + const T* beta, + T_ACC* mean, + T_ACC* rstd, + T* Y, + int64_t wg_size) + : N_(N), + eps_(eps), + X_(X), + gamma_(gamma), + beta_(beta), + mean_(mean), + rstd_(rstd), + Y_(Y), + wg_size_(wg_size) {} + + private: + const int N_; + T_ACC eps_; + const T* RESTRICT X_; + const T* gamma_; + const T* beta_; + T_ACC* mean_; + T_ACC* rstd_; + T* Y_; + int64_t sg_size_; + int64_t wg_size_; + sycl_local_acc_t buf_; +}; + +template +void launch_vectorized_layer_norm_kernel( + int N, + int64_t M, + T_ACC eps, + const T* X_data, + const T* gamma_data, + const T* beta_data, + T* Y_data, + T_ACC* mean_data, + T_ACC* rstd_data) { + using KernelClass = VectorizedLayerNormKernelFunctor; + int64_t wg_size = syclMaxWorkGroupSize(); + KernelClass kfn( + N, + eps, + X_data, + gamma_data, + beta_data, + mean_data, + rstd_data, + Y_data, + wg_size); + sycl::range<2> local_range{size_t(wg_size / SIMD), SIMD}; + sycl::range<2> global_range(size_t(wg_size / SIMD), M * SIMD); + auto queue = getCurrentSYCLQueue(); + sycl_kernel_submit(global_range, local_range, queue, kfn); +} + +template void _layer_norm_kernel( const Tensor& X, const Tensor& gamma, const Tensor& beta, int64_t M, int64_t N, - acc_type_device eps, - Tensor& Y, - Tensor& mean, - Tensor& rstd) { - TORCH_CHECK(X.numel() == M * N); - TORCH_CHECK(!gamma.defined() || gamma.numel() == N); - TORCH_CHECK(!beta.defined() || beta.numel() == N); - - const scalar_t* X_data = X.const_data_ptr(); - scalar_t* Y_data = Y.data_ptr(); - mean_t* mean_data = mean.data_ptr(); - mean_t* var_data = rstd.data_ptr(); - const weight_t* gamma_data = - gamma.defined() ? gamma.const_data_ptr() : nullptr; - const weight_t* beta_data = - beta.defined() ? beta.const_data_ptr() : nullptr; - - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); - bool can_use_32bit_index = canUse32BitIndexMath(X); - LayerNormForward norm( - X_data, Y_data, mean_data, var_data, gamma_data, beta_data, eps, M, N); - - if (config.workgroup_num_foreach == 1) { - vectorized_fused_norm_kernel( - norm, config, can_use_32bit_index); + T_ACC eps, + Tensor* Y, + Tensor* mean, + Tensor* rstd) { + const T* X_data = X.const_data_ptr(); + const T* gamma_data = gamma.defined() ? gamma.const_data_ptr() : nullptr; + const T* beta_data = beta.defined() ? beta.const_data_ptr() : nullptr; + T* Y_data = Y->data_ptr(); + T_ACC* mean_data = mean->data_ptr(); + T_ACC* rstd_data = rstd->data_ptr(); + + constexpr int num_vec_elems = vec_size; + constexpr int alignment = num_vec_elems * sizeof(T); + bool can_vec_X = can_vectorize(X_data, alignment); + bool can_vec_Y = can_vectorize(Y_data, alignment); + bool can_vec_gamma = + gamma.defined() ? can_vectorize(gamma_data, alignment) : true; + bool can_vec_beta = + beta.defined() ? can_vectorize(beta_data, alignment) : true; + + if ((std::is_same_v || std::is_same_v || + std::is_same_v)&&N <= + static_cast(1ULL << std::numeric_limits::digits) && + N % num_vec_elems == 0 && can_vec_X && can_vec_Y && can_vec_gamma && + can_vec_beta) { + launch_vectorized_layer_norm_kernel( + static_cast(N), + M, + eps, + X_data, + gamma_data, + beta_data, + Y_data, + mean_data, + rstd_data); } else { - Tensor semaphores, scratchpad; - config.template init_global_reduce(X, semaphores, scratchpad); - rowwise_moments_kernel( - norm, config, can_use_32bit_index); - norm_update_kernel( - norm, config, can_use_32bit_index); + launch_rowwise_moments_kernel(N, M, eps, X_data, mean_data, rstd_data); + launch_layer_norm_forward_kernel( + N, M, X_data, mean_data, rstd_data, gamma_data, beta_data, Y_data); } } @@ -596,30 +898,26 @@ void _layer_norm_backward_kernel( dY, X, mean_data, var_data, dgamma, dbeta, config_w); } -std::tuple layer_norm_kernel( +void layer_norm_kernel( const Tensor& X, const Tensor& gamma, const Tensor& beta, int64_t M, int64_t N, double eps, - Tensor& Y, - Tensor& mean, - Tensor& rstd) { - if (M > 0) { - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - X.scalar_type(), - "layer_norm_xpu", - [&]() { - using acc_t = acc_type_device; - _layer_norm_kernel( - X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); - }); - } - - return std::make_tuple(Y, mean, rstd); + Tensor* Y, + Tensor* mean, + Tensor* rstd) { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + X.scalar_type(), + "layer_norm_xpu", + [&]() { + using acc_t = acc_type_device; + _layer_norm_kernel( + X, gamma, beta, M, N, static_cast(eps), Y, mean, rstd); + }); } std::tuple layer_norm_backward_kernel( diff --git a/src/ATen/native/xpu/sycl/LayerNormKernels.h b/src/ATen/native/xpu/sycl/LayerNormKernels.h index 0c57a61ba..1f6f2f974 100644 --- a/src/ATen/native/xpu/sycl/LayerNormKernels.h +++ b/src/ATen/native/xpu/sycl/LayerNormKernels.h @@ -6,16 +6,16 @@ namespace at { namespace native { namespace xpu { -TORCH_XPU_API std::tuple layer_norm_kernel( +TORCH_XPU_API void layer_norm_kernel( const Tensor& X, const Tensor& gamma, const Tensor& beta, int64_t M, int64_t N, double eps, - Tensor& Y, - Tensor& mean, - Tensor& rstd); + Tensor* Y, + Tensor* mean, + Tensor* rstd); TORCH_XPU_API std::tuple layer_norm_backward_kernel( const Tensor& dY, diff --git a/test/regressions/test_layer_norm.py b/test/regressions/test_layer_norm.py new file mode 100644 index 000000000..254ccb08c --- /dev/null +++ b/test/regressions/test_layer_norm.py @@ -0,0 +1,20 @@ +# Owner(s): ["module: intel"] +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import TestCase + +cpu_device = torch.device("cpu") +xpu_device = torch.device("xpu") + + +class TestLayerNorm(TestCase): + def test_layer_norm_no_nan(self, dtype=torch.float): + dim = [5] + x_cpu = torch.tensor([[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]]) + layernorm_cpu = nn.LayerNorm(dim) + y_cpu = layernorm_cpu(x_cpu) + + x_xpu = x_cpu.to(xpu_device) + layernorm_xpu = nn.LayerNorm(dim).to(xpu_device) + y_xpu = layernorm_xpu(x_xpu) + self.assertEqual(y_cpu, y_xpu.to(cpu_device))