Skip to content

Commit

Permalink
Develop (#154)
Browse files Browse the repository at this point in the history
Add tf ckpt converter for bert (#159)
The memory of addspilttranspose is not continous for Q, K, V out.
fix bert profiler bugs.
update gpu fixed length benchmark scripts.
add protection for splittranspose's QKV outputs
  • Loading branch information
feifeibear authored Aug 19, 2020
1 parent e623096 commit 8fbbd2a
Show file tree
Hide file tree
Showing 18 changed files with 477 additions and 104 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_FLAGS "-Wall")
set(CMAKE_C_FLAGS "-Wall")

set(TURBO_TRANSFORMERS_VERSION 0.4.1)
set(TURBO_TRANSFORMERS_VERSION 0.4.2)

option(WITH_PROFILER "Compile with profiler" OFF)
option(WITH_GPU "Build with GPU" OFF)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/run_gpu_fixed_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ do
for framework in ${FRAMEWORKS[*]}
do
python benchmark.py ${MODEL} --seq_len=${seq_len} --batch_size=${batch_size}\
-n ${N} --framework=${framework}
-n ${N} --framework=${framework} --use_gpu
done
done
done
Expand Down
10 changes: 5 additions & 5 deletions example/cpp/bert_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,18 @@ struct BertModel::Impl {
layer(hidden, extendedAttentionMask, &attOut, &intermediateOut, &hidden);
}

core::Tensor poolingOutput(nullptr);
layers::SequencePool(static_cast<layers::types::PoolType>(pooling))(
hidden, &poolingOutput);
std::vector<float> vec;
if (use_pooler) {
core::Tensor output(nullptr);
core::Tensor poolingOutput(nullptr);
layers::SequencePool(static_cast<layers::types::PoolType>(pooling))(
hidden, &poolingOutput);
(*pooler_)(poolingOutput, &output);
vec.resize(output.numel());
core::Copy(output, vec);
} else {
vec.resize(poolingOutput.numel());
core::Copy(poolingOutput, vec);
vec.resize(hidden.numel());
core::Copy(hidden, vec);
}

return vec;
Expand Down
85 changes: 48 additions & 37 deletions example/python/bert_for_sequence_classification_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,51 @@
# import the class of the acceleration model. here is the example of BertForSequenceClassification.
from transformers.modeling_bert import BertModel as TorchBertModel
from transformers import BertTokenizer
from transformers.modeling_bert import BertForSequenceClassification as TorchBertForSequenceClassification
from transformers.modeling_bert import (
BertForSequenceClassification as TorchBertForSequenceClassification,
)
import os
import torch
from typing import Optional


#TODO(jiarufang) developed under v0.1.0, after that not tested.
#Contact me if you find it is wrong.
# TODO(jiarufang) developed under v0.1.0, after that not tested.
# Contact me if you find it is wrong.
class BertForSequenceClassification: # create a new class for speeding up
def __init__(
self, bertmodel, classifier
self, bertmodel, classifier
): # the realization of the init function(we can just copy it)
self.bert = bertmodel
self.classifier = classifier

def __call__(
self, # the realization of the call function(we can just copy it)
inputs,
attention_masks=None,
token_type_ids=None,
position_ids=None,
pooling_type=PoolingType.FIRST,
return_type=None):
pooler_output, _, _ = self.bert(inputs,
attention_masks,
token_type_ids,
position_ids,
pooling_type,
return_type=ReturnType.TORCH)
self, # the realization of the call function(we can just copy it)
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
pooling_type=PoolingType.FIRST,
return_type=None,
):
bert_outputs = self.bert(
input_ids,
attention_mask,
token_type_ids,
position_ids,
pooling_type,
return_type=ReturnType.TORCH,
)
pooled_output = bert_outputs[1]
logits = self.classifier(
pooler_output
pooled_output
) # It's the output of classifier, if User want to output the other type, he can define them after that.
return logits

