forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcustom_class.cpp
48 lines (38 loc) · 1.47 KB
/
custom_class.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#include <torch/custom_class.h>
#include <ATen/core/jit_type.h>
#include <atomic>
#include <unordered_map>
namespace torch {
std::unordered_map<std::string, at::ClassTypePtr>& customClasses() {
static std::unordered_map<std::string, at::ClassTypePtr> customClasses;
return customClasses;
}
void registerCustomClass(at::ClassTypePtr class_type) {
TORCH_INTERNAL_ASSERT(class_type->name());
auto name = class_type->name()->qualifiedName();
TORCH_CHECK(
!customClasses().count(name),
"Custom class with name ",
name,
" is already registered. Ensure that registration with torch::class_ is only called once.");
customClasses()[name] = std::move(class_type);
}
at::ClassTypePtr getCustomClass(const std::string& name) {
// BC hack so we can upgrade a binary internally
if (name == "__torch__.torch.classes.SentencePiece") {
return getCustomClass("__torch__.torch.classes.fb.SentencePiece");
}
return customClasses().count(name) ? customClasses()[name] : nullptr;
}
bool isCustomClass(const c10::IValue& v) {
return v.isObject() && v.toObject()->type()->name() &&
getCustomClass(v.toObject()->type()->name()->qualifiedName());
}
std::vector<std::unique_ptr<jit::Function>>& customClassMethods() {
static std::vector<std::unique_ptr<jit::Function>> customClassMethods;
return customClassMethods;
}
void registerCustomClassMethod(std::unique_ptr<jit::Function> fn) {
customClassMethods().emplace_back(std::move(fn));
}
} // namespace torch