-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_mmr.py
50 lines (37 loc) · 1.65 KB
/
run_mmr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
## TODO: merge this file with whi_experiment.py; lots of overlapping code
from distutils.util import strtobool
from turtle import pos
import pandas as pd
import numpy as np
from hydra.utils import instantiate
import hydra
from omegaconf import OmegaConf
import pprint
@hydra.main(version_base=None, config_path="config", config_name="default")
def main(cfg):
from numpy.random import default_rng
rng = default_rng(cfg['seed'])
confounder_seeds = rng.choice(range(1000), size=(cfg.num_iters,))
noise_seeds = rng.choice(range(1000), size=(cfg.num_iters,))
results = []
for iter_ in range(cfg.num_iters):
print(f'Simulation Number {iter_+1}')
#params['confounder_seed'] = confounder_seeds[iter_]
#params['noise_seed'] = noise_seeds[iter_]
'''
Part 1: data simulation piece
'''
print(f'data generation parameters:')
#oracle_params = OmegaConf.to_container(cfg.oracle, resolve=True)
data_cls = instantiate(cfg.data)(confounder_seed=confounder_seeds[iter_], noise_seed=noise_seeds[iter_])
data_cls.generate_dataset()
data_dicts = data_cls.get_datasets()
model_cls = instantiate(cfg.model)(seed = cfg["seed"],oracle_params = data_cls.oracle_params)
results_ = model_cls.run(data_cls, data_dicts, alpha = cfg.alpha, iter_=iter_, falsification_type=cfg.falsification_type)
pprint.pprint(results, sort_dicts=False)
results.append(pd.DataFrame(results_))
R_inter = pd.concat(results)
R_inter.to_csv(f"{cfg.model.censoring_type}.csv")
return results
if __name__ == '__main__':
main()