Skip to content

Commit

Permalink
Feature/support latest transformers (#104)
Browse files Browse the repository at this point in the history
* support latest transformers

* evaluation_strategy -> eval_strategy
  • Loading branch information
SeanLee97 authored Nov 14, 2024
1 parent 8137ac0 commit 4f7f963
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
10 changes: 5 additions & 5 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def compute_mlm_loss(self, logits, mask_target_labels):
ignore_index=self.pad_token_id,
)

def compute_loss(self, model, inputs, return_outputs: bool = False):
def compute_loss(self, model, inputs, return_outputs: bool = False, **kwargs):
""" Compute loss for AnglE.
:param model: Huggingface model.
Expand Down Expand Up @@ -942,7 +942,7 @@ def compute_student_loss(self,
) / division
return (loss + compression_loss) / (self.n_layers - 1)

def compute_loss(self, model, inputs, return_outputs=False):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
""" Compute loss for Espresso.
:param model: Huggingface model.
Expand Down Expand Up @@ -1431,7 +1431,7 @@ def fit(self,
warmup_steps: int = 1000,
logging_steps: int = 10,
eval_steps: int = 1000,
evaluation_strategy: str = 'steps',
eval_strategy: str = 'steps',
save_steps: int = 100,
save_strategy: str = 'steps',
save_total_limit: int = 1,
Expand Down Expand Up @@ -1462,7 +1462,7 @@ def fit(self,
:param warmup_steps: int. Default 1000.
:param logging_steps: int. Default 10.
:param eval_steps: int. Default 1000.
:param evaluation_strategy: str. Default 'steps'.
:param eval_strategy: str. Default 'steps'.
:param save_steps: int. Default 100.
:param save_strategy: str. Default steps.
:param save_total_limit: int. Default 10.
Expand Down Expand Up @@ -1549,7 +1549,7 @@ def fit(self,
logging_steps=logging_steps,
save_steps=save_steps,
save_strategy=save_strategy,
evaluation_strategy=evaluation_strategy if valid_ds is not None else 'no',
eval_strategy=eval_strategy if valid_ds is not None else 'no',
eval_steps=eval_steps,
output_dir=output_dir,
save_total_limit=save_total_limit,
Expand Down
6 changes: 3 additions & 3 deletions angle_emb/angle_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@
parser.add_argument('--save_strategy', type=str, default='steps', choices=['steps', 'epoch', 'no'],
help='Specify save_strategy, default steps')
parser.add_argument('--eval_steps', type=int, default=1000, help='Specify eval_steps, default 1000')
parser.add_argument('--evaluation_strategy', type=str, default='steps', choices=['steps', 'epoch', 'no'],
help='Specify evaluation_strategy, default steps')
parser.add_argument('--eval_strategy', type=str, default='steps', choices=['steps', 'epoch', 'no'],
help='Specify eval_strategy, default steps')
parser.add_argument('--batch_size', type=int, default=32, help='Specify batch size, default 32')
parser.add_argument('--maxlen', type=int, default=512, help='Specify max length, default 512')
parser.add_argument('--streaming', action='store_true', default=False,
Expand Down Expand Up @@ -307,7 +307,7 @@ def main():
save_strategy=args.save_strategy,
save_total_limit=args.save_total_limit,
eval_steps=args.eval_steps,
evaluation_strategy=args.evaluation_strategy,
eval_strategy=args.eval_strategy,
warmup_steps=args.warmup_steps,
logging_steps=args.logging_steps,
gradient_accumulation_steps=args.gradient_accumulation_steps,
Expand Down

0 comments on commit 4f7f963

Please sign in to comment.