From 278ce89069439b35b6752d7ef52dc391266b6838 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Fri, 19 Jan 2024 23:15:39 +0800 Subject: [PATCH] fix lint error --- mlora/tasks.py | 1 - mlora/train.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mlora/tasks.py b/mlora/tasks.py index 1b3110af..e8171a30 100644 --- a/mlora/tasks.py +++ b/mlora/tasks.py @@ -280,7 +280,6 @@ def evaluate(model: LLMModel, tokenizer: Tokenizer, configs: List[EvaluateConfig], max_seq_len: int = 512): - device = torch.device(model.device_) max_iterations = 0 for config in configs: config.init_task() diff --git a/mlora/train.py b/mlora/train.py index 660f6ccd..86b910b2 100644 --- a/mlora/train.py +++ b/mlora/train.py @@ -1,4 +1,4 @@ -from mlora.modelargs import MultiLoraBatchData, LoraConfig, MixConfig +from mlora.modelargs import LoraConfig, MixConfig from mlora.dispatcher import Dispatcher from mlora.mix_lora import router_loss_factory from mlora.tasks import train_task_factory @@ -75,8 +75,8 @@ def prepare(self, train_paramas: List[torch.Tensor]): def step_lr_scheduler(self, total_epoch, len_dataset): if self.lr_scheduler_ is None: - total_steps = (len_dataset // self.batch_size_)*total_epoch if len_dataset % self.batch_size_ == 0 else ( - len_dataset // self.batch_size_ + 1)*total_epoch + total_steps = (len_dataset // self.batch_size_) * total_epoch if len_dataset % self.batch_size_ == 0 else ( + len_dataset // self.batch_size_ + 1) * total_epoch warmup_steps = self.warmup_steps_ * \ total_steps if isinstance( self.warmup_steps_, float) else self.warmup_steps_