Skip to content

Commit

Permalink
add 2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
twjiang committed Apr 14, 2016
1 parent f8a0199 commit ea75d78
Show file tree
Hide file tree
Showing 11 changed files with 2,278 additions and 60 deletions.
64 changes: 64 additions & 0 deletions bigcilin/small_filter1/build_train_test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding:utf-8 -*-

import sys, os
import random

in_file_name = "triple.txt"
train_file_name = "train.txt"
test_file_name = "test.txt"

entity_num_dict = dict()
relation_num_dict = dict()

if len(sys.argv) > 1:
in_file_name = sys.argv[1]

fin = open(in_file_name, 'r')

line_num = 0
line = fin.readline()
while line:
entity1, entity2, relation = line.strip().split('\t')
try:
entity_num_dict[entity1] += 1
except:
entity_num_dict[entity1] = 1
try:
entity_num_dict[entity2] += 1
except:
entity_num_dict[entity2] = 1
try:
relation_num_dict[relation] += 1
except:
relation_num_dict[relation] = 1
line_num += 1
line = fin.readline()
fin.close()

fin = open(in_file_name, 'r')
ftrain = open(train_file_name, 'w')
ftest = open(test_file_name, 'w')

hit_line_num = random.sample(xrange(line_num), 9000)

line_num = 0
line = fin.readline()
while line:
entity1, entity2, relation = line.strip().split('\t')
if line_num in hit_line_num:
if entity_num_dict[entity1] > 1 and entity_num_dict[entity2] > 1 and relation_num_dict[relation] > 1:
ftest.write("%s\t%s\t%s\n" % (entity1, entity2, relation))
ftest.flush()
entity_num_dict[entity1] -= 1
entity_num_dict[entity2] -= 1
relation_num_dict[relation] -= 1
else:
ftrain.write("%s\t%s\t%s\n" % (entity1, entity2, relation))
else:
ftrain.write("%s\t%s\t%s\n" % (entity1, entity2, relation))
line = fin.readline()
line_num += 1
fin.close()

ftrain.close()
ftest.close()
50 changes: 50 additions & 0 deletions bigcilin/small_filter1/get_entitys_relations_triple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding:utf-8 -*-

import sys, os

in_file_name = "input.txt"
entity2id_file_name = "entity2id.txt"
relation2id_file_name = "relation2id.txt"
triple_file_name = "triple.txt"

if len(sys.argv) > 1:
in_file_name = sys.argv[1]

entitys_set = set()
relations_set = set()
entity_relation_tuple = set()

fin = open(in_file_name, 'r')
fe = open(entity2id_file_name, 'w')
fr = open(relation2id_file_name, 'w')
ft = open(triple_file_name, 'w')
line = fin.readline()
while line:
entity1, relation, entity2 = line.split('--->')
entity1 = entity1.strip().replace(' ','').replace(' ','')
relation = relation.strip().replace(' ','').replace(' ','')
entity2 = entity2.strip().replace(' ','').replace(' ','')

if (entity1, relation) not in entity_relation_tuple:
ft.write("%s\t%s\t%s\n" % (entity1, entity2, relation))
entity_relation_tuple.add((entity1, relation))
entitys_set.add(entity1)
relations_set.add(relation)
entitys_set.add(entity2)

line = fin.readline()

index = 0
for entity in entitys_set:
fe.write("%s\t%d\n" % (entity, index))
index += 1

index = 0
for relation in relations_set:
fr.write("%s\t%d\n" % (relation, index))
index += 1

fin.close()
fe.close()
fr.close()
ft.close()
10 changes: 2 additions & 8 deletions use_eTransE/test_KB_complete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ void find_new_triple()
map<int, double> ranked_map;
vector<PAIR> result_score_vec;
ofstream new_KB_file;
new_KB_file.open(("./model/"+model+"/KB_new_1.0.txt").c_str());
new_KB_file.open(("./model/"+model+"/KB_new.txt").c_str());

int rank_count = 0;
int rank_sum = 0;
Expand Down Expand Up @@ -327,21 +327,15 @@ void find_new_triple()
pos_distance = calc_distance(triple_h[i], triple_r[i], triple_t[i]);
for (int j = 0; j < relation_num; j++)
{
if (j == triple_r[i]){
continue;
}
if (is_good_triple[make_pair(triple_h[i], j)].count(triple_t[i]) > 0){
continue;
}
else
{
neg_distance = calc_distance(triple_h[i], j, triple_t[i]);
//cout << pos_distance << "\t" << neg_distance << endl;
if (neg_distance < pos_distance){
//cout << id2relation[j] << endl;
//cout << pos_distance << "\t" << neg_distance << endl;
if (neg_distance < pos_distance)
before_count++;
}
}
}
rank_count++;
Expand Down
Loading

0 comments on commit ea75d78

Please sign in to comment.