Skip to content

Commit

Permalink
fix - update bert model weight init
Browse files Browse the repository at this point in the history
  • Loading branch information
ne7ermore committed Aug 8, 2019
1 parent 254c133 commit 17caad1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions BERT/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def reshape(x):
return self.lm(outputs + residual)

def reset_parameters(self):
self.w_qs.data.normal_(INIT_RANGE)
self.w_ks.data.normal_(INIT_RANGE)
self.w_vs.data.normal_(INIT_RANGE)
self.w_o.weight.data.normal_(INIT_RANGE)
self.w_qs.data.normal_(std=INIT_RANGE)
self.w_ks.data.normal_(std=INIT_RANGE)
self.w_vs.data.normal_(std=INIT_RANGE)
self.w_o.weight.data.normal_(std=INIT_RANGE)


class EncoderLayer(nn.Module):
Expand All @@ -166,7 +166,7 @@ def __init__(self, d_model):
super().__init__()

self.linear = nn.Linear(d_model, d_model)
self.linear.weight.data.normal_(INIT_RANGE)
self.linear.weight.data.normal_(std=INIT_RANGE)
self.linear.bias.data.zero_()

def forward(self, x):
Expand Down Expand Up @@ -199,10 +199,10 @@ def __init__(self, args):
self.gelu = GELU()

def reset_parameters(self):
self.enc_ebd.weight.data.normal_(INIT_RANGE)
self.seg_ebd.weight.data.normal_(INIT_RANGE)
self.enc_ebd.weight.data.normal_(std=INIT_RANGE)
self.seg_ebd.weight.data.normal_(std=INIT_RANGE)

self.transform.weight.data.normal_(INIT_RANGE)
self.transform.weight.data.normal_(std=INIT_RANGE)
self.transform.bias.data.zero_()

def forward(self, inp, pos, segment_label):
Expand Down

0 comments on commit 17caad1

Please sign in to comment.