forked from xzenglab/KGNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
139 lines (121 loc) · 6.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
import os
import gc
import time
import numpy as np
from collections import defaultdict
from keras import backend as K
from keras import optimizers
from utils import load_data, pickle_load, format_filename, write_log
from models import KGCN
from config import ModelConfig, PROCESSED_DATA_DIR, ENTITY_VOCAB_TEMPLATE, \
RELATION_VOCAB_TEMPLATE, ADJ_ENTITY_TEMPLATE, ADJ_RELATION_TEMPLATE, LOG_DIR, PERFORMANCE_LOG, \
DRUG_VOCAB_TEMPLATE
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def get_optimizer(op_type, learning_rate):
if op_type == 'sgd':
return optimizers.SGD(learning_rate)
elif op_type == 'rmsprop':
return optimizers.RMSprop(learning_rate)
elif op_type == 'adagrad':
return optimizers.Adagrad(learning_rate)
elif op_type == 'adadelta':
return optimizers.Adadelta(learning_rate)
elif op_type == 'adam':
return optimizers.Adam(learning_rate, clipnorm=5)
else:
raise ValueError('Optimizer Not Understood: {}'.format(op_type))
def train(train_d,dev_d,test_d,kfold,dataset, neighbor_sample_size, embed_dim, n_depth, l2_weight, lr, optimizer_type,
batch_size, aggregator_type, n_epoch, callbacks_to_add=None, overwrite=True):
config = ModelConfig()
config.neighbor_sample_size = neighbor_sample_size
config.embed_dim = embed_dim
config.n_depth = n_depth
config.l2_weight = l2_weight
config.dataset=dataset
config.K_Fold=kfold
config.lr = lr
config.optimizer = get_optimizer(optimizer_type, lr)
config.batch_size = batch_size
config.aggregator_type = aggregator_type
config.n_epoch = n_epoch
config.callbacks_to_add = callbacks_to_add
config.drug_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
DRUG_VOCAB_TEMPLATE,
dataset=dataset)))
config.entity_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
ENTITY_VOCAB_TEMPLATE,
dataset=dataset)))
config.relation_vocab_size = len(pickle_load(format_filename(PROCESSED_DATA_DIR,
RELATION_VOCAB_TEMPLATE,
dataset=dataset)))
config.adj_entity = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_ENTITY_TEMPLATE,
dataset=dataset))
config.adj_relation = np.load(format_filename(PROCESSED_DATA_DIR, ADJ_RELATION_TEMPLATE,
dataset=dataset))
config.exp_name = f'kgcn_{dataset}_neigh_{neighbor_sample_size}_embed_{embed_dim}_depth_' \
f'{n_depth}_agg_{aggregator_type}_optimizer_{optimizer_type}_lr_{lr}_' \
f'batch_size_{batch_size}_epoch_{n_epoch}'
callback_str = '_' + '_'.join(config.callbacks_to_add)
callback_str = callback_str.replace('_modelcheckpoint', '').replace('_earlystopping', '')#去掉了这两种方式使用swa得方式平均
config.exp_name += callback_str
train_log = {'exp_name': config.exp_name, 'batch_size': batch_size, 'optimizer': optimizer_type,
'epoch': n_epoch, 'learning_rate': lr}
print('Logging Info - Experiment: %s' % config.exp_name)
model_save_path = os.path.join(config.checkpoint_dir, '{}.hdf5'.format(config.exp_name))
model = KGCN(config)
train_data=np.array(train_d)
valid_data=np.array(dev_d)
test_data=np.array(test_d)
if not os.path.exists(model_save_path) or overwrite:
start_time = time.time()
model.fit(x_train=[train_data[:, :1], train_data[:, 1:2]], y_train=train_data[:, 2:3],
x_valid=[valid_data[:, :1], valid_data[:, 1:2]], y_valid=valid_data[:, 2:3])
elapsed_time = time.time() - start_time
print('Logging Info - Training time: %s' % time.strftime("%H:%M:%S",
time.gmtime(elapsed_time)))
train_log['train_time'] = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
print('Logging Info - Evaluate over valid data:')
model.load_best_model()
auc, acc, f1,aupr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])
print(f'Logging Info - dev_auc: {auc}, dev_acc: {acc}, dev_f1: {f1}, dev_aupr: {aupr}'
)
train_log['dev_auc'] = auc
train_log['dev_acc'] = acc
train_log['dev_f1'] = f1
train_log['dev_aupr']=aupr
train_log['k_fold']=kfold
train_log['dataset']=dataset
train_log['aggregate_type']=config.aggregator_type
if 'swa' in config.callbacks_to_add:
model.load_swa_model()
print('Logging Info - Evaluate over valid data based on swa model:')
auc, acc, f1,aupr = model.score(x=[valid_data[:, :1], valid_data[:, 1:2]], y=valid_data[:, 2:3])
train_log['swa_dev_auc'] = auc
train_log['swa_dev_acc'] = acc
train_log['swa_dev_f1'] = f1
train_log['swa_dev_aupr']=aupr
print(f'Logging Info - swa_dev_auc: {auc}, swa_dev_acc: {acc}, swa_dev_f1: {f1}, swa_dev_aupr: {aupr}') #修改输出指标
print('Logging Info - Evaluate over test data:')
model.load_best_model()
auc, acc, f1, aupr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])
train_log['test_auc'] = auc
train_log['test_acc'] = acc
train_log['test_f1'] = f1
train_log['test_aupr'] =aupr
print(f'Logging Info - test_auc: {auc}, test_acc: {acc}, test_f1: {f1}, test_aupr: {aupr}')
if 'swa' in config.callbacks_to_add:
model.load_swa_model()
print('Logging Info - Evaluate over test data based on swa model:')
auc, acc, f1,aupr = model.score(x=[test_data[:, :1], test_data[:, 1:2]], y=test_data[:, 2:3])
train_log['swa_test_auc'] = auc
train_log['swa_test_acc'] = acc
train_log['swa_test_f1'] = f1
train_log['swa_test_aupr'] = aupr
print(f'Logging Info - swa_test_auc: {auc}, swa_test_acc: {acc}, swa_test_f1: {f1}, swa_test_aupr: {aupr}')
train_log['timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
write_log(format_filename(LOG_DIR, PERFORMANCE_LOG), log=train_log, mode='a')
del model
gc.collect()
K.clear_session()
return train_log