@staticmethod
def from_torch(
model: TorchBertModel, # from_torch函数实现
device: Optional[torch.device] = None):
if device is not None and 'cuda' in device.type and torch.cuda.is_available(
):
model: TorchBertModel, device: Optional[torch.device] = None # from_torch函数实现
):
if device is not None and "cuda" in device.type and torch.cuda.is_available():
model.to(device)
bertmodel = turbo_transformers.BertModel.from_torch(model.bert)
# We can copy the following code and do not change it
Expand All @@ -67,11 +72,11 @@ def from_torch(
return BertForSequenceClassification(bertmodel, model.classifier)

@staticmethod
def from_pretrained(model_id_or_path: str,
device: Optional[torch.device] = None):
def from_pretrained(model_id_or_path: str, device: Optional[torch.device] = None):
# First, Use the function of from_pretrained to load the model you trained.
torch_model = TorchBertForSequenceClassification.from_pretrained(
model_id_or_path)
model_id_or_path
)
# Then, Use the init function of the acceleration model to get it.
model = BertForSequenceClassification.from_torch(torch_model, device)
model._torch_model = torch_model # prevent destroy torch model.
Expand All @@ -82,18 +87,24 @@ def from_pretrained(model_id_or_path: str,
turbo_transformers.set_num_threads(4)

model_id = os.path.join(
os.path.dirname(__file__),
'test-seq-classification-model') # the model of huggingface's path
tokenizer = BertTokenizer.from_pretrained(
model_id) # the initialization of tokenizer
os.path.dirname(__file__), "bert_model"
) # the model of huggingface's path
tokenizer = BertTokenizer.from_pretrained(model_id) # the initialization of tokenizer
turbo_model = BertForSequenceClassification.from_pretrained(
model_id,
torch.device('cpu:0')) # the initialization of the acceleration model
model_id, torch.device("cpu:0")
) # the initialization of the acceleration model

# predict after loading the model
input_ids = torch.tensor(
tokenizer.encode('测试一下bert模型的性能和精度是不是符合要求?',
add_special_tokens=True)).unsqueeze(0)
torch_result = turbo_model(input_ids)
print(torch_result)
# tensor([[ 0.1451, -0.0373]], grad_fn=<AddmmBackward>)

text = "Sample input text"
inputs = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors="pt")
# turbo_result holds the returned logits from TurboTransformers model
turbo_result = turbo_model(**inputs)

torch_model = TorchBertForSequenceClassification.from_pretrained(model_id)
# torch_result holds the returned logits from original Transformers model
torch_result = torch_model(**inputs)[0]
print(turbo_result)
# tensor([[0.2716, 0.0318]], grad_fn=<AddmmBackward>)
print(torch_result) # torch_result and turbo_result should hold the same logits
# tensor([[0.2716, 0.0318]], grad_fn=<AddmmBackward>)
119 changes: 119 additions & 0 deletions tools/convert_tf_bert_to_npz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (C) 2020 THL A29 Limited, a Tencent company.
# All rights reserved.
# Licensed under the BSD 3-Clause License (the "License"); you may
# not use this file except in compliance with the License. You may
# obtain a copy of the License at
# https://opensource.org/licenses/BSD-3-Clause
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" basis,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
# See the AUTHORS file for names of contributors.

from transformers import BertConfig
try:
import tensorflow as tf
except ImportError:
print("please install tensorflow 2.0 by run `pip install tensorflow`")
import numpy as np
import sys
import os


# User should define the map between tf model's layer name to tt model's layer name
def build_dic(num_layers):
dic = {
'bert/embeddings/word_embeddings':
'embeddings.word_embeddings.weight',
'bert/embeddings/position_embeddings':
'embeddings.position_embeddings.weight',
'bert/embeddings/token_type_embeddings':
'embeddings.token_type_embeddings.weight',
'bert/embeddings/LayerNorm/gamma':
'embeddings.LayerNorm.weight',
'bert/embeddings/LayerNorm/beta':
'embeddings.LayerNorm.bias',
'bert/pooler/dense/kernel': 'pooler.dense.weight',
'bert/pooler/dense/bias': 'pooler.dense.bias'
}

