Skip to content

Commit

Permalink
Merge pull request #16 from CogComp/dev
Browse files Browse the repository at this point in the history
A few API changes
  • Loading branch information
Slash0BZ authored Oct 19, 2018
2 parents df24ecb + 59471ca commit 4a9a6df
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 4 deletions.
27 changes: 26 additions & 1 deletion frontend/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,12 @@
document.getElementById("row-simple-" + String(index)).innerHTML = String(types);
}

function updateVecResult(result) {
var types = result["type"];
var index = result["index"];
document.getElementById("row-vec-" + String(index)).innerHTML = String(types);
}

function getInferenceMode() {
if (document.getElementById("preset-taxonomy-select").checked) {
return document.getElementById("preset-taxonomy-select-value").value;
Expand Down Expand Up @@ -459,7 +465,8 @@
"<tr>\n" +
"<th scope=\"col\">#</th>\n" +
"<th scope=\"col\">Mention</th>\n" +
"<th scope=\"col\">Cache Type</th>\n" +
"<th scope=\"col\">Cached Types (surface)</th>\n" +
"<th scope=\"col\">Word2Vec based (surface)</th>\n" +
"<th scope=\"col\">Contextual Type</th>\n" +
"<th scope=\"col\">Why?</th>\n" +
"</tr>\n" +
Expand All @@ -469,6 +476,7 @@
table += "<tr><th score='row'>" + String(i) + "</th>"
+ "<td>" + mention_surfaces[i] + "</td>"
+ "<td id='row-simple-" + String(i) + "'>" + loading_sign + "</td>"
+ "<td id='row-vec-" + String(i) + "'>" + loading_sign + "</td>"
+ "<td id='row-computed-" + String(i) + "'>" + loading_sign + "</td>"
+ "<td id='row-button-" + String(i) + "'>" + loading_sign + "</td></tr>"
+ "<tr>\n" +
Expand All @@ -485,6 +493,23 @@
"</div>" +
table;
for (let i = 0; i < mention_surfaces.length; i++) {
let xhr_vec = new XMLHttpRequest();
xhr_vec.open("POST", SERVER_API + "annotate_vec", true);
xhr_vec.setRequestHeader("Content-Type", "application/json");
xhr_vec.onreadystatechange = function () {
if (xhr_vec.readyState === XMLHttpRequest.DONE && xhr_vec.status === 200) {
var json = JSON.parse(xhr_vec.responseText);
updateVecResult(json);
}
};
var data_vec = JSON.stringify({
index: i,
tokens: sentence.trim().split(" "),
mention_starts: [mention_starts[i]],
mention_ends: [mention_ends[i]],
});
xhr_vec.send(data_vec);

let xhr_simple = new XMLHttpRequest();
xhr_simple.open("POST", SERVER_API + "annotate_cache", true);
xhr_simple.setRequestHeader("Content-Type", "application/json");
Expand Down
10 changes: 10 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ def process_sentence(self, sentence, inference_processor=None):
inference_processor.inference(sentence, elmo_candidates, esa_candidates)
return sentence

def process_sentence_vec(self, sentence, inference_processor=None):
esa_candidates = self.esa_processor.get_candidates(sentence)
elmo_candidates = self.elmo_processor.rank_candidates_vec(sentence, esa_candidates)
if len(elmo_candidates) > 0 and elmo_candidates[0][0] == self.elmo_processor.stop_sign:
return -1
if inference_processor is None:
inference_processor = self.inference_processor
inference_processor.inference(sentence, elmo_candidates, esa_candidates)
return sentence

"""
Helper function to evaluate on a dataset that has multiple sentences
@file_name: A string indicating the data file.
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ regex
Flask
flask-cors
cython
ccg_nlpy
ccg_nlpy
gensim
41 changes: 39 additions & 2 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(self, sql_db_path, surface_cache_path):
self.pipeline = local_pipeline.LocalPipeline()
self.runner = ZoeRunner(allow_tensorflow=True)
self.runner.elmo_processor.load_sqlite_db(sql_db_path, server_mode=True)
self.runner.elmo_processor.rank_candidates_vec()
signal.signal(signal.SIGINT, self.grace_end)

@staticmethod
Expand Down Expand Up @@ -196,11 +197,45 @@ def handle_simple_input(self):
for sentence in sentences:
surface = sentence.get_mention_surface()
cached_types = self.surface_cache.query_cache(surface)
types.append(cached_types)
distinct = set()
for t in cached_types:
distinct.add("/" + t.split("/")[1])
types.append(list(distinct))
ret["type"] = types
ret["index"] = r["index"]
return json.dumps(ret)

def handle_word2vec_input(self):
ret = {}
r = request.get_json()
if "tokens" not in r or "mention_starts" not in r or "mention_ends" not in r or "index" not in r:
ret["type"] = [["INVALID_INPUT"]]
return json.dumps(ret)
sentences = []
for i in range(0, len(r["mention_starts"])):
sentence = Sentence(r["tokens"], int(r["mention_starts"][i]), int(r["mention_ends"][i]), "")
sentences.append(sentence)
predicted_types = []
for sentence in sentences:
self.runner.process_sentence_vec(sentence)
predicted_types.append(list(sentence.predicted_types))
ret["type"] = predicted_types
ret["index"] = r["index"]
return json.dumps(ret)

def handle_elmo_input(self):
ret = {}
results = []
r = request.get_json()
if "sentence" not in r:
ret["vectors"] = []
return json.dumps(ret)
elmo_map = self.runner.elmo_processor.process_single_continuous(r["sentence"])
for token in r["sentence"].split():
results.append((token, str(elmo_map[token])))
ret["vectors"] = results
return json.dumps(ret)

"""
Handler to start the Flask app
@localhost: Whether the server lives only in localhost
Expand All @@ -212,6 +247,8 @@ def start(self, localhost=False, port=80):
self.app.add_url_rule("/annotate", "annotate", self.handle_input, methods=['POST'])
self.app.add_url_rule("/annotate_mention", "annotate_mention", self.handle_mention_input, methods=['POST'])
self.app.add_url_rule("/annotate_cache", "annotate_cache", self.handle_simple_input, methods=['POST'])
self.app.add_url_rule("/annotate_vec", "annotate_vec", self.handle_word2vec_input, methods=['POST'])
self.app.add_url_rule("/annotate_elmo", "annotate_elmo", self.handle_elmo_input, methods=['POST'])
if localhost:
self.app.run()
else:
Expand All @@ -226,6 +263,6 @@ def grace_end(self, signum, frame):


if __name__ == '__main__':
server = Server("/Volumes/Storage/Resources/wikilinks/elmo_cache_correct.db", "./data/surface_cache_new.db")
server = Server("/Volumes/External/elmo_cache_correct.db", "./data/surface_cache_new.db")
server.start(localhost=True)

47 changes: 47 additions & 0 deletions zoe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pickle
import sqlite3

import gensim
import numpy as np
import regex
import tensorflow as tf
Expand Down Expand Up @@ -33,6 +34,7 @@ def __init__(self, allow_tensorflow):
self.batcher, self.ids_placeholder, self.ops, self.sess = initialize_sess(self.vocab_file, self.options_file, self.weight_file)
self.db_loaded = False
self.server_mode = False
self.word2vec = None

def load_sqlite_db(self, path, server_mode=False):
self.db_conn = sqlite3.connect(path)
Expand Down Expand Up @@ -240,6 +242,51 @@ def rank_candidates(self, sentence, candidates):
sorted_results = sorted(results.items(), key=lambda kv: kv[1], reverse=True)
return [(x[0], x[1]) for x in sorted_results][:self.RANKED_RETURN_NUM]

def word2vec_helper(self, input):
vec = np.zeros(300)
if self.word2vec is None:
return None
if input in self.word2vec:
return self.word2vec[input]
if input.lower() in self.word2vec:
return self.word2vec[input.lower()]
count = 0.0
for token in input.split("_"):
if token in self.word2vec:
vec += self.word2vec[token]
count += 1.0
elif token.lower() in self.word2vec:
vec += self.word2vec[token.lower()]
count += 1.0
if count == 0.0:
return None
return vec / count

def rank_candidates_vec(self, sentence=None, candidates=None):
data_path = "data/word2vec/GoogleNews-vectors-negative300.bin"
if not os.path.isfile(data_path):
return candidates
if self.word2vec is None:
self.word2vec = gensim.models.KeyedVectors.load_word2vec_format(data_path, binary=True)
if sentence is None:
return None
candidates = [x[0] for x in candidates]
target_vec = self.word2vec_helper(sentence.get_mention_surface())
if target_vec is None:
print(sentence.get_mention_surface() + " not found in word2vec")
return candidates
assert(len(target_vec) == 300)
results = {}
for candidate in candidates:
candidate_vec = self.word2vec_helper(candidate)
if candidate_vec is None:
similarity = 0
else:
similarity = cosine(target_vec, candidate_vec)
results[candidate] = similarity
sorted_results = sorted(results.items(), key=lambda kv: kv[1], reverse=True)
return [(x[0], x[1]) for x in sorted_results][:self.RANKED_RETURN_NUM]

"""
To save the cache maps generated by the processor instance
"""
Expand Down

0 comments on commit 4a9a6df

Please sign in to comment.