diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index 23371b4..622eefd 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -7,13 +7,15 @@ on:
- main
pull_request:
branches: [ main ]
+ workflow_dispatch:
+
jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: [3.7, 3.8]
+ python-version: [3.9]
steps:
- uses: actions/checkout@v3
@@ -40,4 +42,4 @@ jobs:
- name: Test with pytest
run: |
- pytest -W ignore
+ pytest tests
diff --git a/.gitignore b/.gitignore
index 1251b0e..4116e27 100644
--- a/.gitignore
+++ b/.gitignore
@@ -106,6 +106,7 @@ celerybeat.pid
.venv
env/
venv/
+venv3/
ENV/
env.bak/
venv.bak/
@@ -133,3 +134,5 @@ dmypy.json
# Project specific
/data
+data
+000README
diff --git a/README.md b/README.md
index 93c5fa9..17abdb7 100644
--- a/README.md
+++ b/README.md
@@ -5,6 +5,20 @@
[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/radboud-el)](https://pypi.org/project/radboud-el/)
[![PyPI](https://img.shields.io/pypi/v/radboud-el.svg?style=flat)](https://pypi.org/project/radboud-el/)
+---
+
+Example tests:
+
+* Flair: `python3 scripts/efficiency_test.py --process_sentences`
+* Bert: `python3 scripts/efficiency_test.py --use_bert_base_uncased --split_docs_value 500`
+* Server (slower):
+ * `python3 src/REL/server.py --use_bert_base_uncased --split_docs_value 500 --ed-model ed-wiki-2019 data wiki_2019`
+ * `python3 scripts/efficiency_test.py --use_server`
+
+Needs installation of REL documents in directory `doc` (`ed-wiki-2019`, `generic` and `wiki_2019`)
+
+---
+
REL is a modular Entity Linking package that is provided as a Python package as well as a web API. REL has various meanings - one might first notice that it stands for relation, which is a suiting name for the problems that can be tackled with this package. Additionally, in Dutch a 'rel' means a disturbance of the public order, which is exactly what we aim to achieve with the release of this package.
REL utilizes *English* Wikipedia as a knowledge base and can be used for the following tasks:
diff --git a/conftest.py b/conftest.py
new file mode 100644
index 0000000..9653aac
--- /dev/null
+++ b/conftest.py
@@ -0,0 +1,11 @@
+import os
+import pytest
+
+
+def pytest_addoption(parser):
+ parser.addoption("--base_url", action="store", default=os.path.dirname(__file__) + "/src/data/")
+
+
+@pytest.fixture
+def base_url(request):
+ return request.config.getoption("--base_url")
diff --git a/docs/tutorials/custom_models.md b/docs/tutorials/custom_models.md
index 5dca8c2..373fb95 100644
--- a/docs/tutorials/custom_models.md
+++ b/docs/tutorials/custom_models.md
@@ -20,7 +20,7 @@ model, you can only use a local filepath.
NER and ED models that we provide as part of REL can be loaded easily using
aliases. Available models are listed
-[on the REL repository](https://github.com/informagi/REL/tree/master/REL/models/models.json).
+[on the REL repository](https://github.com/informagi/REL/tree/master/src/REL/models/models.json).
All models that need to be downloaded from the web are cached for subsequent
use.
diff --git a/docs/tutorials/index.md b/docs/tutorials/index.md
index 3aba1f0..0317754 100644
--- a/docs/tutorials/index.md
+++ b/docs/tutorials/index.md
@@ -9,9 +9,10 @@ The remainder of the tutorials are optional and for users who wish to e.g. train
1. [How to get started (project folder and structure).](how_to_get_started/)
2. [End-to-End Entity Linking.](e2e_entity_linking/)
-3. [Evaluate on GERBIL.](evaluate_gerbil/)
-4. [Deploy REL for a new Wikipedia corpus](deploy_REL_new_wiki/):
-5. [Reproducing our results](reproducing_our_results/)
-6. [REL server](server/)
-7. [Notes on using custom models](custom_models/)
-7. [Conversational entity linking](conversations/)
+3. [Mention Detection models.](ner/)
+4. [Evaluate on GERBIL.](evaluate_gerbil/)
+5. [Deploy REL for a new Wikipedia corpus](deploy_REL_new_wiki/):
+6. [Reproducing our results](reproducing_our_results/)
+7. [REL server](server/)
+8. [Notes on using custom models](custom_models/)
+9. [Conversational entity linking](conversations/)
diff --git a/docs/tutorials/ner.md b/docs/tutorials/ner.md
new file mode 100644
index 0000000..03b642b
--- /dev/null
+++ b/docs/tutorials/ner.md
@@ -0,0 +1,24 @@
+# Mention Detection models
+
+REL offers different named entity models for mention detection:
+
+- `flair`: named model for English, expects upper and lower case text (default)
+- `bert_base_cased`: basic name model for English, expects upper and lower case text
+- `bert_base_uncased`: basic name model for English, expects lower case text
+- `bert_large_cased`: extensive name model for English, expects upper and lower case text
+- `bert_large_uncased`: extensive name model for English, expects lower case text
+- `bert_multilingual`: multilingual name model, expects upper and lower case text
+
+To change the default Flair model, specify the required model with the `--tagger_ner_name` option, for example when calling the server:
+
+```bash
+python src/REL/server.py --tagger_ner_name bert_base_cased
+```
+
+or specify the model in the `tagger_name` parameter of a mention detection call:
+
+```python
+mentions_dataset, n_mentions = mention_detection.find_mentions(docs, tagger_ner="bert_base_cased")
+```
+
+The available named entity models are specified in the file `src/REL/ner/set_tagger_ner.py`. The file names refer to locations on the website huggingface.co, for example https://huggingface.co/flair/ner-english-fast . The file can be extended with new models, for example for other languages.
diff --git a/requirements.txt b/requirements.txt
index c84bf33..a9025ce 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,4 @@ pydantic
segtok
torch
uvicorn
+scipy<=1.10
diff --git a/scripts/efficiency_test.py b/scripts/efficiency_test.py
deleted file mode 100644
index 0875f92..0000000
--- a/scripts/efficiency_test.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import os
-
-import numpy as np
-import requests
-
-from REL.training_datasets import TrainingEvaluationDatasets
-
-np.random.seed(seed=42)
-
-base_url = os.environ.get("REL_BASE_URL")
-wiki_version = "wiki_2019"
-host = "localhost"
-port = "5555"
-datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"]
-
-# random_docs = np.random.choice(list(datasets.keys()), 50)
-
-server = True
-docs = {}
-for i, doc in enumerate(datasets):
- sentences = []
- for x in datasets[doc]:
- if x["sentence"] not in sentences:
- sentences.append(x["sentence"])
- text = ". ".join([x for x in sentences])
-
- if len(docs) == 50:
- print("length docs is 50.")
- print("====================")
- break
-
- if len(text.split()) > 200:
- docs[doc] = [text, []]
- # Demo script that can be used to query the API.
- if server:
- myjson = {
- "text": text,
- "spans": [
- # {"start": 41, "length": 16}
- ],
- }
- print("----------------------------")
- print(i, "Input API:")
- print(myjson)
-
- print("Output API:")
- print(requests.post(f"http://{host}:{port}", json=myjson).json())
- print("----------------------------")
-
-
-# --------------------- Now total --------------------------------
-# ------------- RUN SEPARATELY TO BALANCE LOAD--------------------
-if not server:
- from time import time
-
- import flair
- import torch
- from flair.models import SequenceTagger
-
- from REL.entity_disambiguation import EntityDisambiguation
- from REL.mention_detection import MentionDetection
-
- base_url = "C:/Users/mickv/desktop/data_back/"
-
- flair.device = torch.device("cuda:0")
-
- mention_detection = MentionDetection(base_url, wiki_version)
-
- # Alternatively use Flair NER tagger.
- tagger_ner = SequenceTagger.load("ner-fast")
-
- start = time()
- mentions_dataset, n_mentions = mention_detection.find_mentions(docs, tagger_ner)
- print("MD took: {}".format(time() - start))
-
- # 3. Load model.
- config = {
- "mode": "eval",
- "model_path": "{}/{}/generated/model".format(base_url, wiki_version),
- }
- model = EntityDisambiguation(base_url, wiki_version, config)
-
- # 4. Entity disambiguation.
- start = time()
- predictions, timing = model.predict(mentions_dataset)
- print("ED took: {}".format(time() - start))
diff --git a/scripts/gerbil_middleware/Makefile b/scripts/gerbil_middleware/Makefile
deleted file mode 100644
index 519a9f7..0000000
--- a/scripts/gerbil_middleware/Makefile
+++ /dev/null
@@ -1,10 +0,0 @@
-default: build dockerize
-
-build:
- mvn clean package -U
-
-dockerize:
- docker build -t git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test .
-
-push:
- docker push git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test
\ No newline at end of file
diff --git a/setup.cfg b/setup.cfg
index 92244af..28a050a 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -43,17 +43,17 @@ package_dir =
= src
include_package_data = True
install_requires =
- anyascii
colorama
- fastapi
- flair>=0.11
konoha
- nltk
- pydantic
+ flair>=0.11
segtok
- spacy
torch
- uvicorn
+ nltk
+ anyascii
+ termcolor
+ syntok
+ spacy
+ scipy<=1.12.0
[options.extras_require]
develop =
@@ -80,3 +80,4 @@ where = src
# [options.entry_points]
# console_scripts =
+
diff --git a/src/REL/crel/s2e_pe/data.py b/src/REL/crel/s2e_pe/data.py
index 7cb5488..bc96136 100644
--- a/src/REL/crel/s2e_pe/data.py
+++ b/src/REL/crel/s2e_pe/data.py
@@ -210,7 +210,7 @@ def pad_batch(self, batch, max_length):
example[0],
use_fast=False,
add_special_tokens=True,
- pad_to_max_length=True,
+ padding='longest',
max_length=max_length,
return_attention_mask=True,
return_tensors="pt",
diff --git a/src/REL/db/base.py b/src/REL/db/base.py
index b4ef715..a753484 100644
--- a/src/REL/db/base.py
+++ b/src/REL/db/base.py
@@ -186,9 +186,15 @@ def lookup_wik(self, w, table_name, column):
"select {} from {} where word = :word".format(column, table_name),
{"word": w},
).fetchone()
- res = (
- e if e is None else json.loads(e[0].decode()) if column == "p_e_m" else e[0]
- )
+ if not e:
+ res = None
+ elif column == "p_e_m":
+ try:
+ res = json.loads(e[0].decode())
+ except AttributeError:
+ res = json.loads("".join(chr(int(x, 2)) for x in e[0].split()))
+ else:
+ res = e[0]
return res
diff --git a/src/REL/mention_detection.py b/src/REL/mention_detection.py
index d4649ba..ac9f676 100644
--- a/src/REL/mention_detection.py
+++ b/src/REL/mention_detection.py
@@ -1,10 +1,21 @@
+import sys
+from termcolor import colored
from flair.data import Sentence
from flair.models import SequenceTagger
-from segtok.segmenter import split_single
+from syntok import segmenter
from REL.mention_detection_base import MentionDetectionBase
+class Entity:
+ def __init__(self, text, start_position, end_position, score, tag):
+ self.text = text
+ self.start_position = start_position
+ self.end_position = end_position
+ self.score = score
+ self.tag = tag
+
+
class MentionDetection(MentionDetectionBase):
"""
Class responsible for mention detection.
@@ -17,7 +28,7 @@ def __init__(self, base_url, wiki_version):
super().__init__(base_url, wiki_version)
- def format_spans(self, dataset):
+ def format_spans(self, dataset, process_sentences=True, split_docs_value=0, tagger_ner=None):
"""
Responsible for formatting given spans into dataset for the ED step. More specifically,
it returns the mention, its left/right context and a set of candidates.
@@ -25,7 +36,7 @@ def format_spans(self, dataset):
:return: Dictionary with mentions per document.
"""
- dataset, _, _ = self.split_text(dataset)
+ dataset, _, _ = self.split_text(dataset, process_sentences, split_docs_value, tagger_ner)
results = {}
total_ment = 0
@@ -40,7 +51,7 @@ def format_spans(self, dataset):
# end_pos = start_pos + length
# ngram = text[start_pos:end_pos]
- mention = self.preprocess_mention(ngram)
+ mention = self.preprocess_mention(ngram) # mention may be different from ngram
left_ctxt, right_ctxt = self.get_ctxt(
start_pos, end_pos, idx_sent, sentence, sentences_doc
)
@@ -62,7 +73,18 @@ def format_spans(self, dataset):
results[doc] = results_doc
return results, total_ment
- def split_text(self, dataset, is_flair=False):
+
+ def split_single(self, text):
+ sentences_as_token_lists = segmenter.analyze(text)
+ sentences = []
+ for paragraph in sentences_as_token_lists:
+ for sentence in paragraph:
+ tokens = [ str(token) for token in sentence ]
+ sentences.append("".join(tokens))
+ return sentences
+
+
+ def split_text(self, dataset, process_sentences, split_docs_value=0, tagger_ner=None):
"""
Splits text into sentences with optional spans (format is a requirement for GERBIL usage).
This behavior is required for the default NER-tagger, which during experiments was experienced
@@ -74,9 +96,21 @@ def split_text(self, dataset, is_flair=False):
res = {}
splits = [0]
processed_sentences = []
+ is_flair = isinstance(tagger_ner, SequenceTagger)
for doc in dataset:
text, spans = dataset[doc]
- sentences = split_single(text)
+ if process_sentences:
+ sentences = self.split_single(text)
+ if split_docs_value > 0:
+ sentences_split = []
+ for sentence in sentences:
+ split_sentences = self.split_text_in_parts(sentence, split_docs_value, tagger_ner)
+ sentences_split.extend(split_sentences)
+ sentences = sentences_split
+ elif split_docs_value > 0:
+ sentences = self.split_text_in_parts(text, split_docs_value, tagger_ner)
+ else:
+ sentences = [ text ]
res[doc] = {}
i = 0
@@ -104,28 +138,109 @@ def split_text(self, dataset, is_flair=False):
splits.append(splits[-1] + i)
return res, processed_sentences, splits
- def find_mentions(self, dataset, tagger=None):
+
+ def combine_entities(self, ner_results, sentence):
+ ner_results_out = []
+ i = 0
+ while i < len(ner_results):
+ last_end = ner_results[i]["end"]
+ ner_results_out.append(dict(ner_results[i]))
+ j = 1
+ while i + j < len(ner_results) and (ner_results[i+j]["start"] == last_end or
+ (ner_results[i+j]["start"] == last_end + 1 and
+ ner_results[i+j]["entity"].startswith("I") and
+ ner_results[i+j]["entity"][2:] == ner_results[i]["entity"][2:])):
+ if ner_results[i+j]["start"] == last_end:
+ ner_results_out[-1]["word"] += ner_results[i+j]["word"].removeprefix("##")
+ else:
+ ner_results_out[-1]["word"] += " " + ner_results[i+j]["word"]
+ ner_results_out[-1]["end"] = ner_results[i+j]["end"]
+ last_end = ner_results[i+j]["end"]
+ j += 1
+ i += j
+ return [ Entity(sentence[entity["start"]: entity["end"]], entity["start"], entity["end"], entity["score"], entity["entity"])
+ for entity in ner_results_out ]
+
+
+ def split_sentence_in_bert_tokens(self, sentence, tagger_ner):
+ tokenizer_results = tagger_ner.tokenizer([sentence], return_offsets_mapping=True) # warns if sentence is too long (>512)
+ input_ids = tokenizer_results["input_ids"][0]
+ token_spans = tokenizer_results["offset_mapping"][0]
+ tokens = [ tagger_ner.tokenizer.decode(token_id) for token_id in input_ids ]
+ return tokens, token_spans
+
+
+ def split_text_in_parts(self, text, split_docs_value, tagger_ner):
+ """
+ Splits text in parts of as most split_docs_value tokens. Texts are split at sentence
+ boundaries. If a sentence is longer than the limit it will be split in parts of
+ maximally split_docs_value tokens.
+ """
+ sentences = self.split_single(text)
+ token_lists = []
+ texts = []
+ is_flair = isinstance(tagger_ner, SequenceTagger)
+ for sentence in sentences:
+ if is_flair:
+ raise Exception("Splitting documents does not work in combination with Flair")
+ sentence_tokens, token_spans = self.split_sentence_in_bert_tokens(sentence, tagger_ner)
+ if len(token_lists) == 0 or (len(token_lists[-1]) + len(sentence_tokens)) > split_docs_value:
+ token_lists.append([])
+ texts.append("")
+ token_lists[-1].extend(sentence_tokens)
+ if texts[-1] == "":
+ texts[-1] = sentence
+ else:
+ texts[-1] += sentence
+ first_split_point = 0
+ while len(token_lists[-1]) > split_docs_value:
+ token_lists.append(list(token_lists[-1]))
+ token_lists[-2] = token_lists[-2][:split_docs_value]
+ token_lists[-1] = token_lists[-1][split_docs_value:]
+ second_split_point = token_spans[-len(token_lists[-1])][0]
+ texts[-1] = sentence[first_split_point:second_split_point]
+ texts.append(sentence[second_split_point:])
+ first_split_point = second_split_point
+ return texts
+
+
+ def prune_word_internal_mentions(self, raw_text, result_doc, total_ment):
+ """ remove entities which are part of a larger word """
+ to_be_deleted = []
+ for i in range(0, len(result_doc)):
+ start_pos = result_doc[i]["pos"]
+ end_pos = result_doc[i]["end_pos"]
+ if ((i > 0 and raw_text[start_pos-1].isalpha()) or
+ (end_pos < len(raw_text) and raw_text[end_pos].isalpha())):
+ to_be_deleted.append(i)
+ total_ment -= len(to_be_deleted)
+ while len(to_be_deleted) > 0:
+ result_doc.pop(to_be_deleted.pop(-1))
+ return result_doc, total_ment
+
+
+ def find_mentions(self, dataset, process_sentences, split_docs_value=0, tagger_ner=None):
"""
Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically,
it returns the mention, its left/right context and a set of candidates.
:return: Dictionary with mentions per document.
"""
- if tagger is None:
+ if tagger_ner is None:
raise Exception(
"No NER tagger is set, but you are attempting to perform Mention Detection.."
)
# Verify if Flair, else ngram or custom.
- is_flair = isinstance(tagger, SequenceTagger)
+ is_flair = isinstance(tagger_ner, SequenceTagger)
dataset_sentences_raw, processed_sentences, splits = self.split_text(
- dataset, is_flair
+ dataset, process_sentences, split_docs_value, tagger_ner
)
results = {}
total_ment = 0
if is_flair:
- tagger.predict(processed_sentences)
+ tagger_ner.predict(processed_sentences) # predict with Flair
for i, doc in enumerate(dataset_sentences_raw):
- contents = dataset_sentences_raw[doc]
raw_text = dataset[doc][0]
+ contents = dataset_sentences_raw[doc]
sentences_doc = [v[0] for v in contents.values()]
sentences = processed_sentences[splits[i] : splits[i + 1]]
result_doc = []
@@ -134,14 +249,15 @@ def find_mentions(self, dataset, tagger=None):
for (idx_sent, (sentence, ground_truth_sentence)), snt in zip(
contents.items(), sentences
):
- # Only include offset if using Flair.
- if is_flair:
- offset = raw_text.find(sentence, cum_sent_length)
-
+ offset = raw_text.find(sentence, cum_sent_length)
+ if offset < 0:
+ print(colored(f"sentence not found in text: cannot happen: {sentence}", "red"), file=sys.stderr)
+ offset = 0
+ entity_counter = 0
for entity in (
- snt.get_spans("ner")
+ snt.get_spans("ner") # predict with Flair
if is_flair
- else tagger.predict(snt, processed_sentences)
+ else self.combine_entities(tagger_ner(snt), sentence) # predict with BERT
):
text, start_pos, end_pos, conf, tag = (
entity.text,
@@ -151,7 +267,7 @@ def find_mentions(self, dataset, tagger=None):
entity.tag,
)
total_ment += 1
- m = self.preprocess_mention(text)
+ m = self.preprocess_mention(text) # m may be different from text
cands = self.get_candidates(m)
if len(cands) == 0:
continue
@@ -161,7 +277,7 @@ def find_mentions(self, dataset, tagger=None):
start_pos, end_pos, idx_sent, sentence, sentences_doc
)
res = {
- "mention": m,
+ "mention": text, # 20230113 was m
"context": (left_ctxt, right_ctxt),
"candidates": cands,
"gold": ["NONE"],
@@ -175,5 +291,6 @@ def find_mentions(self, dataset, tagger=None):
}
result_doc.append(res)
cum_sent_length += len(sentence) + (offset - cum_sent_length)
+ result_doc, total_ment = self.prune_word_internal_mentions(raw_text, result_doc, total_ment)
results[doc] = result_doc
return results, total_ment
diff --git a/src/REL/ner/bert_wrapper.py b/src/REL/ner/bert_wrapper.py
new file mode 100644
index 0000000..ea67823
--- /dev/null
+++ b/src/REL/ner/bert_wrapper.py
@@ -0,0 +1,11 @@
+from transformers import AutoTokenizer, AutoModelForTokenClassification
+from transformers import pipeline
+
+def load_bert_ner(path_or_url):
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(path_or_url)
+ model = AutoModelForTokenClassification.from_pretrained(path_or_url)
+ return pipeline("ner", model=model, tokenizer=tokenizer)
+ except Exception:
+ raise SystemExit(f"cannot load Bert named entity recognition module from {path_or_url}")
+ return
diff --git a/src/REL/ner/flair_wrapper.py b/src/REL/ner/flair_wrapper.py
index 7a3699d..377b6a2 100644
--- a/src/REL/ner/flair_wrapper.py
+++ b/src/REL/ner/flair_wrapper.py
@@ -8,5 +8,5 @@ def load_flair_ner(path_or_url):
try:
return SequenceTagger.load(path_or_url)
except Exception:
- pass
+ raise SystemExit(f"cannot load Flair named entity recognition module from {path_or_url}")
return SequenceTagger.load(fetch_model(path_or_url, cache_root))
diff --git a/src/REL/ner/set_tagger_ner.py b/src/REL/ner/set_tagger_ner.py
new file mode 100644
index 0000000..9b72f12
--- /dev/null
+++ b/src/REL/ner/set_tagger_ner.py
@@ -0,0 +1,24 @@
+import re
+
+from flair.models import SequenceTagger
+from REL.ner.bert_wrapper import load_bert_ner
+
+taggers_ner = {
+ "flair": "flair/ner-english-fast",
+ "bert_base_cased": "dslim/bert-base-NER",
+ "bert_base_uncased": "dslim/bert-base-NER-uncased",
+ "bert_large_cased": "dslim/bert-large-NER",
+ "bert_large_uncased": "Jorgeutd/bert-large-uncased-finetuned-ner",
+ "bert_multilingual": "Davlan/bert-base-multilingual-cased-ner-hrl"
+}
+
+
+def set_tagger_ner(tagger_ner_name):
+ if re.search("^flair", tagger_ner_name):
+ tagger_ner = SequenceTagger.load(taggers_ner[tagger_ner_name])
+ elif re.search("^bert", tagger_ner_name):
+ tagger_ner = load_bert_ner(taggers_ner[tagger_ner_name])
+ else:
+ raise Exception(f"unknown tagger name: {tagger_ner_name}")
+
+ return tagger_ner
diff --git a/src/REL/server.py b/src/REL/server.py
index ce0e022..299c30f 100644
--- a/src/REL/server.py
+++ b/src/REL/server.py
@@ -1,289 +1,213 @@
-from typing import Annotated, List, Literal, Optional, Tuple, Union
+import json
+import numpy
+import os
+from http.server import BaseHTTPRequestHandler
-from fastapi import FastAPI
-from fastapi.responses import JSONResponse
-from pydantic import BaseModel, Field
+from flair.models import SequenceTagger
-from REL.response_handler import ResponseHandler
+from REL.mention_detection import MentionDetection
+from REL.utils import process_results
+from REL.ner.set_tagger_ner import set_tagger_ner
-DEBUG = False
+API_DOC = "API_DOC"
-app = FastAPI()
-Span = Tuple[int, int]
-
-
-class NamedEntityConfig(BaseModel):
- """Config for named entity linking. For more information, see
-
+def make_handler(base_url, wiki_version, ed_model, tagger_ner, process_sentences, split_docs_value=0):
"""
-
- text: str = Field(..., description="Text for entity linking or disambiguation.")
- spans: Optional[List[Span]] = Field(
- None,
- description=(
- """
-For EL: the spans field needs to be set to an empty list.
-
-For ED: spans should consist of a list of tuples, where each tuple refers to
-the start position and length of a mention.
-
-This is used when mentions are already identified and disambiguation is only
-needed. Each tuple represents start position and length of mention (in
-characters); e.g., `[(0, 8), (15,11)]` for mentions 'Nijmegen' and
-'Netherlands' in text 'Nijmegen is in the Netherlands'.
-"""
- ),
- )
- tagger: Literal[
- "ner-fast",
- "ner-fast-with-lowercase",
- ] = Field("ner-fast", description="NER tagger to use.")
-
- class Config:
- schema_extra = {
- "example": {
- "text": "If you're going to try, go all the way - Charles Bukowski.",
- "spans": [(41, 16)],
- "tagger": "ner-fast",
- }
- }
-
- def response(self):
- """Return response for request."""
- handler = handlers[self.tagger]
- response = handler.generate_response(text=self.text, spans=self.spans)
- return response
-
-
-class NamedEntityConceptConfig(BaseModel):
- """Config for named entity linking. Not yet implemented."""
-
- def response(self):
- """Return response for request."""
- response = JSONResponse(
- content={"msg": "Mode `ne_concept` has not been implemeted."},
- status_code=501,
- )
- return response
-
-
-class ConversationTurn(BaseModel):
- """Specify turns in a conversation. Each turn has a `speaker`
- and an `utterance`."""
-
- speaker: Literal["USER", "SYSTEM"] = Field(
- ..., description="Speaker for this turn, must be one of `USER` or `SYSTEM`."
- )
- utterance: str = Field(..., description="Input utterance to be annotated.")
-
- class Config:
- schema_extra = {
- "example": {
- "speaker": "USER",
- "utterance": "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.",
- }
- }
-
-
-class ConversationConfig(BaseModel):
- """Config for conversational entity linking. For more information:
- .
+ Class/function combination that is used to setup an API that can be used for e.g. GERBIL evaluation.
"""
+ class GetHandler(BaseHTTPRequestHandler):
+ def __init__(self, *args, **kwargs):
+ self.ed_model = ed_model
+ self.tagger_ner = tagger_ner
+ self.process_sentences = process_sentences
+ self.split_docs_value = split_docs_value
+
+ self.base_url = base_url
+ self.wiki_version = wiki_version
+
+ self.custom_ner = not isinstance(tagger_ner, SequenceTagger)
+ self.mention_detection = MentionDetection(base_url, wiki_version)
+
+ super().__init__(*args, **kwargs)
+
+ def do_GET(self):
+ self.send_response(200)
+ self.end_headers()
+ self.wfile.write(
+ bytes(
+ json.dumps(
+ {
+ "schemaVersion": 1,
+ "label": "status",
+ "message": "up",
+ "color": "green",
+ }
+ ),
+ "utf-8",
+ )
+ )
+ return
+
+ def do_HEAD(self):
+ # send bad request response code
+ self.send_response(400)
+ self.end_headers()
+ self.wfile.write(bytes(json.dumps([]), "utf-8"))
+ return
+
+ def solve_floats(self, data):
+ data_new = []
+ for data_set in data:
+ data_set_new_list = []
+ for data_el in data_set:
+ if isinstance(data_el, numpy.float32):
+ data_el = float(data_el)
+ data_set_new_list.append(data_el)
+ data_new.append(data_set_new_list)
+ return data_new
+
+ def do_POST(self):
+ """
+ Returns response.
- text: List[ConversationTurn] = Field(
- ..., description="Conversation as list of turns between two speakers."
- )
- tagger: Literal["default",] = Field("default", description="NER tagger to use.")
-
- class Config:
- schema_extra = {
- "example": {
- "text": (
- {
- "speaker": "USER",
- "utterance": "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.",
- },
- {
- "speaker": "SYSTEM",
- "utterance": "Some people are allergic to histamine in tomatoes.",
- },
- {
- "speaker": "USER",
- "utterance": "Talking of food, can you recommend me a restaurant in my city for our anniversary?",
- },
- ),
- "tagger": "default",
- }
- }
-
- def response(self):
- """Return response for request."""
- text = self.dict()["text"]
- conv_handler = conv_handlers[self.tagger]
- response = conv_handler.annotate(text)
- return response
-
-
-class TurnAnnotation(BaseModel):
- __root__: List[Union[int, str]] = Field(
- ...,
- min_items=4,
- max_items=4,
- description="""
-The 4 values of the annotation represent the start index of the word,
-length of the word, the annotated word, and the prediction.
-""",
- )
-
- class Config:
- schema_extra = {"example": [82, 6, "London", "London"]}
-
-
-class SystemResponse(ConversationTurn):
- """Return input when the speaker equals 'SYSTEM'."""
-
- speaker: str = "SYSTEM"
-
- class Config:
- schema_extra = {
- "example": {
- "speaker": "SYSTEM",
- "utterance": "Some people are allergic to histamine in tomatoes.",
- },
- }
-
-
-class UserResponse(ConversationTurn):
- """Return annotations when the speaker equals 'USER'."""
-
- speaker: str = "USER"
- annotations: List[TurnAnnotation] = Field(..., description="List of annotations.")
-
- class Config:
- schema_extra = {
- "example": {
- "speaker": "USER",
- "utterance": "I am allergic to tomatoes but we have a lot of famous Italian restaurants here in London.",
- "annotations": [
- [17, 8, "tomatoes", "Tomato"],
- [54, 19, "Italian restaurants", "Italian_cuisine"],
- [82, 6, "London", "London"],
- ],
- },
- }
-
-
-TurnResponse = Union[UserResponse, SystemResponse]
-
-
-class NEAnnotation(BaseModel):
- """Annotation for named entity linking."""
-
- __root__: List[Union[int, str, float]] = Field(
- ...,
- min_items=7,
- max_items=7,
- description="""
-The 7 values of the annotation represent the
-start index, end index, the annotated word, prediction, ED confidence, MD confidence, and tag.
-""",
- )
-
- class Config:
- schema_extra = {
- "example": [41, 16, "Charles Bukowski", "Charles_Bukowski", 0, 0, "NULL"]
- }
-
-
-class StatusResponse(BaseModel):
- schemaVersion: int
- label: str
- message: str
- color: str
-
-
-@app.get("/", response_model=StatusResponse)
-def server_status():
- """Returns server status."""
- return {
- "schemaVersion": 1,
- "label": "status",
- "message": "up",
- "color": "green",
- }
+ :return:
+ """
+ try:
+ content_length = int(self.headers["Content-Length"])
+ post_data = self.rfile.read(content_length)
+ self.send_response(200)
+ self.end_headers()
+
+ text, spans = self.read_json(post_data)
+ response = self.generate_response(text, spans)
+
+ self.wfile.write(bytes(json.dumps(self.solve_floats(response)), "utf-8"))
+ except Exception as e:
+ print(f"Encountered exception: {repr(e)}")
+ self.send_response(400)
+ self.end_headers()
+ self.wfile.write(bytes(json.dumps([]), "utf-8"))
+ return
+
+ def read_json(self, post_data):
+ """
+ Reads input JSON message.
+ :return: document text and spans.
+ """
-@app.post("/", response_model=List[NEAnnotation])
-@app.post("/ne", response_model=List[NEAnnotation])
-def named_entity_linking(config: NamedEntityConfig):
- """Submit your text here for entity disambiguation or linking.
+ data = json.loads(post_data.decode("utf-8"))
+ text = data["text"]
+ text = text.replace("&", "&")
+
+ # GERBIL sends dictionary, users send list of lists.
+ if "spans" in data:
+ try:
+ spans = [list(d.values()) for d in data["spans"]]
+ except Exception:
+ spans = data["spans"]
+ pass
+ else:
+ spans = []
+
+ return text, spans
+
+ def convert_bert_result(self, result):
+ new_result = {}
+ for doc_key in result:
+ new_result[doc_key] = []
+ for mention_data in result[doc_key]:
+ new_result[doc_key].append(list(mention_data))
+ new_result[doc_key][-1][2], new_result[doc_key][-1][3] =\
+ new_result[doc_key][-1][3], new_result[doc_key][-1][2]
+ new_result[doc_key][-1] = tuple(new_result[doc_key][-1])
+ return new_result
+
+ def generate_response(self, text, spans):
+ """
+ Generates response for API. Can be either ED only or EL, meaning end-to-end.
- The REL annotation mode can be selected by changing the path.
- use `/` or `/ne/` for annotating regular text with named
- entities (default), `/ne_concept/` for regular text with both concepts and
- named entities, and `/conv/` for conversations with both concepts and
- named entities.
- """
- if DEBUG:
- return []
- return config.response()
+ :return: list of tuples for each entity found.
+ """
+ if len(text) == 0:
+ return []
+
+ if len(spans) > 0:
+ # ED.
+ processed = {API_DOC: [text, spans]}
+ mentions_dataset, total_ment = self.mention_detection.format_spans(
+ processed
+ )
+ else:
+ # EL
+ processed = {API_DOC: [text, spans]}
+ mentions_dataset, total_ment = self.mention_detection.find_mentions(
+ processed, self.process_sentences, self.split_docs_value, self.tagger_ner
+ )
+
+ # Disambiguation
+ predictions, timing = self.ed_model.predict(mentions_dataset)
+
+ # Process result.
+ result = process_results(
+ mentions_dataset,
+ predictions,
+ processed,
+ include_offset=False if ((len(spans) > 0) or self.custom_ner) else True,
+ )
+ # result = self.convert_bert_result(result)
-@app.post("/conv", response_model=List[TurnResponse])
-def conversational_entity_linking(config: ConversationConfig):
- """Submit your text here for conversational entity linking."""
- if DEBUG:
- return []
- return config.response()
+ # Singular document.
+ if len(result) > 0:
+ return [*result.values()][0]
+ return []
-@app.post("/ne_concept", response_model=List[NEAnnotation])
-def conceptual_named_entity_linking(config: NamedEntityConceptConfig):
- """Submit your text here for conceptual entity disambiguation or linking."""
- if DEBUG:
- return []
- return config.response()
+ return GetHandler
if __name__ == "__main__":
import argparse
+ from http.server import HTTPServer
- import uvicorn
+ from REL.entity_disambiguation import EntityDisambiguation
+ from REL.ner.flair_wrapper import load_flair_ner
+ from REL.ner.bert_wrapper import load_bert_ner
p = argparse.ArgumentParser()
- p.add_argument("base_url")
- p.add_argument("wiki_version")
+ p.add_argument("--base_url", default=os.path.abspath(os.path.dirname(__file__) + "/../data/"))
+ p.add_argument("--wiki_version", default="wiki_2019")
p.add_argument("--ed-model", default="ed-wiki-2019")
- p.add_argument("--ner-model", default="ner-fast", nargs="+")
+ p.add_argument("--ner-model", default="ner-fast")
p.add_argument("--bind", "-b", metavar="ADDRESS", default="0.0.0.0")
p.add_argument("--port", "-p", default=5555, type=int)
- args = p.parse_args()
+ p.add_argument("--tagger_ner_name", default="flair", help = "mention detection tagger")
+ p.add_argument("--process_sentences", help = "process sentences rather than documents", action="store_true")
+ p.add_argument("--split_docs_value", action="store", type=int, default=0, help = "threshold number of tokens to split document")
- if not DEBUG:
- from REL.crel.conv_el import ConvEL
- from REL.entity_disambiguation import EntityDisambiguation
- from REL.ner import load_flair_ner
+ args = p.parse_args()
- ed_model = EntityDisambiguation(
- args.base_url,
- args.wiki_version,
- {"mode": "eval", "model_path": args.ed_model},
- )
+ tagger_ner_name = args.tagger_ner_name
+ tagger_ner = set_tagger_ner(tagger_ner_name)
+ split_docs_value = args.split_docs_value
- handlers = {}
+ process_sentences = args.process_sentences
- for ner_model_name in args.ner_model:
- print("Loading NER model:", ner_model_name)
- ner_model = load_flair_ner(ner_model_name)
- handler = ResponseHandler(
- args.base_url, args.wiki_version, ed_model, ner_model
- )
- handlers[ner_model_name] = handler
+ ed_model = EntityDisambiguation(
+ args.base_url, args.wiki_version, {"mode": "eval", "model_path": args.ed_model}
+ )
+ server_address = (args.bind, args.port)
+ server = HTTPServer(
+ server_address,
+ make_handler(args.base_url, args.wiki_version, ed_model, tagger_ner, process_sentences, split_docs_value)
+ )
- conv_handlers = {
- "default": ConvEL(args.base_url, args.wiki_version, ed_model=ed_model)
- }
+ try:
+ print("Ready for listening.")
+ server.serve_forever()
+ except KeyboardInterrupt:
+ exit(0)
- uvicorn.run(app, port=args.port, host=args.bind)
diff --git a/scripts/WikiExtractor.py b/src/scripts/WikiExtractor.py
similarity index 100%
rename from scripts/WikiExtractor.py
rename to src/scripts/WikiExtractor.py
diff --git a/scripts/__init__.py b/src/scripts/__init__.py
similarity index 100%
rename from scripts/__init__.py
rename to src/scripts/__init__.py
diff --git a/scripts/code_tutorials/batch_EL.py b/src/scripts/code_tutorials/batch_EL.py
similarity index 100%
rename from scripts/code_tutorials/batch_EL.py
rename to src/scripts/code_tutorials/batch_EL.py
diff --git a/scripts/code_tutorials/example_custom_MD.py b/src/scripts/code_tutorials/example_custom_MD.py
similarity index 100%
rename from scripts/code_tutorials/example_custom_MD.py
rename to src/scripts/code_tutorials/example_custom_MD.py
diff --git a/scripts/code_tutorials/generate_p_e_m.py b/src/scripts/code_tutorials/generate_p_e_m.py
similarity index 100%
rename from scripts/code_tutorials/generate_p_e_m.py
rename to src/scripts/code_tutorials/generate_p_e_m.py
diff --git a/scripts/code_tutorials/generate_train_val.py b/src/scripts/code_tutorials/generate_train_val.py
similarity index 100%
rename from scripts/code_tutorials/generate_train_val.py
rename to src/scripts/code_tutorials/generate_train_val.py
diff --git a/scripts/code_tutorials/predict_EL.py b/src/scripts/code_tutorials/predict_EL.py
similarity index 100%
rename from scripts/code_tutorials/predict_EL.py
rename to src/scripts/code_tutorials/predict_EL.py
diff --git a/scripts/code_tutorials/test_API.py b/src/scripts/code_tutorials/test_API.py
similarity index 100%
rename from scripts/code_tutorials/test_API.py
rename to src/scripts/code_tutorials/test_API.py
diff --git a/scripts/code_tutorials/train_LR.py b/src/scripts/code_tutorials/train_LR.py
similarity index 100%
rename from scripts/code_tutorials/train_LR.py
rename to src/scripts/code_tutorials/train_LR.py
diff --git a/scripts/code_tutorials/train_eval_ED.py b/src/scripts/code_tutorials/train_eval_ED.py
similarity index 100%
rename from scripts/code_tutorials/train_eval_ED.py
rename to src/scripts/code_tutorials/train_eval_ED.py
diff --git a/scripts/comparison_BLINK/run_server.py b/src/scripts/comparison_BLINK/run_server.py
similarity index 100%
rename from scripts/comparison_BLINK/run_server.py
rename to src/scripts/comparison_BLINK/run_server.py
diff --git a/scripts/comparison_BLINK/test.py b/src/scripts/comparison_BLINK/test.py
similarity index 100%
rename from scripts/comparison_BLINK/test.py
rename to src/scripts/comparison_BLINK/test.py
diff --git a/scripts/download_data.sh b/src/scripts/download_data.sh
similarity index 100%
rename from scripts/download_data.sh
rename to src/scripts/download_data.sh
diff --git a/scripts/efficiency_results.py b/src/scripts/efficiency_results.py
similarity index 100%
rename from scripts/efficiency_results.py
rename to src/scripts/efficiency_results.py
diff --git a/src/scripts/efficiency_test.py b/src/scripts/efficiency_test.py
new file mode 100644
index 0000000..a3ce630
--- /dev/null
+++ b/src/scripts/efficiency_test.py
@@ -0,0 +1,124 @@
+import argparse
+from scripts.evaluate_predictions import evaluate
+import json
+import numpy as np
+import os
+import requests
+
+from REL.ner.set_tagger_ner import set_tagger_ner
+from REL.training_datasets import TrainingEvaluationDatasets
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--max_docs", help = "number of documents", default="50")
+parser.add_argument("--process_sentences", help = "process sentences rather than documents", action="store_true")
+parser.add_argument("--split_docs_value", help = "threshold number of tokens to split document", default="0")
+parser.add_argument("--tagger_ner_name", help = "mention detection tagger", default="flair")
+parser.add_argument("--use_server", help = "use server", action="store_true")
+parser.add_argument("--wiki_version", help = "Wiki version", default="wiki_2019")
+args = parser.parse_args()
+
+np.random.seed(seed=42)
+
+base_url = os.path.abspath(os.path.dirname(__file__) + "/../data/")
+process_sentences = args.process_sentences
+
+split_docs_value = int(args.split_docs_value)
+max_docs = int(args.max_docs)
+wiki_version = args.wiki_version
+
+datasets = TrainingEvaluationDatasets(base_url, wiki_version).load()["aida_testB"]
+
+use_server = args.use_server
+
+docs = {}
+all_results = {}
+for i, doc in enumerate(datasets):
+ sentences = []
+ for x in datasets[doc]:
+ if x["sentence"] not in sentences:
+ sentences.append(x["sentence"])
+ text = ". ".join([x for x in sentences])
+ if len(text.split()) > 200:
+ docs[doc] = [text, []]
+ if len(docs) >= max_docs:
+ break
+
+if use_server:
+ for i, doc in enumerate(datasets):
+ text = docs[doc][0]
+ if len(text.split()) > 200:
+ results_list = []
+ # Demo script that can be used to query the API.
+ print(f"max_docs={max_docs} use_server={use_server}")
+ myjson = {
+ "text": text,
+ "spans": [
+ # {"start": 41, "length": 16}
+ ],
+ }
+ print("----------------------------")
+ print(i, "Input API:")
+ print(myjson)
+
+ print("Output API:")
+ results = requests.post("http://0.0.0.0:5555", json=myjson)
+ print(results.json())
+ print("----------------------------")
+ try:
+ results_list = [{"mention": result[2], "prediction": result[3]} for result in results.json()]
+ except json.decoder.JSONDecodeError:
+ print("The analysis results are not in json format:", str(results))
+
+ if results_list:
+ all_results[doc] = results_list
+
+ if len(docs) >= max_docs:
+ print(f"length docs is {len(docs)}.")
+ print("====================")
+ break
+
+ if all_results:
+ evaluate(all_results, base_url)
+
+
+# --------------------- Now total --------------------------------
+# ------------- RUN SEPARATELY TO BALANCE LOAD--------------------
+if not use_server:
+ from time import time
+
+ import flair
+ import torch
+ from flair.models import SequenceTagger
+
+ from REL.entity_disambiguation import EntityDisambiguation
+ from REL.mention_detection import MentionDetection
+
+ from REL.ner.bert_wrapper import load_bert_ner
+
+ flair.device = torch.device("cpu")
+ tagger_ner_name = args.tagger_ner_name
+ tagger_ner = set_tagger_ner(tagger_ner_name)
+
+ print(f"max_docs={max_docs} tagger_ner_name={tagger_ner_name} wiki_version={wiki_version} process_sentences={process_sentences} split_docs_value={split_docs_value}")
+
+ mention_detection = MentionDetection(base_url, wiki_version)
+
+ start = time()
+ mentions_dataset, n_mentions = mention_detection.find_mentions(docs, process_sentences=process_sentences, split_docs_value=split_docs_value, tagger_ner=tagger_ner)
+
+ print("MD took: {} seconds".format(round(time() - start, 2)))
+
+ # 3. Load ED model.
+ config = {
+ "mode": "eval",
+ "model_path": "{}/{}/generated/model".format(base_url, wiki_version),
+ }
+ ed_model = EntityDisambiguation(base_url, wiki_version, config)
+
+ # 4. Entity disambiguation.
+ start = time()
+ predictions, timing = ed_model.predict(mentions_dataset)
+ print(f"ED took: {time() - start:.2f} seconds")
+
+
+ evaluate(predictions, base_url)
diff --git a/src/scripts/evaluate_predictions.py b/src/scripts/evaluate_predictions.py
new file mode 100644
index 0000000..1276b69
--- /dev/null
+++ b/src/scripts/evaluate_predictions.py
@@ -0,0 +1,145 @@
+import os
+
+
+UNUSED = -1
+
+
+def get_gold_data(doc, base_url):
+ GOLD_DATA_FILE = os.path.join(base_url, "generic/test_datasets/AIDA/AIDA-YAGO2-dataset.tsv")
+ entities = []
+
+ in_file = open(GOLD_DATA_FILE, "r")
+ for line in in_file:
+ if line.startswith(f"-DOCSTART- ({doc} "):
+ break
+ for line in in_file:
+ if line.startswith(f"-DOCSTART- "):
+ break
+ fields = line.strip().split("\t")
+ if len(fields) > 3:
+ if fields[1] == "B":
+ entities.append([fields[2], fields[3]])
+ in_file.close()
+ return entities
+
+
+def md_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i):
+ return gold_entities[gold_i][0].lower() == predicted_entities[predicted_i][0].lower()
+
+
+def el_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i):
+ return(gold_entities[gold_i][0].lower() == predicted_entities[predicted_i][0].lower() and
+ gold_entities[gold_i][1].lower() == predicted_entities[predicted_i][1].lower())
+
+
+def find_correct_els(gold_entities, predicted_entities, gold_links, predicted_links):
+ for gold_i in range(0, len(gold_entities)):
+ if gold_links[gold_i] == UNUSED:
+ for predicted_i in range(0, len(predicted_entities)):
+ if (predicted_links[predicted_i] == UNUSED and
+ el_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i)):
+ gold_links[gold_i] = predicted_i
+ predicted_links[predicted_i] = gold_i
+ return gold_links, predicted_links
+
+
+def find_correct_mds(gold_entities, predicted_entities, gold_links, predicted_links):
+ for gold_i in range(0, len(gold_entities)):
+ if gold_links[gold_i] == UNUSED:
+ for predicted_i in range(0, len(predicted_entities)):
+ if (predicted_links[predicted_i] == UNUSED and
+ md_match(gold_entities, predicted_entities, predicted_links, gold_i, predicted_i)):
+ gold_links[gold_i] = predicted_i
+ predicted_links[predicted_i] = gold_i
+ return gold_links, predicted_links
+
+
+
+def compare_entities(gold_entities, predicted_entities):
+ gold_links = len(gold_entities) * [UNUSED]
+ predicted_links = len(predicted_entities) * [UNUSED]
+ gold_links, predicted_links = find_correct_els(gold_entities, predicted_entities, gold_links, predicted_links)
+ gold_links, predicted_links = find_correct_mds(gold_entities, predicted_entities, gold_links, predicted_links)
+ return gold_links, predicted_links
+
+
+def count_entities(gold_entities, predicted_entities, gold_links, predicted_links):
+ """ returns: - correct: number of entities correctly identified and correctly linked
+ - wrong_md: number of entities identified but wrong
+ - wrong_el: number of entities correctly identified but incorrectly linked
+ - missed: number of gold standard entities not found and not linked
+ """
+ correct = 0
+ wrong_md = 0
+ wrong_el = 0
+ missed = 0
+ for predicted_i in range(0, len(predicted_links)):
+ if predicted_links[predicted_i] == UNUSED:
+ wrong_md += 1
+ elif predicted_entities[predicted_i][1] == gold_entities[predicted_links[predicted_i]][1]:
+ # assumption: predicted_entities[predicted_i][0] == gold_entities[predicted_links[predicted_i]][0]
+ correct += 1
+ else:
+ wrong_el += 1
+ for gold_i in range(0, len(gold_links)):
+ if gold_links[gold_i] == UNUSED:
+ missed += 1
+ return correct, wrong_md, wrong_el, missed
+
+
+def compare_and_count_entities(gold_entities, predicted_entities):
+ gold_links, predicted_links = compare_entities(gold_entities, predicted_entities)
+ return count_entities(gold_entities, predicted_entities, gold_links, predicted_links)
+
+
+def compute_md_scores(correct_all, wrong_md_all, wrong_el_all, missed_all):
+ correct_md_all = correct_all + wrong_el_all
+ if correct_md_all > 0:
+ precision_md = correct_md_all / (correct_md_all + wrong_md_all)
+ recall_md = correct_md_all / (correct_md_all + missed_all)
+ f1_md = 2 * precision_md * recall_md / ( precision_md + recall_md )
+ else:
+ precision_md = 0
+ recall_md = 0
+ f1_md = 0
+ return precision_md, recall_md, f1_md
+
+
+def compute_el_scores(correct_all, wrong_md_all, wrong_el_all, missed_all):
+ """ reported el_scores are combined md plus el scores """
+ if correct_all > 0:
+ precision_el = correct_all / (correct_all + wrong_md_all + wrong_el_all)
+ recall_el = correct_all / (correct_all + wrong_el_all + missed_all)
+ f1_el = 2 * precision_el * recall_el / ( precision_el + recall_el )
+ else:
+ precision_el = 0.0
+ recall_el = 0
+ f1_el = 0
+ return precision_el, recall_el, f1_el
+
+
+def print_scores(correct_all, wrong_md_all, wrong_el_all, missed_all):
+ precision_md, recall_md, f1_md = compute_md_scores(correct_all, wrong_md_all, wrong_el_all, missed_all)
+ precision_el, recall_el, f1_el = compute_el_scores(correct_all, wrong_md_all, wrong_el_all, missed_all)
+ print("Results: PMD RMD FMD PEL REL FEL: ", end="")
+ print(f"{precision_md:0.1%} {recall_md:0.1%} {f1_md:0.1%} | ",end="")
+ print(f"{precision_el:0.1%} {recall_el:0.1%} {f1_el:0.1%}")
+ return precision_md, recall_md, f1_md, precision_el, recall_el, f1_el
+
+
+def evaluate(predictions, base_url):
+ correct_all = 0
+ wrong_md_all = 0
+ wrong_el_all = 0
+ missed_all = 0
+ for doc in predictions:
+ gold_entities = get_gold_data(doc, base_url)
+ predicted_entities = []
+ for mention in predictions[doc]:
+ predicted_entities.append([mention["mention"], mention["prediction"]])
+ correct, wrong_md, wrong_el, missed = compare_and_count_entities(gold_entities, predicted_entities)
+ correct_all += correct
+ wrong_md_all += wrong_md
+ wrong_el_all += wrong_el
+ missed_all += missed
+ print_scores(correct_all, wrong_md_all, wrong_el_all, missed_all)
diff --git a/scripts/gerbil_middleware/.gitignore b/src/scripts/gerbil_middleware/.gitignore
similarity index 100%
rename from scripts/gerbil_middleware/.gitignore
rename to src/scripts/gerbil_middleware/.gitignore
diff --git a/scripts/gerbil_middleware/Dockerfile b/src/scripts/gerbil_middleware/Dockerfile
similarity index 100%
rename from scripts/gerbil_middleware/Dockerfile
rename to src/scripts/gerbil_middleware/Dockerfile
diff --git a/scripts/gerbil_middleware/LICENSE b/src/scripts/gerbil_middleware/LICENSE
similarity index 100%
rename from scripts/gerbil_middleware/LICENSE
rename to src/scripts/gerbil_middleware/LICENSE
diff --git a/src/scripts/gerbil_middleware/Makefile b/src/scripts/gerbil_middleware/Makefile
new file mode 100644
index 0000000..e899701
--- /dev/null
+++ b/src/scripts/gerbil_middleware/Makefile
@@ -0,0 +1,10 @@
+default: build dockerize
+
+build:
+ mvn clean package -U
+
+dockerize:
+ docker build -t git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test .
+
+push:
+ docker push git.project-hobbit.eu:4567/gerbil/spotwrapnifws4test
diff --git a/scripts/gerbil_middleware/README.md b/src/scripts/gerbil_middleware/README.md
similarity index 100%
rename from scripts/gerbil_middleware/README.md
rename to src/scripts/gerbil_middleware/README.md
diff --git a/scripts/gerbil_middleware/curlExample.sh b/src/scripts/gerbil_middleware/curlExample.sh
similarity index 100%
rename from scripts/gerbil_middleware/curlExample.sh
rename to src/scripts/gerbil_middleware/curlExample.sh
diff --git a/scripts/gerbil_middleware/docker-compose.yml b/src/scripts/gerbil_middleware/docker-compose.yml
similarity index 100%
rename from scripts/gerbil_middleware/docker-compose.yml
rename to src/scripts/gerbil_middleware/docker-compose.yml
diff --git a/scripts/gerbil_middleware/example.ttl b/src/scripts/gerbil_middleware/example.ttl
similarity index 100%
rename from scripts/gerbil_middleware/example.ttl
rename to src/scripts/gerbil_middleware/example.ttl
diff --git a/scripts/gerbil_middleware/pom.xml b/src/scripts/gerbil_middleware/pom.xml
similarity index 96%
rename from scripts/gerbil_middleware/pom.xml
rename to src/scripts/gerbil_middleware/pom.xml
index 97e8aa0..af5da3e 100644
--- a/scripts/gerbil_middleware/pom.xml
+++ b/src/scripts/gerbil_middleware/pom.xml
@@ -76,7 +76,7 @@
org.apache.jena
jena-core
- 4.2.0
+ 2.11.1
org.apache.jena
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.md5 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.md5
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.md5
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.sha1 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.sha1
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-javadoc.jar.sha1
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.md5 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.md5
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.md5
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.sha1 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.sha1
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT-sources.jar.sha1
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.md5 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.md5
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.md5
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.sha1 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.sha1
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.jar.sha1
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.md5 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.md5
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.md5
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.sha1 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.sha1
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/gerbil.nif.transfer-1.1.0-SNAPSHOT.pom.sha1
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.md5 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.md5
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.md5
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.sha1 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.sha1
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/1.1.0-SNAPSHOT/maven-metadata-local.xml.sha1
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.md5 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.md5
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.md5
diff --git a/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.sha1 b/src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.sha1
rename to src/scripts/gerbil_middleware/repository/org/aksw/gerbil.nif.transfer/maven-metadata-local.xml.sha1
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.md5 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.md5
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.md5
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.sha1 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.sha1
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.jar.sha1
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.md5 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.md5
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.md5
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.sha1 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.sha1
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/2.2.1/org.restlet.ext.servlet-2.2.1.pom.sha1
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.md5 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.md5
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.md5
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.sha1 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.sha1
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet.ext.servlet/maven-metadata-local.xml.sha1
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.md5 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.md5
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.md5
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.sha1 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.sha1
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.jar.sha1
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.md5 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.md5
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.md5
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.sha1 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.sha1
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/2.2.1/org.restlet-2.2.1.pom.sha1
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.md5 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.md5
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.md5
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.md5
diff --git a/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.sha1 b/src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.sha1
similarity index 100%
rename from scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.sha1
rename to src/scripts/gerbil_middleware/repository/org/restlet/org.restlet/maven-metadata-local.xml.sha1
diff --git a/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/EDResource.java b/src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/EDResource.java
similarity index 100%
rename from scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/EDResource.java
rename to src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/EDResource.java
diff --git a/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/LocalIntermediateWebserver.java b/src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/LocalIntermediateWebserver.java
similarity index 100%
rename from scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/LocalIntermediateWebserver.java
rename to src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/LocalIntermediateWebserver.java
diff --git a/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/MyResource.java b/src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/MyResource.java
similarity index 100%
rename from scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/MyResource.java
rename to src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/MyResource.java
diff --git a/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightClient.java b/src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightClient.java
similarity index 100%
rename from scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightClient.java
rename to src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightClient.java
diff --git a/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightResource.java b/src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightResource.java
similarity index 100%
rename from scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightResource.java
rename to src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/SpotlightResource.java
diff --git a/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/TestApplication.java b/src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/TestApplication.java
similarity index 100%
rename from scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/TestApplication.java
rename to src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/TestApplication.java
diff --git a/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/data_format b/src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/data_format
similarity index 100%
rename from scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/data_format
rename to src/scripts/gerbil_middleware/src/main/java/org/aksw/gerbil/ws4test/data_format
diff --git a/scripts/gerbil_middleware/src/main/resources/log4j.properties b/src/scripts/gerbil_middleware/src/main/resources/log4j.properties
similarity index 100%
rename from scripts/gerbil_middleware/src/main/resources/log4j.properties
rename to src/scripts/gerbil_middleware/src/main/resources/log4j.properties
diff --git a/scripts/gerbil_middleware/src/main/webapp/WEB-INF/web.xml b/src/scripts/gerbil_middleware/src/main/webapp/WEB-INF/web.xml
similarity index 100%
rename from scripts/gerbil_middleware/src/main/webapp/WEB-INF/web.xml
rename to src/scripts/gerbil_middleware/src/main/webapp/WEB-INF/web.xml
diff --git a/scripts/test_server.py b/src/scripts/test_server.py
similarity index 100%
rename from scripts/test_server.py
rename to src/scripts/test_server.py
diff --git a/scripts/truecase/README.md b/src/scripts/truecase/README.md
similarity index 100%
rename from scripts/truecase/README.md
rename to src/scripts/truecase/README.md
diff --git a/scripts/truecase/relq.py b/src/scripts/truecase/relq.py
similarity index 100%
rename from scripts/truecase/relq.py
rename to src/scripts/truecase/relq.py
diff --git a/scripts/truecase/truecase-m.py b/src/scripts/truecase/truecase-m.py
similarity index 100%
rename from scripts/truecase/truecase-m.py
rename to src/scripts/truecase/truecase-m.py
diff --git a/scripts/update_db_pem.py b/src/scripts/update_db_pem.py
similarity index 100%
rename from scripts/update_db_pem.py
rename to src/scripts/update_db_pem.py
diff --git a/scripts/w2v/preprocess.sh b/src/scripts/w2v/preprocess.sh
similarity index 100%
rename from scripts/w2v/preprocess.sh
rename to src/scripts/w2v/preprocess.sh
diff --git a/scripts/w2v/train.sh b/src/scripts/w2v/train.sh
similarity index 100%
rename from scripts/w2v/train.sh
rename to src/scripts/w2v/train.sh
diff --git a/tests/test_bert_md.py b/tests/test_bert_md.py
new file mode 100644
index 0000000..0565547
--- /dev/null
+++ b/tests/test_bert_md.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+from pathlib import Path
+
+import os
+import pytest
+
+from transformers import AutoTokenizer, AutoModelForTokenClassification
+from transformers import pipeline
+
+from REL.mention_detection import MentionDetection
+from REL.ner.bert_wrapper import load_bert_ner
+
+
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
+def test_md(base_url):
+ tagger_ner = load_bert_ner("dslim/bert-base-NER")
+ process_sentences = False
+ split_docs_value = 0
+ wiki_version = "wiki_2019"
+ md = MentionDetection(base_url, wiki_version)
+
+ # first test case: repeating sentences
+ sample1 = {"test_doc": ["Fox, Fox. Fox.", []]}
+ resulting_spans1 = {(0, 3), (5, 3), (10, 3)}
+ predictions = md.find_mentions(sample1, process_sentences, split_docs_value, tagger_ner)
+ predicted_spans = []
+ for i in range(0, 1):
+ p = {
+ (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[i]["test_doc"]
+ }
+ predicted_spans.extend(list(p))
+ predicted_spans = set(predicted_spans)
+ assert resulting_spans1 == predicted_spans
+
+ # second test case: excessive whitespace
+ sample2 = {"test_doc": ["Fox, Fox, Fox.", []]}
+ resulting_spans2 = {(0, 3), (20, 3), (43, 3)}
+ predictions = md.find_mentions(sample2, process_sentences, split_docs_value, tagger_ner)
+ predicted_spans = {
+ (m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"]
+ }
+ assert resulting_spans2 == predicted_spans
diff --git a/tests/test_ed_pipeline.py b/tests/test_ed_pipeline.py
index b2041e7..10fb7b7 100644
--- a/tests/test_ed_pipeline.py
+++ b/tests/test_ed_pipeline.py
@@ -3,12 +3,18 @@
from pathlib import Path
+import os
+import pytest
+
from REL.entity_disambiguation import EntityDisambiguation
from REL.mention_detection import MentionDetection
from REL.ner import Cmns
from REL.utils import process_results
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
def test_pipeline():
base_url = Path(__file__).parent
wiki_subfolder = "wiki_test"
@@ -22,13 +28,13 @@ def test_pipeline():
tagger = Cmns(base_url, wiki_subfolder, n=5)
model = EntityDisambiguation(base_url, wiki_subfolder, config)
- mentions_dataset, total_mentions = md.format_spans(sample)
+ mentions_dataset, total_mentions = md.format_spans(sample, process_sentences=True)
predictions, _ = model.predict(mentions_dataset)
results = process_results(
mentions_dataset, predictions, sample, include_offset=False
)
- gold_truth = {"test_doc": [(10, 3, "Fox", "fox", -1, "NULL", 0.0)]}
+ gold_truth = {"test_doc": [(10, 3, "fox", "Fox", 0.0, 0.0, "NULL")]}
- return results == gold_truth
+ assert results == gold_truth
diff --git a/tests/test_evaluate_predictions.py b/tests/test_evaluate_predictions.py
new file mode 100644
index 0000000..3e8475f
--- /dev/null
+++ b/tests/test_evaluate_predictions.py
@@ -0,0 +1,37 @@
+from scripts.evaluate_predictions import compare_and_count_entities, print_scores
+import pytest
+
+
+def test_perfect():
+ gold_entities = [ [ "1", "1" ] ]
+ predicted_entities = [ [ "1", "1" ] ]
+ counts = compare_and_count_entities(gold_entities, predicted_entities)
+ scores = print_scores(*counts)
+ assert list(scores) == [1, 1, 1, 1, 1, 1], "should be perfect MD and perfect EL"
+
+
+def test_el_wrong():
+ gold_entities = [ [ "1", "1" ] ]
+ predicted_entities = [ [ "1", "0" ] ]
+ counts = compare_and_count_entities(gold_entities, predicted_entities)
+ scores = print_scores(*counts)
+ assert list(scores) == [1, 1, 1, 0, 0, 0], "should be perfect MD and failed EL"
+
+
+def test_md_wrong():
+ gold_entities = [ [ "1", "1" ] ]
+ predicted_entities = [ [ "0", "1" ] ]
+ counts = compare_and_count_entities(gold_entities, predicted_entities)
+ scores = print_scores(*counts)
+ assert list(scores) == [0, 0, 0, 0, 0, 0], "should be failed MD and failed EL"
+
+
+def test_combined():
+ gold_entities = [ [ "1", "1" ], [ "1", "1" ], [ "2", "2" ] ]
+ predicted_entities = [ [ "0", "0" ], [ "0", "1" ], [ "1", "0" ], [ "1", "1" ] ]
+ counts = compare_and_count_entities(gold_entities, predicted_entities)
+ scores = print_scores(*counts)
+ target_scores = [1/2, 2/3, 4/7, 1/4, 1/3, 2/7]
+ for index in range(0, len(scores)):
+ assert pytest.approx(scores[index], 0.0001) == target_scores[index], "should be various scores"
+
diff --git a/tests/test_flair_md.py b/tests/test_flair_md.py
index ac7ab62..1311830 100644
--- a/tests/test_flair_md.py
+++ b/tests/test_flair_md.py
@@ -3,20 +3,30 @@
from pathlib import Path
+import os
+import pytest
+
from flair.models import SequenceTagger
from REL.mention_detection import MentionDetection
-def test_md():
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
+def test_md(base_url):
# return standard Flair tagger + mention detection object
- tagger = SequenceTagger.load("ner-fast")
- md = MentionDetection(Path(__file__).parent, "wiki_test")
+ tagger_ner = SequenceTagger.load("ner-fast")
+ process_sentences = True
+ split_docs_value = 0
+ wiki_version = "wiki_2019"
+ md = MentionDetection(base_url, wiki_version)
# first test case: repeating sentences
- sample1 = {"test_doc": ["Fox, Fox. Fox.", []]}
+ sample1 = {"test_doc": [ "Fox. Fox. Fox." , []] }
resulting_spans1 = {(0, 3), (5, 3), (10, 3)}
- predictions = md.find_mentions(sample1, tagger)
+ predictions = md.find_mentions(sample1, process_sentences, split_docs_value, tagger_ner)
+
predicted_spans = {
(m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"]
}
@@ -25,7 +35,7 @@ def test_md():
# second test case: excessive whitespace
sample2 = {"test_doc": ["Fox. Fox. Fox.", []]}
resulting_spans2 = {(0, 3), (20, 3), (43, 3)}
- predictions = md.find_mentions(sample2, tagger)
+ predictions = md.find_mentions(sample2, process_sentences, split_docs_value, tagger_ner)
predicted_spans = {
(m["pos"], m["end_pos"] - m["pos"]) for m in predictions[0]["test_doc"]
}
diff --git a/tests/test_instantiation.py b/tests/test_instantiation.py
index 52c3bf2..16a34a2 100644
--- a/tests/test_instantiation.py
+++ b/tests/test_instantiation.py
@@ -3,6 +3,8 @@
from pathlib import Path
+import os
+import pytest
import torch
from REL.entity_disambiguation import EntityDisambiguation
@@ -11,31 +13,46 @@
from REL.ner import Cmns
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
def test_entity_disambiguation_instantiation():
- return EntityDisambiguation(
+ assert True == bool(EntityDisambiguation(
Path(__file__).parent,
"wiki_test",
{
"mode": "eval",
"model_path": Path(__file__).parent / "wiki_test" / "generated" / "model",
},
- )
+ ))
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
def test_cmns_instantiation():
- return Cmns(Path(__file__).parent, "wiki_test")
+ assert True == bool(Cmns(Path(__file__).parent, "wiki_test"))
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
def test_mention_detection_instantiation():
- return MentionDetection(Path(__file__).parent, "wiki_test")
+ assert True == bool(MentionDetection(Path(__file__).parent, "wiki_test"))
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
def test_prerank_instantiation():
# NOTE: this is basically just a blank constructor; if this fails, something is
# seriously wrong
- return PreRank({})
+ assert True == bool(PreRank({}))
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
def test_mulrel_ranker_instantiation():
# minimal config to make the constructor run
config = {
@@ -46,4 +63,4 @@ def test_mulrel_ranker_instantiation():
"use_local": True,
"use_pad_ent": True,
}
- return MulRelRanker(config, torch.device("cpu"))
+ assert True == bool(MulRelRanker(config, torch.device("cpu")))
diff --git a/tests/test_ngram.py b/tests/test_ngram.py
index 664e757..f849944 100644
--- a/tests/test_ngram.py
+++ b/tests/test_ngram.py
@@ -3,17 +3,26 @@
from pathlib import Path
+import os
+import pytest
+
from REL.ner import Cmns, Span
def compare_spans(a: Span, b: Span, fields=(0, 1, 2)):
+ if len(a) != len(b):
+ return False
for f in fields:
- if a[f] != b[f]:
- return False
+ for index in range(0, len(a)):
+ if a[index][f] != b[index][f]:
+ return (False, a[index][f], b[index][f])
else:
return True
+@pytest.mark.skipif(
+ os.getenv("GITHUB_ACTIONS")=='true', reason="No way of testing this on Github actions."
+)
def test_cmns():
model = Cmns(Path(__file__).parent, "wiki_test", n=5)
predictions = model.predict("the brown fox jumped over the lazy dog", None)
@@ -28,4 +37,4 @@ def test_cmns():
Span("dog", 35, 38, None, None),
]
- return compare_spans(predictions, labels)
+ assert compare_spans(predictions, labels) == True