diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index b06b0d5ad..9050a9baf 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -49,7 +49,11 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): """ def __init__( - self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + self, + finetuning_args: "FinetuningArguments", + processor: Optional["ProcessorMixin"], + gen_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, ) -> None: if is_transformers_version_greater_than("4.46"): kwargs["processing_class"] = kwargs.pop("tokenizer") @@ -58,6 +62,9 @@ def __init__( super().__init__(**kwargs) self.finetuning_args = finetuning_args + if gen_kwargs is not None: + # https://github.com/huggingface/transformers/blob/v4.45.0/src/transformers/trainer_seq2seq.py#L287 + self._gen_kwargs = gen_kwargs if processor is not None: self.add_callback(SaveProcessorCallback(processor)) diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index f41f24cba..5b904244f 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -78,6 +78,12 @@ def run_sft( metric_module["compute_metrics"] = ComputeAccuracy() metric_module["preprocess_logits_for_metrics"] = eval_logit_processor + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict(obey_generation_config=True) + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + # Initialize our Trainer trainer = CustomSeq2SeqTrainer( model=model, @@ -85,17 +91,12 @@ def run_sft( finetuning_args=finetuning_args, data_collator=data_collator, callbacks=callbacks, + gen_kwargs=gen_kwargs, **dataset_module, **tokenizer_module, **metric_module, ) - # Keyword arguments for `model.generate` - gen_kwargs = generating_args.to_dict(obey_generation_config=True) - gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids - gen_kwargs["pad_token_id"] = tokenizer.pad_token_id - gen_kwargs["logits_processor"] = get_logits_processor() - # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)