for i in range(num_layers):
dic[f'bert/encoder/layer_{i}/attention/self/query/kernel'] = f'encoder.layer.{i}.attention.self.query.weight'
dic[f'bert/encoder/layer_{i}/attention/self/query/bias'] = f'encoder.layer.{i}.attention.self.query.bias'
dic[f'bert/encoder/layer_{i}/attention/self/key/kernel'] = f'encoder.layer.{i}.attention.self.key.weight'
dic[f'bert/encoder/layer_{i}/attention/self/key/bias'] = f'encoder.layer.{i}.attention.self.key.bias'
dic[f'bert/encoder/layer_{i}/attention/self/value/kernel'] = f'encoder.layer.{i}.attention.self.value.weight'
dic[f'bert/encoder/layer_{i}/attention/self/value/bias'] = f'encoder.layer.{i}.attention.self.value.bias'
dic[f'bert/encoder/layer_{i}/attention/output/dense/kernel'] = f'encoder.layer.{i}.attention.output.dense.weight'
dic[f'bert/encoder/layer_{i}/attention/output/dense/bias'] = f'encoder.layer.{i}.attention.output.dense.bias'
dic[f'bert/encoder/layer_{i}/attention/output/LayerNorm/gamma'] = f'encoder.layer.{i}.attention.output.LayerNorm.weight'
dic[f'bert/encoder/layer_{i}/attention/output/LayerNorm/beta'] = f'encoder.layer.{i}.attention.output.LayerNorm.bias'
dic[f'bert/encoder/layer_{i}/intermediate/dense/kernel'] = f'encoder.layer.{i}.intermediate.dense.weight'
dic[f'bert/encoder/layer_{i}/intermediate/dense/bias'] = f'encoder.layer.{i}.intermediate.dense.bias'
dic[f'bert/encoder/layer_{i}/output/dense/kernel'] = f'encoder.layer.{i}.output.dense.weight'
dic[f'bert/encoder/layer_{i}/output/dense/bias'] = f'encoder.layer.{i}.output.dense.bias'
dic[f'bert/encoder/layer_{i}/output/LayerNorm/gamma'] = f'encoder.layer.{i}.output.LayerNorm.weight'
dic[f'bert/encoder/layer_{i}/output/LayerNorm/beta'] = f'encoder.layer.{i}.output.LayerNorm.bias'
return dic


def trans_layer_name_tf2turbo(dic, name):
return dic[name]


def main():
if len(sys.argv) != 3:
print(
"Usage: \n"
" convert_tf_bert_to_npz.py model_name output_file")
exit(0)
model_path = sys.argv[1]
ckpt_path = os.path.join(model_path, "bert_model.ckpt")
cfg = BertConfig.from_pretrained(os.path.join(model_path, "bert_config.json"))
dic = build_dic(cfg.num_hidden_layers)
names = [v[0] for v in tf.train.list_variables(ckpt_path)]

arrays = {}
for i in range(len(names)):
if names[i].startswith("cls"):
continue
arrays[trans_layer_name_tf2turbo(dic, names[i])] = tf.train.load_variable(ckpt_path, names[i])

q_weight_key = 'self.query.weight'
k_weight_key = 'self.key.weight'
v_weight_key = 'self.value.weight'

q_bias_key = 'self.query.bias'
k_bias_key = 'self.key.bias'
v_bias_key = 'self.value.bias'

numpy_dict = {}

for k in arrays.keys():
if k.endswith(q_weight_key):
ret = []
ret.append(arrays[k])
ret.append(arrays[k[:-len(q_weight_key)] + k_weight_key])
ret.append(arrays[k[:-len(q_weight_key)] + v_weight_key])
v = np.concatenate(ret, axis=1)
numpy_dict[k[:-len(q_weight_key)] +
"qkv.weight"] = np.ascontiguousarray(v)
elif k.endswith(q_bias_key):
ret = []
ret.append(arrays[k])
ret.append(arrays[k[:-len(q_bias_key)] + k_bias_key])
ret.append(arrays[k[:-len(q_bias_key)] + v_bias_key])
v = np.ascontiguousarray(np.concatenate(ret, axis=0))
numpy_dict[k[:-len(q_bias_key)] + 'qkv.bias'] = v
elif any((k.endswith(suffix) for suffix in (k_weight_key, v_weight_key,
k_bias_key, v_bias_key))):
continue
else:
numpy_dict[k] = np.ascontiguousarray(arrays[k])

np.savez_compressed(sys.argv[2], **numpy_dict)


if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion tools/docker/Dockerfile_release.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ RUN /opt/conda/bin/conda install pytorch==1.5.0 cpuonly -c pytorch && \
/opt/conda/bin/conda install make cmake git graphviz gperftools git-lfs docopt -c conda-forge && \
/opt/conda/bin/conda clean -afy

RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt
RUN pip --no-cache-dir install contexttimer future transformers==3.0.2 docopt onnxruntime-tools
WORKDIR /workspace
1 change: 0 additions & 1 deletion turbo_transformers/layers/bert_embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "turbo_transformers/layers/kernels/embedding.h"
#include "turbo_transformers/layers/kernels/layer_norm.h"


namespace turbo_transformers {
namespace layers {

Expand Down
2 changes: 1 addition & 1 deletion turbo_transformers/layers/bert_intermediate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void BertIntermediate::operator()(const core::Tensor& input_tensor,
kernels::AddBiasAct<float, kernels::ActivationType::Gelu>(
dense_bias_, output_tensor, "BertIntermediate/AddBiasAct");
#ifdef WITH_PERFTOOLS
profile_ctx.end_profile("BertIntermediate");
profile_ctx.end_profile("BertIntermediate", input_tensor.device_type());
#endif
}

Expand Down
2 changes: 1 addition & 1 deletion turbo_transformers/layers/bert_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void BertOutput::operator()(const core::Tensor &hidden_states,
input_tensor, dense_bias_, layer_norm_weight_, layer_norm_bias_,
output_tensor, 1e-12, "BertOutput/AddBiasLayerNorm");
#ifdef WITH_PERFTOOLS
profile_ctx.end_profile("BertOutput");
profile_ctx.end_profile("BertOutput", input_tensor.device_type());
#endif
}

Expand Down
60 changes: 60 additions & 0 deletions turbo_transformers/layers/kernels/gpu_transpose_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,66 @@ void GPUSplitAddBiasTransposeForScore(
weight_num, size_per_head, out_data);
}

/*
Output transpose results into three tensors
*/
static __global__ void split_add_bias_transpose_for_score_3output(
const float* input_data, const float* bias_data, const int batch_size,
const int seq_len, const int head_num, const int weight_num,
const int size_per_head, float* q_output_data, float* k_output_data,
float* v_output_data) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int idx = tid;
int batch_id = bid / (seq_len * weight_num * head_num);
int seq_id =
bid % (seq_len * weight_num * head_num) / (weight_num * head_num);
int weight_id = bid % (weight_num * head_num) / head_num;
int head_id = bid % head_num;

int head_num_size_per_head = head_num * size_per_head;
int weight_id_head_num_size_per_head = weight_id * head_num_size_per_head;
int head_id_size_per_head = head_id * size_per_head;

float* output_data = nullptr;
if (weight_id == 0) {
output_data = q_output_data;
} else if (weight_id == 1) {
output_data = k_output_data;
} else if (weight_id == 2) {
output_data = v_output_data;
}

while (idx < size_per_head) {
float bias_val = bias_data[weight_id_head_num_size_per_head +
head_id_size_per_head + idx];
output_data[batch_id * seq_len * head_num_size_per_head +
head_id * seq_len * size_per_head + seq_id * size_per_head +
idx] =
input_data[batch_id * seq_len * weight_num * head_num_size_per_head +
seq_id * weight_num * head_num_size_per_head +
weight_id_head_num_size_per_head + head_id_size_per_head +
idx] +
bias_val;
idx += blockDim.x;
}
}

template <>
void GPUSplitAddBiasTransposeForScoreThreeOutput(
const float* input_data, const float* bias_data, int64_t batch_size,
int64_t seq_len, int64_t weight_num, int64_t num_attention_heads,
int64_t size_per_head, cudaStream_t stream, float* q_out_data,
float* k_out_data, float* v_out_data) {
const int n = size_per_head;
const int m = batch_size * seq_len * num_attention_heads * weight_num;
dim3 grid(m);
dim3 block(min(n, 1024));
split_add_bias_transpose_for_score_3output<<<grid, block, 0, stream>>>(
input_data, bias_data, batch_size, seq_len, num_attention_heads,
weight_num, size_per_head, q_out_data, k_out_data, v_out_data);
}

namespace {

// batch, head, seq, size_per_head -> batch head seq size_per_head
Expand Down
Loading

0 comments on commit 8fbbd2a

Please sign in to comment.