diff --git a/stanza/utils/datasets/coref/__init__.py b/stanza/utils/datasets/coref/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/stanza/utils/datasets/coref/convert_ontonotes.py b/stanza/utils/datasets/coref/convert_ontonotes.py new file mode 100644 index 0000000000..aedf3eb55b --- /dev/null +++ b/stanza/utils/datasets/coref/convert_ontonotes.py @@ -0,0 +1,122 @@ +from collections import defaultdict +import os + +import stanza + +from stanza.models.constituency import tree_reader +from stanza.utils.default_paths import get_default_paths +from stanza.utils.get_tqdm import get_tqdm + +tqdm = get_tqdm() + +def read_paragraphs(section): + for doc in section: + part_id = None + paragraph = [] + for sentence in doc['sentences']: + if part_id is None: + part_id = sentence['part_id'] + elif part_id != sentence['part_id']: + yield doc['document_id'], part_id, paragraph + paragraph = [] + part_id = sentence['part_id'] + paragraph.append(sentence) + if paragraph != []: + yield doc['document_id'], part_id, paragraph + +def convert_dataset_section(pipe, section): + processed_section = [] + section = list(x for x in read_paragraphs(section)) + for idx, (doc_id, part_id, paragraph) in enumerate(tqdm(section)): + sentences = [x['words'] for x in paragraph] + sentence_lens = [len(x) for x in sentences] + cased_words = [y for x in sentences for y in x] + sent_id = [y for idx, sent_len in enumerate(sentence_lens) for y in [idx] * sent_len] + speaker = [y for x, sent_len in zip(paragraph, sentence_lens) for y in [x['speaker']] * sent_len] + + # use the trees to get the xpos tags + # alternatively, could translate the pos_tags field, + # but those have numbers, which is annoying + #tree_text = "\n".join(x['parse_tree'] for x in paragraph) + #trees = tree_reader.read_trees(tree_text) + #pos = [x.label for tree in trees for x in tree.yield_preterminals()] + # actually, the downstream code doesn't use pos at all. maybe we can skip? + + clusters = defaultdict(list) + word_total = 0 + for sentence in paragraph: + coref_spans = sentence['coref_spans'] + for span in coref_spans: + # input is expected to be start word, end word + 1 + # counting from 0 + clusters[span[0]].append((span[1] + word_total, span[2] + word_total + 1)) + word_total += len(sentence['words']) + clusters = sorted([sorted(values) for _, values in clusters.items()]) + + doc = pipe(sentences) + word_total = 0 + heads = [] + # TODO: does SD vs UD matter? + deprel = [] + for sentence in doc.sentences: + for word in sentence.words: + deprel.append(word.deprel) + if word.head == 0: + heads.append("null") + else: + heads.append(word.head - 1 + word_total) + word_total += len(sentence.words) + + processed = { + "document_id": doc_id, + "cased_words": sentences, + "sent_id": sent_id, + "part_id": part_id, + "speaker": speaker, + #"pos": pos, + "deprel": deprel, + "head": heads, + "clusters": clusters + } + processed_section.append(paragraph) + return processed_section + +SECTION_NAMES = {"train": "train", + "dev": "validation", + "test": "test"} + +def process_dataset(short_name, ontonotes_path, coref_output_path): + try: + from datasets import load_dataset + except ImportError as e: + raise ImportError("Please install the datasets package to process OntoNotes coref with Stanza") + + if short_name == 'en_ontonotes': + config_name = 'english_v4' + elif short_name in ('zh_ontonotes', 'zh-hans_ontonotes'): + config_name = 'chinese_v4' + elif short_name == 'ar_ontonotes': + config_name = 'arabic_v4' + else: + raise ValueError("Unknown short name for downloading ontonotes: %s" % short_name) + + pipe = stanza.Pipeline("en", processors="tokenize,pos,lemma,depparse", package="default_accurate", tokenize_pretokenized=True) + dataset = load_dataset("conll2012_ontonotesv5", config_name, cache_dir=ontonotes_path) + for section, hf_name in SECTION_NAMES.items(): + print("Processing %s" % section) + converted_section = convert_dataset_section(pipe, dataset[hf_name]) + output_filename = os.path.join(coref_output_path, "%s.%s.json" % (short_name, section)) + with open(output_filename, "w", encoding="utf-8") as fout: + json.dump(converted_section, fout, indent=2) + + +def main(): + paths = get_default_paths() + coref_input_path = paths['COREF_BASE'] + ontonotes_path = os.path.join(coref_input_path, "english", "en_ontonotes") + coref_output_path = paths['COREF_DATA_DIR'] + process_dataset("en_ontonotes", ontonotes_path, coref_output_path) + +if __name__ == '__main__': + main() + diff --git a/stanza/utils/default_paths.py b/stanza/utils/default_paths.py index 551cab53d9..0618d8b634 100644 --- a/stanza/utils/default_paths.py +++ b/stanza/utils/default_paths.py @@ -20,6 +20,7 @@ def get_default_paths(): "CHARLM_DATA_DIR": DATA_ROOT + "/charlm", "SENTIMENT_DATA_DIR": DATA_ROOT + "/sentiment", "CONSTITUENCY_DATA_DIR": DATA_ROOT + "/constituency", + "COREF_DATA_DIR": DATA_ROOT + "/coref", # Set directories to store external word vector data "WORDVEC_DIR": "extern_data/wordvec", @@ -32,6 +33,7 @@ def get_default_paths(): "NERBASE": "extern_data/ner", "CONSTITUENCY_BASE": "extern_data/constituency", "SENTIMENT_BASE": "extern_data/sentiment", + "COREF_BASE": "extern_data/coref", # there's a stanford github, stanfordnlp/handparsed-treebank, # with some data for different languages