diff --git a/stanza/models/common/constant.py b/stanza/models/common/constant.py index 7e24394019..d5daf60646 100644 --- a/stanza/models/common/constant.py +++ b/stanza/models/common/constant.py @@ -529,3 +529,6 @@ def treebank_to_langid(treebank): short_name = treebank_to_short_name(treebank) return short_name.split("_")[0] + +# special string to mark a MWT +MANUAL_MWT_MAGIC = "$TOK_MWT" diff --git a/stanza/models/common/doc.py b/stanza/models/common/doc.py index 3e33567b8c..94db3ffec4 100644 --- a/stanza/models/common/doc.py +++ b/stanza/models/common/doc.py @@ -10,6 +10,7 @@ import warnings import networkx as nx +from stanza.models.common.constant import MANUAL_MWT_MAGIC from stanza.models.common.stanza_object import StanzaObject from stanza.models.ner.utils import decode_from_bioes @@ -335,6 +336,9 @@ def get_mwt_expansions(self, evaluation=False): if m or n: src = token.text dst = ' '.join([word.text for word in token.words]) + if dst[:len(MANUAL_MWT_MAGIC)] == MANUAL_MWT_MAGIC: + dst = dst[len(MANUAL_MWT_MAGIC):] + src = MANUAL_MWT_MAGIC+dst expansions.append([src, dst]) if evaluation: expansions = [e[0] for e in expansions] return expansions diff --git a/stanza/models/tokenization/utils.py b/stanza/models/tokenization/utils.py index be809761b9..b2481e8099 100644 --- a/stanza/models/tokenization/utils.py +++ b/stanza/models/tokenization/utils.py @@ -9,6 +9,7 @@ from torch.utils.data import DataLoader as TorchDataLoader import stanza.utils.default_paths as default_paths +from stanza.models.common.constant import * from stanza.models.common.utils import ud_scores, harmonic_mean from stanza.models.common.doc import Document from stanza.utils.conll import CoNLL @@ -356,6 +357,7 @@ def postprocess_doc(doc, postprocessor, orig_text=None): # collect the words and MWTs seperately corrected_words = [] corrected_mwts = [] + corrected_expansions = [] # for each word, if its just a string (without the ("word", mwt_bool) format) # we default that the word is not a MWT. @@ -363,26 +365,29 @@ def postprocess_doc(doc, postprocessor, orig_text=None): sent_words = [] sent_mwts = [] for word in sent: - if type(word) == str: + if isinstance(word, str): sent_words.append(word) sent_mwts.append(False) else: - sent_words.append(word[0]) - sent_mwts.append(word[1]) + if isinstance(word[1], bool): + sent_words.append(word[0]) + sent_mwts.append(word[1]) + corrected_expansions.append(word[0]) + else: + sent_words.append(word[0]) + sent_mwts.append(True) + corrected_expansions.append(MANUAL_MWT_MAGIC+" ".join(word[1])) corrected_words.append(sent_words) corrected_mwts.append(sent_mwts) - - # check postprocessor output - token_lens = [len(i) for i in corrected_words] - mwt_lens = [len(i) for i in corrected_mwts] - assert token_lens == mwt_lens, "Postprocessor returned token and MWT lists of different length! Token list lengths %s, MWT list lengths %s" % (token_lens, mwt_lens) + # recassemble document. offsets and oov shouldn't change - doc = reassemble_doc_from_tokens(corrected_words, corrected_mwts, raw_text) + doc = reassemble_doc_from_tokens(corrected_words, corrected_mwts, + corrected_expansions, raw_text) return doc -def reassemble_doc_from_tokens(tokens, mwts, raw_text): +def reassemble_doc_from_tokens(tokens, mwts, expansions, raw_text): """Assemble a Stanza document list format from a list of string tokens, calculating offsets as needed. Parameters @@ -392,6 +397,8 @@ def reassemble_doc_from_tokens(tokens, mwts, raw_text): mwts : List[List[bool]] Whether or not each of the tokens are MWTs to be analyzed by the MWT raw. + mwts : List[List[List[str}]] + A list of possible expansions for MWTs parser_text : str The raw text off of which we can compare offsets. @@ -435,7 +442,11 @@ def reassemble_doc_from_tokens(tokens, mwts, raw_text): corrected_doc.append(sentence_doc) - return corrected_doc + # use the built in MWT system to expand MWTs + doc = Document(corrected_doc, raw_text) + doc.set_mwt_expansions(expansions) + + return doc.to_dict() def decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand): """ diff --git a/stanza/pipeline/mwt_processor.py b/stanza/pipeline/mwt_processor.py index e424496d77..392e995d85 100644 --- a/stanza/pipeline/mwt_processor.py +++ b/stanza/pipeline/mwt_processor.py @@ -3,6 +3,7 @@ """ import io +from stanza.models.common.constant import MANUAL_MWT_MAGIC import torch @@ -24,8 +25,9 @@ def _set_up_model(self, config, pipeline, device): def process(self, document): batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True) + expansions = batch.doc.get_mwt_expansions(evaluation=True) if len(batch) > 0: - dict_preds = self.trainer.predict_dict(batch.doc.get_mwt_expansions(evaluation=True)) + dict_preds = self.trainer.predict_dict(expansions) # decide trainer type and run eval if self.config['dict_only']: preds = dict_preds @@ -41,7 +43,16 @@ def process(self, document): # skip eval if dev data does not exist preds = [] - batch.doc.set_mwt_expansions(preds) + # force set back those marked with the special MWT + # forced tokenization string to the tokenizations + final_preds = [] + for inp, out in zip(expansions, preds): + if MANUAL_MWT_MAGIC in inp: + final_preds.append(inp[len(MANUAL_MWT_MAGIC):]) + else: + final_preds.append(out) + + batch.doc.set_mwt_expansions(final_preds) return batch.doc def bulk_process(self, docs): diff --git a/stanza/tests/pipeline/test_tokenizer.py b/stanza/tests/pipeline/test_tokenizer.py index 99cb9e1062..81a3b57fb2 100644 --- a/stanza/tests/pipeline/test_tokenizer.py +++ b/stanza/tests/pipeline/test_tokenizer.py @@ -128,6 +128,23 @@ EN_DOC_NO_SSPLIT = ["This is a sentence. This is another.", "This is a third."] EN_DOC_NO_SSPLIT_SENTENCES = [['This', 'is', 'a', 'sentence', '.', 'This', 'is', 'another', '.'], ['This', 'is', 'a', 'third', '.']] +FR_DOC = "Le prince va manger du poulet aux les magasins aujourd'hui." +FR_DOC_POSTPROCESSOR_TOKENS_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', "aujourd'hui", '.']] +FR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', ("aujourd'hui", ["aujourd'", "hui"]), '.']] +FR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS = """ +]> +]> +]> +]> +, ]> +]> +, ]> +]> +]> +, ]> +]> +""" + JA_DOC = "北京は中国の首都です。 北京の人口は2152万人です。\n" # add some random whitespaces that need to be skipped JA_DOC_GOLD_TOKENS = """ ]> @@ -314,6 +331,22 @@ def dummy_postprocessor(input): doc = nlp(EN_DOC) assert EN_DOC_POSTPROCESSOR_COMBINED_TOKENS.strip() == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]).strip() +def test_postprocessor_mwt(): + + def dummy_postprocessor(input): + # Importantly, EN_DOC_POSTPROCESSOR_COMBINED_LIST returns a few tokens joinde + # with space. As some languages (such as VN) contains tokens with space in between + # its important to have joined space tested as one of the tokens + assert input == FR_DOC_POSTPROCESSOR_TOKENS_LIST + return FR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST + + nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, + 'lang': 'fr', + 'tokenize_postprocessor': dummy_postprocessor}) + doc = nlp(FR_DOC) + assert FR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS.strip() == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]).strip() + + def test_postprocessor_typeerror(): with pytest.raises(ValueError): nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en', diff --git a/stanza/tests/tokenization/test_tokenize_utils.py b/stanza/tests/tokenization/test_tokenize_utils.py index 9970d40c80..3787859045 100644 --- a/stanza/tests/tokenization/test_tokenize_utils.py +++ b/stanza/tests/tokenization/test_tokenize_utils.py @@ -102,7 +102,7 @@ def test_postprocessor_application(): good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']] text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken." - target_doc = [[{'id': (1,), 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': (2,), 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': (3,), 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': (4,), 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': (5,), 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': (6,), 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': (1,), 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': (2,), 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': (3,), 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': (4,), 'text': '.', 'start_char': 31, 'end_char': 32}]] + target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32}]] def postprocesor(_): return good_tokenization @@ -121,9 +121,9 @@ def test_reassembly_indexing(): text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken." - target_doc = [[{'id': (1,), 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': (2,), 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': (3,), 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': (4,), 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': (5,), 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': (6,), 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': (1,), 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': (2,), 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': (3,), 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': (4,), 'text': '.', 'start_char': 31, 'end_char': 32}]] + target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32}]] - res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, text) + res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, [], text) assert res == target_doc @@ -144,12 +144,12 @@ def test_reassembly_reference_failures(): text = "Joe Smith lives in California." with pytest.raises(ValueError): - utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, text) + utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, [], text) with pytest.raises(ValueError): - utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, text) + utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, [], text) - utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, text) + utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, [], text)