-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcomputational_graph_map.cpp
87 lines (50 loc) · 2.89 KB
/
computational_graph_map.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include "tensor.h"
#include "computational_graph_map.h"
#include "m_algorithms_register.h"
#include <assert.h>
namespace NeuralNetwork {
namespace Computation {
namespace Graph {
TensorID ComputationalGraphMap::tensor_id = TensorID(0);
FunctionObject ComputationalGraphMap::_get_operation(TensorID my_tensor_id) noexcept {
// u_int16_t my_tensor_id = _t->get_tensor_id();
assert(my_tensor_id > TensorID(0) && "Must be an op_id greater than 0.");
assert(my_tensor_id <= tensor_id && "OP registry not this large");
// if (my_tensor_id >= tensor_id) throw std::invalid_argument("OP registry not this large.");
return op_registry.at(my_tensor_id.get());
}
std::shared_ptr<Tensor> ComputationalGraphMap::_get_tensor(TensorID my_tensor_id) noexcept {
std::cout << "Get Tensor ID: " << my_tensor_id.get() << std::endl;
assert(my_tensor_id > TensorID(0) && "Must be an op_id greater than 0.");
assert(my_tensor_id <= tensor_id && "OP registry not this large");
return tensor_registry.at(my_tensor_id.get());
}
void ComputationalGraphMap::_recover_tensor_id(TensorID my_tensor_id) noexcept {
recovered_tensor_id.push(my_tensor_id);
tensor_registry.at(my_tensor_id.get()) = nullptr;
}
TensorID ComputationalGraphMap::_obtain_tensor_id() noexcept {
// Matrix::Operations::Utility::Stringify stringify;
TensorID next_tensor_id = TensorID(0);
if (!recovered_tensor_id.empty()){
next_tensor_id = recovered_tensor_id.top();
recovered_tensor_id.pop();
// auto fn = Matrix::Operations::Utility::Function::from(_get_operation(next_tensor_id).get_code());
// std::cout << "Recovered Registry: O[" << next_tensor_id << "]" << std::visit(stringify, fn) << std::endl;
std::cout << "Recovered Registry: OP[" << next_tensor_id.get() << "]" << std::endl;
} else {
next_tensor_id = ++tensor_id;
}
return next_tensor_id;
}
TensorID ComputationalGraphMap::_register_operation(std::shared_ptr<Tensor> _t, FunctionObject& _node) noexcept {
TensorID my_tensor_id = _t->get_tensor_id();
assert(my_tensor_id <= tensor_id && "OP registry not this large");
op_registry.at(my_tensor_id.get()) = _node;
tensor_registry.at(my_tensor_id.get()) = _t;
std::cout << "Updated Operation: OP[" << my_tensor_id.get() << "]" << std::endl;
return my_tensor_id;
}
}
}
}