diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 72f66ba..333ece1 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -859,6 +859,7 @@ def compute_loss(self, model, inputs, return_outputs: bool = False): return (loss, outputs) if return_outputs else loss + @torch.no_grad() def prediction_step(self, model, inputs, *args, **kwargs): eval_loss = self.compute_loss(model, inputs, return_outputs=False) return eval_loss, None, None @@ -1542,7 +1543,6 @@ def fit(self, logging_steps=logging_steps, save_strategy=save_strategy, evaluation_strategy=evaluation_strategy, - prediction_loss_only=True, eval_steps=eval_steps, save_steps=save_steps, output_dir=output_dir,