-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexpression_visualizer.cpp
73 lines (56 loc) · 2.16 KB
/
expression_visualizer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/*
* Copyright 2020 Casey Sanchez
*/
#include "expression_visualizer.hpp"
ExpressionVisualizer::ExpressionVisualizer(std::variant<Scalar, Matrix> const &node_variant, std::map<std::string, std::variant<Scalar, Matrix>> const &node_map) : m_node_variant(node_variant), m_node_map(node_map)
{
}
std::string ExpressionVisualizer::Visualize() const
{
std::ostringstream ostringstream;
Visualize(ostringstream);
return ostringstream.str();
}
void ExpressionVisualizer::Visualize(std::ostream &ostream) const
{
Visualize(ostream, m_node_variant, 0);
}
void ExpressionVisualizer::Visualize(std::ostream &ostream, std::variant<Scalar, Matrix> const &node_variant, size_t const &depth) const
{
ostream << std::string(depth * 4, ' ');
if (std::holds_alternative<Matrix>(node_variant)) {
Matrix matrix = std::get<Matrix>(node_variant);
ostream << "[Matrix] ";
ostream << matrix << std::endl;
for (size_t i = 0; i < matrix.Rows(); ++i) {
for (size_t j = 0; j < matrix.Cols(); ++j) {
Visualize(ostream, matrix(i, j), depth + 1);
}
}
}
else if (std::holds_alternative<Scalar>(node_variant)) {
Scalar scalar = std::get<Scalar>(node_variant);
ostream << "[" << scalar->Type() << "] ";
auto node_it = std::find_if(std::cbegin(m_node_map), std::cend(m_node_map),
[&scalar](std::pair<std::string, std::variant<Scalar, Matrix>> const &node_pair) {
if (std::holds_alternative<Scalar>(node_pair.second)) {
return std::get<Scalar>(node_pair.second) == scalar;
}
return false;
});
if (node_it != std::cend(m_node_map)) {
ostream << node_it->first << std::endl;
}
else {
ostream << scalar << std::endl;
}
for (auto const &argument_ptr : scalar->Arguments()) {
Visualize(ostream, argument_ptr, depth + 1);
}
}
}
std::ostream &operator<<(std::ostream &ostream, ExpressionVisualizer const &expression_visualizer)
{
expression_visualizer.Visualize(ostream);
return ostream;
}