diff --git a/fms_extras/models/calico.py b/fms_extras/models/calico.py index 9b7f3b1..6071457 100644 --- a/fms_extras/models/calico.py +++ b/fms_extras/models/calico.py @@ -185,7 +185,7 @@ def __init__( self.pad_id = self.config.pad_id self.max_expected_seq_len = self.config.max_expected_seq_len - shared = WordEmbedding( + self.shared = WordEmbedding( self.config.src_vocab_size, self.config.emb_dim, padding_idx=self.config.pad_id, @@ -194,7 +194,6 @@ def __init__( tie_weights=True, bias=False, ) - self.shared = self.distributed_strategy.distribute_module(shared) self.rot_emb = RotaryEmbedding( dim=self.config.emb_dim // self.config.nheads,