From 8fbbd2ac86e9a0d02933dc668d648bfbf8e145ab Mon Sep 17 00:00:00 2001 From: feifeibear Date: Wed, 19 Aug 2020 17:09:34 +0800 Subject: [PATCH] Develop (#154) Add tf ckpt converter for bert (#159) The memory of addspilttranspose is not continous for Q, K, V out. fix bert profiler bugs. update gpu fixed length benchmark scripts. add protection for splittranspose's QKV outputs --- CMakeLists.txt | 2 +- benchmark/run_gpu_fixed_benchmark.sh | 2 +- example/cpp/bert_model.cpp | 10 +- ...ert_for_sequence_classification_example.py | 85 +++++++----- tools/convert_tf_bert_to_npz.py | 119 ++++++++++++++++ tools/docker/Dockerfile_release.cpu | 2 +- turbo_transformers/layers/bert_embedding.cpp | 1 - .../layers/bert_intermediate.cpp | 2 +- turbo_transformers/layers/bert_output.cpp | 2 +- .../layers/kernels/gpu_transpose_kernel.cu | 60 ++++++++ .../layers/kernels/gpu_transpose_kernel.h | 7 + .../layers/kernels/transpose.cpp | 110 ++++++++++++++- turbo_transformers/layers/kernels/transpose.h | 9 ++ .../layers/multi_headed_attention.cpp | 35 +++-- .../python/tests/qbert_intermediate_test.py | 4 +- .../python/tests/qbert_output_test.py | 2 - .../turbo_transformers/layers/__init__.py | 1 - .../layers/qmodeling_bert.py | 128 +++++++++++++----- 18 files changed, 477 insertions(+), 104 deletions(-) create mode 100644 tools/convert_tf_bert_to_npz.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 05c7bc44..bb47f2d0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,7 +21,7 @@ set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_FLAGS "-Wall") set(CMAKE_C_FLAGS "-Wall") -set(TURBO_TRANSFORMERS_VERSION 0.4.1) +set(TURBO_TRANSFORMERS_VERSION 0.4.2) option(WITH_PROFILER "Compile with profiler" OFF) option(WITH_GPU "Build with GPU" OFF) diff --git a/benchmark/run_gpu_fixed_benchmark.sh b/benchmark/run_gpu_fixed_benchmark.sh index 3bcade10..c27e9841 100644 --- a/benchmark/run_gpu_fixed_benchmark.sh +++ b/benchmark/run_gpu_fixed_benchmark.sh @@ -29,7 +29,7 @@ do for framework in ${FRAMEWORKS[*]} do python benchmark.py ${MODEL} --seq_len=${seq_len} --batch_size=${batch_size}\ - -n ${N} --framework=${framework} + -n ${N} --framework=${framework} --use_gpu done done done diff --git a/example/cpp/bert_model.cpp b/example/cpp/bert_model.cpp index 72fddaa1..96fe7dfc 100644 --- a/example/cpp/bert_model.cpp +++ b/example/cpp/bert_model.cpp @@ -211,18 +211,18 @@ struct BertModel::Impl { layer(hidden, extendedAttentionMask, &attOut, &intermediateOut, &hidden); } - core::Tensor poolingOutput(nullptr); - layers::SequencePool(static_cast(pooling))( - hidden, &poolingOutput); std::vector vec; if (use_pooler) { core::Tensor output(nullptr); + core::Tensor poolingOutput(nullptr); + layers::SequencePool(static_cast(pooling))( + hidden, &poolingOutput); (*pooler_)(poolingOutput, &output); vec.resize(output.numel()); core::Copy(output, vec); } else { - vec.resize(poolingOutput.numel()); - core::Copy(poolingOutput, vec); + vec.resize(hidden.numel()); + core::Copy(hidden, vec); } return vec; diff --git a/example/python/bert_for_sequence_classification_example.py b/example/python/bert_for_sequence_classification_example.py index 1cfbdedc..d4f63696 100644 --- a/example/python/bert_for_sequence_classification_example.py +++ b/example/python/bert_for_sequence_classification_example.py @@ -19,46 +19,51 @@ # import the class of the acceleration model. here is the example of BertForSequenceClassification. from transformers.modeling_bert import BertModel as TorchBertModel from transformers import BertTokenizer -from transformers.modeling_bert import BertForSequenceClassification as TorchBertForSequenceClassification +from transformers.modeling_bert import ( + BertForSequenceClassification as TorchBertForSequenceClassification, +) import os import torch from typing import Optional -#TODO(jiarufang) developed under v0.1.0, after that not tested. -#Contact me if you find it is wrong. +# TODO(jiarufang) developed under v0.1.0, after that not tested. +# Contact me if you find it is wrong. class BertForSequenceClassification: # create a new class for speeding up def __init__( - self, bertmodel, classifier + self, bertmodel, classifier ): # the realization of the init function(we can just copy it) self.bert = bertmodel self.classifier = classifier def __call__( - self, # the realization of the call function(we can just copy it) - inputs, - attention_masks=None, - token_type_ids=None, - position_ids=None, - pooling_type=PoolingType.FIRST, - return_type=None): - pooler_output, _, _ = self.bert(inputs, - attention_masks, - token_type_ids, - position_ids, - pooling_type, - return_type=ReturnType.TORCH) + self, # the realization of the call function(we can just copy it) + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + pooling_type=PoolingType.FIRST, + return_type=None, + ): + bert_outputs = self.bert( + input_ids, + attention_mask, + token_type_ids, + position_ids, + pooling_type, + return_type=ReturnType.TORCH, + ) + pooled_output = bert_outputs[1] logits = self.classifier( - pooler_output + pooled_output ) # It's the output of classifier, if User want to output the other type, he can define them after that. return logits @staticmethod def from_torch( - model: TorchBertModel, # from_torch函数实现 - device: Optional[torch.device] = None): - if device is not None and 'cuda' in device.type and torch.cuda.is_available( - ): + model: TorchBertModel, device: Optional[torch.device] = None # from_torch函数实现 + ): + if device is not None and "cuda" in device.type and torch.cuda.is_available(): model.to(device) bertmodel = turbo_transformers.BertModel.from_torch(model.bert) # We can copy the following code and do not change it @@ -67,11 +72,11 @@ def from_torch( return BertForSequenceClassification(bertmodel, model.classifier) @staticmethod - def from_pretrained(model_id_or_path: str, - device: Optional[torch.device] = None): + def from_pretrained(model_id_or_path: str, device: Optional[torch.device] = None): # First, Use the function of from_pretrained to load the model you trained. torch_model = TorchBertForSequenceClassification.from_pretrained( - model_id_or_path) + model_id_or_path + ) # Then, Use the init function of the acceleration model to get it. model = BertForSequenceClassification.from_torch(torch_model, device) model._torch_model = torch_model # prevent destroy torch model. @@ -82,18 +87,24 @@ def from_pretrained(model_id_or_path: str, turbo_transformers.set_num_threads(4) model_id = os.path.join( - os.path.dirname(__file__), - 'test-seq-classification-model') # the model of huggingface's path -tokenizer = BertTokenizer.from_pretrained( - model_id) # the initialization of tokenizer + os.path.dirname(__file__), "bert_model" +) # the model of huggingface's path +tokenizer = BertTokenizer.from_pretrained(model_id) # the initialization of tokenizer turbo_model = BertForSequenceClassification.from_pretrained( - model_id, - torch.device('cpu:0')) # the initialization of the acceleration model + model_id, torch.device("cpu:0") +) # the initialization of the acceleration model # predict after loading the model -input_ids = torch.tensor( - tokenizer.encode('测试一下bert模型的性能和精度是不是符合要求?', - add_special_tokens=True)).unsqueeze(0) -torch_result = turbo_model(input_ids) -print(torch_result) -# tensor([[ 0.1451, -0.0373]], grad_fn=) + +text = "Sample input text" +inputs = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors="pt") +# turbo_result holds the returned logits from TurboTransformers model +turbo_result = turbo_model(**inputs) + +torch_model = TorchBertForSequenceClassification.from_pretrained(model_id) +# torch_result holds the returned logits from original Transformers model +torch_result = torch_model(**inputs)[0] +print(turbo_result) +# tensor([[0.2716, 0.0318]], grad_fn=) +print(torch_result) # torch_result and turbo_result should hold the same logits +# tensor([[0.2716, 0.0318]], grad_fn=) diff --git a/tools/convert_tf_bert_to_npz.py b/tools/convert_tf_bert_to_npz.py new file mode 100644 index 00000000..4a4ba370 --- /dev/null +++ b/tools/convert_tf_bert_to_npz.py @@ -0,0 +1,119 @@ +# Copyright (C) 2020 THL A29 Limited, a Tencent company. +# All rights reserved. +# Licensed under the BSD 3-Clause License (the "License"); you may +# not use this file except in compliance with the License. You may +# obtain a copy of the License at +# https://opensource.org/licenses/BSD-3-Clause +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" basis, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. +# See the AUTHORS file for names of contributors. + +from transformers import BertConfig +try: + import tensorflow as tf +except ImportError: + print("please install tensorflow 2.0 by run `pip install tensorflow`") +import numpy as np +import sys +import os + + +# User should define the map between tf model's layer name to tt model's layer name +def build_dic(num_layers): + dic = { + 'bert/embeddings/word_embeddings': + 'embeddings.word_embeddings.weight', + 'bert/embeddings/position_embeddings': + 'embeddings.position_embeddings.weight', + 'bert/embeddings/token_type_embeddings': + 'embeddings.token_type_embeddings.weight', + 'bert/embeddings/LayerNorm/gamma': + 'embeddings.LayerNorm.weight', + 'bert/embeddings/LayerNorm/beta': + 'embeddings.LayerNorm.bias', + 'bert/pooler/dense/kernel': 'pooler.dense.weight', + 'bert/pooler/dense/bias': 'pooler.dense.bias' + } + + for i in range(num_layers): + dic[f'bert/encoder/layer_{i}/attention/self/query/kernel'] = f'encoder.layer.{i}.attention.self.query.weight' + dic[f'bert/encoder/layer_{i}/attention/self/query/bias'] = f'encoder.layer.{i}.attention.self.query.bias' + dic[f'bert/encoder/layer_{i}/attention/self/key/kernel'] = f'encoder.layer.{i}.attention.self.key.weight' + dic[f'bert/encoder/layer_{i}/attention/self/key/bias'] = f'encoder.layer.{i}.attention.self.key.bias' + dic[f'bert/encoder/layer_{i}/attention/self/value/kernel'] = f'encoder.layer.{i}.attention.self.value.weight' + dic[f'bert/encoder/layer_{i}/attention/self/value/bias'] = f'encoder.layer.{i}.attention.self.value.bias' + dic[f'bert/encoder/layer_{i}/attention/output/dense/kernel'] = f'encoder.layer.{i}.attention.output.dense.weight' + dic[f'bert/encoder/layer_{i}/attention/output/dense/bias'] = f'encoder.layer.{i}.attention.output.dense.bias' + dic[f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma'] = f'encoder.layer.{i}.attention.output.LayerNorm.weight' + dic[f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'] = f'encoder.layer.{i}.attention.output.LayerNorm.bias' + dic[f'bert/encoder/layer_{i}/intermediate/dense/kernel'] = f'encoder.layer.{i}.intermediate.dense.weight' + dic[f'bert/encoder/layer_{i}/intermediate/dense/bias'] = f'encoder.layer.{i}.intermediate.dense.bias' + dic[f'bert/encoder/layer_{i}/output/dense/kernel'] = f'encoder.layer.{i}.output.dense.weight' + dic[f'bert/encoder/layer_{i}/output/dense/bias'] = f'encoder.layer.{i}.output.dense.bias' + dic[f'bert/encoder/layer_{i}/output/LayerNorm/gamma'] = f'encoder.layer.{i}.output.LayerNorm.weight' + dic[f'bert/encoder/layer_{i}/output/LayerNorm/beta'] = f'encoder.layer.{i}.output.LayerNorm.bias' + return dic + + +def trans_layer_name_tf2turbo(dic, name): + return dic[name] + + +def main(): + if len(sys.argv) != 3: + print( + "Usage: \n" + " convert_tf_bert_to_npz.py model_name output_file") + exit(0) + model_path = sys.argv[1] + ckpt_path = os.path.join(model_path, "bert_model.ckpt") + cfg = BertConfig.from_pretrained(os.path.join(model_path, "bert_config.json")) + dic = build_dic(cfg.num_hidden_layers) + names = [v[0] for v in tf.train.list_variables(ckpt_path)] + + arrays = {} + for i in range(len(names)): + if names[i].startswith("cls"): + continue + arrays[trans_layer_name_tf2turbo(dic, names[i])] = tf.train.load_variable(ckpt_path, names[i]) + + q_weight_key = 'self.query.weight' + k_weight_key = 'self.key.weight' + v_weight_key = 'self.value.weight' + + q_bias_key = 'self.query.bias' + k_bias_key = 'self.key.bias' + v_bias_key = 'self.value.bias' + + numpy_dict = {} + + for k in arrays.keys(): + if k.endswith(q_weight_key): + ret = [] + ret.append(arrays[k]) + ret.append(arrays[k[:-len(q_weight_key)] + k_weight_key]) + ret.append(arrays[k[:-len(q_weight_key)] + v_weight_key]) + v = np.concatenate(ret, axis=1) + numpy_dict[k[:-len(q_weight_key)] + + "qkv.weight"] = np.ascontiguousarray(v) + elif k.endswith(q_bias_key): + ret = [] + ret.append(arrays[k]) + ret.append(arrays[k[:-len(q_bias_key)] + k_bias_key]) + ret.append(arrays[k[:-len(q_bias_key)] + v_bias_key]) + v = np.ascontiguousarray(np.concatenate(ret, axis=0)) + numpy_dict[k[:-len(q_bias_key)] + 'qkv.bias'] = v + elif any((k.endswith(suffix) for suffix in (k_weight_key, v_weight_key, + k_bias_key, v_bias_key))): + continue + else: + numpy_dict[k] = np.ascontiguousarray(arrays[k]) + + np.savez_compressed(sys.argv[2], **numpy_dict) + + +if __name__ == '__main__': + main() diff --git a/tools/docker/Dockerfile_release.cpu b/tools/docker/Dockerfile_release.cpu index a22a2b27..4b96e547 100644 --- a/tools/docker/Dockerfile_release.cpu +++ b/tools/docker/Dockerfile_release.cpu @@ -14,5 +14,5 @@ RUN /opt/conda/bin/conda install pytorch==1.5.0 cpuonly -c pytorch && \ /opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \ /opt/conda/bin/conda clean -afy -RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt +RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt onnxruntime-tools WORKDIR /workspace diff --git a/turbo_transformers/layers/bert_embedding.cpp b/turbo_transformers/layers/bert_embedding.cpp index f080599d..d5a353ff 100644 --- a/turbo_transformers/layers/bert_embedding.cpp +++ b/turbo_transformers/layers/bert_embedding.cpp @@ -18,7 +18,6 @@ #include "turbo_transformers/layers/kernels/embedding.h" #include "turbo_transformers/layers/kernels/layer_norm.h" - namespace turbo_transformers { namespace layers { diff --git a/turbo_transformers/layers/bert_intermediate.cpp b/turbo_transformers/layers/bert_intermediate.cpp index 01df8d2e..02f9ee51 100644 --- a/turbo_transformers/layers/bert_intermediate.cpp +++ b/turbo_transformers/layers/bert_intermediate.cpp @@ -45,7 +45,7 @@ void BertIntermediate::operator()(const core::Tensor& input_tensor, kernels::AddBiasAct( dense_bias_, output_tensor, "BertIntermediate/AddBiasAct"); #ifdef WITH_PERFTOOLS - profile_ctx.end_profile("BertIntermediate"); + profile_ctx.end_profile("BertIntermediate", input_tensor.device_type()); #endif } diff --git a/turbo_transformers/layers/bert_output.cpp b/turbo_transformers/layers/bert_output.cpp index 636d29c4..62e3852b 100644 --- a/turbo_transformers/layers/bert_output.cpp +++ b/turbo_transformers/layers/bert_output.cpp @@ -48,7 +48,7 @@ void BertOutput::operator()(const core::Tensor &hidden_states, input_tensor, dense_bias_, layer_norm_weight_, layer_norm_bias_, output_tensor, 1e-12, "BertOutput/AddBiasLayerNorm"); #ifdef WITH_PERFTOOLS - profile_ctx.end_profile("BertOutput"); + profile_ctx.end_profile("BertOutput", input_tensor.device_type()); #endif } diff --git a/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu b/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu index 5d6b65ff..0d8c484b 100644 --- a/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu +++ b/turbo_transformers/layers/kernels/gpu_transpose_kernel.cu @@ -73,6 +73,66 @@ void GPUSplitAddBiasTransposeForScore( weight_num, size_per_head, out_data); } +/* + Output transpose results into three tensors +*/ +static __global__ void split_add_bias_transpose_for_score_3output( + const float* input_data, const float* bias_data, const int batch_size, + const int seq_len, const int head_num, const int weight_num, + const int size_per_head, float* q_output_data, float* k_output_data, + float* v_output_data) { + int tid = threadIdx.x; + int bid = blockIdx.x; + int idx = tid; + int batch_id = bid / (seq_len * weight_num * head_num); + int seq_id = + bid % (seq_len * weight_num * head_num) / (weight_num * head_num); + int weight_id = bid % (weight_num * head_num) / head_num; + int head_id = bid % head_num; + + int head_num_size_per_head = head_num * size_per_head; + int weight_id_head_num_size_per_head = weight_id * head_num_size_per_head; + int head_id_size_per_head = head_id * size_per_head; + + float* output_data = nullptr; + if (weight_id == 0) { + output_data = q_output_data; + } else if (weight_id == 1) { + output_data = k_output_data; + } else if (weight_id == 2) { + output_data = v_output_data; + } + + while (idx < size_per_head) { + float bias_val = bias_data[weight_id_head_num_size_per_head + + head_id_size_per_head + idx]; + output_data[batch_id * seq_len * head_num_size_per_head + + head_id * seq_len * size_per_head + seq_id * size_per_head + + idx] = + input_data[batch_id * seq_len * weight_num * head_num_size_per_head + + seq_id * weight_num * head_num_size_per_head + + weight_id_head_num_size_per_head + head_id_size_per_head + + idx] + + bias_val; + idx += blockDim.x; + } +} + +template <> +void GPUSplitAddBiasTransposeForScoreThreeOutput( + const float* input_data, const float* bias_data, int64_t batch_size, + int64_t seq_len, int64_t weight_num, int64_t num_attention_heads, + int64_t size_per_head, cudaStream_t stream, float* q_out_data, + float* k_out_data, float* v_out_data) { + const int n = size_per_head; + const int m = batch_size * seq_len * num_attention_heads * weight_num; + dim3 grid(m); + dim3 block(min(n, 1024)); + split_add_bias_transpose_for_score_3output<<>>( + input_data, bias_data, batch_size, seq_len, num_attention_heads, + weight_num, size_per_head, q_out_data, k_out_data, v_out_data); +} + namespace { // batch, head, seq, size_per_head -> batch head seq size_per_head diff --git a/turbo_transformers/layers/kernels/gpu_transpose_kernel.h b/turbo_transformers/layers/kernels/gpu_transpose_kernel.h index 74285418..0b56fb61 100644 --- a/turbo_transformers/layers/kernels/gpu_transpose_kernel.h +++ b/turbo_transformers/layers/kernels/gpu_transpose_kernel.h @@ -25,6 +25,13 @@ void GPUSplitAddBiasTransposeForScore(const T* input_data, const T* bias_data, int64_t size_per_head, cudaStream_t stream); +template +void GPUSplitAddBiasTransposeForScoreThreeOutput( + const T* input_data, const T* bias_data, int64_t batch_size, + int64_t seq_len, int64_t weight_num, int64_t num_attention_heads, + int64_t size_per_head, cudaStream_t stream, T* q_out_data, T* k_out_data, + T* v_out_data); + template void GPUTransposeForScore(const T* input_data, const T* bias, int64_t batch_size, int64_t seq_len, diff --git a/turbo_transformers/layers/kernels/transpose.cpp b/turbo_transformers/layers/kernels/transpose.cpp index e8464099..49969eac 100644 --- a/turbo_transformers/layers/kernels/transpose.cpp +++ b/turbo_transformers/layers/kernels/transpose.cpp @@ -84,7 +84,8 @@ void TransposeForScore(core::Tensor* output, const core::Tensor& input, profile_ctx.start_profile(name, input.device_type()); #endif TT_ENFORCE_EQ(input.n_dim(), 4, "input should be a 4-D tensor"); - TT_ENFORCE_GE(output->n_dim(), 3, "output tensor dim should be greater than 3"); + TT_ENFORCE_GE(output->n_dim(), 3, + "output tensor dim should be greater than 3"); TT_ENFORCE_EQ(input.numel(), output->numel(), "input.numel() and output.numel() should be the same"); if (input.device_type() == kDLCPU && output->device_type() == kDLCPU) { @@ -158,9 +159,6 @@ void SplitAddBiasTransposeForScore(core::Tensor* output_tensor, TT_ENFORCE_EQ(output_tensor->n_dim(), 5, "output_tensor should be (weight_num, batch_size, seq_length, " "num_attention_heads, size_per_head)"); - // TT_ENFORCE_EQ(bias_tensor.n_dim(), 1, - // "output_tensor should be (weight_num * num_attention_heads, " - // "size_per_head)"); auto batch_size = output_tensor->shape(1); auto seq_length = output_tensor->shape(3); @@ -228,6 +226,110 @@ void SplitAddBiasTransposeForScore(core::Tensor* output_tensor, #endif } +// input_tensor: 4D array (batch_size, seq_length, 3, head_num * size_per_head) +void SplitAddBiasTransposeForScore(const core::Tensor& input_tensor, + const core::Tensor& bias_tensor, + core::Tensor& q_out_tensor, + core::Tensor& k_out_tensor, + core::Tensor& v_out_tensor, + const std::string name) { +#ifdef WITH_PERFTOOLS + auto& profile_ctx = core::Profiler::GetInstance(); + profile_ctx.start_profile(name, input_tensor.device_type()); +#endif + + TT_ENFORCE_EQ(input_tensor.n_dim(), 4, + "output_tensor should be (batch_size, seq_length, " + "num_attention_heads * size_per_head)"); + + auto batch_size = input_tensor.shape(0); + auto seq_length = input_tensor.shape(1); + auto weight_num = 3; + auto num_attention_heads = q_out_tensor.shape(1); + auto width = input_tensor.shape(3) / num_attention_heads; + auto input = input_tensor.data(); + auto bias = bias_tensor.data(); + auto q_out = q_out_tensor.mutableData(); + auto k_out = k_out_tensor.mutableData(); + auto v_out = v_out_tensor.mutableData(); + + TT_ENFORCE_EQ(common::is_same_device_ctx(input_tensor.device_ctx(), + bias_tensor.device_ctx()), + true, + "SplitAddBiasTransposeForScore: input_tensor and bias_tensor " + "should have the same device type and device id."); + TT_ENFORCE_EQ(common::is_same_device_ctx(input_tensor.device_ctx(), + q_out_tensor.device_ctx()), + true, + "SplitAddBiasTransposeForScore: input_tensor and q_out_tensor " + "should have the same device type and device id."); + TT_ENFORCE_EQ(q_out_tensor.numel(), input_tensor.numel() / 3, + "numel of q_out_tensor should 1/3 of input tensor"); + TT_ENFORCE_EQ(k_out_tensor.numel(), input_tensor.numel() / 3, + "numel of q_out_tensor should 1/3 of input tensor"); + TT_ENFORCE_EQ(v_out_tensor.numel(), input_tensor.numel() / 3, + "numel of q_out_tensor should 1/3 of input tensor"); + if (q_out_tensor.device_type() == kDLCPU && + input_tensor.device_type() == kDLCPU && + bias_tensor.device_type() == kDLCPU) { +#pragma omp parallel for + for (int64_t idx = 0; idx < batch_size * weight_num * seq_length; ++idx) { + auto batch_idx = idx / (seq_length * weight_num); + auto seq_idx = idx / weight_num % seq_length; + auto weight_idx = idx % weight_num; + + for (int64_t head_idx = 0; head_idx < num_attention_heads; ++head_idx) { + auto* src_ptr = + input + + batch_idx * + (seq_length * weight_num * num_attention_heads * width) + + seq_idx * weight_num * num_attention_heads * width + + weight_idx * (num_attention_heads * width) + head_idx * width; + float* dst_ptr = nullptr; + switch (weight_idx) { + case 0: + dst_ptr = q_out + + batch_idx * (num_attention_heads * seq_length * width) + + head_idx * seq_length * width + seq_idx * width; + break; + case 1: + dst_ptr = k_out + + batch_idx * (num_attention_heads * seq_length * width) + + head_idx * seq_length * width + seq_idx * width; + break; + case 2: + dst_ptr = v_out + + batch_idx * (num_attention_heads * seq_length * width) + + head_idx * seq_length * width + seq_idx * width; + break; + default: + break; + } + auto* bias_ptr = + bias + weight_idx * width * num_attention_heads + head_idx * width; +#pragma omp simd + for (int64_t width_idx = 0; width_idx < width; ++width_idx) { + dst_ptr[width_idx] = src_ptr[width_idx] + bias_ptr[width_idx]; + } + } + } // end for + } else if (q_out_tensor.device_type() == kDLGPU && + input_tensor.device_type() == kDLGPU && + bias_tensor.device_type() == kDLGPU) { +#ifdef TT_WITH_CUDA + core::CUDADeviceContext& cuda_ctx = core::CUDADeviceContext::GetInstance(); + GPUSplitAddBiasTransposeForScoreThreeOutput( + input, bias, batch_size, seq_length, weight_num, num_attention_heads, + width, cuda_ctx.stream(), q_out, k_out, v_out); +#endif + } else { + TT_THROW("device_type is not supported"); + } +#ifdef WITH_PERFTOOLS + profile_ctx.end_profile(name, input_tensor.device_type()); +#endif +} + } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/kernels/transpose.h b/turbo_transformers/layers/kernels/transpose.h index fb8e0f7b..6b9d625a 100644 --- a/turbo_transformers/layers/kernels/transpose.h +++ b/turbo_transformers/layers/kernels/transpose.h @@ -44,11 +44,20 @@ extern void AddBiasTransposeForScore( // input: (batch_size, seq_length, 3, head_num, *size_per_head) // bias: (3, head_num, size_per_head) // output: (3, batch_size, num_attention_heads, seq_length, size_per_head) +// TODO(jiaruifang) the output is a tensor contains a continous memory space, +// which stores q_out, k_out, v_out. Because the lifetime of q_out, k_out and +// v_out are different, we should seperate them into 3 memory space extern void SplitAddBiasTransposeForScore( core::Tensor* output, const core::Tensor& input_tensor, const core::Tensor& bias_tensor, const std::string name = "SplitAddBiasTransposeForScore"); +// A API friendly to variable-length input memory allocations. +extern void SplitAddBiasTransposeForScore( + const core::Tensor& input_tensor, const core::Tensor& bias_tensor, + core::Tensor& q_out, core::Tensor& k_out, core::Tensor& v_out, + const std::string name = "SplitAddBiasTransposeForScore"); + } // namespace kernels } // namespace layers } // namespace turbo_transformers diff --git a/turbo_transformers/layers/multi_headed_attention.cpp b/turbo_transformers/layers/multi_headed_attention.cpp index 49788807..9eb27651 100644 --- a/turbo_transformers/layers/multi_headed_attention.cpp +++ b/turbo_transformers/layers/multi_headed_attention.cpp @@ -77,8 +77,9 @@ void MultiHeadedAttention::operator()( auto devtype = query_tensor.device_type(); auto devid = query_tensor.device_id(); - // TODO we should caching allocate intermediate tensor. + // TODO we should caching allocated intermediate tensor. core::Tensor *q_ptr{nullptr}, *k_ptr{nullptr}, *v_ptr{nullptr}; + core::Tensor q_out{nullptr}, k_out{nullptr}, v_out{nullptr}; core::Tensor q_out1(nullptr); core::Tensor v_out1(nullptr); core::Tensor k_out1(nullptr); @@ -86,7 +87,6 @@ void MultiHeadedAttention::operator()( core::Tensor v_out2(nullptr); core::Tensor k_out2(nullptr); core::Tensor qkv_out1(nullptr); - core::Tensor qkv_out2(nullptr); bool layer_cache_not_none = layer_cache.size() > 0 ? true : false; bool memory_keys_not_none = false, memory_values_not_none = false, @@ -199,7 +199,7 @@ void MultiHeadedAttention::operator()( } } // else } else if (attn_type == "self") { - qkv_out1.Reshape({3, batch_size, query_seq_length, hidden_size}, + qkv_out1.Reshape({batch_size, query_seq_length, 3, hidden_size}, devtype, devid, "self/qkv_out1/Reshape"); if (pre_layernorm) { core::Tensor layernormed_query(nullptr); @@ -215,26 +215,33 @@ void MultiHeadedAttention::operator()( kernels::MatMul(query_tensor, false, qkv_weight_, is_trans_weight, 1.0, &qkv_out1, 0.0, "self/gemm012_fused"); } - qkv_out2.Reshape( - {3, batch_size, num_attention_heads_, query_seq_length, size_per_head}, - devtype, devid, "self/qkv_out2/Reshape"); + q_out.Reshape( + {batch_size, num_attention_heads_, query_seq_length, size_per_head}, + devtype, devid, "self/q/Reshape"); + k_out.Reshape( + {batch_size, num_attention_heads_, query_seq_length, size_per_head}, + devtype, devid, "self/k/Reshape"); + v_out.Reshape( + {batch_size, num_attention_heads_, query_seq_length, size_per_head}, + devtype, devid, "self/v/Reshape"); + kernels::SplitAddBiasTransposeForScore( - &qkv_out2, qkv_out1, qkv_bias_, "self/SplitAddBiasTransposeForScore"); - q_ptr = - new core::Tensor(qkv_out2[0]); // copy temporary tensor to heap space. + qkv_out1, qkv_bias_, q_out, k_out, v_out, + "self/SplitAddBiasTransposeForScore"); + q_ptr = &q_out; if (self_keys_not_none) { - kernels::Concat(*layer_cache["self_keys"], qkv_out2[1], 2, &k_out2, + kernels::Concat(*layer_cache["self_keys"], k_out, 2, &k_out2, "self/keys/Concat"); k_ptr = &k_out2; } else { - k_ptr = new core::Tensor(qkv_out2[1]); + k_ptr = &k_out; } if (self_values_not_none) { - kernels::Concat(*layer_cache["self_values"], qkv_out2[2], 2, - &v_out2, "self/values/Concat"); + kernels::Concat(*layer_cache["self_values"], v_out, 2, &v_out2, + "self/values/Concat"); v_ptr = &v_out2; } else { - v_ptr = new core::Tensor(qkv_out2[2]); + v_ptr = &v_out; } if (layer_cache_not_none) { layer_cache["self_keys"]->Reshape( diff --git a/turbo_transformers/python/tests/qbert_intermediate_test.py b/turbo_transformers/python/tests/qbert_intermediate_test.py index db4bf722..19fda2f6 100644 --- a/turbo_transformers/python/tests/qbert_intermediate_test.py +++ b/turbo_transformers/python/tests/qbert_intermediate_test.py @@ -4,7 +4,6 @@ from turbo_transformers.layers.utils import convert2tt_tensor, try_convert, convert_returns_as_type, ReturnType import time - cfg = transformers.BertConfig() model = transformers.BertModel(cfg) model.eval() @@ -33,5 +32,4 @@ end = time.time() print("turbo int8 layer QPS =", loops/(end-start)) -assert torch.max(torch.abs(res-res2)) < 1e-3 - +assert torch.max(torch.abs(res-res2)) < 1e-3 \ No newline at end of file diff --git a/turbo_transformers/python/tests/qbert_output_test.py b/turbo_transformers/python/tests/qbert_output_test.py index f863ac44..46e43340 100644 --- a/turbo_transformers/python/tests/qbert_output_test.py +++ b/turbo_transformers/python/tests/qbert_output_test.py @@ -4,7 +4,6 @@ from turbo_transformers.layers.utils import convert2tt_tensor, try_convert, convert_returns_as_type, ReturnType import time - cfg = transformers.BertConfig() model = transformers.BertModel(cfg) model.eval() @@ -35,4 +34,3 @@ print("turbo int8 layer QPS =", loops/(end-start)) assert torch.max(torch.abs(res-res2)) < 1e-3 - diff --git a/turbo_transformers/python/turbo_transformers/layers/__init__.py b/turbo_transformers/python/turbo_transformers/layers/__init__.py index daf56457..74f992c8 100644 --- a/turbo_transformers/python/turbo_transformers/layers/__init__.py +++ b/turbo_transformers/python/turbo_transformers/layers/__init__.py @@ -31,5 +31,4 @@ 'PositionwiseFeedForward', 'TransformerDecoderLayer', 'TransformerDecoder', 'RobertaModel', 'QBertIntermediate', 'QBertOutput', 'QBertLayer', 'QBertEncoder', 'QBertModel', 'GPT2Model' - ] diff --git a/turbo_transformers/python/turbo_transformers/layers/qmodeling_bert.py b/turbo_transformers/python/turbo_transformers/layers/qmodeling_bert.py index 5e0080dc..db01bb9c 100644 --- a/turbo_transformers/python/turbo_transformers/layers/qmodeling_bert.py +++ b/turbo_transformers/python/turbo_transformers/layers/qmodeling_bert.py @@ -1,5 +1,6 @@ import turbo_transformers.turbo_transformers_cxx as cxx import torch +import numpy as np from .return_type import convert_returns_as_type, ReturnType from .utils import try_convert, convert2tt_tensor, to_param_dict_convert_tt, to_param_dict, create_empty_if_none, AnyTensor from .modeling_bert import BertAttention, BertEmbeddings, BertEncoder, BertPooler, SequencePool, PoolingType, PoolingMap @@ -102,12 +103,58 @@ def from_torch(encoder): layers = [QBertLayer.from_torch(bert_layer) for bert_layer in encoder.layer] return QBertEncoder(layers) +def _build_onnxrt_session(model): + # using https://github.com/microsoft/onnxruntime/tree/master/onnxruntime/python/tools/transformers + dummy_input = {'input_ids': torch.ones(1,128, dtype=torch.int64), + 'attention_mask': torch.ones(1,128, dtype=torch.int64), + 'token_type_ids': torch.ones(1,128, dtype=torch.int64)} + symbolic_names = {0: 'batch_size', 1: 'max_seq_len'} + onnx_model_path = "/tmp/temp_turbo_onnx.model" + onnx_opt_model_path = "/tmp/temp_turbo_onnx_opt.model" + quantized_model_path = "/tmp/temp_turbo_onnx_q.model" + # (1) export to onnx fp32 model + with open(onnx_model_path, 'wb') as f: + torch.onnx.export(model, (dummy_input['input_ids'], dummy_input['attention_mask'], dummy_input['token_type_ids']), + f, input_names=['input_ids', 'attention_mask', 'token_type_ids'], output_names=['output'], + opset_version=11, + dynamic_axes={'input_ids': symbolic_names, 'attention_mask': symbolic_names, 'token_type_ids': symbolic_names}) + # (2) optimize the fp32 model + from onnxruntime_tools import optimizer + from onnxruntime_tools.transformers.onnx_model_bert import BertOptimizationOptions + opt_options = BertOptimizationOptions('bert') + opt_options.enable_embed_layer_norm = False + opt_model = optimizer.optimize_model( + onnx_model_path, + 'bert', + num_heads=model.config.num_attention_heads, + hidden_size=model.config.hidden_size, + optimization_options=opt_options) + opt_model.save_model_to_file(onnx_opt_model_path) + # (3) quantize the model + from onnxruntime.quantization import quantize, QuantizationMode + import onnx + import onnxruntime + import onnxruntime.backend + opt_model = onnx.load(onnx_opt_model_path) + quantized_onnx_model = quantize(opt_model, quantization_mode=QuantizationMode.IntegerOps, symmetric_weight=True, force_fusions=True) + onnx.save(quantized_onnx_model, quantized_model_path) + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + return onnxruntime.InferenceSession(quantized_model_path, sess_options) + + class QBertModel: - def __init__(self, model): - self.embeddings = BertEmbeddings.from_torch(model.embeddings) - self.encoder = QBertEncoder.from_torch(model.encoder) - self.pooler = BertPooler.from_torch(model.pooler) - self.prepare = cxx.PrepareBertMasks() + def __init__(self, model, backend='onnxrt'): + if backend == 'turbo': + self.backend = 'turbo' + self.embeddings = BertEmbeddings.from_torch(model.embeddings) + self.encoder = QBertEncoder.from_torch(model.encoder) + self.pooler = BertPooler.from_torch(model.pooler) + self.prepare = cxx.PrepareBertMasks() + else: + self.backend = 'onnxrt' + self.session = _build_onnxrt_session(model) + def __call__(self, inputs, attention_masks = None, token_type_ids = None, @@ -118,31 +165,48 @@ def __call__(self, inputs, output_hidden_states = None, pooling_type = PoolingType.FIRST, pooler_output = None): - attention_masks = try_convert(create_empty_if_none(attention_masks)) - token_type_ids = try_convert(create_empty_if_none(token_type_ids)) - position_ids = try_convert(create_empty_if_none(position_ids)) - inputs = try_convert(inputs) - extended_attention_masks = cxx.Tensor.create_empty() - self.prepare(inputs, attention_masks, token_type_ids, position_ids, extended_attention_masks) - hidden_cache = self.embeddings( - inputs, - position_ids=position_ids, - token_type_ids=token_type_ids, - return_type=ReturnType.TORCH) - encoder_outputs = self.encoder( - hidden_states=hidden_cache, - attention_mask=extended_attention_masks, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states) - sequence_output = encoder_outputs[0] - self.seq_pool = SequencePool(PoolingMap[pooling_type]) - sequence_pool_output = self.seq_pool( - input_tensor=sequence_output, - return_type=ReturnType.TORCH) - pooler_output = self.pooler(sequence_pool_output, ReturnType.TORCH, - pooler_output) - return (sequence_output, pooler_output, ) + encoder_outputs[1:] - @staticmethod - def from_torch(model): - return QBertModel(model) + if self.backend == 'turbo': + attention_masks = try_convert(create_empty_if_none(attention_masks)) + token_type_ids = try_convert(create_empty_if_none(token_type_ids)) + position_ids = try_convert(create_empty_if_none(position_ids)) + inputs = try_convert(inputs) + extended_attention_masks = cxx.Tensor.create_empty() + self.prepare(inputs, attention_masks, token_type_ids, position_ids, extended_attention_masks) + hidden_cache = self.embeddings( + inputs, + position_ids=position_ids, + token_type_ids=token_type_ids, + return_type=ReturnType.TORCH) + encoder_outputs = self.encoder( + hidden_states=hidden_cache, + attention_mask=extended_attention_masks, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states) + sequence_output = encoder_outputs[0] + self.seq_pool = SequencePool(PoolingMap[pooling_type]) + sequence_pool_output = self.seq_pool( + input_tensor=sequence_output, + return_type=ReturnType.TORCH) + pooler_output = self.pooler(sequence_pool_output, ReturnType.TORCH, + pooler_output) + return (sequence_output, pooler_output, ) + encoder_outputs[1:] + else: + if attention_masks is None: + attention_masks = np.ones(inputs.size(), dtype=np.int64) + else: + attention_masks = attention_masks.cpu().numpy() + if token_type_ids is None: + token_type_ids = np.zeros(inputs.size(), dtype=np.int64) + else: + token_type_ids = token_type_ids.cpu().numpy() + ort_inputs = {'input_ids': inputs.cpu().numpy(), + 'attention_mask': attention_masks, + 'token_type_ids': token_type_ids} + outputs = self.session.run(None, ort_inputs) + for idx, item in enumerate(outputs): + outputs[idx] = torch.tensor(item, device=inputs.device) + return tuple(outputs) + @staticmethod + def from_torch(model, backend='onnxrt'): + return QBertModel(model, backend)