From 271dcd9d6a19ec83913832caf92688eb9aa26d16 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 2 Sep 2024 00:31:53 -0400 Subject: [PATCH 01/17] initial work on contextual-aware tokenizer --- stanza/models/tokenization/data.py | 17 ++-- stanza/models/tokenization/model.py | 123 +++++++++++++++++++++++++- stanza/models/tokenization/trainer.py | 16 ++-- stanza/models/tokenizer.py | 28 +++++- stanza/pipeline/tokenize_processor.py | 5 +- 5 files changed, 170 insertions(+), 19 deletions(-) diff --git a/stanza/models/tokenization/data.py b/stanza/models/tokenization/data.py index 3ff919b0ba..3c4746a6d8 100644 --- a/stanza/models/tokenization/data.py +++ b/stanza/models/tokenization/data.py @@ -1,5 +1,5 @@ from bisect import bisect_right -from copy import copy +from copy import copy, deepcopy import numpy as np import random import logging @@ -355,15 +355,20 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): features[i, :len(f_), :] = f_ raw_units.append(r_ + [''] * (pad_len - len(r_))) + # so we always return text that's not been ed, but will return + # IDs with s in them. REVIEW: check if the raw text is used + # anywhere else such that the lack of UNKs will cause a problem + dropped_units = deepcopy(raw_units) + if unit_dropout > 0 and not self.eval: # dropout characters/units at training time and replace them with UNKs mask = np.random.random_sample(units.shape) < unit_dropout mask[units == padid] = 0 units[mask] = unkid - for i in range(len(raw_units)): - for j in range(len(raw_units[i])): + for i in range(len(dropped_units)): + for j in range(len(dropped_units[i])): if mask[i, j]: - raw_units[i][j] = '' + dropped_units[i][j] = '' # dropout unit feature vector in addition to only torch.dropout in the model. # experiments showed that only torch.dropout hurts the model @@ -372,8 +377,8 @@ def strings_starting(id_pair, offset=0, pad_len=self.args['max_seqlen']): if self.args['use_dictionary'] and feat_unit_dropout > 0 and not self.eval: mask_feat = np.random.random_sample(units.shape) < feat_unit_dropout mask_feat[units == padid] = 0 - for i in range(len(raw_units)): - for j in range(len(raw_units[i])): + for i in range(len(dropped_units)): + for j in range(len(dropped_units[i])): if mask_feat[i,j]: features[i,j,:] = 0 diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 1f60987126..1d8053b0d7 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -1,12 +1,50 @@ import torch import torch.nn.functional as F import torch.nn as nn +from itertools import tee + +from stanza.models.common.seq2seq_constant import PAD, UNK, UNK_ID + +class SentenceAnalyzer(nn.Module): + def __init__(self, args, pretrain, hidden_dim, device=None): + super().__init__() + + assert pretrain != None, "2nd pass sentence anayzer is missing pretrain word vectors" + + self.args = args + self.vocab = pretrain.vocab + self.embeddings = nn.Embedding.from_pretrained( + torch.from_numpy(pretrain.emb), freeze=True) + + self.emb_proj = nn.Linear(pretrain.emb.shape[1], hidden_dim) + self.conv = nn.Conv1d(hidden_dim, hidden_dim, + args["sentence_analyzer_kernel"], padding="same", + padding_mode="circular") + self.ffnn = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1) + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, x): + # map the vocab to pretrain IDs + embs = self.embeddings(torch.tensor([[self.vocab[j] for j in i] for i in x], + device=self.device)) + net = self.emb_proj(embs) + net = self.conv(net.permute(0,2,1)).permute(0,2,1) + return self.ffnn(net) + class Tokenizer(nn.Module): - def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout): + def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pretrain=None): super().__init__() self.args = args + self.pretrain = pretrain feat_dim = args['feat_dim'] self.embeddings = nn.Embedding(nchars, emb_dim, padding_idx=0) @@ -36,12 +74,15 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout): if self.args['use_mwt']: self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) + if args['sentence_second_pass']: + self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim) + self.dropout = nn.Dropout(dropout) self.dropout_feat = nn.Dropout(feat_dropout) self.toknoise = nn.Dropout(self.args['tok_noise']) - def forward(self, x, feats): + def forward(self, x, feats, text): emb = self.embeddings(x) emb = self.dropout(emb) feats = self.dropout_feat(feats) @@ -87,12 +128,86 @@ def forward(self, x, feats): nontok = F.logsigmoid(-tok0) tok = F.logsigmoid(tok0) - nonsent = F.logsigmoid(-sent0) - sent = F.logsigmoid(sent0) if self.args['use_mwt']: nonmwt = F.logsigmoid(-mwt0) mwt = F.logsigmoid(mwt0) + # use the rough predictions from the char tokenizer to create word tokens + # then use those word tokens + contextual/fixed word embeddings to refine + # sentence predictions + if self.args["sentence_second_pass"]: + # these are the draft predictions for only token-level decisinos + # which we can use to slice the text + draft_preds = torch.cat([nontok, tok+nonmwt, tok+mwt], 2).argmax(dim=2) + draft_preds = (draft_preds > 0) + # we add a prefix zero + # TODO inefficient / how to parallelize this? + token_locations = [[-1] + i.nonzero().squeeze(1).cpu().tolist() + for i in draft_preds] + + # both: batch x seq x [variable: text token count] + batch_tokens = [] # str tokens + batch_tokenid_locations = [] # id locations for the *end* of each str token + # corresponding to char token + for location,chars, toks in zip(token_locations, text, x): + # we append len(chars)-1 to append the last token which wouldn't + # necessearily have been captured by the splits; though in theory + # the model should put a token at the end of each sentence so this + # should be less of a problem + + a,b = tee(location+[len(chars)-1]) + tokens = [] + tokenid_locations = [] + next(b) # because we want to start iterating on the NEXT id to create pairs + j = -1 + for i,j in zip(a,b): + split = chars[i+1:j+1] + # if the entire unit is UNK, leave as UNK into the predictor + is_unk = ((toks[i+1:j+1]) == UNK_ID).all().cpu().item() + if set(split) == set([PAD]): + continue + tokenid_locations.append(j) + + if not is_unk: + tokens.append("".join(split).replace(PAD, "")) + else: + tokens.append(UNK) + + batch_tokens.append(tokens) + batch_tokenid_locations.append(tokenid_locations) + + # dynamically pad the batch tokens to size + # why max 5? our + max_size = max(max([len(i) for i in batch_tokens]), + self.args["sentence_analyzer_kernel"]) + batch_tokens_padded = [] + batch_tokens_isntpad = [] + for i in batch_tokens: + batch_tokens_padded.append(i + [PAD for _ in range(max_size-len(i))]) + batch_tokens_isntpad.append([True for _ in range(len(i))] + + [False for _ in range(max_size-len(i))]) + + ##### TODO EVERYTHING BELOW THIS LINE IS UNTESTED ##### + second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded) + + # we only add scores for slots for which we have a possible word ending + # i.e. its not padding and its also not a middle of rough score's resulting + # words + second_pass_chars_align = torch.zeros_like(sent0) + token_location_selectors = torch.tensor([[i,k] for i,j in + enumerate(batch_tokenid_locations) + for k in j]) + + second_pass_chars_align[ + token_location_selectors[:,0], + token_location_selectors[:,1] + ] = second_pass_scores[torch.tensor(batch_tokens_isntpad)] + + sent0 += second_pass_chars_align + + nonsent = F.logsigmoid(-sent0) + sent = F.logsigmoid(sent0) + if self.args['use_mwt']: pred = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2) else: diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index 254f419f92..d9842259fe 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -14,7 +14,10 @@ logger = logging.getLogger('stanza') class Trainer(BaseTrainer): - def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None): + def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, pretrain=None): + if args["sentence_second_pass"]: + assert bool(pretrain), "context-aware sentence analysis requires pretrained wordvectors; download them!" + if model_file is not None: # load everything from file self.load(model_file) @@ -24,23 +27,24 @@ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_f self.vocab = vocab self.lexicon = lexicon self.dictionary = dictionary - self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) + self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'], pretrain=pretrain) self.model = self.model.to(device) self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device) self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay']) self.feat_funcs = self.args.get('feat_funcs', None) self.lang = self.args['lang'] # language determines how token normalization is done + self.pretrain = pretrain def update(self, inputs): self.model.train() - units, labels, features, _ = inputs + units, labels, features, text = inputs device = next(self.model.parameters()).device units = units.to(device) labels = labels.to(device) features = features.to(device) - pred = self.model(units, features) + pred = self.model(units, features, text) self.optimizer.zero_grad() classes = pred.size(2) @@ -54,13 +58,13 @@ def update(self, inputs): def predict(self, inputs): self.model.eval() - units, _, features, _ = inputs + units, _, features, text = inputs device = next(self.model.parameters()).device units = units.to(device) features = features.to(device) - pred = self.model(units, features) + pred = self.model(units, features, text) return pred.data.cpu().numpy() diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index 8e78798b2c..07eac024e2 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -27,6 +27,8 @@ from stanza.models.tokenization.trainer import Trainer from stanza.models.tokenization.data import DataLoader, TokenizationDataset from stanza.models.tokenization.utils import load_mwt_dict, eval_model, output_predictions, load_lexicon, create_dictionary +from stanza.models.common import pretrain + from stanza.models import _training_logging logger = logging.getLogger('stanza') @@ -46,6 +48,13 @@ def build_argparse(): parser.add_argument('--lang', type=str, help="Language") parser.add_argument('--shorthand', type=str, help="UD treebank shorthand") + parser.add_argument('--wordvec_dir', type=str, default='extern_data/wordvec', help='Directory of word vectors.') + parser.add_argument('--wordvec_file', type=str, default=None, help='Word vectors filename.') + parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') + parser.add_argument('--pretrain_max_vocab', type=int, default=250000) + + parser.add_argument('--sentence-analyzer-kernel', type=int, default=4) + parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.") @@ -54,6 +63,7 @@ def build_argparse(): parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.") parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections") parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer") + parser.add_argument('--no-sentence-second-pass', dest='sentence_second_pass', action='store_false', help="predict the sentences together with tokens instead of after") parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers") parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") @@ -113,6 +123,17 @@ def model_file_name(args): return save_name return os.path.join(args['save_dir'], save_name) +def load_pretrain(args): + pt = None + if args['sentence_second_pass']: + pretrain_file = pretrain.find_pretrain_file(args['wordvec_pretrain_file'], args['save_dir'], args['shorthand'], args['lang']) + if os.path.exists(pretrain_file): + vec_file = None + else: + vec_file = args['wordvec_file'] if args['wordvec_file'] else utils.get_wordvec_file(args['wordvec_dir'], args['shorthand']) + pt = pretrain.Pretrain(pretrain_file, vec_file, args['pretrain_max_vocab']) + return pt + def main(args=None): args = parse_args(args=args) @@ -164,7 +185,9 @@ def train(args): args['use_mwt'] = train_batches.has_mwt() logger.info("Found {}mwts in the training data. Setting use_mwt to {}".format(("" if args['use_mwt'] else "no "), args['use_mwt'])) - trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device']) + # load pretrained vectors if needed + pretrain = load_pretrain(args) + trainer = Trainer(args=args, vocab=vocab, lexicon=lexicon, dictionary=dictionary, device=args['device'], pretrain=pretrain) if args['load_name'] is not None: load_name = os.path.join(args['save_dir'], args['load_name']) @@ -234,7 +257,8 @@ def train(args): def evaluate(args): mwt_dict = load_mwt_dict(args['mwt_json_file']) - trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device']) + pretrain = load_pretrain(args) + trainer = Trainer(model_file=args['load_name'] or args['save_name'], device=args['device'], pretrain=pretrain) loaded_args, vocab = trainer.args, trainer.vocab for k in loaded_args: diff --git a/stanza/pipeline/tokenize_processor.py b/stanza/pipeline/tokenize_processor.py index f2fc242db2..92ae8f3ad9 100644 --- a/stanza/pipeline/tokenize_processor.py +++ b/stanza/pipeline/tokenize_processor.py @@ -37,11 +37,14 @@ class TokenizeProcessor(UDProcessor): MAX_SEQ_LENGTH_DEFAULT = 1000 def _set_up_model(self, config, pipeline, device): + # get pretrained word vectors + self._pretrain = pipeline.foundation_cache.load_pretrain(config['pretrain_path']) if 'pretrain_path' in config else None + # set up trainer if config.get('pretokenized'): self._trainer = None else: - self._trainer = Trainer(model_file=config['model_path'], device=device) + self._trainer = Trainer(model_file=config['model_path'], device=device, pretrain=self.pretrain) # get and typecheck the postprocessor postprocessor = config.get('postprocessor') From 7e9388d632e2103172177a3372e5f82c6f3a13a1 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sun, 1 Sep 2024 22:00:20 -0700 Subject: [PATCH 02/17] last mile tokenizer changes for running (still slow) --- stanza/models/tokenization/trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index d9842259fe..edaf4f28e2 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -15,9 +15,7 @@ class Trainer(BaseTrainer): def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_file=None, device=None, pretrain=None): - if args["sentence_second_pass"]: - assert bool(pretrain), "context-aware sentence analysis requires pretrained wordvectors; download them!" - + self.pretrain = pretrain if model_file is not None: # load everything from file self.load(model_file) @@ -28,6 +26,10 @@ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_f self.lexicon = lexicon self.dictionary = dictionary self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'], pretrain=pretrain) + + if self.args["sentence_second_pass"]: + assert bool(pretrain), "context-aware sentence analysis requires pretrained wordvectors; download them!" + self.model = self.model.to(device) self.criterion = nn.CrossEntropyLoss(ignore_index=-1).to(device) self.optimizer = utils.get_optimizer("adam", self.model, lr=self.args['lr0'], betas=(.9, .9), weight_decay=self.args['weight_decay']) @@ -92,7 +94,7 @@ def load(self, filename): # Default to True as many currently saved models # were built with mwt layers self.args['use_mwt'] = True - self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout']) + self.model = Tokenizer(self.args, self.args['vocab_size'], self.args['emb_dim'], self.args['hidden_dim'], dropout=self.args['dropout'], feat_dropout=self.args['feat_dropout'], pretrain=self.pretrain) self.model.load_state_dict(checkpoint['model']) self.vocab = Vocab.load_state_dict(checkpoint['vocab']) self.lexicon = checkpoint['lexicon'] From 58e90b89e54c0d30d229ef419185128a5eef18c3 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sun, 1 Sep 2024 22:14:43 -0700 Subject: [PATCH 03/17] small optmizitain changes --- stanza/models/tokenization/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 1d8053b0d7..1717f7b315 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -142,7 +142,9 @@ def forward(self, x, feats, text): draft_preds = (draft_preds > 0) # we add a prefix zero # TODO inefficient / how to parallelize this? - token_locations = [[-1] + i.nonzero().squeeze(1).cpu().tolist() + front_pad = torch.tensor([-1]).to(draft_preds.device) + back_pad = torch.tensor([len(text[0])-1]).to(draft_preds.device) + token_locations = [torch.cat([front_pad, i.nonzero().squeeze(1).detach(), back_pad]) for i in draft_preds] # both: batch x seq x [variable: text token count] @@ -155,7 +157,7 @@ def forward(self, x, feats, text): # the model should put a token at the end of each sentence so this # should be less of a problem - a,b = tee(location+[len(chars)-1]) + a,b = tee(location) tokens = [] tokenid_locations = [] next(b) # because we want to start iterating on the NEXT id to create pairs @@ -163,7 +165,7 @@ def forward(self, x, feats, text): for i,j in zip(a,b): split = chars[i+1:j+1] # if the entire unit is UNK, leave as UNK into the predictor - is_unk = ((toks[i+1:j+1]) == UNK_ID).all().cpu().item() + is_unk = ((toks[i+1:j+1]) == UNK_ID).all() if set(split) == set([PAD]): continue tokenid_locations.append(j) From 1b9a9ce0fea4c74551599e2da9c7c8f78d8f5f76 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 2 Sep 2024 01:17:42 -0400 Subject: [PATCH 04/17] remove extra spaces --- stanza/models/tokenization/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 1d8053b0d7..97ccb5ea21 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -32,7 +32,7 @@ def device(self): def forward(self, x): # map the vocab to pretrain IDs - embs = self.embeddings(torch.tensor([[self.vocab[j] for j in i] for i in x], + embs = self.embeddings(torch.tensor([[self.vocab[j.strip()] for j in i] for i in x], device=self.device)) net = self.emb_proj(embs) net = self.conv(net.permute(0,2,1)).permute(0,2,1) From 23d4ab57dd9f80a56e570a12d186ea103e6124e3 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 2 Sep 2024 14:59:23 -0700 Subject: [PATCH 05/17] use mwt info --- stanza/models/tokenization/model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 627534fe35..acaca2e626 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -138,7 +138,10 @@ def forward(self, x, feats, text): if self.args["sentence_second_pass"]: # these are the draft predictions for only token-level decisinos # which we can use to slice the text - draft_preds = torch.cat([nontok, tok+nonmwt, tok+mwt], 2).argmax(dim=2) + if self.args['use_mwt']: + draft_preds = torch.cat([nontok, tok+nonmwt, tok+mwt], 2).argmax(dim=2) + else: + draft_preds = torch.cat([nontok, tok], 2).argmax(dim=2) draft_preds = (draft_preds > 0) # we add a prefix zero # TODO inefficient / how to parallelize this? From 38a3398ed46a9187d30514f8938eca3680f7a25e Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Tue, 3 Sep 2024 22:07:14 -0700 Subject: [PATCH 06/17] make the tokenizer a smidge more efficinet --- stanza/models/tokenization/model.py | 104 ++++++++++++---------------- 1 file changed, 46 insertions(+), 58 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index acaca2e626..aae96f3a00 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -17,10 +17,16 @@ def __init__(self, args, pretrain, hidden_dim, device=None): torch.from_numpy(pretrain.emb), freeze=True) self.emb_proj = nn.Linear(pretrain.emb.shape[1], hidden_dim) - self.conv = nn.Conv1d(hidden_dim, hidden_dim, - args["sentence_analyzer_kernel"], padding="same", - padding_mode="circular") + self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, + args["sentence_analyzer_kernel"], padding="same", + padding_mode="zeros") + self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, + args["sentence_analyzer_kernel"], padding="same", + padding_mode="zeros") + self.ffnn = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) @@ -32,11 +38,12 @@ def device(self): def forward(self, x): # map the vocab to pretrain IDs - embs = self.embeddings(torch.tensor([[self.vocab[j.strip()] for j in i] for i in x], + token_ids = [[self.vocab[j.strip()] for j in i] for i in x] + embs = self.embeddings(torch.tensor(token_ids, device=self.device)) net = self.emb_proj(embs) - net = self.conv(net.permute(0,2,1)).permute(0,2,1) - return self.ffnn(net) + net = self.conv2(F.relu(self.conv1(net.permute(0,2,1)))) + return self.ffnn(net.permute(0,2,1)) class Tokenizer(nn.Module): @@ -143,70 +150,51 @@ def forward(self, x, feats, text): else: draft_preds = torch.cat([nontok, tok], 2).argmax(dim=2) draft_preds = (draft_preds > 0) - # we add a prefix zero - # TODO inefficient / how to parallelize this? - front_pad = torch.tensor([-1]).to(draft_preds.device) - back_pad = torch.tensor([len(text[0])-1]).to(draft_preds.device) - token_locations = [torch.cat([front_pad, i.nonzero().squeeze(1).detach(), back_pad]) - for i in draft_preds] - - # both: batch x seq x [variable: text token count] - batch_tokens = [] # str tokens - batch_tokenid_locations = [] # id locations for the *end* of each str token - # corresponding to char token - for location,chars, toks in zip(token_locations, text, x): - # we append len(chars)-1 to append the last token which wouldn't - # necessearily have been captured by the splits; though in theory - # the model should put a token at the end of each sentence so this - # should be less of a problem - - a,b = tee(location) - tokens = [] - tokenid_locations = [] - next(b) # because we want to start iterating on the NEXT id to create pairs - j = -1 - for i,j in zip(a,b): - split = chars[i+1:j+1] - # if the entire unit is UNK, leave as UNK into the predictor - is_unk = ((toks[i+1:j+1]) == UNK_ID).all() - if set(split) == set([PAD]): - continue - tokenid_locations.append(j) - - if not is_unk: - tokens.append("".join(split).replace(PAD, "")) - else: - tokens.append(UNK) - - batch_tokens.append(tokens) - batch_tokenid_locations.append(tokenid_locations) + # these boolean indicies are *inclusive*, so predict it or not + # we need to split on the last token if we want to keep the + # final word + draft_preds[:,-1] = True + + # both: batch x [variable: text token count] + extracted_tokens = [] + partial = [] + last = 0 + last_batch = -1 + + nonzero = draft_preds.nonzero().cpu().tolist() + for i,j in nonzero: + if i != last_batch: + last_batch = i + last = 0 + if i != 0: + extracted_tokens.append(partial) + partial = [] + + substring = text[i][last:j+1] + last = j+1 + + partial.append("".join(substring)) + extracted_tokens.append(partial) # dynamically pad the batch tokens to size - # why max 5? our - max_size = max(max([len(i) for i in batch_tokens]), + # why to at least a fix size? it must be wider + # than our kernel + max_size = max(max([len(i) for i in extracted_tokens]), self.args["sentence_analyzer_kernel"]) batch_tokens_padded = [] batch_tokens_isntpad = [] - for i in batch_tokens: + for i in extracted_tokens: batch_tokens_padded.append(i + [PAD for _ in range(max_size-len(i))]) batch_tokens_isntpad.append([True for _ in range(len(i))] + [False for _ in range(max_size-len(i))]) - ##### TODO EVERYTHING BELOW THIS LINE IS UNTESTED ##### second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded) - # we only add scores for slots for which we have a possible word ending - # i.e. its not padding and its also not a middle of rough score's resulting - # words + # # we only add scores for slots for which we have a possible word ending + # # i.e. its not padding and its also not a middle of rough score's resulting + # # words second_pass_chars_align = torch.zeros_like(sent0) - token_location_selectors = torch.tensor([[i,k] for i,j in - enumerate(batch_tokenid_locations) - for k in j]) - - second_pass_chars_align[ - token_location_selectors[:,0], - token_location_selectors[:,1] - ] = second_pass_scores[torch.tensor(batch_tokens_isntpad)] + second_pass_chars_align[draft_preds] = second_pass_scores[torch.tensor(batch_tokens_isntpad)] sent0 += second_pass_chars_align From e4b691a2692c9fb76ac3917ee43b6b09c451902b Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sat, 7 Sep 2024 16:20:32 -0400 Subject: [PATCH 07/17] now use an LSTM --- stanza/models/tokenization/model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index acaca2e626..1e37c61e70 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -17,11 +17,15 @@ def __init__(self, args, pretrain, hidden_dim, device=None): torch.from_numpy(pretrain.emb), freeze=True) self.emb_proj = nn.Linear(pretrain.emb.shape[1], hidden_dim) - self.conv = nn.Conv1d(hidden_dim, hidden_dim, - args["sentence_analyzer_kernel"], padding="same", - padding_mode="circular") + self.lstm = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True, + batch_first=True, num_layers=args['rnn_layers']) + + # standard up and down projection a la transformers self.ffnn = nn.Sequential( - nn.Linear(hidden_dim, hidden_dim), + nn.Linear(hidden_dim*2, hidden_dim*4), + nn.ReLU(), + nn.Linear(hidden_dim*4, hidden_dim), + nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1) ) @@ -35,7 +39,7 @@ def forward(self, x): embs = self.embeddings(torch.tensor([[self.vocab[j.strip()] for j in i] for i in x], device=self.device)) net = self.emb_proj(embs) - net = self.conv(net.permute(0,2,1)).permute(0,2,1) + net = self.lstm(net)[0] return self.ffnn(net) From ca1423b1eaf666ff7fa4aa7621f47e2a684360aa Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sun, 8 Sep 2024 21:18:35 -0700 Subject: [PATCH 08/17] various tokenizer changes for a smaller model --- stanza/models/tokenization/model.py | 27 ++++++++++++--------------- stanza/models/tokenizer.py | 4 ++-- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index bc4093c7c7..b759cf8a09 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -20,26 +20,16 @@ def __init__(self, args, pretrain, hidden_dim, device=None): self.lstm = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True, batch_first=True, num_layers=args['rnn_layers']) - self.ffnn = nn.Sequential( - nn.Linear(hidden_dim*2, hidden_dim*4), - nn.ReLU(), - nn.Linear(hidden_dim*4, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, hidden_dim), - nn.ReLU(), - nn.Linear(hidden_dim, 1) - ) + self.ffnn = nn.Linear(hidden_dim*2, 1, bias=False) @property def device(self): return next(self.parameters()).device - def forward(self, x): + def forward(self, x, s0): # map the vocab to pretrain IDs token_ids = [[self.vocab[j.strip()] for j in i] for i in x] - embs = self.embeddings(torch.tensor(token_ids, - device=self.device)) + embs = self.embeddings(torch.tensor(token_ids, device=self.device)) net = self.emb_proj(embs) net = self.lstm(net)[0] return self.ffnn(net) @@ -82,6 +72,8 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pre if args['sentence_second_pass']: self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim) + self.sent_2nd_smoother = nn.Conv1d(1, 1, args["sentence_analyzer_kernel"], padding="same", padding_mode="replicate") + self.sent_2nd_mix = nn.Parameter(torch.full((1,), 0.0), requires_grad=True) self.dropout = nn.Dropout(dropout) self.dropout_feat = nn.Dropout(feat_dropout) @@ -187,7 +179,7 @@ def forward(self, x, feats, text): batch_tokens_isntpad.append([True for _ in range(len(i))] + [False for _ in range(max_size-len(i))]) - second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded) + second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded, sent0[draft_preds]) # # we only add scores for slots for which we have a possible word ending # # i.e. its not padding and its also not a middle of rough score's resulting @@ -195,7 +187,12 @@ def forward(self, x, feats, text): second_pass_chars_align = torch.zeros_like(sent0) second_pass_chars_align[draft_preds] = second_pass_scores[torch.tensor(batch_tokens_isntpad)] - sent0 += second_pass_chars_align + mix = F.sigmoid(self.sent_2nd_mix) + smoothed = self.sent_2nd_smoother( + second_pass_chars_align.permute(0,2,1) + ).permute(0,2,1) + + sent0 = (1-mix)*sent0 + mix*smoothed nonsent = F.logsigmoid(-sent0) sent = F.logsigmoid(sent0) diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index 07eac024e2..048e4b5c96 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -53,7 +53,7 @@ def build_argparse(): parser.add_argument('--wordvec_pretrain_file', type=str, default=None, help='Exact name of the pretrain file to read') parser.add_argument('--pretrain_max_vocab', type=int, default=250000) - parser.add_argument('--sentence-analyzer-kernel', type=int, default=4) + parser.add_argument('--sentence_analyzer_kernel', type=int, default=4) parser.add_argument('--mode', default='train', choices=['train', 'predict']) parser.add_argument('--skip_newline', action='store_true', help="Whether to skip newline characters in input. Particularly useful for languages like Chinese.") @@ -63,7 +63,7 @@ def build_argparse(): parser.add_argument('--conv_filters', type=str, default="1,9", help="Configuration of conv filters. ,, separates layers and , separates filter sizes in the same layer.") parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections") parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer") - parser.add_argument('--no-sentence-second-pass', dest='sentence_second_pass', action='store_false', help="predict the sentences together with tokens instead of after") + parser.add_argument('--no_sentence_second_pass', dest='sentence_second_pass', action='store_false', help="predict the sentences together with tokens instead of after") parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers") parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") From a8934d53ce50912984fbd1bd6e260ff09837010a Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sun, 8 Sep 2024 22:00:07 -0700 Subject: [PATCH 09/17] some edits to second pass classifier --- stanza/models/tokenization/model.py | 22 +++++++++++++++------- stanza/models/tokenization/trainer.py | 15 +++++++++++++-- stanza/models/tokenizer.py | 5 +++++ 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index b759cf8a09..6e10ccdd07 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -17,6 +17,7 @@ def __init__(self, args, pretrain, hidden_dim, device=None): torch.from_numpy(pretrain.emb), freeze=True) self.emb_proj = nn.Linear(pretrain.emb.shape[1], hidden_dim) + self.tok_proj = nn.Linear(hidden_dim*2, hidden_dim) self.lstm = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True, batch_first=True, num_layers=args['rnn_layers']) @@ -26,12 +27,13 @@ def __init__(self, args, pretrain, hidden_dim, device=None): def device(self): return next(self.parameters()).device - def forward(self, x, s0): + def forward(self, x, inp0, pad_mask): # map the vocab to pretrain IDs token_ids = [[self.vocab[j.strip()] for j in i] for i in x] embs = self.embeddings(torch.tensor(token_ids, device=self.device)) - net = self.emb_proj(embs) + net = self.emb_proj(embs) net = self.lstm(net)[0] + net[pad_mask] += inp0 return self.ffnn(net) @@ -73,14 +75,16 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pre if args['sentence_second_pass']: self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim) self.sent_2nd_smoother = nn.Conv1d(1, 1, args["sentence_analyzer_kernel"], padding="same", padding_mode="replicate") - self.sent_2nd_mix = nn.Parameter(torch.full((1,), 0.0), requires_grad=True) + # initially, don't use 2nd pass that much (this is near 0, meaning it will pretty much + # not be mixed in + self.sent_2nd_mix = nn.Parameter(torch.full((1,), -5.0), requires_grad=True) self.dropout = nn.Dropout(dropout) self.dropout_feat = nn.Dropout(feat_dropout) self.toknoise = nn.Dropout(self.args['tok_noise']) - def forward(self, x, feats, text): + def forward(self, x, feats, text, detach_2nd_pass=False): emb = self.embeddings(x) emb = self.dropout(emb) feats = self.dropout_feat(feats) @@ -179,20 +183,24 @@ def forward(self, x, feats, text): batch_tokens_isntpad.append([True for _ in range(len(i))] + [False for _ in range(max_size-len(i))]) - second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded, sent0[draft_preds]) + pad_mask = torch.tensor(batch_tokens_isntpad) + second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded, inp[draft_preds], pad_mask) # # we only add scores for slots for which we have a possible word ending # # i.e. its not padding and its also not a middle of rough score's resulting # # words second_pass_chars_align = torch.zeros_like(sent0) - second_pass_chars_align[draft_preds] = second_pass_scores[torch.tensor(batch_tokens_isntpad)] + second_pass_chars_align[draft_preds] = second_pass_scores[pad_mask] mix = F.sigmoid(self.sent_2nd_mix) smoothed = self.sent_2nd_smoother( second_pass_chars_align.permute(0,2,1) ).permute(0,2,1) - sent0 = (1-mix)*sent0 + mix*smoothed + if detach_2nd_pass: + sent0 = (1-mix.detach())*sent0 + mix.detach()*smoothed.detach() + else: + sent0 = (1-mix)*sent0 + mix*smoothed nonsent = F.logsigmoid(-sent0) sent = F.logsigmoid(sent0) diff --git a/stanza/models/tokenization/trainer.py b/stanza/models/tokenization/trainer.py index edaf4f28e2..6d892a5a5c 100644 --- a/stanza/models/tokenization/trainer.py +++ b/stanza/models/tokenization/trainer.py @@ -36,8 +36,15 @@ def __init__(self, args=None, vocab=None, lexicon=None, dictionary=None, model_f self.feat_funcs = self.args.get('feat_funcs', None) self.lang = self.args['lang'] # language determines how token normalization is done self.pretrain = pretrain + self.global_step_counter_ = 0 + self.train_2nd_pass = False + + @property + def steps(self): + return self.global_step_counter_ def update(self, inputs): + self.global_step_counter_ += 1 self.model.train() units, labels, features, text = inputs @@ -46,7 +53,8 @@ def update(self, inputs): labels = labels.to(device) features = features.to(device) - pred = self.model(units, features, text) + # we detach 2nd pass if we are not training second pass + pred = self.model(units, features, text, not self.train_2nd_pass) self.optimizer.zero_grad() classes = pred.size(2) @@ -75,7 +83,8 @@ def save(self, filename): 'model': self.model.state_dict() if self.model is not None else None, 'vocab': self.vocab.state_dict(), 'lexicon': self.lexicon, - 'config': self.args + 'config': self.args, + 'steps': self.global_step_counter_ } try: torch.save(params, filename, _use_new_zipfile_serialization=False) @@ -99,6 +108,8 @@ def load(self, filename): self.vocab = Vocab.load_state_dict(checkpoint['vocab']) self.lexicon = checkpoint['lexicon'] + self.global_step_counter_ = checkpoint.get("steps", 0) + if self.lexicon is not None: self.dictionary = create_dictionary(self.lexicon) else: diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index 048e4b5c96..e4a546c357 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -64,6 +64,7 @@ def build_argparse(): parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections") parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer") parser.add_argument('--no_sentence_second_pass', dest='sentence_second_pass', action='store_false', help="predict the sentences together with tokens instead of after") + parser.add_argument('--second_pass_start_steps', type=int, help="when (how many steps) to start training the second pass classifier", default=256) parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers") parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") @@ -214,6 +215,10 @@ def train(args): batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout']) loss = trainer.update(batch) + + if trainer.steps > args["second_pass_start_steps"]: + trainer.train_2nd_pass = True + if step % args['report_steps'] == 0: logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss)) if args['wandb']: From b687c6fe9b41e601be5eb1dd62b7adec2774fd26 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Sun, 8 Sep 2024 22:03:04 -0700 Subject: [PATCH 10/17] bump second pass start steps --- stanza/models/tokenizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index e4a546c357..c95c007113 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -64,7 +64,7 @@ def build_argparse(): parser.add_argument('--no-residual', dest='residual', action='store_false', help="Add linear residual connections") parser.add_argument('--no-hierarchical', dest='hierarchical', action='store_false', help="\"Hierarchical\" RNN tokenizer") parser.add_argument('--no_sentence_second_pass', dest='sentence_second_pass', action='store_false', help="predict the sentences together with tokens instead of after") - parser.add_argument('--second_pass_start_steps', type=int, help="when (how many steps) to start training the second pass classifier", default=256) + parser.add_argument('--second_pass_start_steps', type=int, help="when (how many steps) to start training the second pass classifier", default=5000) parser.add_argument('--hier_invtemp', type=float, default=0.5, help="Inverse temperature used in propagating tokenization predictions between RNN layers") parser.add_argument('--input_dropout', action='store_true', help="Dropout input embeddings as well") parser.add_argument('--conv_res', type=str, default=None, help="Convolutional residual layers for the RNN") From 49ef2e6f8a711a65dec0f972e73ddaec58178fc3 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 9 Sep 2024 10:30:51 -0700 Subject: [PATCH 11/17] fixing some ordering problems --- stanza/models/tokenization/model.py | 5 ++--- stanza/models/tokenizer.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 6e10ccdd07..8c8e033615 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -27,13 +27,12 @@ def __init__(self, args, pretrain, hidden_dim, device=None): def device(self): return next(self.parameters()).device - def forward(self, x, inp0, pad_mask): + def forward(self, x): # map the vocab to pretrain IDs token_ids = [[self.vocab[j.strip()] for j in i] for i in x] embs = self.embeddings(torch.tensor(token_ids, device=self.device)) net = self.emb_proj(embs) net = self.lstm(net)[0] - net[pad_mask] += inp0 return self.ffnn(net) @@ -184,7 +183,7 @@ def forward(self, x, feats, text, detach_2nd_pass=False): [False for _ in range(max_size-len(i))]) pad_mask = torch.tensor(batch_tokens_isntpad) - second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded, inp[draft_preds], pad_mask) + second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded) # # we only add scores for slots for which we have a possible word ending # # i.e. its not padding and its also not a middle of rough score's resulting diff --git a/stanza/models/tokenizer.py b/stanza/models/tokenizer.py index c95c007113..badf718439 100644 --- a/stanza/models/tokenizer.py +++ b/stanza/models/tokenizer.py @@ -214,11 +214,11 @@ def train(args): for step in range(1, steps+1): batch = train_batches.next(unit_dropout=args['unit_dropout'], feat_unit_dropout = args['feat_unit_dropout']) - loss = trainer.update(batch) - if trainer.steps > args["second_pass_start_steps"]: trainer.train_2nd_pass = True + loss = trainer.update(batch) + if step % args['report_steps'] == 0: logger.info("Step {:6d}/{:6d} Loss: {:.3f}".format(step, steps, loss)) if args['wandb']: From 58345bae46077e408b15e27dc7f7bdf3cf1c0d7e Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 9 Sep 2024 11:55:31 -0700 Subject: [PATCH 12/17] split on even using draft positions --- stanza/models/tokenization/model.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 8c8e033615..15cff86d6b 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -133,16 +133,21 @@ def forward(self, x, feats, text, detach_2nd_pass=False): nonmwt = F.logsigmoid(-mwt0) mwt = F.logsigmoid(mwt0) + nonsent = F.logsigmoid(-sent0) + sent = F.logsigmoid(sent0) + # use the rough predictions from the char tokenizer to create word tokens # then use those word tokens + contextual/fixed word embeddings to refine # sentence predictions + if self.args["sentence_second_pass"]: # these are the draft predictions for only token-level decisinos # which we can use to slice the text if self.args['use_mwt']: - draft_preds = torch.cat([nontok, tok+nonmwt, tok+mwt], 2).argmax(dim=2) + draft_preds = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2).argmax(dim=2) else: - draft_preds = torch.cat([nontok, tok], 2).argmax(dim=2) + draft_preds = torch.cat([nontok, tok+nonsent, tok+sent], 2).argmax(dim=2) + draft_preds = (draft_preds > 0) # these boolean indicies are *inclusive*, so predict it or not # we need to split on the last token if we want to keep the @@ -191,11 +196,13 @@ def forward(self, x, feats, text, detach_2nd_pass=False): second_pass_chars_align = torch.zeros_like(sent0) second_pass_chars_align[draft_preds] = second_pass_scores[pad_mask] - mix = F.sigmoid(self.sent_2nd_mix) smoothed = self.sent_2nd_smoother( second_pass_chars_align.permute(0,2,1) ).permute(0,2,1) + mix = F.sigmoid(self.sent_2nd_mix) + + # update sent0 value if detach_2nd_pass: sent0 = (1-mix.detach())*sent0 + mix.detach()*smoothed.detach() else: From 337405cd38348db1b14f77933d7a4a536b2ef3fb Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 9 Sep 2024 12:09:09 -0700 Subject: [PATCH 13/17] initially, an identity transform --- stanza/models/tokenization/model.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 8c8e033615..d80cb7bfc2 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -21,7 +21,10 @@ def __init__(self, args, pretrain, hidden_dim, device=None): self.lstm = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True, batch_first=True, num_layers=args['rnn_layers']) - self.ffnn = nn.Linear(hidden_dim*2, 1, bias=False) + # this is zero-initialized to make the second pass initially the id + # function; and then it could change only as needed but would otherwise + # be zero + self.final_proj = nn.Parameter(torch.zeros(hidden_dim*2, 1), requires_grad=True) @property def device(self): @@ -33,7 +36,7 @@ def forward(self, x): embs = self.embeddings(torch.tensor(token_ids, device=self.device)) net = self.emb_proj(embs) net = self.lstm(net)[0] - return self.ffnn(net) + return self.final_proj @ net class Tokenizer(nn.Module): @@ -73,7 +76,6 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pre if args['sentence_second_pass']: self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim) - self.sent_2nd_smoother = nn.Conv1d(1, 1, args["sentence_analyzer_kernel"], padding="same", padding_mode="replicate") # initially, don't use 2nd pass that much (this is near 0, meaning it will pretty much # not be mixed in self.sent_2nd_mix = nn.Parameter(torch.full((1,), -5.0), requires_grad=True) @@ -192,14 +194,10 @@ def forward(self, x, feats, text, detach_2nd_pass=False): second_pass_chars_align[draft_preds] = second_pass_scores[pad_mask] mix = F.sigmoid(self.sent_2nd_mix) - smoothed = self.sent_2nd_smoother( - second_pass_chars_align.permute(0,2,1) - ).permute(0,2,1) - if detach_2nd_pass: - sent0 = (1-mix.detach())*sent0 + mix.detach()*smoothed.detach() + sent0 = (1-mix.detach())*sent0 + mix.detach()*second_pass_chars_align.detach() else: - sent0 = (1-mix)*sent0 + mix*smoothed + sent0 = (1-mix)*sent0 + mix*second_pass_chars_align nonsent = F.logsigmoid(-sent0) sent = F.logsigmoid(sent0) From e7ae33a7a0098d96a8b181dff551335083228810 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 9 Sep 2024 12:29:13 -0700 Subject: [PATCH 14/17] whopps, it was applied in the othre direction --- stanza/models/tokenization/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 8a4a147573..020703735e 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -36,7 +36,7 @@ def forward(self, x): embs = self.embeddings(torch.tensor(token_ids, device=self.device)) net = self.emb_proj(embs) net = self.lstm(net)[0] - return self.final_proj @ net + return net @ self.final_proj class Tokenizer(nn.Module): From d6e427695cb9d9062dad631b2a327730dafa0f61 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 9 Sep 2024 20:04:37 -0700 Subject: [PATCH 15/17] add some dropout --- stanza/models/tokenization/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 020703735e..ed2949631e 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -6,7 +6,7 @@ from stanza.models.common.seq2seq_constant import PAD, UNK, UNK_ID class SentenceAnalyzer(nn.Module): - def __init__(self, args, pretrain, hidden_dim, device=None): + def __init__(self, args, pretrain, hidden_dim, device=None, dropout=0): super().__init__() assert pretrain != None, "2nd pass sentence anayzer is missing pretrain word vectors" @@ -21,6 +21,8 @@ def __init__(self, args, pretrain, hidden_dim, device=None): self.lstm = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True, batch_first=True, num_layers=args['rnn_layers']) + self.dropout = nn.Dropout(dropout) + # this is zero-initialized to make the second pass initially the id # function; and then it could change only as needed but would otherwise # be zero @@ -34,7 +36,7 @@ def forward(self, x): # map the vocab to pretrain IDs token_ids = [[self.vocab[j.strip()] for j in i] for i in x] embs = self.embeddings(torch.tensor(token_ids, device=self.device)) - net = self.emb_proj(embs) + net = self.dropout(self.emb_proj(embs)) net = self.lstm(net)[0] return net @ self.final_proj @@ -75,7 +77,7 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pre self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) if args['sentence_second_pass']: - self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim) + self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim, dropout) # initially, don't use 2nd pass that much (this is near 0, meaning it will pretty much # not be mixed in self.sent_2nd_mix = nn.Parameter(torch.full((1,), -5.0), requires_grad=True) From 7545f39ff3d1f85e22b8943ea8d7c0e0f482a097 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Mon, 9 Sep 2024 22:00:08 -0700 Subject: [PATCH 16/17] include character information in the tokenizer model --- stanza/models/tokenization/model.py | 51 ++++++++++++++++------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index ed2949631e..587cc9d60d 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -6,7 +6,7 @@ from stanza.models.common.seq2seq_constant import PAD, UNK, UNK_ID class SentenceAnalyzer(nn.Module): - def __init__(self, args, pretrain, hidden_dim, device=None, dropout=0): + def __init__(self, args, pretrain, hidden_dim, device=None): super().__init__() assert pretrain != None, "2nd pass sentence anayzer is missing pretrain word vectors" @@ -17,11 +17,10 @@ def __init__(self, args, pretrain, hidden_dim, device=None, dropout=0): torch.from_numpy(pretrain.emb), freeze=True) self.emb_proj = nn.Linear(pretrain.emb.shape[1], hidden_dim) - self.tok_proj = nn.Linear(hidden_dim*2, hidden_dim) - self.lstm = nn.LSTM(hidden_dim, hidden_dim, bidirectional=True, + self.lstm = nn.LSTM(hidden_dim*3, hidden_dim, bidirectional=True, batch_first=True, num_layers=args['rnn_layers']) - self.dropout = nn.Dropout(dropout) + self.hidden = hidden_dim # this is zero-initialized to make the second pass initially the id # function; and then it could change only as needed but would otherwise @@ -32,12 +31,22 @@ def __init__(self, args, pretrain, hidden_dim, device=None, dropout=0): def device(self): return next(self.parameters()).device - def forward(self, x): + def forward(self, words, tok_embeds, word_tok_mapping, padding_mask): # map the vocab to pretrain IDs - token_ids = [[self.vocab[j.strip()] for j in i] for i in x] + token_ids = [[self.vocab[j.strip()] for j in i] for i in words] embs = self.embeddings(torch.tensor(token_ids, device=self.device)) - net = self.dropout(self.emb_proj(embs)) - net = self.lstm(net)[0] + net = self.emb_proj(embs) + # we want to now concatenate token embeddings with the word embeddings + final_inp = torch.zeros(tok_embeds.size(0), tok_embeds.size(1), + self.hidden*3).to(tok_embeds.device) + final_inp[:,:,:tok_embeds.size(2)] = tok_embeds + # because we want to set the values for that's relavent to the word token embedding + # to True, but everything else to False (including the slots for tok_embs) + final_inp_second_idx = word_tok_mapping.unsqueeze(-1).repeat(1,1,self.hidden*3) + final_inp_second_idx[:,:,:tok_embeds.size(2)] = False + final_inp[final_inp_second_idx] = net[padding_mask].view(-1) + + net = self.lstm(final_inp)[0] return net @ self.final_proj @@ -77,7 +86,7 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pre self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) if args['sentence_second_pass']: - self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim, dropout) + self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim) # initially, don't use 2nd pass that much (this is near 0, meaning it will pretty much # not be mixed in self.sent_2nd_mix = nn.Parameter(torch.full((1,), -5.0), requires_grad=True) @@ -148,15 +157,15 @@ def forward(self, x, feats, text, detach_2nd_pass=False): # these are the draft predictions for only token-level decisinos # which we can use to slice the text if self.args['use_mwt']: - draft_preds = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2).argmax(dim=2) + draft_pred_locs = torch.cat([nontok, tok+nonsent+nonmwt, tok+sent+nonmwt, tok+nonsent+mwt, tok+sent+mwt], 2).argmax(dim=2) else: - draft_preds = torch.cat([nontok, tok+nonsent, tok+sent], 2).argmax(dim=2) + draft_pred_locs = torch.cat([nontok, tok+nonsent, tok+sent], 2).argmax(dim=2) - draft_preds = (draft_preds > 0) + draft_pred_locs = (draft_pred_locs > 0) # these boolean indicies are *inclusive*, so predict it or not # we need to split on the last token if we want to keep the # final word - draft_preds[:,-1] = True + draft_pred_locs[:,-1] = True # both: batch x [variable: text token count] extracted_tokens = [] @@ -164,7 +173,7 @@ def forward(self, x, feats, text, detach_2nd_pass=False): last = 0 last_batch = -1 - nonzero = draft_preds.nonzero().cpu().tolist() + nonzero = draft_pred_locs.nonzero().cpu().tolist() for i,j in nonzero: if i != last_batch: last_batch = i @@ -190,23 +199,19 @@ def forward(self, x, feats, text, detach_2nd_pass=False): batch_tokens_padded.append(i + [PAD for _ in range(max_size-len(i))]) batch_tokens_isntpad.append([True for _ in range(len(i))] + [False for _ in range(max_size-len(i))]) - pad_mask = torch.tensor(batch_tokens_isntpad) - second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded) - # # we only add scores for slots for which we have a possible word ending - # # i.e. its not padding and its also not a middle of rough score's resulting - # # words - second_pass_chars_align = torch.zeros_like(sent0) - second_pass_chars_align[draft_preds] = second_pass_scores[pad_mask] + + # pass the aligned result to the second pass classifier + second_pass_scores = self.sent_2nd_pass_clf(batch_tokens_padded, inp, draft_pred_locs, pad_mask) mix = F.sigmoid(self.sent_2nd_mix) # update sent0 value if detach_2nd_pass: - sent0 = (1-mix.detach())*sent0 + mix.detach()*second_pass_chars_align.detach() + sent0 = (1-mix.detach())*sent0 + mix.detach()*second_pass_scores.detach() else: - sent0 = (1-mix)*sent0 + mix*second_pass_chars_align + sent0 = (1-mix)*sent0 + mix*second_pass_scores nonsent = F.logsigmoid(-sent0) sent = F.logsigmoid(sent0) From 6fe35de72b56cbff318e41b15169961315e61a49 Mon Sep 17 00:00:00 2001 From: Houjun Liu Date: Wed, 11 Sep 2024 22:08:11 -0700 Subject: [PATCH 17/17] add a tiny bit of dropout --- stanza/models/tokenization/model.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/stanza/models/tokenization/model.py b/stanza/models/tokenization/model.py index 587cc9d60d..8ce08775db 100644 --- a/stanza/models/tokenization/model.py +++ b/stanza/models/tokenization/model.py @@ -6,7 +6,7 @@ from stanza.models.common.seq2seq_constant import PAD, UNK, UNK_ID class SentenceAnalyzer(nn.Module): - def __init__(self, args, pretrain, hidden_dim, device=None): + def __init__(self, args, pretrain, hidden_dim, device=None, dropout=0): super().__init__() assert pretrain != None, "2nd pass sentence anayzer is missing pretrain word vectors" @@ -20,6 +20,8 @@ def __init__(self, args, pretrain, hidden_dim, device=None): self.lstm = nn.LSTM(hidden_dim*3, hidden_dim, bidirectional=True, batch_first=True, num_layers=args['rnn_layers']) + self.dropout = nn.Dropout(dropout) + self.hidden = hidden_dim # this is zero-initialized to make the second pass initially the id @@ -46,7 +48,7 @@ def forward(self, words, tok_embeds, word_tok_mapping, padding_mask): final_inp_second_idx[:,:,:tok_embeds.size(2)] = False final_inp[final_inp_second_idx] = net[padding_mask].view(-1) - net = self.lstm(final_inp)[0] + net = self.lstm(self.dropout(final_inp))[0] return net @ self.final_proj @@ -86,7 +88,7 @@ def __init__(self, args, nchars, emb_dim, hidden_dim, dropout, feat_dropout, pre self.mwt_clf2 = nn.Linear(hidden_dim * 2, 1, bias=False) if args['sentence_second_pass']: - self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim) + self.sent_2nd_pass_clf = SentenceAnalyzer(args, pretrain, hidden_dim, dropout) # initially, don't use 2nd pass that much (this is near 0, meaning it will pretty much # not be mixed in self.sent_2nd_mix = nn.Parameter(torch.full((1,), -5.0), requires_grad=True)