Skip to content

Commit

Permalink
superbly jank implementation of manual MWT control
Browse files Browse the repository at this point in the history
  • Loading branch information
Jemoka committed Oct 24, 2023
1 parent 71253bc commit 27471c0
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 19 deletions.
3 changes: 3 additions & 0 deletions stanza/models/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,3 +529,6 @@ def treebank_to_langid(treebank):
short_name = treebank_to_short_name(treebank)
return short_name.split("_")[0]


# special string to mark a MWT
MANUAL_MWT_MAGIC = "$TOK_MWT"
4 changes: 4 additions & 0 deletions stanza/models/common/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import warnings

import networkx as nx
from stanza.models.common.constant import MANUAL_MWT_MAGIC

from stanza.models.common.stanza_object import StanzaObject
from stanza.models.ner.utils import decode_from_bioes
Expand Down Expand Up @@ -335,6 +336,9 @@ def get_mwt_expansions(self, evaluation=False):
if m or n:
src = token.text
dst = ' '.join([word.text for word in token.words])
if dst[:len(MANUAL_MWT_MAGIC)] == MANUAL_MWT_MAGIC:
dst = dst[len(MANUAL_MWT_MAGIC):]
src = MANUAL_MWT_MAGIC+dst
expansions.append([src, dst])
if evaluation: expansions = [e[0] for e in expansions]
return expansions
Expand Down
33 changes: 22 additions & 11 deletions stanza/models/tokenization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.utils.data import DataLoader as TorchDataLoader

import stanza.utils.default_paths as default_paths
from stanza.models.common.constant import *
from stanza.models.common.utils import ud_scores, harmonic_mean
from stanza.models.common.doc import Document
from stanza.utils.conll import CoNLL
Expand Down Expand Up @@ -356,33 +357,37 @@ def postprocess_doc(doc, postprocessor, orig_text=None):
# collect the words and MWTs seperately
corrected_words = []
corrected_mwts = []
corrected_expansions = []

# for each word, if its just a string (without the ("word", mwt_bool) format)
# we default that the word is not a MWT.
for sent in postprocessor_return:
sent_words = []
sent_mwts = []
for word in sent:
if type(word) == str:
if isinstance(word, str):
sent_words.append(word)
sent_mwts.append(False)
else:
sent_words.append(word[0])
sent_mwts.append(word[1])
if isinstance(word[1], bool):
sent_words.append(word[0])
sent_mwts.append(word[1])
corrected_expansions.append(word[0])
else:
sent_words.append(word[0])
sent_mwts.append(True)
corrected_expansions.append(MANUAL_MWT_MAGIC+" ".join(word[1]))
corrected_words.append(sent_words)
corrected_mwts.append(sent_mwts)

# check postprocessor output
token_lens = [len(i) for i in corrected_words]
mwt_lens = [len(i) for i in corrected_mwts]
assert token_lens == mwt_lens, "Postprocessor returned token and MWT lists of different length! Token list lengths %s, MWT list lengths %s" % (token_lens, mwt_lens)


# recassemble document. offsets and oov shouldn't change
doc = reassemble_doc_from_tokens(corrected_words, corrected_mwts, raw_text)
doc = reassemble_doc_from_tokens(corrected_words, corrected_mwts,
corrected_expansions, raw_text)

return doc

def reassemble_doc_from_tokens(tokens, mwts, raw_text):
def reassemble_doc_from_tokens(tokens, mwts, expansions, raw_text):
"""Assemble a Stanza document list format from a list of string tokens, calculating offsets as needed.
Parameters
Expand All @@ -392,6 +397,8 @@ def reassemble_doc_from_tokens(tokens, mwts, raw_text):
mwts : List[List[bool]]
Whether or not each of the tokens are MWTs to be analyzed by
the MWT raw.
mwts : List[List[List[str}]]
A list of possible expansions for MWTs
parser_text : str
The raw text off of which we can compare offsets.
Expand Down Expand Up @@ -435,7 +442,11 @@ def reassemble_doc_from_tokens(tokens, mwts, raw_text):

corrected_doc.append(sentence_doc)

return corrected_doc
# use the built in MWT system to expand MWTs
doc = Document(corrected_doc, raw_text)
doc.set_mwt_expansions(expansions)

return doc.to_dict()

def decode_predictions(vocab, mwt_dict, orig_text, all_raw, all_preds, no_ssplit, skip_newline, use_la_ittb_shorthand):
"""
Expand Down
15 changes: 13 additions & 2 deletions stanza/pipeline/mwt_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import io
from stanza.models.common.constant import MANUAL_MWT_MAGIC

import torch

Expand All @@ -24,8 +25,9 @@ def _set_up_model(self, config, pipeline, device):

def process(self, document):
batch = DataLoader(document, self.config['batch_size'], self.config, vocab=self.vocab, evaluation=True)
expansions = batch.doc.get_mwt_expansions(evaluation=True)
if len(batch) > 0:
dict_preds = self.trainer.predict_dict(batch.doc.get_mwt_expansions(evaluation=True))
dict_preds = self.trainer.predict_dict(expansions)
# decide trainer type and run eval
if self.config['dict_only']:
preds = dict_preds
Expand All @@ -41,7 +43,16 @@ def process(self, document):
# skip eval if dev data does not exist
preds = []

batch.doc.set_mwt_expansions(preds)
# force set back those marked with the special MWT
# forced tokenization string to the tokenizations
final_preds = []
for inp, out in zip(expansions, preds):
if MANUAL_MWT_MAGIC in inp:
final_preds.append(inp[len(MANUAL_MWT_MAGIC):])
else:
final_preds.append(out)

batch.doc.set_mwt_expansions(final_preds)
return batch.doc

def bulk_process(self, docs):
Expand Down
33 changes: 33 additions & 0 deletions stanza/tests/pipeline/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,23 @@
EN_DOC_NO_SSPLIT = ["This is a sentence. This is another.", "This is a third."]
EN_DOC_NO_SSPLIT_SENTENCES = [['This', 'is', 'a', 'sentence', '.', 'This', 'is', 'another', '.'], ['This', 'is', 'a', 'third', '.']]

FR_DOC = "Le prince va manger du poulet aux les magasins aujourd'hui."
FR_DOC_POSTPROCESSOR_TOKENS_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', "aujourd'hui", '.']]
FR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST = [['Le', 'prince', 'va', 'manger', ('du', True), 'poulet', ('aux', True), 'les', 'magasins', ("aujourd'hui", ["aujourd'", "hui"]), '.']]
FR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS = """
<Token id=1;words=[<Word id=1;text=Le>]>
<Token id=2;words=[<Word id=2;text=prince>]>
<Token id=3;words=[<Word id=3;text=va>]>
<Token id=4;words=[<Word id=4;text=manger>]>
<Token id=5-6;words=[<Word id=5;text=de>, <Word id=6;text=le>]>
<Token id=7;words=[<Word id=7;text=poulet>]>
<Token id=8-9;words=[<Word id=8;text=à>, <Word id=9;text=les>]>
<Token id=10;words=[<Word id=10;text=les>]>
<Token id=11;words=[<Word id=11;text=magasins>]>
<Token id=12-13;words=[<Word id=12;text=aujourd'>, <Word id=13;text=hui>]>
<Token id=14;words=[<Word id=14;text=.>]>
"""

JA_DOC = "北京は中国の首都です。 北京の人口は2152万人です。\n" # add some random whitespaces that need to be skipped
JA_DOC_GOLD_TOKENS = """
<Token id=1;words=[<Word id=1;text=北京>]>
Expand Down Expand Up @@ -314,6 +331,22 @@ def dummy_postprocessor(input):
doc = nlp(EN_DOC)
assert EN_DOC_POSTPROCESSOR_COMBINED_TOKENS.strip() == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]).strip()

def test_postprocessor_mwt():

def dummy_postprocessor(input):
# Importantly, EN_DOC_POSTPROCESSOR_COMBINED_LIST returns a few tokens joinde
# with space. As some languages (such as VN) contains tokens with space in between
# its important to have joined space tested as one of the tokens
assert input == FR_DOC_POSTPROCESSOR_TOKENS_LIST
return FR_DOC_POSTPROCESSOR_COMBINED_MWT_LIST

nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR,
'lang': 'fr',
'tokenize_postprocessor': dummy_postprocessor})
doc = nlp(FR_DOC)
assert FR_DOC_PRETOKENIZED_LIST_GOLD_TOKENS.strip() == '\n\n'.join([sent.tokens_string() for sent in doc.sentences]).strip()


def test_postprocessor_typeerror():
with pytest.raises(ValueError):
nlp = stanza.Pipeline(**{'processors': 'tokenize', 'dir': TEST_MODELS_DIR, 'lang': 'en',
Expand Down
12 changes: 6 additions & 6 deletions stanza/tests/tokenization/test_tokenize_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_postprocessor_application():
good_tokenization = [['I', 'am', 'Joe.', '⭆⊱⇞', 'Hi', '.'], ["I'm", 'a', 'chicken', '.']]
text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken."

target_doc = [[{'id': (1,), 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': (2,), 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': (3,), 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': (4,), 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': (5,), 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': (6,), 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': (1,), 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': (2,), 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': (3,), 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': (4,), 'text': '.', 'start_char': 31, 'end_char': 32}]]
target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32}]]

def postprocesor(_):
return good_tokenization
Expand All @@ -121,9 +121,9 @@ def test_reassembly_indexing():

text = "I am Joe. ⭆⊱⇞ Hi. I'm a chicken."

target_doc = [[{'id': (1,), 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': (2,), 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': (3,), 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': (4,), 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': (5,), 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': (6,), 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': (1,), 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': (2,), 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': (3,), 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': (4,), 'text': '.', 'start_char': 31, 'end_char': 32}]]
target_doc = [[{'id': 1, 'text': 'I', 'start_char': 0, 'end_char': 1}, {'id': 2, 'text': 'am', 'start_char': 2, 'end_char': 4}, {'id': 3, 'text': 'Joe.', 'start_char': 5, 'end_char': 9}, {'id': 4, 'text': '⭆⊱⇞', 'start_char': 10, 'end_char': 13}, {'id': 5, 'text': 'Hi', 'start_char': 14, 'end_char': 16}, {'id': 6, 'text': '.', 'start_char': 16, 'end_char': 17}], [{'id': 1, 'text': "I'm", 'start_char': 18, 'end_char': 21}, {'id': 2, 'text': 'a', 'start_char': 22, 'end_char': 23}, {'id': 3, 'text': 'chicken', 'start_char': 24, 'end_char': 31}, {'id': 4, 'text': '.', 'start_char': 31, 'end_char': 32}]]

res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, text)
res = utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, [], text)

assert res == target_doc

Expand All @@ -144,12 +144,12 @@ def test_reassembly_reference_failures():
text = "Joe Smith lives in California."

with pytest.raises(ValueError):
utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, text)
utils.reassemble_doc_from_tokens(bad_addition_tokenization, bad_addition_mwts, [], text)

with pytest.raises(ValueError):
utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, text)
utils.reassemble_doc_from_tokens(bad_inline_tokenization, bad_inline_mwts, [], text)

utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, text)
utils.reassemble_doc_from_tokens(good_tokenization, good_mwts, [], text)



Expand Down

0 comments on commit 27471c0

Please sign in to comment.