Skip to content

Commit

Permalink
merge develop
Browse files Browse the repository at this point in the history
  • Loading branch information
RichardWooSJTU committed Jan 9, 2024
2 parents afb1dd0 + dab175b commit ed6db33
Show file tree
Hide file tree
Showing 27 changed files with 1,000 additions and 30 deletions.
23 changes: 21 additions & 2 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,8 +1304,19 @@ def create_predictor(
dtype=predictor_args.dtype,
)
model.eval()
elif "qwen" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
QWenForCausalLMInferenceModel,
)

model = QWenForCausalLMInferenceModel.from_pretrained(
predictor_args.model_name_or_path,
config=config,
dtype=predictor_args.dtype,
)
model.eval()
else:
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]")
if predictor_args.block_attn:
predictor = DygraphBlockInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
else:
Expand Down Expand Up @@ -1369,8 +1380,16 @@ def create_predictor(
cache_kvs_shape = GPTForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
elif "qwen" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
QWenForCausalLMInferenceModel,
)

cache_kvs_shape = QWenForCausalLMInferenceModel.get_cache_kvs_shape(
config, predictor_args.batch_size, predictor_args.total_max_length
)
else:
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt]")
raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]")
if predictor_args.block_attn:
predictor = StaticBlockInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
else:
Expand Down
4 changes: 2 additions & 2 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ def _collate_data(data, stack_fn=Stack()):
tokens = tokens_[:, :-1]

return {
"input_ids": tokens,
"labels": labels,
"input_ids": paddle.to_tensor(tokens),
"labels": paddle.to_tensor(labels),
}

if need_data:
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/experimental/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
from .gpt import *
from .llama import *
from .opt import *
from .qwen import *
15 changes: 15 additions & 0 deletions paddlenlp/experimental/transformers/qwen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling import *
Loading

0 comments on commit ed6db33

Please sign in to comment.