-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
138 lines (110 loc) · 5 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from optimization import BinarySubsetSelectionOptimizationFunctions, BinarySubsetSelectionSemanticPrimitivesProblem
from page_rank import DictPageRank
from algorithm_utils import BinarySubsetSelectionGeneticAlgorithm
from minimization_utils import AlgorithmMinimizer
from graph_utils import load_graph_dict, get_num_vertices
from postprocessing_utils import PopulationDecoder, load_decoding_dict
from multiprocessing.pool import ThreadPool
from joblib import load
import json, os
import numpy as np
from typing import List
from pymoo.core.problem import starmap_parallelized_eval
def get_sp_gen_unique_ids(cands: List[List[int]]) -> np.array:
"""
Get unique vertexes from generated lists of semantic primitives
:param cands: list of lists of ints, generated semantic primitives lists
:return: np.array, unordered array of unique vertexes
"""
unique_vals = set(sum(cands, []))
cand_ids = np.array(unique_vals)
return cand_ids
def fit_ga(args):
graph_dict = load_graph_dict(
json_graph_path=os.path.join(args.load_dir, "graph.json")
)
n_vals = get_num_vertices(
json_enc_dict_path=os.path.join(args.load_dir, "encoding_dict.json")
)
optim_pagerank = load(
os.path.join(args.load_dir, "pagerank.pickle")
)
sp_gen_lists = json.load(open(args.sp_gen_lists_path, "r"))
sp_gen_unique_ids = get_sp_gen_unique_ids(sp_gen_lists)
optim_functions = BinarySubsetSelectionOptimizationFunctions(
wpagerank=optim_pagerank,
graph_dict=graph_dict,
n_vals=n_vals,
card_mean=int(args.card_mean),
val_prank_fill=args.val_prank_fill,
sq_card_diff=args.card_diff**2
)
pool = ThreadPool(args.n_threads)
problem_params = {
"n_var": n_vals,
"n_obj": optim_functions.n_obj,
"n_constr": optim_functions.n_constr,
"xl": None,
"xu": None,
"optim_functions": optim_functions,
"runner": pool.starmap,
"func_eval": starmap_parallelized_eval
}
proplem = BinarySubsetSelectionSemanticPrimitivesProblem(**problem_params)
algorithm = BinarySubsetSelectionGeneticAlgorithm.get_algorithm(
sp_gen_lists=sp_gen_lists,
sp_gen_unique_ids=sp_gen_unique_ids,
pop_size=args.pop_size,
max_mutate=args.max_mutate,
min_mutate=args.min_mutate
)
minimizer = AlgorithmMinimizer(
algorithm=algorithm,
proplem=proplem,
checkpoint_path=args.chp_path,
n_max_gen=args.n_max_gen
)
minimizer.run_minimization(save_dir=args.save_dir)
decoding_dict = load_decoding_dict(
enc_dict_path=os.path.join(args.load_dir, "encoding_dict.json")
)
final_populations = np.load(os.path.join(args.save_dir, "final_pop", "X.numpy.npy"))
PopulationDecoder.decode_binary_populations(
populations=final_populations,
decoding_dict=decoding_dict,
save_dir=args.save_dir
)
with open(os.path.join(args.save_dir, "args.json"), "w") as f:
json.dump(
vars(args), f
)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='GA Semantic Primitives Optim., Binary Subset Selection')
parser.add_argument('--load_dir', type=str, default="wordnet_graph_StanzaLemm_SSCDefsDrop/",
help='path to dir, which contains graph.json,encoding_dict.json and pagerank.pickle files')
parser.add_argument('--chp_path', type=str, default="",
help='path to checkpoint to continue from')
parser.add_argument('--sp_gen_lists_path', type=str,
default="wordnet_graph_StanzaLemm_SSCDefsDrop/wordnet_graph_StanzaLemm_SSCDefsDrop_1000_candidates_random2.json",
help='path to generated Sem.Prims. json file')
parser.add_argument('--n_threads', type=int, default=5,
help="Num threads to use (multiprocessing)")
parser.add_argument('--val_prank_fill', type=float, default=-1.0,
help="Value to return for pagerank obj. if there is a cycle in graph")
parser.add_argument('--pop_size', type=int, default=100,
help="pop_size for Algotithm")
parser.add_argument('--card_diff', type=int, default=50,
help="Cardinality max difference")
parser.add_argument('--card_mean', type=int, default=2800,
help="Cardinality upper bound")
parser.add_argument('--max_mutate', type=int, default=60,
help="Maximum number of elements to mutate per population")
parser.add_argument('--min_mutate', type=int, default=0,
help="Minimum number of elements to mutate per population")
parser.add_argument('--n_max_gen', type=int, default=30,
help="Maximum number of iterations to fit algorithm")
parser.add_argument('--save_dir', type=str, default="GA_fitted",
help="dir to save results")
args = parser.parse_args()
fit_ga(args)