Skip to content

Commit

Permalink
Initial version of script to convert OntoNotes from HF using Stanza's…
Browse files Browse the repository at this point in the history
… depparse
  • Loading branch information
AngledLuffa committed Nov 19, 2023
1 parent 3e2f3a8 commit 1b7aaeb
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
Empty file.
122 changes: 122 additions & 0 deletions stanza/utils/datasets/coref/convert_ontonotes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from collections import defaultdict
import os

import stanza

from stanza.models.constituency import tree_reader
from stanza.utils.default_paths import get_default_paths
from stanza.utils.get_tqdm import get_tqdm

tqdm = get_tqdm()

def read_paragraphs(section):
for doc in section:
part_id = None
paragraph = []
for sentence in doc['sentences']:
if part_id is None:
part_id = sentence['part_id']
elif part_id != sentence['part_id']:
yield doc['document_id'], part_id, paragraph
paragraph = []
part_id = sentence['part_id']
paragraph.append(sentence)
if paragraph != []:
yield doc['document_id'], part_id, paragraph

def convert_dataset_section(pipe, section):
processed_section = []
section = list(x for x in read_paragraphs(section))
for idx, (doc_id, part_id, paragraph) in enumerate(tqdm(section)):
sentences = [x['words'] for x in paragraph]
sentence_lens = [len(x) for x in sentences]
cased_words = [y for x in sentences for y in x]
sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len]
speaker = [y for x, sent_len in zip(paragraph, sentence_lens) for y in [x['speaker']] * sent_len]

# use the trees to get the xpos tags
# alternatively, could translate the pos_tags field,
# but those have numbers, which is annoying
#tree_text = "\n".join(x['parse_tree'] for x in paragraph)
#trees = tree_reader.read_trees(tree_text)
#pos = [x.label for tree in trees for x in tree.yield_preterminals()]
# actually, the downstream code doesn't use pos at all. maybe we can skip?

clusters = defaultdict(list)
word_total = 0
for sentence in paragraph:
coref_spans = sentence['coref_spans']
for span in coref_spans:
# input is expected to be start word, end word + 1
# counting from 0
clusters[span[0]].append((span[1] + word_total, span[2] + word_total + 1))
word_total += len(sentence['words'])
clusters = sorted([sorted(values) for _, values in clusters.items()])

doc = pipe(sentences)
word_total = 0
heads = []
# TODO: does SD vs UD matter?
deprel = []
for sentence in doc.sentences:
for word in sentence.words:
deprel.append(word.deprel)
if word.head == 0:
heads.append("null")
else:
heads.append(word.head - 1 + word_total)
word_total += len(sentence.words)

processed = {
"document_id": doc_id,
"cased_words": sentences,
"sent_id": sent_id,
"part_id": part_id,
"speaker": speaker,
#"pos": pos,
"deprel": deprel,
"head": heads,
"clusters": clusters
}
processed_section.append(paragraph)
return processed_section

SECTION_NAMES = {"train": "train",
"dev": "validation",
"test": "test"}

def process_dataset(short_name, ontonotes_path, coref_output_path):
try:
from datasets import load_dataset
except ImportError as e:
raise ImportError("Please install the datasets package to process OntoNotes coref with Stanza")

if short_name == 'en_ontonotes':
config_name = 'english_v4'
elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'):
config_name = 'chinese_v4'
elif short_name == 'ar_ontonotes':
config_name = 'arabic_v4'
else:
raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name)

pipe = stanza.Pipeline("en", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True)
dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=ontonotes_path)
for section, hf_name in SECTION_NAMES.items():
print("Processing %s" % section)
converted_section = convert_dataset_section(pipe, dataset[hf_name])
output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section))
with open(output_filename, "w", encoding="utf-8") as fout:
json.dump(converted_section, fout, indent=2)


def main():
paths = get_default_paths()
coref_input_path = paths['COREF_BASE']
ontonotes_path = os.path.join(coref_input_path, "english", "en_ontonotes")
coref_output_path = paths['COREF_DATA_DIR']
process_dataset("en_ontonotes", ontonotes_path, coref_output_path)

if __name__ == '__main__':
main()

2 changes: 2 additions & 0 deletions stanza/utils/default_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get_default_paths():
"CHARLM_DATA_DIR": DATA_ROOT + "/charlm",
"SENTIMENT_DATA_DIR": DATA_ROOT + "/sentiment",
"CONSTITUENCY_DATA_DIR": DATA_ROOT + "/constituency",
"COREF_DATA_DIR": DATA_ROOT + "/coref",

# Set directories to store external word vector data
"WORDVEC_DIR": "extern_data/wordvec",
Expand All @@ -32,6 +33,7 @@ def get_default_paths():
"NERBASE": "extern_data/ner",
"CONSTITUENCY_BASE": "extern_data/constituency",
"SENTIMENT_BASE": "extern_data/sentiment",
"COREF_BASE": "extern_data/coref",

# there's a stanford github, stanfordnlp/handparsed-treebank,
# with some data for different languages
Expand Down

0 comments on commit 1b7aaeb

Please sign in to comment.