-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtabular_input.py
152 lines (120 loc) · 4.62 KB
/
tabular_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
"""A `dowel.logger` input for tabular (key-value) data."""
import contextlib
import warnings
import numpy as np
import tabulate
from dowel.utils import colorize
class TabularInput:
"""This class allows the user to create tables for easy display.
TabularInput may be passed to the logger via its log() method.
"""
def __init__(self):
self._dict = {}
self._recorded = set()
self._prefixes = []
self._prefix_str = ""
self._warned_once = set()
self._disable_warnings = False
def __str__(self):
"""Return a string representation of the table for the logger."""
return tabulate.tabulate(
sorted(self.as_primitive_dict.items(), key=lambda x: x[0])
)
def record(self, key, val):
"""Save key/value entries for the table.
:param key: String key corresponding to the value.
:param val: Value that is to be stored in the table.
"""
self._dict[self._prefix_str + str(key)] = val
def mark(self, key):
"""Mark key as recorded."""
self._recorded.add(key)
def mark_str(self):
"""Mark keys in the primitive dict."""
self._recorded |= self.as_primitive_dict.keys()
def mark_all(self):
"""Mark all keys."""
self._recorded |= self._dict.keys()
def record_misc_stat(self, key, values, placement="back"):
"""Record statistics of an array.
:param key: String key corresponding to the values.
:param values: Array of values to be analyzed.
:param placement: Whether to put the prefix in front or in the back.
"""
if placement == "front":
front = ""
back = key
else:
front = key
back = ""
if values:
self.record(front + "Average" + back, np.average(values))
self.record(front + "Std" + back, np.std(values))
self.record(front + "Median" + back, np.median(values))
self.record(front + "Min" + back, np.min(values))
self.record(front + "Max" + back, np.max(values))
else:
self.record(front + "Average" + back, np.nan)
self.record(front + "Std" + back, np.nan)
self.record(front + "Median" + back, np.nan)
self.record(front + "Min" + back, np.nan)
self.record(front + "Max" + back, np.nan)
@contextlib.contextmanager
def prefix(self, prefix):
"""Handle pushing and popping of a tabular prefix.
Can be used in the following way:
with tabular.prefix('your_prefix_'):
# your code
tabular.record(key, val)
:param prefix: The string prefix to be prepended to logs.
"""
self.push_prefix(prefix)
try:
yield
finally:
self.pop_prefix()
def clear(self):
"""Clear the tabular."""
# Warn if something wasn't logged
for k, v in self._dict.items():
if k not in self._recorded:
warning = (
"TabularInput {{{}: type({})}} was not accepted by any "
"output".format(k, type(v).__name__)
)
self._warn(warning)
self._dict.clear()
self._recorded.clear()
def push_prefix(self, prefix):
"""Push prefix to be appended before printed table.
:param prefix: The string prefix to be prepended to logs.
"""
self._prefixes.append(prefix)
self._prefix_str = "".join(self._prefixes)
def pop_prefix(self):
"""Pop prefix that was appended to the printed table."""
del self._prefixes[-1]
self._prefix_str = "".join(self._prefixes)
@property
def as_primitive_dict(self):
"""Return the dictionary, excluding all nonprimitive types."""
return {key: val for key, val in self._dict.items() if np.isscalar(val)}
@property
def as_dict(self):
"""Return a dictionary of the tabular items."""
return self._dict
def _warn(self, msg):
"""Warns the user using warnings.warn.
The stacklevel parameter needs to be 3 to ensure the call to logger.log
is the one printed.
"""
if not self._disable_warnings and msg not in self._warned_once:
warnings.warn(colorize(msg, "yellow"), TabularInputWarning, stacklevel=3)
self._warned_once.add(msg)
return msg
def disable_warnings(self):
"""Disable logger warnings for testing."""
self._disable_warnings = True
class TabularInputWarning(UserWarning):
"""Warning class for the TabularInput."""
pass