Skip to content

Commit

Permalink
Update LayerNorm with Welford online algorithm (#1374)
Browse files Browse the repository at this point in the history
Welford online algorithm improves the numerical stability of variance
computation compared to naive or two-pass algorithm.

Simple test case:
```python
import torch
import torch.nn as nn

B, D = 1, 5
x = torch.tensor([[1e15, 1e15 + 1, 1e15 + 2, 1e15 + 3, 1e15 + 4]], dtype=torch.float32).to("xpu")
layernorm = nn.LayerNorm(D, elementwise_affine=False)
y = layernorm(x)
print("LayerNorm Output:\n", y)
```

Output now (welford):
```python
tensor([[0., 0., 0., 0., 0.]], device='xpu:0')
```

Output before (two-pass):
```python
tensor([[nan, nan, nan, nan, nan]], device='xpu:0')
```

---------

Co-authored-by: Yu, Guangye <[email protected]>
Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
3 people authored Feb 24, 2025
1 parent 6acd38d commit 306a0ff
Show file tree
Hide file tree
Showing 4 changed files with 477 additions and 158 deletions.
11 changes: 6 additions & 5 deletions src/ATen/native/xpu/LayerNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,19 @@ ::std::tuple<at::Tensor, at::Tensor, at::Tensor> layer_norm_xpu(

Tensor Y = at::native::empty_like(
*X,
c10::nullopt /* dtype */,
c10::nullopt /* layout */,
c10::nullopt /* device */,
c10::nullopt /* pin_memory */,
std::nullopt /* dtype */,
std::nullopt /* layout */,
std::nullopt /* device */,
std::nullopt /* pin_memory */,
LEGACY_CONTIGUOUS_MEMORY_FORMAT);

auto acc_type = at::toAccumulateType(input.scalar_type(), true);

Tensor mean = at::empty({M}, X->options().dtype(acc_type));
Tensor rstd = at::empty({M}, X->options().dtype(acc_type));

native::xpu::layer_norm_kernel(
*X, *gamma, *beta, M, N, epsilon, Y, mean, rstd);
*X, *gamma, *beta, M, N, epsilon, &Y, &mean, &rstd);

const auto input_shape = input.sizes();
const size_t axis = input.dim() - normalized_shape.size();
Expand Down
Loading

0 comments on commit 306a0ff

Please sign in to comment.