Skip to content

Commit

Permalink
Scale down the contrastive loss if a final epoch is set
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Jan 17, 2025
1 parent df55f3a commit 94372c0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
8 changes: 6 additions & 2 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
"""
contrastive_loss = 0.0
contrastive_trees_used = 0
if epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None:
if epoch <= args['contrastive_final_epoch'] and epoch >= args['contrastive_initial_epoch'] and contrastive_loss_function is not None:
reparsed_results = model.parse_sentences(iter([x.tree for x in training_batch]), model.build_batch_from_trees, len(training_batch), model.predict, keep_state=True, keep_constituents=True)
gold_results = model.analyze_trees([x.tree for x in training_batch], keep_constituents=True, keep_scores=False)

Expand Down Expand Up @@ -653,7 +653,11 @@ def contrast_trees(reparsed, gold):
mse = torch.stack([torch.dot(x.squeeze(0), y.squeeze(0)) for x, y in zip(reparsed_negatives, gold_negatives)])
device = next(model.parameters()).device
target = torch.zeros(mse.shape[0]).to(device)
contrastive_loss = args['contrastive_learning_rate'] * contrastive_loss_function(mse, target)
current_contrastive_lr = args['contrastive_learning_rate']
if args['contrastive_final_epoch'] != float('inf'):
current_contrastive_lr = current_contrastive_lr * (args['contrastive_final_epoch'] - epoch + 1) / args['contrastive_final_epoch']
tlogger.info("Current contrastive learning rate: %f", current_contrastive_lr)
contrastive_loss = current_contrastive_lr * contrastive_loss_function(mse, target)
contrastive_trees_used += len(reparsed_negatives)

# now we add the state to the trees in the batch
Expand Down
1 change: 1 addition & 0 deletions stanza/models/constituency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def build_argparse():

parser.add_argument('--contrastive_initial_epoch', default=1, type=int, help='When to start contrastive learning')
parser.add_argument('--contrastive_learning_rate', default=0.0, type=float, help='Multiplicative factor for constrastive learning')
parser.add_argument('--contrastive_final_epoch', default=float('inf'), type=int, help='When to stop contrastive learning. Loss will decay to 0')

parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')
Expand Down

0 comments on commit 94372c0

Please sign in to comment.