Skip to content

Commit

Permalink
feat: CustomPlugin
Browse files Browse the repository at this point in the history
  • Loading branch information
chongqinghuang committed Mar 1, 2023
1 parent 5c87ac8 commit c919fcb
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 0 deletions.
3 changes: 3 additions & 0 deletions plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ set(PLUGIN_LISTS
specialSlicePlugin
splitPlugin
voxelGeneratorPlugin
# 添加编译Custom插件的选项<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
CustomPlugin

)

# Add BERT sources if ${BERT_GENCODES} was populated
Expand Down
18 changes: 18 additions & 0 deletions plugin/CustomPlugin/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
170 changes: 170 additions & 0 deletions plugin/CustomPlugin/ICustomPlugin.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "CustomPlugin.h"
#include "checkMacrosPlugin.h"
#include "kernel.h"

using namespace nvinfer1;
using nvinfer1::plugin::CustomPluginCreator;
using nvinfer1::plugin::Custom;

static const char* Custom_PLUGIN_VERSION{"1"};
static const char* Custom_PLUGIN_NAME{"Custom_TRT"};
PluginFieldCollection CustomPluginCreator::mFC{};
std::vector<PluginField> CustomPluginCreator::mPluginAttributes;

// LeakyReLU {{{
Custom::Custom(float negSlope)
: mNegSlope(negSlope)
, mBatchDim(1)
{
}

Custom::Custom(const void* buffer, size_t length)
{
const char *d = reinterpret_cast<const char *>(buffer), *a = d;
mNegSlope = read<float>(d);
mBatchDim = read<int>(d);
ASSERT(d == a + length);
}

int Custom::getNbOutputs() const noexcept
{
return 1;
}

Dims Custom::getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept
{
ASSERT(nbInputDims == 1);
ASSERT(index == 0);
return inputs[0];
}

int Custom::enqueue(
int batchSize, const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
const void* inputData = inputs[0];
void* outputData = outputs[0];
pluginStatus_t status = CustomInference(stream, mBatchDim * batchSize, mNegSlope, inputData, outputData);
return status;
}

size_t Custom::getSerializationSize() const noexcept
{
// mNegSlope, mBatchDim
return sizeof(float) + sizeof(int);
}

void Custom::serialize(void* buffer) const noexcept
{
char *d = reinterpret_cast<char *>(buffer), *a = d;
write(d, mNegSlope);
write(d, mBatchDim);
ASSERT(d == a + getSerializationSize());
}

void Custom::configureWithFormat(
const Dims* inputDims, int /* nbInputs */, const Dims* /* outputDims */, int nbOutputs, DataType type, PluginFormat format, int) noexcept
{
ASSERT(type == DataType::kFLOAT && format == PluginFormat::kLINEAR);
ASSERT(mBatchDim == 1);
ASSERT(nbOutputs == 1);
for (int i = 0; i < inputDims[0].nbDims; ++i)
{
mBatchDim *= inputDims[0].d[i];
}
}

bool Custom::supportsFormat(DataType type, PluginFormat format) const noexcept
{
return (type == DataType::kFLOAT && format == PluginFormat::kLINEAR);
}

int Custom::initialize() noexcept
{
return 0;
}

void Custom::terminate() noexcept {}

size_t Custom::getWorkspaceSize(int /* maxBatchSize */) const noexcept
{
return 0;
}

const char* Custom::getPluginType() const noexcept // 对应
{
return Custom_PLUGIN_NAME;
}

const char* Custom::getPluginVersion() const noexcept // 对应
{
return Custom_PLUGIN_VERSION;
}

void Custom::destroy() noexcept
{
delete this;
}

IPluginV2* Custom::clone() const noexcept
{
IPluginV2* plugin = new Custom(mNegSlope);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}

CustomPluginCreator::CustomPluginCreator()
{
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("negSlope", nullptr, PluginFieldType::kFLOAT32, 1));

mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}

const char* CustomPluginCreator::getPluginName() const noexcept
{
return Custom_PLUGIN_NAME;
}

const char* CustomPluginCreator::getPluginVersion() const noexcept
{
return Custom_PLUGIN_VERSION;
}

const PluginFieldCollection* CustomPluginCreator::getFieldNames() noexcept
{
return &mFC;
}

IPluginV2* CustomPluginCreator::createPlugin(const char* name, const PluginFieldCollection* fc) noexcept
{
const PluginField* fields = fc->fields;
ASSERT(fc->nbFields == 1);
ASSERT(fields[0].type == PluginFieldType::kFLOAT32);
float negSlope = *(static_cast<const float*>(fields[0].data));

return new Custom(negSlope);
}

IPluginV2* CustomPluginCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call CustomPlugin::destroy()
return new Custom(serialData, serialLength);
}
// LeakReLU }}}
101 changes: 101 additions & 0 deletions plugin/CustomPlugin/ICustomPlugin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_L_Custom_PLUGIN_H
#define TRT_L_Custom_PLUGIN_H
#include "NvInferPlugin.h"
#include "kernel.h"
#include "plugin.h"
#include <cassert>
#include <iostream>
#include <string>
#include <vector>

namespace nvinfer1
{
namespace plugin
{

class Custom : public BasePlugin
{
public:
Custom(float negSlope);

Custom(const void* buffer, size_t length);

~Custom() override = default;

int getNbOutputs() const noexcept override;

Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) noexcept override;

int initialize() noexcept override;

void terminate() noexcept override;

size_t getWorkspaceSize(int maxBatchSize) const noexcept override;

int enqueue(int batchSize, const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept override;

size_t getSerializationSize() const noexcept override;

void serialize(void* buffer) const noexcept override;

void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) noexcept override;

bool supportsFormat(DataType type, PluginFormat format) const noexcept override;

const char* getPluginType() const noexcept override;

const char* getPluginVersion() const noexcept override;

void destroy() noexcept override;

IPluginV2* clone() const noexcept override;

private:
float mNegSlope;
int mBatchDim;
};

// 需要写一个创建Custom插件的类
class CustomPluginCreator : public BaseCreator
{
public:
CustomPluginCreator();

~CustomPluginCreator() override = default;

const char* getPluginName() const noexcept override; // 对应

const char* getPluginVersion() const noexcept override; //对应

const PluginFieldCollection* getFieldNames() noexcept override;

IPluginV2* createPlugin(const char* name, const PluginFieldCollection* fc) noexcept override;

IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) noexcept override;

private:
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
};

typedef Custom PCustom; // Temporary. For backward compatibilty.
} // namespace plugin
} // namespace nvinfer1

#endif // TRT_L_Custom_PLUGIN_H
3 changes: 3 additions & 0 deletions plugin/InferPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ extern "C"
initializePlugin<nvinfer1::plugin::SpecialSlicePluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::SplitPluginCreator>(logger, libNamespace);
initializePlugin<nvinfer1::plugin::VoxelGeneratorPluginCreator>(logger, libNamespace);
// inferplugin.cpp文件中添加初始化plugin的接口
// CustomPluginCreator locate: plugin/CustomPlugin/ICustomPlugin.h
initializePlugin<nvinfer1::plugin::CustomPluginCreator>(logger, libNamespace);
return true;
}
} // extern "C"

0 comments on commit c919fcb

Please sign in to comment.