forked from tangshengeng/ProgressiveTransformersSLP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinitialization.py
126 lines (100 loc) · 4.52 KB
/
initialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# coding: utf-8
"""
Implements custom initialization
"""
import math
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.init import _calculate_fan_in_and_fan_out
def xavier_uniform_n_(w: Tensor, gain: float = 1., n: int = 4) -> None:
"""
Xavier initializer for parameters that combine multiple matrices in one
parameter for efficiency. This is e.g. used for GRU and LSTM parameters,
where e.g. all gates are computed at the same time by 1 big matrix.
:param w: parameter
:param gain: default 1
:param n: default 4
"""
with torch.no_grad():
fan_in, fan_out = _calculate_fan_in_and_fan_out(w)
assert fan_out % n == 0, "fan_out should be divisible by n"
fan_out //= n
std = gain * math.sqrt(2.0 / (fan_in + fan_out))
a = math.sqrt(3.0) * std
nn.init.uniform_(w, -a, a)
# pylint: disable=too-many-branches
def initialize_model(model: nn.Module, cfg: dict, src_padding_idx: int,
trg_padding_idx: int) -> None:
"""
This initializes a model based on the provided config.
All initializer configuration is part of the `model` section of the
configuration file.
For an example, see e.g. `https://github.com/joeynmt/joeynmt/
blob/master/configs/iwslt_envi_xnmt.yaml#L47`
The main initializer is set using the `initializer` key.
Possible values are `xavier`, `uniform`, `normal` or `zeros`.
(`xavier` is the default).
When an initializer is set to `uniform`, then `init_weight` sets the
range for the values (-init_weight, init_weight).
When an initializer is set to `normal`, then `init_weight` sets the
standard deviation for the weights (with mean 0).
The word embedding initializer is set using `embed_initializer` and takes
the same values. The default is `normal` with `embed_init_weight = 0.01`.
Biases are initialized separately using `bias_initializer`.
The default is `zeros`, but you can use the same initializers as
the main initializer.
:param model: model to initialize
:param cfg: the model configuration
:param src_padding_idx: index of source padding token
:param trg_padding_idx: index of target padding token
"""
# defaults: xavier, embeddings: normal 0.01, biases: zeros, no orthogonal
gain = float(cfg.get("init_gain", 1.0)) # for xavier
init = cfg.get("initializer", "xavier")
init_weight = float(cfg.get("init_weight", 0.01))
embed_init = cfg.get("embed_initializer", "normal")
embed_init_weight = float(cfg.get("embed_init_weight", 0.01))
embed_gain = float(cfg.get("embed_init_gain", 1.0)) # for xavier
bias_init = cfg.get("bias_initializer", "zeros")
bias_init_weight = float(cfg.get("bias_init_weight", 0.01))
# pylint: disable=unnecessary-lambda, no-else-return
def _parse_init(s, scale, _gain):
scale = float(scale)
assert scale > 0., "incorrect init_weight"
if s.lower() == "xavier":
return lambda p: nn.init.xavier_uniform_(p, gain=_gain)
elif s.lower() == "uniform":
return lambda p: nn.init.uniform_(p, a=-scale, b=scale)
elif s.lower() == "normal":
return lambda p: nn.init.normal_(p, mean=0., std=scale)
elif s.lower() == "zeros":
return lambda p: nn.init.zeros_(p)
else:
raise ValueError("unknown initializer")
init_fn_ = _parse_init(init, init_weight, gain)
embed_init_fn_ = _parse_init(embed_init, embed_init_weight, embed_gain)
bias_init_fn_ = _parse_init(bias_init, bias_init_weight, gain)
with torch.no_grad():
for name, p in model.named_parameters():
if "embed" in name:
if "bias" in name:
bias_init_fn_(p)
else:
embed_init_fn_(p)
elif "bias" in name:
bias_init_fn_(p)
elif len(p.size()) > 1:
# RNNs combine multiple matrices is one, which messes up
# xavier initialization
if init == "xavier" and "rnn" in name:
n = 1
if "encoder" in name:
n = 4 if isinstance(model.encoder.rnn, nn.LSTM) else 3
elif "decoder" in name:
n = 4 if isinstance(model.decoder.rnn, nn.LSTM) else 3
xavier_uniform_n_(p.data, gain=gain, n=n)
else:
init_fn_(p)
# zero out paddings
model.src_embed.lut.weight.data[src_padding_idx].zero_()