-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathutils.py
63 lines (47 loc) · 1.84 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import sys
from argparse import Namespace
import git
class Config(Namespace):
def __init__(self, config):
for key, value in config.items():
if isinstance(value, (list, tuple)):
setattr(self, key, [Config(x) if isinstance(x, dict) else x for x in value])
else:
setattr(self, key, Config(value) if isinstance(value, dict) else value)
class Logger(object):
"""Save terminal outputs to log file, and continue to print on the terminal."""
def __init__(self, log_filename):
self.terminal = sys.stdout
self.log = open(log_filename, 'a', buffering=1)
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
# This flush method is needed for python 3 compatibility.
# This handles the flush command by doing nothing.
pass
def close(self):
self.log.flush()
os.fsync(self.log.fileno())
self.log.close()
def get_git_hash():
repository = git.Repo()
git_hash = repository.head.object.hexsha
return git_hash
def format_time(s):
"""Convert time in seconds to time in hours, minutes and seconds."""
s = int(s)
m, s = divmod(s, 60)
h, m = divmod(m, 60)
return f'{h:02d}h{m:02d}m{s:02d}s'
def print_model_spec(model, name=''):
n_parameters = count_n_parameters(model)
n_trainable_parameters = count_n_parameters(model, only_trainable=True)
print(f'Model {name}: {n_parameters:.2f}M parameters of which {n_trainable_parameters:.2f}M are trainable.\n')
def count_n_parameters(model, only_trainable=False):
if only_trainable:
n_parameters = sum([p.numel() for p in model.parameters() if p.requires_grad])
else:
n_parameters = sum([p.numel() for p in model.parameters()])
return n_parameters / 10**6