-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcsv_output.py
83 lines (64 loc) · 2.42 KB
/
csv_output.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""A `dowel.logger.LogOutput` for CSV files."""
import csv
import warnings
from dowel import TabularInput
from dowel.simple_outputs import FileOutput
from dowel.utils import colorize
class CsvOutput(FileOutput):
"""CSV file output for logger.
:param file_name: The file this output should log to.
"""
def __init__(self, file_name):
super().__init__(file_name)
self._writer = None
self._fieldnames = None
self._warned_once = set()
self._disable_warnings = False
@property
def types_accepted(self):
"""Accept TabularInput objects only."""
return (TabularInput,)
def record(self, data, prefix=""):
"""Log tabular data to CSV."""
if isinstance(data, TabularInput):
to_csv = data.as_primitive_dict
if not to_csv.keys() and not self._writer:
return
if not self._writer:
self._fieldnames = set(to_csv.keys())
self._writer = csv.DictWriter(
self._log_file,
fieldnames=sorted(list(self._fieldnames)),
extrasaction="ignore",
)
self._writer.writeheader()
if to_csv.keys() != self._fieldnames:
self._warn(
"Inconsistent TabularInput keys detected. "
"CsvOutput keys: {}. "
"TabularInput keys: {}. "
"Did you change key sets after your first "
"logger.log(TabularInput)?".format(
set(self._fieldnames), set(to_csv.keys())
)
)
self._writer.writerow(to_csv)
for k in to_csv.keys():
data.mark(k)
else:
raise ValueError("Unacceptable type.")
def _warn(self, msg):
"""Warns the user using warnings.warn.
The stacklevel parameter needs to be 3 to ensure the call to logger.log
is the one printed.
"""
if not self._disable_warnings and msg not in self._warned_once:
warnings.warn(colorize(msg, "yellow"), CsvOutputWarning, stacklevel=3)
self._warned_once.add(msg)
return msg
def disable_warnings(self):
"""Disable logger warnings for testing."""
self._disable_warnings = True
class CsvOutputWarning(UserWarning):
"""Warning class for CsvOutput."""
pass