From 832bd16ad8c7c4be8ba07f320c7d2ac9fb01fac8 Mon Sep 17 00:00:00 2001 From: deval281shah Date: Tue, 25 Feb 2025 15:57:55 -0800 Subject: [PATCH 1/2] Update models.py to fix num_layers issue for scan_layer=False in DeepSeek --- MaxText/layers/models.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index 8f423e712..fb0fa9081 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -456,13 +456,14 @@ def __call__( layer_prefix = ["dense_layers", "moe_layers"] num_layers = [cfg.first_num_dense_layers, num_moe_layers] for index in range(len(layers)): - y = layers[index](config=cfg, mesh=mesh, name=f"{layer_prefix[index]}_{index}", quant=self.quant)( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) + for index_j in range(num_layers[index]): + y = layers[index](config=cfg, mesh=mesh, name=f"{layer_prefix[index]}_{index_j}", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) else: for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = RemattedBlockLayers[0] From 1c0a9a07a5cc157c193df7503824905b7823118e Mon Sep 17 00:00:00 2001 From: deval281shah Date: Thu, 27 Feb 2025 15:45:52 -0800 Subject: [PATCH 2/2] Update models.py --- MaxText/layers/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/MaxText/layers/models.py b/MaxText/layers/models.py index fb0fa9081..70c81f3d7 100644 --- a/MaxText/layers/models.py +++ b/MaxText/layers/models.py @@ -456,8 +456,8 @@ def __call__( layer_prefix = ["dense_layers", "moe_layers"] num_layers = [cfg.first_num_dense_layers, num_moe_layers] for index in range(len(layers)): - for index_j in range(num_layers[index]): - y = layers[index](config=cfg, mesh=mesh, name=f"{layer_prefix[index]}_{index_j}", quant=self.quant)( + for lyr in range(num_layers[index]): + y = layers[index](config=cfg, mesh=mesh, name=f"{layer_prefix[index]}_{lyr}", quant=self.quant)( y, decoder_segment_ids, decoder_positions,