diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 8f423e712..70c81f3d7 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -456,13 +456,14 @@ def __call__( layer_prefix = ["dense_layers", "moe_layers"] num_layers = [cfg.first_num_dense_layers, num_moe_layers] for index in range(len(layers)): - y = layers[index](config=cfg, mesh=mesh, name=f"{layer_prefix[index]}_{index}", quant=self.quant)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) + for lyr in range(num_layers[index]): + y = layers[index](config=cfg, mesh=mesh, name=f"{layer_prefix[index]}_{lyr}", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) else: for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = RemattedBlockLayers[0]