From aeea516b7a0eebeb8524303a78f8f8ea67ba58df Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Sun, 29 Sep 2024 20:12:58 +0800 Subject: [PATCH] support normal evaluation --- angle_emb/angle.py | 19 +++++++++++++------ angle_emb/angle_trainer.py | 26 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 6693ede..8466425 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -1412,13 +1412,15 @@ def detect_dataset_format(self, ds: Dataset): def fit(self, train_ds: Dataset, valid_ds: Optional[Dataset] = None, + valid_ds_for_callback: Optional[Dataset] = None, batch_size: int = 32, output_dir: Optional[str] = None, epochs: int = 1, learning_rate: float = 1e-5, warmup_steps: int = 1000, logging_steps: int = 10, - eval_steps: Optional[int] = None, + eval_steps: int = 1000, + evaluation_strategy: str = 'steps', save_steps: int = 100, save_strategy: str = 'steps', save_total_limit: int = 10, @@ -1439,13 +1441,17 @@ def fit(self, :param train_ds: Dataset. tokenized train dataset. Required. :param valid_ds: Optional[Dataset]. tokenized valid dataset. Default None. + :param valid_ds_for_callback: Optional[Dataset]. tokenized valid dataset for callback use. + The dataset format should be `DatasetFormats.A`. The spearmans' correlation will be computed + after each epoch training and the best model will be saved. Default None. :param batch_size: int. Default 32. :param output_dir: Optional[str]. save dir. Default None. :param epochs: int. Default 1. :param learning_rate: float. Default 1e-5. :param warmup_steps: int. Default 1000. :param logging_steps: int. Default 10. - :param eval_steps: Optional[int]. Default None. + :param eval_steps: int. Default 1000. + :param evaluation_strategy: str. Default 'steps'. :param save_steps: int. Default 100. :param save_strategy: str. Default steps. :param save_total_limit: int. Default 10. @@ -1491,16 +1497,16 @@ def fit(self, trainer_kwargs = {} callbacks = None - if valid_ds is not None: + if valid_ds_for_callback is not None: # check format - for obj in valid_ds: + for obj in valid_ds_for_callback: if obj['extra']['dataset_format'] != DatasetFormats.A: raise ValueError('Currently only support evaluation for DatasetFormats.A.') break best_ckpt_dir = None if output_dir is not None: best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint') - evaluate_callback = EvaluateCallback(self, valid_ds, + evaluate_callback = EvaluateCallback(self, valid_ds_for_callback, partial(self.evaluate, batch_size=batch_size), save_dir=best_ckpt_dir, push_to_hub=push_to_hub, @@ -1519,7 +1525,7 @@ def fit(self, model=self.backbone, dataset_format=self.detect_dataset_format(train_ds), train_dataset=train_ds, - eval_dataset=None, + eval_dataset=valid_ds, loss_kwargs=loss_kwargs, tokenizer=self.tokenizer, args=TrainingArguments( @@ -1531,6 +1537,7 @@ def fit(self, fp16=fp16, logging_steps=logging_steps, save_strategy=save_strategy, + evaluation_strategy=evaluation_strategy, eval_steps=eval_steps, save_steps=save_steps, output_dir=output_dir, diff --git a/angle_emb/angle_trainer.py b/angle_emb/angle_trainer.py index 87b6f83..1608ac0 100644 --- a/angle_emb/angle_trainer.py +++ b/angle_emb/angle_trainer.py @@ -35,6 +35,12 @@ help='Specify huggingface datasets subset name for valid set, default None') parser.add_argument('--valid_split_name', type=str, default='train', help='Specify huggingface datasets split name for valid set, default `train`') +parser.add_argument('--valid_name_or_path_for_callback', type=str, default=None, + help='Specify huggingface datasets name or local file path for valid set for callback use, default None.') +parser.add_argument('--valid_subset_name_for_callback', type=str, default=None, + help='Specify huggingface datasets subset name for valid set for callback use, default None') +parser.add_argument('--valid_split_name_for_callback', type=str, default='train', + help='Specify huggingface datasets split name for valid set for callback use, default `train`') parser.add_argument('--prompt_template', type=str, default=None, help='Specify prompt_template like "xxx: {text}", default None.' 'This prompt will be applied for all text columns.' @@ -228,6 +234,25 @@ def main(): valid_ds = valid_ds[args.valid_split_name or 'train'].map( AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), num_proc=args.workers) + + valid_ds_for_callback = None + if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None: + logger.info('Validation for callback detected, processing validation...') + if os.path.exists(args.valid_name_or_path_for_callback): + valid_ds_for_callback = load_dataset( + 'json', data_files=[args.valid_name_or_path_for_callback], num_proc=args.workers) + else: + if args.valid_subset_name_for_callback is not None: + valid_ds_for_callback = load_dataset( + args.valid_name_or_path_for_callback, + args.valid_subset_name_for_callback, + num_proc=args.workers) + else: + valid_ds_for_callback = load_dataset( + args.valid_name_or_path_for_callback, num_proc=args.workers) + valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map( + AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template), + num_proc=args.workers) argument_kwargs = {} if args.push_to_hub: @@ -256,6 +281,7 @@ def main(): model.fit( train_ds=train_ds, valid_ds=valid_ds, + valid_ds_for_callback=valid_ds_for_callback, output_dir=args.save_dir, batch_size=args.batch_size, epochs=args.epochs,