diff --git a/mlora.py b/mlora.py index ec8b4f6e..1ed93a39 100644 --- a/mlora.py +++ b/mlora.py @@ -171,8 +171,6 @@ def train(config: Dict[str, any], llm_model: mlora.LLMModel, dispatcher: mlora.D step_cnt = 0 while not dispatcher.check_task_done(): input: mlora.MultiLoraBatchData = dispatcher.get_train_data() - for lora in input.lora_batch_data_config_: - all_optimizer[lora.adapter_name_].zero_grad() step_cnt += 1 @@ -201,6 +199,7 @@ def train(config: Dict[str, any], llm_model: mlora.LLMModel, dispatcher: mlora.D for lora in input.lora_batch_data_config_: if step_cnt % accumulation_step[lora.adapter_name_] == 0: all_optimizer[lora.adapter_name_].step() + all_optimizer[lora.adapter_name_].zero_grad() if step_cnt % config["save_step"] == 0: mlora.save_lora_model(llm_model, config, f"{step_cnt}")