-
Notifications
You must be signed in to change notification settings - Fork 0
/
single_main.py
executable file
·125 lines (98 loc) · 3.48 KB
/
single_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#!/usr/bin/env python3
import argparse
import pprint
import torch
from config import get_single_config, load_config, parse_config_arg
from two_step_zoo import (
get_single_module, get_single_trainer, get_loaders_from_config,
get_writer, get_evaluator, get_ood_evaluator
)
parser = argparse.ArgumentParser(
description="Single Density Estimation or Generalized Autoencoder Training Module"
)
parser.add_argument("--dataset", type=str,
help="Dataset to train on. Required if load_dir not specified.")
parser.add_argument("--model", type=str,
help="Model to train. Required if load_dir not specified.")
parser.add_argument("--is-gae", action="store_true",
help="Indicates that we are training a generalized autoencoder.")
parser.add_argument("--load-dir", type=str, default="",
help="Directory to load from.")
parser.add_argument("--max-epochs-loaded", type=int,
help="New maximum shared epochs for loaded model.")
parser.add_argument("--load-best-valid-first", action="store_true",
help="Load the best_valid checkpoint first")
parser.add_argument("--config", default=[], action="append",
help="Override config entries. Specify as `key=value`.")
parser.add_argument("--only-test", action="store_true",
help="Only perform a test, no training.")
parser.add_argument("--test-ood", action="store_true",
help="Perform an OOD test.")
args = parser.parse_args()
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.load_dir:
# NOTE: Not updating config values using cmd line arguments (besides max_epochs)
# when loading a run.
cfg = load_config(
args=args
)
else:
cfg = get_single_config(
dataset=args.dataset,
model=args.model,
gae=args.is_gae,
standalone=True
)
cfg = {**cfg, **dict(parse_config_arg(kv) for kv in args.config)}
pprint.sorted = lambda x, key=None: x
pp = pprint.PrettyPrinter(indent=4)
print(10*"-" + "cfg" + 10*"-")
pp.pprint(cfg)
train_loader, valid_loader, test_loader = get_loaders_from_config(cfg)
writer = get_writer(args, cfg=cfg)
module = get_single_module(
cfg,
data_dim=cfg["data_dim"],
data_shape=cfg["data_shape"],
train_dataset_size=cfg["train_dataset_size"]
).to(device)
if args.test_ood or "likelihood_ood_acc" in cfg["test_metrics"]:
evaluator = get_ood_evaluator(
module,
cfg=cfg,
include_low_dim=False,
valid_loader=valid_loader,
test_loader=test_loader,
train_loader=train_loader,
savedir=writer.logdir
)
else:
if cfg["early_stopping_metric"] == "fd" and "fd" not in cfg["valid_metrics"]:
cfg["valid_metrics"].append("fd")
evaluator = get_evaluator(
module,
train_loader=train_loader, valid_loader=valid_loader, test_loader=test_loader,
valid_metrics=cfg["valid_metrics"],
test_metrics=cfg["test_metrics"],
**cfg.get("metric_kwargs", {}),
)
trainer = get_single_trainer(
module=module,
ckpt_prefix="gae" if cfg["gae"] else "de",
writer=writer,
cfg=cfg,
train_loader=train_loader,
valid_loader=valid_loader,
test_loader=test_loader,
evaluator=evaluator,
only_test=args.only_test
)
checkpoint_load_list = ["latest", "best_valid"]
if args.load_best_valid_first: checkpoint_load_list = checkpoint_load_list[::-1]
for ckpt in checkpoint_load_list:
try:
trainer.load_checkpoint(ckpt)
break
except FileNotFoundError:
print(f"Did not find {ckpt} checkpoint")
trainer.train()