Skip to content

Commit

Permalink
Add a script to combine OntoNotes and WW into a single NER model with…
Browse files Browse the repository at this point in the history
… 18 classes
  • Loading branch information
AngledLuffa committed Nov 11, 2023
1 parent 16482d9 commit 4fa53a2
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 0 deletions.
85 changes: 85 additions & 0 deletions stanza/utils/datasets/ner/ontonotes_multitag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse
import json
import os
import shutil

from stanza.utils.datasets.ner.utils import combine_files
from stanza.utils.datasets.ner.simplify_ontonotes_to_worldwide import simplify_ontonotes_to_worldwide

def convert_ontonotes_file(filename, simplify, bigger_first):
assert "en_ontonotes" in filename
if not os.path.exists(filename):
raise FileNotFoundError("Cannot convert missing file %s" % filename)
new_filename = filename.replace("en_ontonotes", "en_ontonotes-multi")

with open(filename) as fin:
doc = json.load(fin)

for sentence in doc:
for word in sentence:
ner = word['ner']
if simplify:
simplified = simplify_ontonotes_to_worldwide(ner)
else:
simplified = "-"
if bigger_first:
word['multi_ner'] = (ner, simplified)
else:
word['multi_ner'] = (simplified, ner)

with open(new_filename, "w") as fout:
json.dump(doc, fout, indent=2)

def convert_worldwide_file(filename, bigger_first):
assert "en_worldwide-9class" in filename
if not os.path.exists(filename):
raise FileNotFoundError("Cannot convert missing file %s" % filename)

new_filename = filename.replace("en_worldwide-9class", "en_worldwide-9class-multi")

with open(filename) as fin:
doc = json.load(fin)

for sentence in doc:
for word in sentence:
ner = word['ner']
if bigger_first:
word['multi_ner'] = ("-", ner)
else:
word['multi_ner'] = (ner, "-")

with open(new_filename, "w") as fout:
json.dump(doc, fout, indent=2)

def build_multitag_dataset(base_output_path, short_name, simplify, bigger_first):
convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.train.json"), simplify, bigger_first)
convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.dev.json"), simplify, bigger_first)
convert_ontonotes_file(os.path.join(base_output_path, "en_ontonotes.test.json"), simplify, bigger_first)

convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.train.json"), bigger_first)
convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.dev.json"), bigger_first)
convert_worldwide_file(os.path.join(base_output_path, "en_worldwide-9class.test.json"), bigger_first)

combine_files(os.path.join(base_output_path, "%s.train.json" % short_name),
os.path.join(base_output_path, "en_ontonotes-multi.train.json"),
os.path.join(base_output_path, "en_worldwide-9class-multi.train.json"))
shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.dev.json"),
os.path.join(base_output_path, "%s.dev.json" % short_name))
shutil.copyfile(os.path.join(base_output_path, "en_ontonotes-multi.test.json"),
os.path.join(base_output_path, "%s.test.json" % short_name))


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--no_simplify', dest='simplify', action='store_false', help='By default, this script will simplify the OntoNotes 18 classes to the 8 WorldWide classes in a second column. Turning that off will leave that column blank. Initial experiments with that setting were very bad, though')
parser.add_argument('--no_bigger_first', dest='bigger_first', action='store_false', help='By default, this script will put the 18 class tags in the first column and the 8 in the second. This flips the order')
args = parser.parse_args()

paths = default_paths.get_default_paths()
base_output_path = paths["NER_DATA_DIR"]

build_multitag_dataset(base_output_path, "en_ontonotes-ww-multi", args.simplify, args.bigger_first)

if __name__ == '__main__':
main()

13 changes: 13 additions & 0 deletions stanza/utils/datasets/ner/prepare_ner_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@
import stanza.utils.datasets.ner.convert_nkjp as convert_nkjp
import stanza.utils.datasets.ner.prepare_ner_file as prepare_ner_file
import stanza.utils.datasets.ner.convert_sindhi_siner as convert_sindhi_siner
import stanza.utils.datasets.ner.ontonotes_multitag as ontonotes_multitag
import stanza.utils.datasets.ner.simplify_en_worldwide as simplify_en_worldwide
import stanza.utils.datasets.ner.suc_to_iob as suc_to_iob
import stanza.utils.datasets.ner.suc_conll_to_iob as suc_conll_to_iob
Expand Down Expand Up @@ -1125,6 +1126,17 @@ def process_en_conll03_worldwide(paths, short_name):
shutil.copyfile(os.path.join(paths['NER_DATA_DIR'], "en_conll03.test.json"),
os.path.join(paths['NER_DATA_DIR'], "%s.test.json" % short_name))

def process_en_ontonotes_ww_multi(paths, short_name):
"""
Combine the worldwide data with the OntoNotes data in a multi channel format
"""
print("=============== Preparing OntoNotes ===============")
process_en_ontonotes(paths, "en_ontonotes")
print("========== Preparing 9 Class Worldwide ================")
process_en_worldwide_9class(paths, "en_worldwide-9class")
# TODO: pass in options?
ontonotes_multitag.build_multitag_dataset(paths['NER_DATA_DIR'], short_name, True, True)


def process_en_conllpp(paths, short_name):
"""
Expand Down Expand Up @@ -1170,6 +1182,7 @@ def process_ar_aqmar(paths, short_name):
"en_conll03ww": process_en_conll03_worldwide,
"en_conllpp": process_en_conllpp,
"en_ontonotes": process_en_ontonotes,
"en_ontonotes-ww-multi": process_en_ontonotes_ww_multi,
"en_worldwide-4class": process_en_worldwide_4class,
"en_worldwide-9class": process_en_worldwide_9class,
"fa_arman": process_fa_arman,
Expand Down

0 comments on commit 4fa53a2

Please sign in to comment.