-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
166 lines (140 loc) · 4.44 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
from typing import List
from dataclasses import dataclass, field
from omegaconf import OmegaConf
from collections import OrderedDict
from IPython import embed
@dataclass
class BaseConfig:
save_model: bool = True
save_dir: str = 'saves'
tensorboard: bool = False
gpus: str = '0'
# name of the experiment: used as direectory names under 'save_dir'
exp_name: str = 'sample'
verbose:int = 0
seed:int = 2021
# model save interval
ckpt_interval:int = 1
# validation interval
eval_interval:int = 1
@dataclass
class ExpConfig:
batch_size: int = 32
num_epochs: int = 5
learning_rate: float = 1e-5
threshold: float = 0.5
warmup_ratio: float = 0.2
max_grad_norm: int = 1
truncated_loss: bool = True
forget_rate: float = 0.3
num_gradual: int = 3
exponent: float = 2
prune_epoch: int = 1
probability: float = 0.4
n_components: int = 2
max_iter: int = 100
tol: float = 1e-3
reg_covar: float = 1e-6
@dataclass
class DatasetConfig:
datadir: str = './data'
debug: bool = False
max_len:int = 120
phr_sep: bool = False
phr_sep_token: str = '[SEP]'
augment_type: str = 'recon'
augment_ratio: float = 0.0
augment_lambda: float = 1.0
num_truncate: int = 1
num_negatives: int = 1
num_sample: int = 1
@dataclass
class ModelConfig:
bert_model: str = 'monologg/koelectra-base-v3-discriminator' ## monologg/koelectra-base-v3-discriminator, etribert
fc_layers: list = field(default_factory=lambda: [])
dropout: float = 0.2
cls_enhanced: bool = False
pred_num_core: bool = False
core_lambda: float = 0.0
contrastive: bool = False
cont_temp: float = 0.01
cont_lambda: float = 1.0
tree_transformer: bool = False
num_tree_layers: int = 1
num_tree_heads: int = 12
truncated_loss: bool = True
pairwise: bool = True
@dataclass
class PairDatasetConfig:
# 위와 동일
datadir: str = './data'
dataset: str = 'naver'
debug: bool = False
# pair라서 2배 해줌
max_len:int = 120
num_negatives: int = 1
num_sample: int = 1
@dataclass
class PairModelConfig:
bert_model: str = 'monologg/koelectra-base-v3-discriminator' ## monologg/koelectra-base-v3-discriminator, etribert
fc_layers: list = field(default_factory=lambda: [])
dropout: float = 0.2
pooling: bool = False
pairwise: bool = False
def load_config():
base_conf = OmegaConf.structured({'base' : BaseConfig})
dataset_conf = OmegaConf.structured({'dataset' : DatasetConfig})
model_conf = OmegaConf.structured({'model' : ModelConfig})
exp_conf = OmegaConf.structured({'exp' : ExpConfig})
# cli_conf = OmegaConf.from_cli()
conf = OmegaConf.merge(base_conf, dataset_conf, model_conf, exp_conf)
return conf
def load_pair_config():
base_conf = OmegaConf.structured({'base' : BaseConfig})
dataset_conf = OmegaConf.structured({'dataset' : PairDatasetConfig})
model_conf = OmegaConf.structured({'model' : PairModelConfig})
exp_conf = OmegaConf.structured({'exp' : ExpConfig})
conf = OmegaConf.merge(base_conf, dataset_conf, model_conf, exp_conf)
return conf
def load_config_from_path(yaml_config_path):
if not os.path.exists(yaml_config_path):
raise FileNotFoundError(f'Config file not found in {yaml_config_path}')
with open(yaml_config_path, "r") as f:
loaded = OmegaConf.load(f)
return loaded
def ensure_value_type(v):
BOOLEAN = {'false': False, 'False': False,
'true': True, 'True': True}
if isinstance(v, str):
try:
value = eval(v)
if not isinstance(value, (str, int, float, list, tuple)):
value = v
except:
if v in BOOLEAN:
v = BOOLEAN[v]
value = v
else:
value = v
return value
def update_params(conf, params):
# for now, assume 'params' is dictionary
new_params = OrderedDict()
new_params.update(params)
params = new_params
for k, v in params.items():
updated=False
for section in conf.keys():
if k in conf[section]:
conf[section][k] = ensure_value_type(v)
updated = True
break
if not updated:
# raise ValueError
print('Parameter not updated. \'%s\' not exists.' % k)
return conf
if __name__ == '__main__':
import pprint
conf = load_config()
pprint.pprint(conf)