Skip to content

Commit

Permalink
[ONNIFI] Implement onnxSetIOAndRunGraph onnxifi extension (pytorch#2360)
Browse files Browse the repository at this point in the history
* Implement onnxSetIOAndRunGraph extension

* address comments
  • Loading branch information
jackm321 authored Feb 6, 2019
1 parent a0bddca commit 2a37c5a
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 2 deletions.
4 changes: 3 additions & 1 deletion lib/Onnxifi/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ void Graph::runAsync(EventPtr inputEvent, EventPtr outputEvent) {
backend()->runAsync([inputEvent, outputEvent, inputPlaceholderToBuffer,
outputPlaceholderToBuffer, this]() {
// Wait for all inputs to be ready.
inputEvent->wait();
if (inputEvent) {
inputEvent->wait();
}
// Run inference.
this->run(inputPlaceholderToBuffer, outputPlaceholderToBuffer);
// Signal that the outputs are ready.
Expand Down
1 change: 1 addition & 0 deletions lib/Onnxifi/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "glow/Support/ThreadPool.h"

#include "onnx/onnxifi.h"
#include "onnx/onnxifi_ext.h"

#include <atomic>
#include <condition_variable>
Expand Down
94 changes: 94 additions & 0 deletions lib/Onnxifi/onnxifiGlow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxGetBackendInfo)(
case ONNXIFI_BACKEND_SYNCHRONIZATION_TYPES:
return setBackendInfoUInt64(infoValue, infoValueSize,
ONNXIFI_SYNCHRONIZATION_EVENT);
case ONNXIFI_BACKEND_EXTENSIONS:
return setBackendInfoString(infoValue, infoValueSize,
"onnxSetIOAndRunGraphFunction");
default:
return ONNXIFI_STATUS_UNSUPPORTED_PROPERTY;
}
Expand Down Expand Up @@ -472,6 +475,62 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxRunGraph)(
return ONNXIFI_STATUS_SUCCESS;
}

/// Binds inputs and outputs of an ONNXIFI graph to specific addresses then
/// asynchronously execute operations in the graph using the provided
/// addresses.
EXTERNC ONNXIFI_PUBLIC ONNXIFI_CHECK_RESULT onnxStatus ONNXIFI_ABI
GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxSetIOAndRunGraph)(
onnxGraph graph, uint32_t inputsCount,
const onnxTensorDescriptorV1 *inputDescriptors, uint32_t outputsCount,
const onnxTensorDescriptorV1 *outputDescriptors,
onnxMemoryFenceV1 *outputFence) {
auto &manager = glow::onnxifi::GlowOnnxifiManager::get();

if (!inputDescriptors || !outputDescriptors || !outputFence) {
return ONNXIFI_STATUS_INVALID_POINTER;
}

// Check output fence is correct type and tag.
if (outputFence->type != ONNXIFI_SYNCHRONIZATION_EVENT ||
outputFence->tag != ONNXIFI_TAG_MEMORY_FENCE_V1) {
return ONNXIFI_STATUS_UNSUPPORTED_TAG;
}

// Check glowGraph is valid.
auto *glowGraph = static_cast<glow::onnxifi::GraphPtr>(graph);
if (!manager.isValid(glowGraph)) {
return ONNXIFI_STATUS_INVALID_GRAPH;
}

// Initialize outputFence's event.
auto outputEventInitStatus = GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(
onnxInitEvent)(glowGraph->backend(), &outputFence->event);
if (outputEventInitStatus != ONNXIFI_STATUS_SUCCESS) {
return outputEventInitStatus;
}

// Verify inputs.
auto inputStatus = verifyDescriptors(inputsCount, inputDescriptors);
if (inputStatus != ONNXIFI_STATUS_SUCCESS) {
return inputStatus;
}

// Verify outputs.
auto outputStatus = verifyDescriptors(outputsCount, outputDescriptors);
if (outputStatus != ONNXIFI_STATUS_SUCCESS) {
return outputStatus;
}

auto *outputEvent = static_cast<glow::onnxifi::EventPtr>(outputFence->event);

// Set graph IO and run asynchronous.
glowGraph->setIO(inputsCount, inputDescriptors, outputsCount,
outputDescriptors);
glowGraph->runAsync(/*inputEvent*/ nullptr, outputEvent);

return ONNXIFI_STATUS_SUCCESS;
}

/// Deinitialize an ONNXIFI graph and release associated resources.
/// It blocks until all in-flight inference operations complete.
EXTERNC ONNXIFI_PUBLIC ONNXIFI_CHECK_RESULT onnxStatus ONNXIFI_ABI
Expand All @@ -487,3 +546,38 @@ GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxReleaseGraph)(onnxGraph graph) {

return ONNXIFI_STATUS_SUCCESS;
}

ONNXIFI_PUBLIC ONNXIFI_CHECK_RESULT onnxStatus ONNXIFI_ABI
onnxGetExtensionFunctionAddress(onnxBackendID backendID, const char *name,
onnxExtensionFunctionPointer *function) {
if (!name || !function) {
return ONNXIFI_STATUS_INVALID_POINTER;
}

auto &manager = glow::onnxifi::GlowOnnxifiManager::get();

auto *glowBackendId = static_cast<glow::onnxifi::BackendIdPtr>(backendID);
if (!manager.isValid(glowBackendId)) {
return ONNXIFI_STATUS_INVALID_ID;
}

// Map of name to onnxExtensionFunctionPointer, one entry for each implemented
// onnxifi extension.
// NOTE: when updating this map, also update the response from
// onnxGetBackendInfo for the ONNXIFI_BACKEND_EXTENSIONS query.
static const std::unordered_map<std::string, onnxExtensionFunctionPointer>
extensionMap = {
{"onnxSetIOAndRunGraphFunction",
reinterpret_cast<onnxExtensionFunctionPointer>(
GLOW_ONNXIFI_LIBRARY_FUNCTION_WRAPPER(onnxSetIOAndRunGraph))}};

auto extensionIt = extensionMap.find(name);

if (extensionIt == extensionMap.end()) {
// No function found for the given name.
return ONNXIFI_STATUS_UNIDENTIFIED_NAME;
}

*function = extensionIt->second;
return ONNXIFI_STATUS_SUCCESS;
}
2 changes: 1 addition & 1 deletion thirdparty/onnx

0 comments on commit 2a37c5a

Please sign in to comment.