Skip to content

Commit

Permalink
support normal evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Sep 29, 2024
1 parent 7104888 commit aeea516
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
19 changes: 13 additions & 6 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,13 +1412,15 @@ def detect_dataset_format(self, ds: Dataset):
def fit(self,
train_ds: Dataset,
valid_ds: Optional[Dataset] = None,
valid_ds_for_callback: Optional[Dataset] = None,
batch_size: int = 32,
output_dir: Optional[str] = None,
epochs: int = 1,
learning_rate: float = 1e-5,
warmup_steps: int = 1000,
logging_steps: int = 10,
eval_steps: Optional[int] = None,
eval_steps: int = 1000,
evaluation_strategy: str = 'steps',
save_steps: int = 100,
save_strategy: str = 'steps',
save_total_limit: int = 10,
Expand All @@ -1439,13 +1441,17 @@ def fit(self,
:param train_ds: Dataset. tokenized train dataset. Required.
:param valid_ds: Optional[Dataset]. tokenized valid dataset. Default None.
:param valid_ds_for_callback: Optional[Dataset]. tokenized valid dataset for callback use.
The dataset format should be `DatasetFormats.A`. The spearmans' correlation will be computed
after each epoch training and the best model will be saved. Default None.
:param batch_size: int. Default 32.
:param output_dir: Optional[str]. save dir. Default None.
:param epochs: int. Default 1.
:param learning_rate: float. Default 1e-5.
:param warmup_steps: int. Default 1000.
:param logging_steps: int. Default 10.
:param eval_steps: Optional[int]. Default None.
:param eval_steps: int. Default 1000.
:param evaluation_strategy: str. Default 'steps'.
:param save_steps: int. Default 100.
:param save_strategy: str. Default steps.
:param save_total_limit: int. Default 10.
Expand Down Expand Up @@ -1491,16 +1497,16 @@ def fit(self,
trainer_kwargs = {}

callbacks = None
if valid_ds is not None:
if valid_ds_for_callback is not None:
# check format
for obj in valid_ds:
for obj in valid_ds_for_callback:
if obj['extra']['dataset_format'] != DatasetFormats.A:
raise ValueError('Currently only support evaluation for DatasetFormats.A.')
break
best_ckpt_dir = None
if output_dir is not None:
best_ckpt_dir = os.path.join(output_dir, 'best-checkpoint')
evaluate_callback = EvaluateCallback(self, valid_ds,
evaluate_callback = EvaluateCallback(self, valid_ds_for_callback,
partial(self.evaluate, batch_size=batch_size),
save_dir=best_ckpt_dir,
push_to_hub=push_to_hub,
Expand All @@ -1519,7 +1525,7 @@ def fit(self,
model=self.backbone,
dataset_format=self.detect_dataset_format(train_ds),
train_dataset=train_ds,
eval_dataset=None,
eval_dataset=valid_ds,
loss_kwargs=loss_kwargs,
tokenizer=self.tokenizer,
args=TrainingArguments(
Expand All @@ -1531,6 +1537,7 @@ def fit(self,
fp16=fp16,
logging_steps=logging_steps,
save_strategy=save_strategy,
evaluation_strategy=evaluation_strategy,
eval_steps=eval_steps,
save_steps=save_steps,
output_dir=output_dir,
Expand Down
26 changes: 26 additions & 0 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
help='Specify huggingface datasets subset name for valid set, default None')
parser.add_argument('--valid_split_name', type=str, default='train',
help='Specify huggingface datasets split name for valid set, default `train`')
parser.add_argument('--valid_name_or_path_for_callback', type=str, default=None,
help='Specify huggingface datasets name or local file path for valid set for callback use, default None.')
parser.add_argument('--valid_subset_name_for_callback', type=str, default=None,
help='Specify huggingface datasets subset name for valid set for callback use, default None')
parser.add_argument('--valid_split_name_for_callback', type=str, default='train',
help='Specify huggingface datasets split name for valid set for callback use, default `train`')
parser.add_argument('--prompt_template', type=str, default=None,
help='Specify prompt_template like "xxx: {text}", default None.'
'This prompt will be applied for all text columns.'
Expand Down Expand Up @@ -228,6 +234,25 @@ def main():
valid_ds = valid_ds[args.valid_split_name or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
num_proc=args.workers)

valid_ds_for_callback = None
if valid_ds_for_callback is None and args.valid_name_or_path_for_callback is not None:
logger.info('Validation for callback detected, processing validation...')
if os.path.exists(args.valid_name_or_path_for_callback):
valid_ds_for_callback = load_dataset(
'json', data_files=[args.valid_name_or_path_for_callback], num_proc=args.workers)
else:
if args.valid_subset_name_for_callback is not None:
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback,
args.valid_subset_name_for_callback,
num_proc=args.workers)
else:
valid_ds_for_callback = load_dataset(
args.valid_name_or_path_for_callback, num_proc=args.workers)
valid_ds_for_callback = valid_ds_for_callback[args.valid_split_name_for_callback or 'train'].map(
AngleDataTokenizer(model.tokenizer, model.max_length, prompt_template=args.prompt_template),
num_proc=args.workers)

argument_kwargs = {}
if args.push_to_hub:
Expand Down Expand Up @@ -256,6 +281,7 @@ def main():
model.fit(
train_ds=train_ds,
valid_ds=valid_ds,
valid_ds_for_callback=valid_ds_for_callback,
output_dir=args.save_dir,
batch_size=args.batch_size,
epochs=args.epochs,
Expand Down

0 comments on commit aeea516

Please sign in to comment.