Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change the demo mention detection #17

Merged
merged 1 commit into from
Oct 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions scripts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pickle
import sqlite3
import sys

from ccg_nlpy import local_pipeline
Expand Down Expand Up @@ -269,6 +270,18 @@ def produce_surface_cache(db_name, cache_name):
progress_bar(counter, total)


def produce_magnitude_vec_file(db_name, out_file):
conn = sqlite3.connect(db_name)
cursor = conn.cursor()
cursor.execute("SELECT * FROM data")
w = open(out_file, "w")
for row in cursor:
key = row[0]
val = row[1]
val = val[1:-1].replace(",", "")
w.write(key + " " + val + "\n")


if __name__ == '__main__':
if len(sys.argv) < 2:
print("[ERROR]: No command given.")
Expand All @@ -286,3 +299,5 @@ def produce_surface_cache(db_name, cache_name):
produce_cache()
if sys.argv[1] == "SURFACECACHE":
produce_surface_cache("data/surface_cache.db", "/Volumes/Storage/Resources/wikilinks/elmo_cache_correct.db")
if sys.argv[1] == "PRODUCE_VEC":
produce_magnitude_vec_file("/Volumes/External/elmo_cache_correct.db", "/Volumes/External/elmo_cache.vec")
53 changes: 49 additions & 4 deletions server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import signal
import time
import traceback

from ccg_nlpy import local_pipeline
from flask import Flask
Expand All @@ -27,6 +28,7 @@ def __init__(self, sql_db_path, surface_cache_path):
self.mem_cache = ServerCache()
self.surface_cache = SurfaceCache(surface_cache_path)
self.pipeline = local_pipeline.LocalPipeline()
self.pipeline_initialize_helper(['.'])
self.runner = ZoeRunner(allow_tensorflow=True)
self.runner.elmo_processor.load_sqlite_db(sql_db_path, server_mode=True)
self.runner.elmo_processor.rank_candidates_vec()
Expand Down Expand Up @@ -163,8 +165,15 @@ def handle_input(self):
ret["other_possible_type"] = other_possible_types
return json.dumps(ret)

def pipeline_initialize_helper(self, tokens):
doc = self.pipeline.doc([tokens], pretokenized=True)
doc.get_shallow_parse
doc.get_ner_conll
doc.get_ner_ontonotes
doc.get_view("MENTION")

"""
Handles chunker requests for mention filling
Handles requests for mention filling
"""
def handle_mention_input(self):
r = request.get_json()
Expand All @@ -174,9 +183,45 @@ def handle_mention_input(self):
tokens = r["tokens"]
doc = self.pipeline.doc([tokens], pretokenized=True)
shallow_parse_view = doc.get_shallow_parse
for chunk in shallow_parse_view:
if chunk['label'] == 'NP':
ret['mention_spans'].append([chunk['start'], chunk['end']])
ner_conll_view = doc.get_ner_conll
ner_ontonotes_view = doc.get_ner_ontonotes
md_view = doc.get_view("MENTION")
ret_set = set()
ret_list = []
additions_views = []
if ner_ontonotes_view.cons_list is not None:
additions_views.append(ner_ontonotes_view)
if md_view.cons_list is not None:
additions_views.append(md_view)
if shallow_parse_view.cons_list is not None:
additions_views.append(shallow_parse_view)
try:
if ner_conll_view.cons_list is not None:
for ner_conll in ner_conll_view:
for i in range(ner_conll['start'], ner_conll['end']):
ret_set.add(i)
ret_list.append((ner_conll['start'], ner_conll['end']))
for additions_view in additions_views:
for cons in additions_view:
add_to_list = True
if additions_view.view_name != "MENTION":
start = cons['start']
end = cons['end']
else:
start = cons['properties']['EntityHeadStartSpan']
end = cons['properties']['EntityHeadEndSpan']
for i in range(start - 1, end + 1):
if i in ret_set:
add_to_list = False
break
if add_to_list:
for i in range(start, end):
ret_set.add(i)
ret_list.append((start, end))
except Exception as e:
traceback.print_exc()
print(e)
ret['mention_spans'] = ret_list
return json.dumps(ret)

"""
Expand Down