forked from NVIDIA/cudnn-frontend
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpycudnn.cpp
79 lines (64 loc) · 2.66 KB
/
pycudnn.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <utility>
#include "pybind11/pybind11.h"
#include "pybind11/cast.h"
#include "pybind11/stl.h"
#include "cudnn_frontend.h"
namespace py = pybind11;
using namespace pybind11::literals;
namespace cudnn_frontend {
void *cudnn_dlhandle = nullptr;
namespace python_bindings {
// Raise C++ exceptions corresponding to C++ FE error codes.
// Pybinds will automatically convert C++ exceptions to pythpn exceptions.
void
throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const &error_msg) {
if (cond == false) return;
switch (error_code) {
case cudnn_frontend::error_code_t::OK:
return;
case cudnn_frontend::error_code_t::ATTRIBUTE_NOT_SET:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::SHAPE_DEDUCTION_FAILED:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::INVALID_TENSOR_NAME:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::INVALID_VARIANT_PACK:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::CUDA_API_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::HANDLE_ERROR:
throw std::runtime_error(error_msg);
}
}
// pybinds for pygraph class
void
init_pygraph_submodule(py::module_ &);
// pybinds for all properties and helpers
void
init_properties(py::module_ &);
void
set_dlhandle_cudnn(std::intptr_t dlhandle) {
cudnn_dlhandle = reinterpret_cast<void *>(dlhandle);
}
PYBIND11_MODULE(_compiled_module, m) {
m.def("backend_version", &cudnn_frontend::get_backend_version);
init_properties(m);
init_pygraph_submodule(m);
m.def("_set_dlhandle_cudnn", &set_dlhandle_cudnn);
}
} // namespace python_bindings
} // namespace cudnn_frontend