Skip to content

Commit

Permalink
Merge pull request #38 from sahilsuneja1/ss_wt_init_pr
Browse files Browse the repository at this point in the history
Speculator wt init fix
  • Loading branch information
JRosenkranz authored Sep 9, 2024
2 parents b1937db + e93d8fa commit 16339f7
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion fms_extras/models/speculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Embedding) or isinstance(m, nn.Linear):
nn.init.trunc_normal_(m.weight, 0, 1 / math.sqrt(self.inner_dim))
nn.init.normal_(m.weight, 0, 1 / math.sqrt(self.inner_dim))
elif isinstance(m, LayerNormParameterized) and hasattr(m, "weight"):
m.weight.data.fill_(1)
m.bias.data.zero_()
Expand Down
4 changes: 2 additions & 2 deletions fms_extras/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ def forward(
# if use_cache=True, we return the hidden_state as well as the kv cache.
# We only reduce the output, and keep the cache thread-local
if use_cache:
out = reduce_from_tensor_model_parallel_region(out_par[0])
out = reduce_from_tensor_model_parallel_region(out_par[0], self.world_size)
return out, out_par[1]
else:
out = reduce_from_tensor_model_parallel_region(out_par)
out = reduce_from_tensor_model_parallel_region(out_par, self.world_size)
return out
4 changes: 3 additions & 1 deletion fms_extras/utils/cache/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def create(cls, kernel, *args, mutated_inputs=[], **kwargs) -> None:
tensor_args,
non_tensor_args,
unflatten_args,
) = cls.process_kernel(kernel, *args, **kwargs)
) = cls.process_kernel(
kernel, *args, **kwargs
) # type: ignore
for tensor_arg in tensor_args:
tensor_arg.realize()

Expand Down

0 comments on commit 16339f7

Please sign in to comment.