-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
126 lines (100 loc) · 3.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
import os
from datetime import datetime
import tensorflow as tf
from tensorflow import keras
from pymongo import MongoClient
from tensorflow.python.keras.callbacks import TensorBoard
from models.DBConnector import DBConnector
from models.source_graph.string2tensor import String2Tensor
def main():
parser = get_arg_parser()
args = parser.parse_args()
mongo_uri = args.mongo_uri
debug_mode = args.debug
if debug_mode:
tf.config.run_functions_eagerly(True)
client = MongoClient(mongo_uri)
db = DBConnector(client)
max_nodes_per_batch = args.max_nodes
String2Tensor.configure_default(
alphabet_string="abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'\"/\\|_@#$%^&*~`+-=<>()[]{}",
node_label_max_chars=19
)
run_start_time = datetime.now().strftime("%Y%m%d-%H%M%S")
logs = f"logs/{run_start_time}"
from datasets import MongoDBDataSource, BugFixDataset
from source_graph.binary_classification.source_graph_binary_classification_model \
import SourceGraphBinaryPredictionModel
dataset_source = MongoDBDataSource(
db_connector=db,
dataset_class=BugFixDataset(max_nodes_per_batch),
dataset_split=(0.8, 0.1, 0.1)
)
(train, valid, test) = dataset_source.get_datasets()
model = SourceGraphBinaryPredictionModel()
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(),
metrics=[
keras.metrics.BinaryAccuracy(),
keras.metrics.Precision(),
keras.metrics.Recall(),
])
epochs = 20
checkpoint_path = f"checkpoints/{run_start_time}/"
os.mkdir(checkpoint_path)
checkpoints_path = checkpoint_path + "{epoch:02d}-{val_loss:.2f}.hdf5"
tboard_callback = TensorBoard(log_dir=logs,
histogram_freq=1,
update_freq="batch")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoints_path,
monitor="val_loss",
mode="max",
save_best_only=True
)
early_stop_callback = tf.keras.callbacks.EarlyStopping(
patience=2
)
model.fit(
x=train.prefetch(100).shuffle(100),
validation_data=valid.prefetch(100),
epochs=epochs,
callbacks=[
tboard_callback,
checkpoint_callback,
early_stop_callback
]
)
result = model.evaluate(test)
print(f"Test result: {result}")
def get_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"-m",
"--mongo-uri",
dest="mongo_uri",
type=str,
help="URI of the MongoDB instance which has the graph samples for training.",
)
parser.add_argument(
"-d",
"--debug",
required=False,
default=False,
type=bool,
help="Disable autograph to allow debugging of TensorFlow Python code.",
)
parser.add_argument(
"-n",
"--max-batch-nodes",
dest="max_nodes",
required=False,
default=100_000,
type=int,
help="Maximum amount of graph nodes per training batch.",
)
return parser
if __name__ == "__main__":
main()