Skip to content

Commit

Permalink
Add a skeleton for the orthogonal loss
Browse files Browse the repository at this point in the history
  • Loading branch information
AngledLuffa committed Dec 4, 2024
1 parent 449feae commit a87da7d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
22 changes: 17 additions & 5 deletions stanza/models/constituency/parser_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@

TrainItem = namedtuple("TrainItem", ['tree', 'gold_sequence', 'preterminals'])

class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
class EpochStats(namedtuple("EpochStats", ['epoch_loss', 'orthogonal_loss', 'transitions_correct', 'transitions_incorrect', 'repairs_used', 'fake_transitions_used', 'nans'])):
def __add__(self, other):
transitions_correct = self.transitions_correct + other.transitions_correct
transitions_incorrect = self.transitions_incorrect + other.transitions_incorrect
repairs_used = self.repairs_used + other.repairs_used
fake_transitions_used = self.fake_transitions_used + other.fake_transitions_used
epoch_loss = self.epoch_loss + other.epoch_loss
orthogonal_loss = self.orthogonal_loss + other.orthogonal_loss
nans = self.nans + other.nans
return EpochStats(epoch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(epoch_loss, orthogonal_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def evaluate(args, model_file, retag_pipeline):
"""
Expand Down Expand Up @@ -429,10 +430,14 @@ def iterate_training(args, trainer, train_trees, train_sequences, transitions, d
"Epoch %d finished" % trainer.epochs_trained,
"Transitions correct: %s" % epoch_stats.transitions_correct,
"Transitions incorrect: %s" % epoch_stats.transitions_incorrect,
]
if args['orthogonal_learning_rate'] > 0.0 and args['orthogonal_initial_epoch'] <= trainer.epochs_trained:
stats_log_lines.append("Orthogonal loss for epoch: %.5f" % epoch_stats.orthogonal_loss)
stats_log_lines.extend([
"Total loss for epoch: %.5f" % epoch_stats.epoch_loss,
"Dev score (%5d): %8f" % (trainer.epochs_trained, f1),
"Best dev score (%5d): %8f" % (trainer.best_epoch, trainer.best_f1)
]
])
tlogger.info("\n ".join(stats_log_lines))

old_lr = trainer.optimizer.param_groups[0]['lr']
Expand Down Expand Up @@ -525,7 +530,7 @@ def train_model_one_epoch(epoch, trainer, transition_tensors, process_outputs, m

optimizer = trainer.optimizer

epoch_stats = EpochStats(0.0, Counter(), Counter(), Counter(), 0, 0)
epoch_stats = EpochStats(0.0, 0.0, Counter(), Counter(), Counter(), 0, 0)

for batch_idx, interval_start in enumerate(tqdm(interval_starts, postfix="Epoch %d" % epoch)):
batch = epoch_data[interval_start:interval_start+args['train_batch_size']]
Expand Down Expand Up @@ -651,8 +656,12 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
errors = torch.cat(all_errors)
answers = torch.cat(all_answers)

# TODO
orthogonal_loss = 0.0

errors = process_outputs(errors)
tree_loss = model_loss_function(errors, answers)
tree_loss += orthogonal_loss
tree_loss.backward()
if args['watch_regex']:
matched = False
Expand All @@ -670,13 +679,16 @@ def train_model_one_batch(epoch, batch_idx, model, training_batch, transition_te
if not matched:
tlogger.info(" (none found!)")
if torch.any(torch.isnan(tree_loss)):
orthogonal_loss = 0.0
batch_loss = 0.0
nans = 1
else:
batch_loss = tree_loss.item()
if not isinstance(orthogonal_loss, float):
orthogonal_loss = orthogonal_loss.item()
nans = 0

return EpochStats(batch_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)
return EpochStats(batch_loss, orthogonal_loss, transitions_correct, transitions_incorrect, repairs_used, fake_transitions_used, nans)

def run_dev_set(model, retagged_trees, original_trees, args, evaluator=None):
"""
Expand Down
3 changes: 3 additions & 0 deletions stanza/models/constituency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,9 @@ def build_argparse():
parser.add_argument('--grad_clipping', default=None, type=float, help='Clip abs(grad) to this amount. Use --no_grad_clipping to turn off grad clipping')
parser.add_argument('--no_grad_clipping', action='store_const', const=None, dest='grad_clipping', help='Use --no_grad_clipping to turn off grad clipping')

parser.add_argument('--orthogonal_initial_epoch', default=1, type=int, help='When to start using the orthogonal loss')
parser.add_argument('--orthogonal_learning_rate', default=0.0, type=float, help='Multiplicative factor for the orthogonal loss')

# Large Margin is from Large Margin In Softmax Cross-Entropy Loss
# it did not help on an Italian VIT test
# scores went from 0.8252 to 0.8248
Expand Down

0 comments on commit a87da7d

Please sign in to comment.