diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index 2306749fcb..5928a16343 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -6,7 +6,7 @@ from apex.normalization import MixedFusedRMSNorm as RMSNorm else: from .rmsnorm import RMSNorm - from torch.nn import LayerNorm + from .layer_norm_p1 import LayerNorm1P as LayerNorm from .distributed import DistributedDataParallel from .bert_model import BertModel diff --git a/megatron/model/layer_norm_p1.py b/megatron/model/layer_norm_p1.py new file mode 100644 index 0000000000..edaac6675a --- /dev/null +++ b/megatron/model/layer_norm_p1.py @@ -0,0 +1,38 @@ +import math +import numbers + +import torch +import torch.nn as nn +from torch.nn.parameter import Parameter +from torch.nn import init + + +class LayerNorm1P(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, apply_layernorm_1p=False): + super(LayerNorm1P, self).__init__() + self.eps = eps + self.apply_layernorm_1p = apply_layernorm_1p + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.bias = Parameter(torch.Tensor(*normalized_shape)) + self.reset_parameters() + + def reset_parameters(self): + + if self.apply_layernorm_1p: + init.zeros_(self.weight) + init.zeros_(self.bias) + else: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input): + if self.apply_layernorm_1p: + weight_plus_1 = (self.weight + 1) + output = torch.nn.functional.layer_norm(input, self.normalized_shape, weight_plus_1, self.bias, self.eps) + return output + else: + return torch.nn.functional.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index e75f13a24f..1672dbe7e0 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -913,7 +913,8 @@ def __init__(self, config, else: self.input_layernorm = LayerNorm( config.hidden_size, - eps=config.layernorm_epsilon) + eps=config.layernorm_epsilon, + apply_layernorm_1p=args.apply_layernorm_1p) else: self.input_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) # Self attention. @@ -939,7 +940,8 @@ def __init__(self, config, else: self.post_attention_layernorm = LayerNorm( config.hidden_size, - eps=config.layernorm_epsilon) + eps=config.layernorm_epsilon, + apply_layernorm_1p=args.apply_layernorm_1p) else: self.post_attention_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon) # Cross attention. @@ -1762,7 +1764,8 @@ def build_layer(layer_number, n_e): else: self.final_layernorm = LayerNorm( config.hidden_size, - eps=config.layernorm_epsilon) + eps=config.layernorm_epsilon, + apply_layernorm_1p=args.apply_layernorm_1p) else: self.final_layernorm = RMSNorm(config.hidden_size, config.layernorm_epsilon)