From 17acf221b44fb5c6284acd5e45ffeac243ead13c Mon Sep 17 00:00:00 2001 From: w5688414 Date: Fri, 5 Jan 2024 17:26:13 +0800 Subject: [PATCH 1/9] Update convert_files_to_dicts_splitter (#7748) --- pipelines/pipelines/utils/preprocessing.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pipelines/pipelines/utils/preprocessing.py b/pipelines/pipelines/utils/preprocessing.py index 2cc2bcdb3148..f90b5a3aef7f 100644 --- a/pipelines/pipelines/utils/preprocessing.py +++ b/pipelines/pipelines/utils/preprocessing.py @@ -59,7 +59,7 @@ def document_rough_split(document_list, max_token=4500): return document_index_rough -def split_document(document_index, all_document, split_text, split_paragraphs: bool, clean_func, path, split_answers): +def split_document(document_index, all_document, splitter, split_paragraphs: bool, clean_func, path, split_answers): start = document_index[0] end = document_index[1] documents = [] @@ -68,7 +68,7 @@ def split_document(document_index, all_document, split_text, split_paragraphs: b if clean_func: text = clean_func(text) if split_paragraphs is True: - text_splits = split_text.split_text(text) + text_splits = splitter.split_text(text) for txt in text_splits: if not txt.strip(): # skip empty paragraphs continue @@ -95,7 +95,7 @@ def split_document(document_index, all_document, split_text, split_paragraphs: b def run_process( document_combination_index, list_documents, - split_text, + splitter, process_num, split_paragraphs, clean_func, @@ -107,7 +107,7 @@ def run_process( split_document_c = functools.partial( split_document, all_document=list_documents, - split_text=split_text, + splitter=splitter, split_paragraphs=split_paragraphs, clean_func=clean_func, path=path, @@ -168,8 +168,6 @@ def convert_files_to_dicts( documents = [] for suffix, paths in suffix2paths.items(): for path in paths: - if encoding is None and suffix == ".pdf": - encoding = "Latin1" logger.info("Converting {}".format(path)) list_documents = suffix2converter[suffix].convert( file_path=path, @@ -280,7 +278,11 @@ def convert_files_to_dicts_splitter( pipeline="en_core_web_sm", ) pdf_splitter = SpacyTextSplitter( - separator=separator, chunk_size=chunk_size, chunk_overlap=chunk_overlap, filters=filters + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + filters=filters, + pipeline="en_core_web_sm", ) text_splitter = CharacterTextSplitter( separator=separator, chunk_size=chunk_size, chunk_overlap=chunk_overlap, filters=filters @@ -309,8 +311,6 @@ def convert_files_to_dicts_splitter( suffix2splitter[file_suffix] = markdown_splitter for suffix, paths in suffix2paths.items(): for path in paths: - if encoding is None and suffix == ".pdf": - encoding = "Latin1" logger.info("Converting {}".format(path)) list_documents = suffix2converter[suffix].convert( file_path=path, @@ -330,7 +330,7 @@ def convert_files_to_dicts_splitter( document_mul = run_process( document_combination_index=document_combination_index, list_documents=list_documents, - split_text=suffix2splitter[suffix], + splitter=suffix2splitter[suffix], process_num=process_num, split_paragraphs=split_paragraphs, clean_func=clean_func, From 079f0674b64b03931ee7bb30e051582db6b895dd Mon Sep 17 00:00:00 2001 From: lugimzzz <63761690+lugimzzz@users.noreply.github.com> Date: Mon, 8 Jan 2024 11:45:02 +0800 Subject: [PATCH 2/9] fix (#7781) --- paddlenlp/peft/lora/lora_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/paddlenlp/peft/lora/lora_model.py b/paddlenlp/peft/lora/lora_model.py index 4f6a6d2ef4bf..33cfd32e9307 100644 --- a/paddlenlp/peft/lora/lora_model.py +++ b/paddlenlp/peft/lora/lora_model.py @@ -195,8 +195,7 @@ def _merge_trainable_tensor_parallel(self, trainable_state_dict): if key in trainable_name_action_mappings: ret = distributed_gather(tensor, group=mp_group, offload=True) action = trainable_name_action_mappings[key] - is_collumn = self.lora_split_mapping[key] - if "_scale" in key and not is_collumn and is_dst: + if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst: ret = paddle.to_tensor(ret) tensor = paddle.max(ret, axis=0) else: From ff1e91088a5fedd11f34877b00a63ad7403e7abf Mon Sep 17 00:00:00 2001 From: qingzhong1 <137043369+qingzhong1@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:48:09 +0800 Subject: [PATCH 3/9] [Paddle-Pipelines] update faiss (#7793) * update faiss * update faiss * update faiss --- pipelines/pipelines/document_stores/faiss.py | 6 +----- pipelines/pipelines/document_stores/sql.py | 3 --- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/pipelines/pipelines/document_stores/faiss.py b/pipelines/pipelines/document_stores/faiss.py index 8ba60f1ba056..4a55160a5982 100644 --- a/pipelines/pipelines/document_stores/faiss.py +++ b/pipelines/pipelines/document_stores/faiss.py @@ -391,7 +391,7 @@ def update_embeddings( vector_id_map = {} for doc in document_batch: - vector_id_map[str(doc.id)] = str(vector_id) + vector_id_map[str(doc.id)] = str(vector_id) + "_" + index vector_id += 1 self.update_vector_ids(vector_id_map, index=index) progress_bar.set_description_str("Documents Processed") @@ -443,7 +443,6 @@ def get_all_documents_generator( ) if return_embedding is None: return_embedding = self.return_embedding - for doc in documents: if return_embedding: if doc.meta and doc.meta.get("vector_id") is not None: @@ -588,7 +587,6 @@ def query_by_embedding( if filters: logger.warning("Query filters are not implemented for the FAISSDocumentStore.") - index = index or self.index if not self.faiss_indexes.get(index): raise Exception(f"Index named '{index}' does not exists. Use 'update_embeddings()' to create an index.") @@ -599,11 +597,9 @@ def query_by_embedding( query_emb = query_emb.reshape(1, -1).astype(np.float32) if self.similarity == "cosine": self.normalize_embedding(query_emb) - score_matrix, vector_id_matrix = self.faiss_indexes[index].search(query_emb, top_k) vector_ids_for_query = [str(vector_id) + "_" + index for vector_id in vector_id_matrix[0] if vector_id != -1] documents = self.get_documents_by_vector_ids(vector_ids_for_query, index=index) - # assign query score to each document scores_for_vector_ids: Dict[str, float] = { str(v_id): s for v_id, s in zip(vector_id_matrix[0], score_matrix[0]) diff --git a/pipelines/pipelines/document_stores/sql.py b/pipelines/pipelines/document_stores/sql.py index 5fcafb72fb95..e579513f55e5 100644 --- a/pipelines/pipelines/document_stores/sql.py +++ b/pipelines/pipelines/document_stores/sql.py @@ -216,7 +216,6 @@ def get_documents_by_vector_ids( ): """Fetch documents by specifying a list of text vector id strings""" index = index or self.index - documents = [] for i in range(0, len(vector_ids), batch_size): query = self.session.query(DocumentORM).filter( @@ -224,7 +223,6 @@ def get_documents_by_vector_ids( ) for row in query.all(): documents.append(self._convert_sql_row_to_document(row)) - sorted_documents = sorted(documents, key=lambda doc: vector_ids.index(doc.meta["vector_id"])) return sorted_documents @@ -405,7 +403,6 @@ def write_documents( document_objects = [Document.from_dict(d) if isinstance(d, dict) else d for d in documents] else: document_objects = documents - document_objects = self._handle_duplicate_documents( documents=document_objects, index=index, duplicate_documents=duplicate_documents ) From 487428b3996e1e4f9c09029a608a5b4db959f14e Mon Sep 17 00:00:00 2001 From: DrownFish19 Date: Mon, 8 Jan 2024 15:59:12 +0800 Subject: [PATCH 4/9] Fix shared weights sync for PipelineLayer (#7772) * fix shared weights sync * fix typo --- paddlenlp/transformers/model_utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 56a698a01170..69c0746c6300 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -41,7 +41,10 @@ ) from huggingface_hub.utils import EntryNotFoundError from paddle import Tensor -from paddle.distributed.fleet.meta_parallel.parallel_layers import SharedLayerDesc +from paddle.distributed.fleet.meta_parallel.parallel_layers import ( + PipelineLayer, + SharedLayerDesc, +) from paddle.nn import Embedding, Layer # TODO(fangzeyang) Temporary fix and replace by paddle framework downloader later @@ -933,6 +936,18 @@ def _post_init(self, original_init, *args, **kwargs): ): self.init_weights() + # Note: + # 1. PipelineLayer will create parameters for each layer and + # call `_synchronize_shared_weights()` to synchronize the shared parameters. + # 2. When setting the model `state_dict`, `_synchronize_shared_weights` will be called to + # synchronize the shared parameters. + # However, `self._init_weights` will re-initialize the parameters without + # synchronizing the shared parameters. If the following step does not load a checkpoint, + # the shared parameters will be different. + + if isinstance(self, PipelineLayer): + self._synchronize_shared_weights() + def _init_weights(self, layer): """ Initialize the weights. This method should be overridden by derived class. From 97f6158e76369ac1cd587f66411b6a6320ca852b Mon Sep 17 00:00:00 2001 From: yujun <50394665+JunnYu@users.noreply.github.com> Date: Mon, 8 Jan 2024 18:38:02 +0800 Subject: [PATCH 5/9] slow (#7798) --- tests/transformers/load_subfolder/test_config.py | 4 ++++ tests/transformers/load_subfolder/test_image_processor.py | 2 ++ tests/transformers/load_subfolder/test_model.py | 7 +++++++ tests/transformers/load_subfolder/test_processor.py | 2 ++ tests/transformers/load_subfolder/test_tokenizer.py | 4 ++++ 5 files changed, 19 insertions(+) diff --git a/tests/transformers/load_subfolder/test_config.py b/tests/transformers/load_subfolder/test_config.py index 1e7c1f687af8..b6b7af459e82 100644 --- a/tests/transformers/load_subfolder/test_config.py +++ b/tests/transformers/load_subfolder/test_config.py @@ -15,9 +15,11 @@ from paddlenlp.transformers import AutoConfig, BertConfig, CLIPConfig, T5Config from paddlenlp.utils.log import logger +from tests.testing_utils import slow class ConfigLoadTester(unittest.TestCase): + @slow def test_bert_config_load(self): logger.info("Download Bert Config from PaddleNLP BOS") bert_config = BertConfig.from_pretrained("bert-base-uncased", from_hf_hub=False) @@ -43,6 +45,7 @@ def test_bert_config_load(self): bert_config = BertConfig.from_pretrained("aistudio/bert-base-uncased", from_aistudio=True) bert_config = AutoConfig.from_pretrained("aistudio/bert-base-uncased", from_aistudio=True) + @slow def test_clip_config_load(self): logger.info("Download CLIP Config from PaddleNLP BOS") clip_config = CLIPConfig.from_pretrained("openai/clip-vit-base-patch32", from_hf_hub=False) @@ -68,6 +71,7 @@ def test_clip_config_load(self): clip_config = CLIPConfig.from_pretrained("aistudio/clip-vit-base-patch32", from_aistudio=True) clip_config = AutoConfig.from_pretrained("aistudio/clip-vit-base-patch32", from_aistudio=True) + @slow def test_t5_config_load(self): logger.info("Download T5 Config from PaddleNLP BOS") t5_config = T5Config.from_pretrained("t5-small", from_hf_hub=False) diff --git a/tests/transformers/load_subfolder/test_image_processor.py b/tests/transformers/load_subfolder/test_image_processor.py index a909015e804d..fc55da116525 100644 --- a/tests/transformers/load_subfolder/test_image_processor.py +++ b/tests/transformers/load_subfolder/test_image_processor.py @@ -16,9 +16,11 @@ from paddlenlp.transformers import AutoImageProcessor, CLIPImageProcessor from paddlenlp.utils.log import logger +from tests.testing_utils import slow class ImageProcessorLoadTester(unittest.TestCase): + @slow def test_clip_load(self): logger.info("Download model from PaddleNLP BOS") clip_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32", from_hf_hub=False) diff --git a/tests/transformers/load_subfolder/test_model.py b/tests/transformers/load_subfolder/test_model.py index 6fbdb53caa7d..938ee33e9608 100644 --- a/tests/transformers/load_subfolder/test_model.py +++ b/tests/transformers/load_subfolder/test_model.py @@ -20,6 +20,7 @@ from paddlenlp.transformers import AutoModel, BertModel, CLIPTextModel, T5Model from paddlenlp.utils.log import logger +from tests.testing_utils import slow class ModelLoadTester(unittest.TestCase): @@ -58,6 +59,7 @@ def test_cache_dir( else: assert any(".pdparams" in f for f in file_list), "*.pdparams not in cache_dir" + @slow def test_bert_load(self): # BOS logger.info("Download model from PaddleNLP BOS") @@ -194,6 +196,7 @@ def test_bert_load(self): use_safetensors=False, ) + @slow def test_bert_load_safe(self): # BOS logger.info("Download model from PaddleNLP BOS") @@ -320,6 +323,7 @@ def test_bert_load_safe(self): use_safetensors=True, ) + @slow def test_clip_load(self): # BOS logger.info("Download model from PaddleNLP BOS") @@ -466,6 +470,7 @@ def test_clip_load(self): use_safetensors=False, ) + @slow def test_clip_load_safe(self): # BOS logger.info("Download model from PaddleNLP BOS") @@ -608,6 +613,7 @@ def test_clip_load_safe(self): use_safetensors=True, ) + @slow def test_t5_load(self): # BOS logger.info("Download model from PaddleNLP BOS") @@ -726,6 +732,7 @@ def test_t5_load(self): use_safetensors=False, ) + @slow def test_t5_load_safe(self): # BOS logger.info("Download model from PaddleNLP BOS") diff --git a/tests/transformers/load_subfolder/test_processor.py b/tests/transformers/load_subfolder/test_processor.py index ac4af8859c1d..bd1e4751660b 100644 --- a/tests/transformers/load_subfolder/test_processor.py +++ b/tests/transformers/load_subfolder/test_processor.py @@ -17,9 +17,11 @@ from paddlenlp.transformers import AutoProcessor, CLIPProcessor from paddlenlp.utils.log import logger +from tests.testing_utils import slow class ProcessorLoadTester(unittest.TestCase): + @slow def test_clip_load(self): logger.info("Download model from PaddleNLP BOS") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", from_hf_hub=False) diff --git a/tests/transformers/load_subfolder/test_tokenizer.py b/tests/transformers/load_subfolder/test_tokenizer.py index 2508b326ca6f..9f3cd636af11 100644 --- a/tests/transformers/load_subfolder/test_tokenizer.py +++ b/tests/transformers/load_subfolder/test_tokenizer.py @@ -22,9 +22,11 @@ T5Tokenizer, ) from paddlenlp.utils.log import logger +from tests.testing_utils import slow class TokenizerLoadTester(unittest.TestCase): + @slow def test_bert_load(self): logger.info("Download model from PaddleNLP BOS") bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", from_hf_hub=False) @@ -57,6 +59,7 @@ def test_bert_load(self): "aistudio/paddlenlp-test-model", subfolder="bert-base-uncased", from_aistudio=True ) + @slow def test_clip_load(self): logger.info("Download model from PaddleNLP BOS") clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32", from_hf_hub=False) @@ -89,6 +92,7 @@ def test_clip_load(self): "aistudio/paddlenlp-test-model", subfolder="clip-vit-base-patch32", from_aistudio=True ) + @slow def test_t5_load(self): logger.info("Download model from PaddleNLP BOS") t5_tokenizer = T5Tokenizer.from_pretrained("t5-small", from_hf_hub=False) From 37b4fe0285dbfd098fc29efcf3eef25984970cc0 Mon Sep 17 00:00:00 2001 From: Wu Chencan <77946882+DanGuge@users.noreply.github.com> Date: Tue, 9 Jan 2024 11:16:52 +0800 Subject: [PATCH 6/9] [INFER][LLM] Support qwen in fined grained dybatch v1 (#7644) * init qwen inference model * fix name * fix hidden dim * fix dtype * fix length * fix attention_mask * fix up & gate dtype bug * fix ffn1 weight * modify codes * remote unused variable * remove unused code * add qwen weight only * format with black * format with isort * fix dtype * add qwen inference model in static graph * add qwen unittest * format with black * print log * remove print * set safetensors usage False * remove tests * Empty-Commit --- llm/predictor.py | 23 +- .../experimental/transformers/__init__.py | 1 + .../transformers/qwen/__init__.py | 15 + .../transformers/qwen/modeling.py | 504 ++++++++++++++++++ 4 files changed, 541 insertions(+), 2 deletions(-) create mode 100644 paddlenlp/experimental/transformers/qwen/__init__.py create mode 100644 paddlenlp/experimental/transformers/qwen/modeling.py diff --git a/llm/predictor.py b/llm/predictor.py index 805cbd58c0e3..7175a58e697d 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -880,8 +880,19 @@ def create_predictor( dtype=predictor_args.dtype, ) model.eval() + elif "qwen" in config.architectures[0].lower(): + from paddlenlp.experimental.transformers import ( + QWenForCausalLMInferenceModel, + ) + + model = QWenForCausalLMInferenceModel.from_pretrained( + predictor_args.model_name_or_path, + config=config, + dtype=predictor_args.dtype, + ) + model.eval() else: - raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]") + raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]") predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) elif predictor_args.mode == "static": config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) @@ -925,8 +936,16 @@ def create_predictor( cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape( config, predictor_args.batch_size, predictor_args.total_max_length ) + elif "qwen" in config.architectures[0].lower(): + from paddlenlp.experimental.transformers import ( + QWenForCausalLMInferenceModel, + ) + + cache_kvs_shape = QWenForCausalLMInferenceModel.get_cache_kvs_shape( + config, predictor_args.batch_size, predictor_args.total_max_length + ) else: - raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]") + raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]") predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) else: raise ValueError("the `mode` should be one of [dynamic, static]") diff --git a/paddlenlp/experimental/transformers/__init__.py b/paddlenlp/experimental/transformers/__init__.py index dca226d668f1..1c7c0e2c0077 100644 --- a/paddlenlp/experimental/transformers/__init__.py +++ b/paddlenlp/experimental/transformers/__init__.py @@ -19,3 +19,4 @@ from .gpt import * from .llama import * from .opt import * +from .qwen import * diff --git a/paddlenlp/experimental/transformers/qwen/__init__.py b/paddlenlp/experimental/transformers/qwen/__init__.py new file mode 100644 index 000000000000..c2a7f656c636 --- /dev/null +++ b/paddlenlp/experimental/transformers/qwen/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .modeling import * diff --git a/paddlenlp/experimental/transformers/qwen/modeling.py b/paddlenlp/experimental/transformers/qwen/modeling.py new file mode 100644 index 000000000000..facf664fc1c8 --- /dev/null +++ b/paddlenlp/experimental/transformers/qwen/modeling.py @@ -0,0 +1,504 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import paddle +from paddle import nn +from paddle.nn.quant import weight_quantize +from paddlenlp_ops import fused_get_rotary_embedding, get_padding_offset + +from paddlenlp.experimental.transformers.fused_transformer_layers import ( + FusedMultiTransformerBase, + FusedMultiTransformerConfig, + FusedMultiTransformerWeightOnly, +) +from paddlenlp.experimental.transformers.generation_utils import ( + GenerationInferenceModel, +) +from paddlenlp.transformers import QWenConfig, QWenPretrainedModel +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from paddlenlp.transformers.model_utils import ( + dy2st_nocheck_guard_context, + register_base_model, +) +from paddlenlp.transformers.qwen.modeling import QWenLMHead, QWenPretrainingCriterion + +__all__ = ["QWenForCausalLMInferenceModel"] + + +class FusedQWenRMSNorm(nn.Layer): + def __init__(self, config): + super().__init__() + self.eps = config.layer_norm_epsilon + self.weight = paddle.create_parameter( + shape=[config.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + + def forward(self, x): + result = paddle.incubate.nn.functional.fused_rms_norm(x, self.weight, None, self.eps, begin_norm_axis=1) + if isinstance(result, tuple): + return result[0] + return result + + +@register_base_model +class QWenInferenceModel(QWenPretrainedModel): + def __init__(self, config: QWenConfig): + super(QWenPretrainedModel, self).__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.intermediate_size = config.intermediate_size + self.num_layers = config.num_hidden_layers + self.layer_norm_epsilon = config.layer_norm_epsilon + self.max_position_embeddings = config.max_position_embeddings + self.quant_type = config.quant_type + self.weight_only_quant_bits = config.weight_only_quant_bits + + if self.quant_type is not None and "weight_only_int" in self.quant_type: + self.use_weight_only = True + else: + self.use_weight_only = False + + if self.use_weight_only: + assert ( + self.quant_type == "weight_only_int8" or self.quant_type == "weight_only_int4" + ), "Expected quant_type equal to 'weight_only_int8' or 'weight_only_int4', but received {}".format( + self.quant_type + ) + + self.wte = nn.Embedding(self.vocab_size, self.hidden_size) + + ln_scale_attrs = [paddle.ParamAttr(name="fuseqwen.{}.ln_scale".format(i)) for i in range(self.num_layers)] + qkv_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen.{}.qkv_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + qkv_bias_attrs = [paddle.ParamAttr(name="fuseqwen.{}.qkv_bias".format(i)) for i in range(self.num_layers)] + out_proj_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen.{}.out_proj_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + ffn_ln_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen.{}.ffn_ln_scale".format(i)) for i in range(self.num_layers) + ] + ffn1_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen.{}.ffn1_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + ffn2_weight_attrs = [ + paddle.ParamAttr( + name="fuseqwen.{}.ffn2_weight".format(i), initializer=paddle.nn.initializer.Constant(value=0) + ) + for i in range(self.num_layers) + ] + + qkv_weight_scale_attrs = None + out_proj_weight_scale_attrs = None + ffn1_weight_scale_attrs = None + ffn2_weight_scale_attrs = None + + if self.use_weight_only: + qkv_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen.{}.qkv_weight_scale".format(i)) for i in range(self.num_layers) + ] + out_proj_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen.{}.out_proj_weight_scale".format(i)) for i in range(self.num_layers) + ] + ffn1_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen.{}.ffn1_weight_scale".format(i)) for i in range(self.num_layers) + ] + ffn2_weight_scale_attrs = [ + paddle.ParamAttr(name="fuseqwen.{}.ffn2_weight_scale".format(i)) for i in range(self.num_layers) + ] + + transformer_config = FusedMultiTransformerConfig( + self.hidden_size, + self.num_attention_heads, + self.intermediate_size // 2, + weight_only_quant_bits=self.weight_only_quant_bits, + activation="swiglu", + num_layers=config.num_hidden_layers, + nranks=1, + ring_id=-1, + ln_scale_attrs=ln_scale_attrs, + qkv_weight_attrs=qkv_weight_attrs, + qkv_weight_scale_attrs=qkv_weight_scale_attrs, + linear_weight_attrs=out_proj_weight_attrs, + linear_weight_scale_attrs=out_proj_weight_scale_attrs, + ffn_ln_scale_attrs=ffn_ln_scale_attrs, + ffn1_weight_attrs=ffn1_weight_attrs, + ffn1_weight_scale_attrs=ffn1_weight_scale_attrs, + ffn2_weight_attrs=ffn2_weight_attrs, + ffn2_weight_scale_attrs=ffn2_weight_scale_attrs, + qkv_bias_attrs=qkv_bias_attrs, + epsilon=self.layer_norm_epsilon, + norm_type="rmsnorm", + use_neox_rotary_style=True, + ) + + if self.use_weight_only: + self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) + else: + self.transformer_block = FusedMultiTransformerBase(transformer_config) + + self.ln_f = FusedQWenRMSNorm(config) + + self.cache_kvs = None + self.head_dim_shape_tensor = paddle.ones((self.hidden_size // self.num_attention_heads), dtype="int8") + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, value): + self.wte = value + + @paddle.no_grad() + def set_state_dict(self, state_dict): + dtype = paddle.get_default_dtype() + wte_weight = paddle.to_tensor(state_dict["qwen.wte.weight"], dtype=dtype) + ln_f_weight = paddle.to_tensor(state_dict["qwen.ln_f.weight"], dtype=self.ln_f.weight.dtype) + self.wte.weight.set_value(wte_weight) + self.ln_f.weight.set_value(ln_f_weight) + + for idx in range(self.num_layers): + ln_scale = paddle.to_tensor( + state_dict["qwen.h.{}.ln_1.weight".format(idx)], dtype=self.transformer_block.ln_scales[idx].dtype + ) + self.transformer_block.ln_scales[idx].set_value(ln_scale) + + qkv_weight = paddle.to_tensor( + state_dict["qwen.h.{}.attn.c_attn.weight".format(idx)].transpose([1, 0]), dtype=dtype + ) + if self.use_weight_only: + qkv_weight = paddle.transpose(qkv_weight, perm=[1, 0]) + qkv_quanted_weight, qkv_weight_scale = weight_quantize(qkv_weight, algo=self.quant_type) + self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight) + self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale) + else: + self.transformer_block.qkv_weights[idx].set_value(qkv_weight) + + qkv_bias = paddle.to_tensor(state_dict["qwen.h.{}.attn.c_attn.bias".format(idx)], dtype=dtype) + self.transformer_block.qkv_biases[idx].set_value(qkv_bias) + + linear_weight = paddle.to_tensor(state_dict["qwen.h.{}.attn.c_proj.weight".format(idx)], dtype=dtype) + if self.use_weight_only: + linear_quanted_weight, linear_weight_scale = weight_quantize(linear_weight, algo=self.quant_type) + self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight) + self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale) + else: + self.transformer_block.linear_weights[idx].set_value(linear_weight) + + ffn_ln_scale = paddle.to_tensor( + state_dict["qwen.h.{}.ln_2.weight".format(idx)], dtype=self.transformer_block.ffn_ln_scales[idx].dtype + ) + self.transformer_block.ffn_ln_scales[idx].set_value(ffn_ln_scale) + + up_weight = paddle.to_tensor(state_dict["qwen.h.{}.mlp.w1.weight".format(idx)], dtype=dtype) + gate_weight = paddle.to_tensor(state_dict["qwen.h.{}.mlp.w2.weight".format(idx)], dtype=dtype) + ffn1_weight = paddle.concat(x=[gate_weight, up_weight], axis=-1) + if self.use_weight_only: + ffn1_quanted_weight, ffn1_weight_scale = weight_quantize(ffn1_weight, algo=self.quant_type) + self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight) + self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale) + else: + self.transformer_block.ffn1_weights[idx].set_value(ffn1_weight) + + ffn2_weight = paddle.to_tensor(state_dict["qwen.h.{}.mlp.c_proj.weight".format(idx)], dtype=dtype) + if self.use_weight_only: + ffn2_quanted_weight, ffn2_weight_scale = weight_quantize(ffn2_weight, algo=self.quant_type) + self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight) + self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale) + else: + self.transformer_block.ffn2_weights[idx].set_value(ffn2_weight) + + def remove_padding(self, input_ids, seq_lens_this_time): + cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) + token_num = paddle.sum(seq_lens_this_time) + ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( + input_ids, cum_offsets_now, token_num, seq_lens_this_time + ) + return ids_remove_padding, padding_offset, cum_offsets + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + cache_kvs=None, + pre_caches=None, + seq_len_encoder=None, + seq_len_decoder=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + **kwargs, + ): + # kwargs["cache"] is used used to distinguish between encoder and decoder phase. + past_key_values = kwargs.get("cache", None) + is_decoder = past_key_values is not None + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + batch, seq_len, hidden_dim = inputs_embeds.shape + inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim]) + + if past_key_values is None: + past_key_values = tuple([None] * self.config.num_hidden_layers) + + if not is_decoder: + ids_remove_padding, padding_offset, cum_offsets = self.remove_padding(input_ids, seq_len_encoder) + else: + ids_remove_padding = input_ids + padding_offset = None + cum_offsets = None + + if inputs_embeds is None: + inputs_embeds = self.wte(ids_remove_padding) + hidden_states = inputs_embeds + + # decoder layers + presents = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + seq_lens = seq_len_decoder if is_decoder else seq_len_encoder + + position_offset = 0 + if not is_decoder and pre_caches is not None: + position_offset = 128 + + new_rope = fused_get_rotary_embedding( + input_ids, position_ids, self.head_dim_shape_tensor, position_offset, True + ) + + with dy2st_nocheck_guard_context(): + hidden_states, _ = self.transformer_block( + input_ids, + hidden_states, + cum_offsets=cum_offsets, + padding_offset=padding_offset, + attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype), + caches=cache_kvs, + pre_caches=pre_caches, + pre_caches_length=position_offset, + seq_lens=seq_lens, + rotary_embs=new_rope, + rotary_emb_dims=1, + time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None, + ) + + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class QWenForCausalLMInferenceModel(GenerationInferenceModel, QWenPretrainedModel): + def __init__(self, config: QWenConfig, **kwargs): + super(QWenForCausalLMInferenceModel, self).__init__(config) + self.qwen = QWenInferenceModel(config) + self.lm_head = QWenLMHead(config) + self.criterion = QWenPretrainingCriterion(config) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, from_hf_hub: bool = False, subfolder: str | None = None, *args, **kwargs + ): + # TODO: Support safetensors loading. + kwargs["use_safetensors"] = False + return super().from_pretrained(pretrained_model_name_or_path, from_hf_hub, subfolder, *args, **kwargs) + + @classmethod + def get_cache_kvs_shape( + cls, config: QWenConfig, max_batch_size: int = None, max_length: int = None + ) -> list[list[int]]: + """get cache_kvs tensor for qwen model + + Args: + max_batch_size (int): the max batch size + max_length (int | None, optional): the max_length of cache_kvs. Defaults to None. + + Returns: + list[paddle.Tensor]: the list tensor shape for cache + """ + if max_length is None: + max_length = config.max_position_embeddings + + cache_kvs = [] + for _ in range(config.num_hidden_layers): + cache_kvs.append( + [ + 2, + max_batch_size, + config.num_attention_heads // max(config.tensor_parallel_degree, 1), + max_length, + config.hidden_size // config.num_attention_heads, + ] + ) + return cache_kvs + + def prepare_inputs_for_generation( + self, + input_ids, + cache_kvs, + seq_len_encoder, + seq_len_decoder, + tgt_ids, + tgt_pos, + tgt_generation_mask, + **kwargs, + ): + position_ids = kwargs.get("position_ids", None) + attention_mask = kwargs.get("attention_mask", None) + cache = kwargs.get("cache", None) + pre_caches = kwargs.get("pre_caches", None) + inputs_embeds = kwargs.get("inputs_embeds", None) + if cache is not None: + input_ids = tgt_ids + position_ids = tgt_pos + attention_mask = (tgt_generation_mask - 1) * 1e4 + # make inputs_embeds be none in decoder phase. + # in forward function, it will be assigned according to input_ids. + inputs_embeds = None + else: + attention_mask = (attention_mask - 1) * 1e4 + model_inputs = { + "input_ids": input_ids, + "inputs_embeds": inputs_embeds, + "position_ids": position_ids, + "attention_mask": attention_mask, + "cache_kvs": cache_kvs, + "seq_len_encoder": seq_len_encoder, + "seq_len_decoder": seq_len_decoder, + "cache": cache, + "pre_caches": pre_caches, + } + return model_inputs + + def forward( + self, + input_ids, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + cache=None, + cache_kvs=None, + pre_caches=None, + seq_len_encoder=None, + seq_len_decoder=None, + past_key_values=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.qwen( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache=cache, + cache_kvs=cache_kvs, + pre_caches=pre_caches, + seq_len_encoder=seq_len_encoder, + seq_len_decoder=seq_len_decoder, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + lm_logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + loss = None + if labels is not None: + loss = self.criterion(lm_logits, labels) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @paddle.no_grad() + def set_state_dict(self, state_dict): + if "lm_head.weight" in state_dict: + lm_head_weight = paddle.to_tensor(state_dict["lm_head.weight"], dtype=self.lm_head.weight.dtype) + self.lm_head.weight.set_value(lm_head_weight) + self.qwen.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) From 393ac187899b58a61d44144d8a9a300b6944117a Mon Sep 17 00:00:00 2001 From: yinwei Date: Tue, 9 Jan 2024 14:07:55 +0800 Subject: [PATCH 7/9] [CE] Add CE for Distributed Hybrid Parallel (#7782) --- ...retrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh | 38 ++++++++++++++++++ ..._bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh | 39 +++++++++++++++++++ ...2_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh | 39 +++++++++++++++++++ ..._bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh | 39 +++++++++++++++++++ ...retrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh | 38 ++++++++++++++++++ ..._bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh | 39 +++++++++++++++++++ ...2_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh | 39 +++++++++++++++++++ ..._bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh | 39 +++++++++++++++++++ .../ce_gpt/benchmark_common/run_benchmark.sh | 7 +++- 9 files changed, 316 insertions(+), 1 deletion(-) create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh new file mode 100644 index 000000000000..70a9a92519b1 --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh @@ -0,0 +1,38 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=CE_gpt-345m_seqlen1024_pretrain +dp_degree=2 +mp_degree=2 +pp_degree=2 +bs_item=32 +fp_item=bf16 +run_mode=MP2-PP2-DP2-mbs8-acc2 +device_num=N1C8 +max_iter=50000 +sharding=False +sharding_degree=1 +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=False + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh new file mode 100644 index 000000000000..781e4b9b47c0 --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh @@ -0,0 +1,39 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=2 +bs_item=32 +fp_item=bf16 +run_mode=MP2-PP2-SD2-Stage1-mbs8-acc2 +device_num=N1C8 +max_iter=50000 +sharding=stage2 +sharding_degree=1 + +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=False + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh new file mode 100644 index 000000000000..7b68ae7bd182 --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh @@ -0,0 +1,39 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=CE_gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=2 +bs_item=32 +fp_item=bf16 +run_mode=MP2-SP2-PP2-SD2-Stage1-mbs8-acc2 +device_num=N1C8 +max_iter=50000 +sharding=stage1 +sharding_degree=2 + +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=True + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh new file mode 100644 index 000000000000..fbfb37e900e5 --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh @@ -0,0 +1,39 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=CE_gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=1 +bs_item=32 +fp_item=bf16 +run_mode=MP2-SP2-SD2-Stage1-mbs8-acc2 +device_num=N1C8 +max_iter=50000 +sharding=stage1 +sharding_degree=2 + +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=False +sequence_parallel=True + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh new file mode 100644 index 000000000000..67d0bf3ae162 --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-DP2-mbs8-acc2.sh @@ -0,0 +1,38 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=gpt-345m_seqlen1024_pretrain +dp_degree=2 +mp_degree=2 +pp_degree=2 +bs_item=32 +fp_item=bf16 +run_mode=MP2-PP2-DP2-mbs8-acc2 +device_num=N1C8 +max_iter=100 +sharding=False +sharding_degree=1 +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=False + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh new file mode 100644 index 000000000000..c95dc1d4582c --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-PP2-SD2-Stage1-mbs8-acc2.sh @@ -0,0 +1,39 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=2 +bs_item=32 +fp_item=bf16 +run_mode=MP2-PP2-SD2-Stage1-mbs8-acc2 +device_num=N1C8 +max_iter=100 +sharding=stage1 +sharding_degree=2 + +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=False + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh new file mode 100644 index 000000000000..732be5e7f990 --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP2-SD2-Stage1-mbs8-acc2.sh @@ -0,0 +1,39 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=2 +bs_item=32 +fp_item=bf16 +run_mode=MP2-SP2-PP2-SD2-Stage1-mbs8-acc2 +device_num=N1C8 +max_iter=100 +sharding=stage1 +sharding_degree=2 + +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=True + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh new file mode 100644 index 000000000000..934e99b7c67e --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-SD2-Stage1-mbs8-acc2.sh @@ -0,0 +1,39 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=1 +bs_item=32 +fp_item=bf16 +run_mode=MP2-SP2-SD2-Stage1-mbs8-acc2 +device_num=N1C8 +max_iter=100 +sharding=stage1 +sharding_degree=2 + +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=False +sequence_parallel=True + +model=gpt +micro_bs=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh index e52ced5842eb..8ee5ec54d8ba 100644 --- a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh @@ -150,11 +150,16 @@ function _train(){ run_pretrain.py ${train_cmd}" workerlog_id=0 ;; - DP8-mbs2-acc2|SD8-stage1-mbs2-acc2|SD8-stage2-mbs2-acc2|SD8-stage3-mbs2-acc2|MP2-SD4-stage1-mbs4-acc2|MP2-SP2-PP2-DP2-mbs8-acc2|MP8-mbs16-acc2) echo "run run_mode: ${run_mode}" + DP8-mbs2-acc2|SD8-stage1-mbs2-acc2|SD8-stage2-mbs2-acc2|SD8-stage3-mbs2-acc2|MP2-SD4-stage1-mbs4-acc2|MP2-SP2-PP2-DP2-mbs8-acc2|MP8-mbs16-acc2|MP2-PP2-DP2-mbs8-acc2|MP2-PP2-SD2-Stage1-mbs8-acc2|MP2-SP2-PP2-SD2-Stage1-mbs8-acc2) echo "run run_mode: ${run_mode}" train_cmd="python -m paddle.distributed.launch --log_dir=./mylog --devices=0,1,2,3,4,5,6,7 ${PADDLE_RANK_OPTION}\ run_pretrain.py ${train_cmd}" workerlog_id=0 ;; + MP2-SP2-SD2-Stage1-mbs8-acc2) echo "run run_mode: ${run_mode}" + train_cmd="python -m paddle.distributed.launch --log_dir=./mylog --devices=0,1,2,3 ${PADDLE_RANK_OPTION}\ + run_pretrain.py ${train_cmd}" + workerlog_id=0 + ;; *) echo "choose run_mode "; exit 1; esac cd ../llm/gpt-3 From fc6ab70e8e9361b3505e21dbe04a71aee68a1fc2 Mon Sep 17 00:00:00 2001 From: tianhaodongbd <137985359+tianhaodongbd@users.noreply.github.com> Date: Tue, 9 Jan 2024 15:08:18 +0800 Subject: [PATCH 8/9] [CE] Add MP2-SP2-pp4-vpp2-SD2-stage1-mbs2-acc8 ce (#7774) --- ...rain_bs32_bf16_MP2-SD4-stage1-mbs4-acc2.sh | 9 ++++- ...6_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh | 40 +++++++++++++++++++ ...6_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh | 40 +++++++++++++++++++ .../ce_gpt/benchmark_common/run_benchmark.sh | 11 +++-- 4 files changed, 95 insertions(+), 5 deletions(-) create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh create mode 100644 tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SD4-stage1-mbs4-acc2.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SD4-stage1-mbs4-acc2.sh index fe257f5970d5..1238bb41fd04 100644 --- a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SD4-stage1-mbs4-acc2.sh +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N1C8/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SD4-stage1-mbs4-acc2.sh @@ -18,6 +18,11 @@ mp_degree=2 pp_degree=1 sharding_degree=4 sharding=stage1 +virtual_pp_degree=1 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=False +sequence_parallel=False bs_item=32 fp_item=bf16 run_mode=MP2-SD4-stage1-mbs4-acc2 @@ -26,8 +31,10 @@ max_iter=50000 model=gpt micro_bs=4 +acc=2 +seed=3589 bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh # run bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ -${max_iter} ${sharding} ${sharding_degree} 2>&1; +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} ${acc} ${seed} 2>&1; diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh new file mode 100644 index 000000000000..e476037f9ada --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/CE_gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh @@ -0,0 +1,40 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=CE_gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=4 +bs_item=32 +fp_item=bf16 +run_mode=MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8 +device_num=N2C16 +max_iter=50000 +sharding=stage1 +sharding_degree=2 +virtual_pp_degree=2 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=True + +model=gpt +micro_bs=2 +acc=8 +seed=3589 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} ${acc} ${seed} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh new file mode 100644 index 000000000000..03a50a587e9e --- /dev/null +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/N2C16/gpt-345m_seqlen1024_pretrain_bs32_bf16_MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8.sh @@ -0,0 +1,40 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +model_item=gpt-345m_seqlen1024_pretrain +dp_degree=1 +mp_degree=2 +pp_degree=4 +bs_item=32 +fp_item=bf16 +run_mode=MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8 +device_num=N2C16 +max_iter=100 +sharding=stage1 +sharding_degree=2 + +virtual_pp_degree=2 +use_recompute=True +eval_freq=25 +use_pipeline_parallel=True +sequence_parallel=True + +model=gpt +micro_bs=2 +acc=8 + +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/prepare.sh +# run +bash ./test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh ${model_item} ${fp_item} ${dp_degree} ${mp_degree} ${pp_degree} ${micro_bs} ${bs_item} ${run_mode} ${device_num} \ +${max_iter} ${sharding} ${sharding_degree} ${virtual_pp_degree} ${use_recompute} ${eval_freq} ${use_pipeline_parallel} ${sequence_parallel} ${acc} 2>&1; \ No newline at end of file diff --git a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh index 8ee5ec54d8ba..b32470ab8d1f 100644 --- a/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh +++ b/tests/test_tipc/dygraph/hybrid_parallelism/ce_gpt/benchmark_common/run_benchmark.sh @@ -37,11 +37,13 @@ function _set_params(){ sharding_degree=${12:-"1"} num_workers=0 # (可选) base_batch_size=$global_batch_size - virtual_pp_degree=${13:-"2"} # (可选) virtualpp数据并行度 + vpp_degree=${13:-"1"} # (可选) virtualpp数据并行度 use_recompute=${14:-"True"} # (可选)是否打开recompute eval_freq=${15:-"25"} # (可选)模型评估间隔 use_pipeline_parallel=${16:-"False"} # (可选)是否开启pipeline_parallel_config sequence_parallel=${17:-"False"} # (可选)是否开启sequence_parallel + acc=${18:-"2"} + seed=${19:-"1234"} # 以下为通用执行命令,无特殊可不用修改 model_name=${model_item}_bs${global_batch_size}_${fp_item}_${run_mode} # (必填) 且格式不要改动,与竞品名称对齐 device=${CUDA_VISIBLE_DEVICES//,/ } @@ -108,10 +110,11 @@ function _train(){ --tensor_parallel_degree ${mp_degree} \ --pipeline_parallel_degree ${pp_degree} \ ${pp_config_disable_partial_send_recv} \ + --virtual_pp_degree ${vpp_degree} \ --sequence_parallel ${sequence_parallel} \ --split 949,50,1 \ --max_seq_length 1024 \ - --seed 1234 \ + --seed ${seed} \ --fuse_attention_qkv True \ --use_flash_attention True \ --bf16 ${bf16} \ @@ -125,7 +128,7 @@ function _train(){ --dataloader_num_workers 1 \ --eval_steps 1000 \ --disable_tqdm True \ - --gradient_accumulation_steps 2 \ + --gradient_accumulation_steps ${acc} \ --weight_decay 0.01\ --max_steps ${max_iter}\ --save_steps 5000\ @@ -150,7 +153,7 @@ function _train(){ run_pretrain.py ${train_cmd}" workerlog_id=0 ;; - DP8-mbs2-acc2|SD8-stage1-mbs2-acc2|SD8-stage2-mbs2-acc2|SD8-stage3-mbs2-acc2|MP2-SD4-stage1-mbs4-acc2|MP2-SP2-PP2-DP2-mbs8-acc2|MP8-mbs16-acc2|MP2-PP2-DP2-mbs8-acc2|MP2-PP2-SD2-Stage1-mbs8-acc2|MP2-SP2-PP2-SD2-Stage1-mbs8-acc2) echo "run run_mode: ${run_mode}" + DP8-mbs2-acc2|SD8-stage1-mbs2-acc2|SD8-stage2-mbs2-acc2|SD8-stage3-mbs2-acc2|MP2-SD4-stage1-mbs4-acc2|MP2-SP2-PP2-DP2-mbs8-acc2|MP8-mbs16-acc2|MP2-PP2-DP2-mbs8-acc2|MP2-PP2-SD2-Stage1-mbs8-acc2|MP2-SP2-PP2-SD2-Stage1-mbs8-acc2|MP2-SP2-PP4-VPP2-SD2-stage1-mbs2-acc8) echo "run run_mode: ${run_mode}" train_cmd="python -m paddle.distributed.launch --log_dir=./mylog --devices=0,1,2,3,4,5,6,7 ${PADDLE_RANK_OPTION}\ run_pretrain.py ${train_cmd}" workerlog_id=0 From dab175b58bfcb4f6dfc3d00891efc40e2ab5d9d7 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Tue, 9 Jan 2024 15:56:38 +0800 Subject: [PATCH 9/9] [Pretrain] Fix eval during pretrain (#7806) * fix eval during pretrain --- llm/run_pretrain.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 00f8928d2ead..377a3666fd45 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -265,8 +265,8 @@ def _collate_data(data, stack_fn=Stack()): tokens = tokens_[:, :-1] return { - "input_ids": tokens, - "labels": labels, + "input_ids": paddle.to_tensor(tokens), + "labels": paddle.to_tensor(labels), } if need_data: