From 94372c0fdb7b6db5d84997d576977d4c21fa2e10 Mon Sep 17 00:00:00 2001 From: John Bauer Date: Fri, 17 Jan 2025 00:09:46 -0800 Subject: [PATCH] Scale down the contrastive loss if a final epoch is set --- stanza/models/constituency/parser_training.py | 8 ++++++-- stanza/models/constituency_parser.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/stanza/models/constituency/parser_training.py b/stanza/models/constituency/parser_training.py index bddc40e07..d72d48387 100644 --- a/stanza/models/constituency/parser_training.py +++ b/stanza/models/constituency/parser_training.py @@ -593,7 +593,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te """ contrastive_loss = 0.0 contrastive_trees_used = 0 - if epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None: + if epoch <= args['contrastive_final_epoch'] and epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None: reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True, keep_constituents=True) gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False) @@ -653,7 +653,11 @@ def contrast_trees(reparsed, gold): mse = torch.stack([torch.dot(x.squeeze(0), y.squeeze(0)) for x, y in zip(reparsed_negatives, gold_negatives)]) device = next(model.parameters()).device target = torch.zeros(mse.shape[0]).to(device) - contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(mse, target) + current_contrastive_lr = args['contrastive_learning_rate'] + if args['contrastive_final_epoch'] != float('inf'): + current_contrastive_lr = current_contrastive_lr * (args['contrastive_final_epoch'] - epoch + 1) / args['contrastive_final_epoch'] + tlogger.info("Current contrastive learning rate: %f", current_contrastive_lr) + contrastive_loss = current_contrastive_lr * contrastive_loss_function(mse, target) contrastive_trees_used += len(reparsed_negatives) # now we add the state to the trees in the batch diff --git a/stanza/models/constituency_parser.py b/stanza/models/constituency_parser.py index 79d8024b7..e3dc493c4 100644 --- a/stanza/models/constituency_parser.py +++ b/stanza/models/constituency_parser.py @@ -555,6 +555,7 @@ def build_argparse(): parser.add_argument('--contrastive_initial_epoch', default=1, type=int, help='When to start contrastive learning') parser.add_argument('--contrastive_learning_rate', default=0.0, type=float, help='Multiplicative factor for constrastive learning') + parser.add_argument('--contrastive_final_epoch', default=float('inf'), type=int, help='When to stop contrastive learning. Loss will decay to 0') parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping') parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')