-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
68 lines (54 loc) · 1.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Standard libraries
import os
import logging
import wandb
# For downloading pre-trained models
import urllib.request
from urllib.error import HTTPError
from pytorch_lightning.loggers import WandbLogger
# PyTorch Lightning
import pytorch_lightning as pl
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from utils.agent_utils import parse_params
from config.hparams import Parameters
from Agents.BaseTrainer import BaseTrainer
def main():
parameters = Parameters.parse()
# initialize wandb instance
wdb_config = parse_params(parameters)
if parameters.hparams.train:
wandb.init(
# vars(parameters), # FIXME use the full parameters
config = wdb_config,
project = parameters.hparams.wandb_project,
entity = parameters.hparams.wandb_entity,
allow_val_change=True,
job_type="train"
)
wandb_run = WandbLogger(
config=wdb_config,# vars(parameters), # FIXME use the full parameters
project=parameters.hparams.wandb_project,
entity=parameters.hparams.wandb_entity,
allow_val_change=True,
#save_dir=parameters.hparams.save_dir,
)
agent = BaseTrainer(parameters, wandb_run)
if not parameters.data_param.only_create_abstract_embeddings and not parameters.data_param.only_create_keywords :
agent.run()
else:
wandb.init(
# vars(parameters), # FIXME use the full parameters
config = wdb_config,
project = parameters.hparams.wandb_project,
entity = parameters.hparams.wandb_entity,
allow_val_change=True,
job_type="test"
)
agent = BaseTrainer(parameters)
agent.predict()
if __name__ == '__main__':
main()