Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
twjiang committed Mar 28, 2016
1 parent f6f3c8c commit f8a0199
Show file tree
Hide file tree
Showing 3 changed files with 579 additions and 39 deletions.
12 changes: 9 additions & 3 deletions use_eTransE/test_KB_complete.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ void find_new_triple()
map<int, double> ranked_map;
vector<PAIR> result_score_vec;
ofstream new_KB_file;
new_KB_file.open(("./model/"+model+"/KB_new.txt").c_str());
new_KB_file.open(("./model/"+model+"/KB_new_1.0.txt").c_str());

int rank_count = 0;
int rank_sum = 0;
Expand All @@ -311,7 +311,7 @@ void find_new_triple()
for (int i = 0; i < triple_h.size(); i++)
{
result_score_vec.clear();
ranked_map = link_prediction(triple_h[i], triple_t[i], 0.7, 10);
ranked_map = link_prediction(triple_h[i], triple_t[i], 1.0, 10);
if (ranked_map.size()!=0)
{
new_KB_file << id2entity[triple_h[i]] << "==" << id2entity[triple_t[i]] << "==" << id2relation[triple_r[i]];
Expand All @@ -327,15 +327,21 @@ void find_new_triple()
pos_distance = calc_distance(triple_h[i], triple_r[i], triple_t[i]);
for (int j = 0; j < relation_num; j++)
{
if (j == triple_r[i]){
continue;
}
if (is_good_triple[make_pair(triple_h[i], j)].count(triple_t[i]) > 0){
continue;
}
else
{
neg_distance = calc_distance(triple_h[i], j, triple_t[i]);
//cout << pos_distance << "\t" << neg_distance << endl;
if (neg_distance < pos_distance)
if (neg_distance < pos_distance){
//cout << id2relation[j] << endl;
//cout << pos_distance << "\t" << neg_distance << endl;
before_count++;
}
}
}
rank_count++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ map<int, string> id2entity, id2relation;
int entity_num, relation_num;

vector<vector<double> > entity_vec, relation_vec;
map<pair<int, int>, int> in_KB;
map<pair<int, int>, map<int, int> > is_good_triple;

char buf[100000];
int L1_flag = 0;
Expand Down Expand Up @@ -83,6 +83,29 @@ void load_test_data()
test_file.close();
}

void load_KB_data()
{
ifstream train_file;
train_file.open(("../bigcilin/"+data_size+"/train.txt").c_str());
string h , r, t;
while(!train_file.eof())
{
train_file >> h >> t >> r;

if (entity2id.count(h)==0)
cout << "no entity: " << h << endl;
if (entity2id.count(t)==0)
cout << "no entity: " << t << endl;

if (relation2id.count(r)==0)
cout << "no relation: " << t << endl;

is_good_triple[make_pair(entity2id[h], relation2id[r])][entity2id[t]] = 1;
}

train_file.close();
}

/*************************************************
Function: load_entity_relation_data()
Description: load entity and relation data from file
Expand Down Expand Up @@ -203,6 +226,9 @@ map<int, double> link_prediction(int entity1_id, int entity2_id, double threshol
{
for(int i = 0; i < relation_num; i++)
{
if (is_good_triple[make_pair(entity1_id, i)].count(entity2_id) > 0){
continue;
}
if (ranked_link_map.find(i) != ranked_link_map.end())
continue;
distance = calc_distance(entity1_id, i, entity2_id);
Expand Down Expand Up @@ -231,68 +257,67 @@ Description: predict the ranked_link_list for given entitys
Return: the ranked_entity_map
Others:
*************************************************/
map<int, double> entity_prediction(int entity1_id, int r_id, int num)
/* vector<int> entity_prediction(int entity1_id, int r_id, double threshold, int num)
{
map<int, double> ranked_entity_map;
vector<int> predict_entity_list;
double min = 10000000;
int entity2_id = -1;
int r_id = -1;
double distance;
for(int j = 0; j < num; j++)
{
for(int i = 0; i < entity_num; i++)
for(int i = 0; i < relation_num; i++)
{
if (ranked_entity_map.find(i) != ranked_entity_map.end())
if (is_good_triple[make_pair(entity1_id, i)].count(entity2_id) > 0){
continue;
}
if (ranked_link_map.find(i) != ranked_link_map.end())
continue;
distance = calc_distance(entity1_id, r_id, i);
distance = calc_distance(entity1_id, i, entity2_id);
if(distance < min)
{
min = distance;
entity2_id = i;
r_id = i;
}
}
ranked_entity_map[entity2_id] = min;
if (min >= threshold)
break;
ranked_link_map[r_id] = min;
min = 10000000;
}
return ranked_entity_map;
}
return ranked_link_map;
} */

/*************************************************
Function: find_new_triple()
Description: due to the given KB and it's vector representation find new triple
Function: test_predict_entity()
Description:
Input:
double threshold: the threshold of score for predictation
Output:
Return:
Others:
*************************************************/
void find_new_triple()
void test_predict_entity()
{
map<int, double> ranked_map;
vector<PAIR> result_score_vec;
ofstream new_KB_file;
new_KB_file.open(("./model/"+model+"/KB_new.txt").c_str());
vector<int> predict_entity_list;
ofstream result_file;
result_file.open(("./model/"+model+"/test_predict_entity.result").c_str());

for (int i = 0; i < triple_h.size(); i++)
{
result_score_vec.clear();
ranked_map = link_prediction(triple_h[i], triple_t[i], 0.7, 10);
if (ranked_map.size()!=0)
cout << id2entity[triple_h[i]] << "==" << id2entity[triple_t[i]] << "==" << id2relation[triple_r[i]] << ": " << calc_distance(triple_h[i], triple_r[i], triple_t[i]) << endl;
/* predict_entity_list.clear();
predict_entity_list = entity_prediction(triple_h[i], triple_r[i], 0.7, 10);
if (predict_entity_list.size()!=0)
{
for (map<int, double>::iterator it=ranked_map.begin(); it!=ranked_map.end(); ++it) {
result_score_vec.push_back(make_pair(it->first, it->second));
}
sort(result_score_vec.begin(), result_score_vec.end(), CmpByValue());
new_KB_file << "=========================================" << endl;
new_KB_file << id2entity[triple_h[i]] << "\t" << id2entity[triple_t[i]] << "\t" << id2relation[triple_r[i]] << endl;
for (vector<PAIR>::iterator it=result_score_vec.begin(); it!=result_score_vec.end(); ++it) {
new_KB_file << id2relation[it->first] << "\t" << it->second << endl;
result_file << id2entity[triple_h[i]] << "==" << id2entity[triple_t[i]] << "==" << id2relation[triple_r[i]];
for (vector<int>::iterator it=predict_entity_list.begin(); it!=predict_entity_list.end(); ++it) {
result_file << "##" << id2entity[*it];
}
new_KB_file.flush();
}
} */
}

new_KB_file.close();
result_file.close();
}

/*************************************************
Expand Down Expand Up @@ -366,14 +391,14 @@ int main(int argc, char **argv)
cout << "load entity and relation embedding data ... ..." << endl;
load_entity_relation_vec();
cout << "load ok." << endl;

load_KB_data();
cout << "load test data ... ..." << endl;
load_test_data();
cout << "load ok." << endl;

cout << "finding ... ..." << endl;
cout << "testing for predicting entity ... ..." << endl;

find_new_triple();
test_predict_entity();

return 0;
}
Loading

0 comments on commit f8a0199

Please sign in to comment.