From 69d8d63b500913de8e3e6e4dc984175e97dc3f2d Mon Sep 17 00:00:00 2001 From: yezhem Date: Fri, 12 Jul 2024 08:02:01 +0000 Subject: [PATCH] [feature] support pipeline parallelism --- README.md | 37 ++ mlora/config/task.py | 2 +- mlora/executor/__init__.py | 3 +- mlora/executor/dispatcher/__init__.py | 11 +- mlora/executor/dispatcher/dispatcher.py | 2 +- mlora/executor/dispatcher/pipe_dispatcher.py | 120 ++++++ mlora/executor/executor.py | 3 +- mlora/executor/pipe_executor.py | 414 +++++++++++++++++++ mlora/executor/pipeline/__init__.py | 0 mlora/executor/pipeline/function.py | 78 ++++ mlora/executor/pipeline/messages.py | 27 ++ mlora/executor/pipeline/queue.py | 109 +++++ mlora/executor/pipeline/rpc_transport.py | 247 +++++++++++ mlora/executor/pipeline/stream.py | 17 + mlora/executor/pipeline/transport.py | 50 +++ mlora/executor/task/cit_task.py | 1 + mlora/executor/task/cpo_task.py | 5 +- mlora/executor/task/dpo_task.py | 2 + mlora/executor/task/task.py | 10 +- mlora/executor/task/train_task.py | 1 + mlora/model/args.py | 16 + mlora/model/llm/model_llama.py | 5 + mlora/model/llm/model_llm.py | 5 + mlora/utils/cmd.py | 3 + mlora_pp_train.py | 49 +++ mlora_server.py | 16 +- tests/lora_op_test.py | 2 + 27 files changed, 1223 insertions(+), 12 deletions(-) create mode 100644 mlora/executor/dispatcher/pipe_dispatcher.py create mode 100644 mlora/executor/pipe_executor.py create mode 100644 mlora/executor/pipeline/__init__.py create mode 100644 mlora/executor/pipeline/function.py create mode 100644 mlora/executor/pipeline/messages.py create mode 100644 mlora/executor/pipeline/queue.py create mode 100644 mlora/executor/pipeline/rpc_transport.py create mode 100644 mlora/executor/pipeline/stream.py create mode 100644 mlora/executor/pipeline/transport.py create mode 100644 mlora_pp_train.py diff --git a/README.md b/README.md index f529fe12..9309a679 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,43 @@ For further detailed usage information, please use `--help` option: python mlora_train.py --help ``` +## Deployment using pipeline parallelism +Similar to Quickstart, the command to start in a two-node environment is as follows: + +NOTE1: Use environment variables `MASTER_ADDR/MASTER_PORT` to set the master node. + +NOTE2: Set balance, indicating the number of decoder layers allocated to each rank. + + +```bash +# in the first node +export MASTER_ADDR=master.svc.cluster.local +export MASTER_PORT=12355 +python mlora_pp_train.py \ + --base_model TinyLlama/TinyLlama-1.1B-Chat-v0.4 \ + --config demo/lora/lora_case_1.yaml \ + --pipeline \ + --device "cuda:0" \ + --rank 0 \ + --balance 12 13 \ + --recompute False \ + --precision fp32 + +# in the second node +export MASTER_ADDR=master.svc.cluster.local +export MASTER_PORT=12355 +python mlora_pp_train.py \ + --base_model TinyLlama/TinyLlama-1.1B-Chat-v0.4 \ + --config demo/lora/lora_case_1.yaml \ + --pipeline \ + --device "cuda:1" \ + --rank 1 \ + --balance 12 13 \ + --recompute False \ + --precision fp32 +``` + + ## Quickstart with Docker mLoRA offers an official Docker image for quick start and development, The image is available on Dockerhub Packages registry. diff --git a/mlora/config/task.py b/mlora/config/task.py index f95418df..4e46a9fb 100644 --- a/mlora/config/task.py +++ b/mlora/config/task.py @@ -25,7 +25,7 @@ def __init__( self.init(self.__params_map, config) self.adapter_ = adapters[config["adapter"]] - self.dataset_ = datasets[config["dataset"]] + self.dataset_: DatasetConfig | None = datasets[config["dataset"]] class TrainTaskConfig(TaskConfig): diff --git a/mlora/executor/__init__.py b/mlora/executor/__init__.py index 69514566..c7515218 100644 --- a/mlora/executor/__init__.py +++ b/mlora/executor/__init__.py @@ -1,3 +1,4 @@ from .executor import Executor +from .pipe_executor import PipeExecutor -__all__ = ["Executor"] +__all__ = ["Executor", "PipeExecutor"] diff --git a/mlora/executor/dispatcher/__init__.py b/mlora/executor/dispatcher/__init__.py index 2652fe6c..ec400b31 100644 --- a/mlora/executor/dispatcher/__init__.py +++ b/mlora/executor/dispatcher/__init__.py @@ -1,6 +1,13 @@ +from typing import Dict, Type + from .backend_dispatcher import BackendDispatcher from .dispatcher import Dispatcher +from .pipe_dispatcher import PipeDispatcher -DISPATCHER_CLASS = {"default": Dispatcher, "backend": BackendDispatcher} +DISPATCHER_CLASS: Dict[str, Type[Dispatcher]] = { + "default": Dispatcher, + "backend": BackendDispatcher, + "pipe": PipeDispatcher, +} -__all__ = ["Dispatcher", "BackendDispatcher", "DISPATCHER_CLASS"] +__all__ = ["Dispatcher", "BackendDispatcher", "PipeDispatcher", "DISPATCHER_CLASS"] diff --git a/mlora/executor/dispatcher/dispatcher.py b/mlora/executor/dispatcher/dispatcher.py index 5ff8b435..2277297c 100644 --- a/mlora/executor/dispatcher/dispatcher.py +++ b/mlora/executor/dispatcher/dispatcher.py @@ -131,7 +131,7 @@ def _align_batch_tokens( return batch_tokens, batch_masks - def data(self) -> MLoRAData: + def data(self) -> MLoRAData | None: self._dispatch_task_in() batch_tokens: List[Tokens] = [] diff --git a/mlora/executor/dispatcher/pipe_dispatcher.py b/mlora/executor/dispatcher/pipe_dispatcher.py new file mode 100644 index 00000000..80366eb4 --- /dev/null +++ b/mlora/executor/dispatcher/pipe_dispatcher.py @@ -0,0 +1,120 @@ +from typing import List, Set, override + +from mlora.config.dispatcher import DispatcherConfig +from mlora.executor.task import Task +from mlora.model.args import Masks, MLoRAData, MLoRADataConfig, Tokens + +from .backend_dispatcher import BackendDispatcher + + +class PipeDispatcher(BackendDispatcher): + lock_set_: Set[str] + + def __init__(self, config: DispatcherConfig) -> None: + super().__init__(config) + self.lock_set_ = set() + + @override + def _dispatch_task_in(self): + # ready task to terminate + terminate_task = [task for task in self.ready_ if task.is_terminate()] + self.ready_ = [task for task in self.ready_ if not task.is_terminate()] + + for task in terminate_task: + self.terminate_event_.notify(task) + + # pipeline only have one running task + while len(self.running_) <= self.concurrency_num_ and len(self.ready_) > 0: + task = self.ready_.pop(0) + self.running_.append(task) + self.running_event_.notify(task) + + def find_the_task(self, task_name: str) -> Task: + # the worker do not really dispather the task + # so we just find it in the read + for task in self.ready_: + if task.task_name() != task_name: + continue + return task + raise Exception(f"No this task {task.task_name()}") + + # if not the head worker, we need to manully dispatch the task + def dispatch_task_to_run(self, task_name: str): + task = self.find_the_task(task_name) + self.running_event_.notify(task) + + def dispatch_task_to_ready(self, task_name: str): + task = self.find_the_task(task_name) + self.ready_event_.notify(task) + + def dispatch_task_to_done(self, task_name: str): + task = self.find_the_task(task_name) + self.done_event_.notify(task) + + def dispatch_task_to_terminal(self, task_name: str): + task = self.find_the_task(task_name) + self.terminate_event_.notify(task) + + def dispatch_task_to_step(self, task_name: str): + task = self.find_the_task(task_name) + task.step() + self.step_event_.notify(task) + + def lock_task(self, name: str): + self.lock_set_.add(name) + + def unlock_task(self, name: str): + if name not in self.lock_set_: + return + self.lock_set_.remove(name) + + def is_lock(self, name: str): + return name in self.lock_set_ + + @override + def data(self) -> MLoRAData | None: + self._dispatch_task_in() + + batch_tokens: List[Tokens] = [] + batch_masks: List[Masks] = [] + data_configs: List[MLoRADataConfig] = [] + + can_run_task = list( + filter(lambda task: not self.is_lock(task.task_name()), self.running_) + ) + + if len(can_run_task) == 0: + return None + + # get all train data + start_idx: int = 0 + # pipe dispatcher just run one task + task = can_run_task[0] + + data, data_config = task.data(start_idx) + + # for unlock the task + for item in data_config: + item.task_name_ = task.task_name() + + data_configs.extend(data_config) + batch_tokens.extend(data) + start_idx = start_idx + len(data) + self.lock_task(task.task_name()) + + # post process this batch data + batch_tokens, batch_masks = self._align_batch_tokens(batch_tokens, data_configs) + + return MLoRAData( + batch_tokens=batch_tokens, batch_mask=batch_masks, data_config=data_configs + ) + + def task_step(self, task_name: str): + # in head worker the task must in running + for task in self.running_: + if task.task_name() != task_name: + continue + task.step() + self.step_event_.notify(task) + + self._dispatch_task_out() diff --git a/mlora/executor/executor.py b/mlora/executor/executor.py index 8c79b112..0b6ade5c 100644 --- a/mlora/executor/executor.py +++ b/mlora/executor/executor.py @@ -99,7 +99,8 @@ def execute(self) -> None: mm_collect_step = 0 while not self.dispatcher_.is_done(): - data: MLoRAData = self.dispatcher_.data() + data: MLoRAData | None = self.dispatcher_.data() + assert data is not None torch.cuda.reset_peak_memory_stats(device=self.model_.device_) diff --git a/mlora/executor/pipe_executor.py b/mlora/executor/pipe_executor.py new file mode 100644 index 00000000..03334f74 --- /dev/null +++ b/mlora/executor/pipe_executor.py @@ -0,0 +1,414 @@ +import logging +import time +import uuid +from enum import Enum, auto +from typing import Any, Dict, List, OrderedDict, cast + +import torch + +from mlora.config import MLoRAConfig +from mlora.config.task import TaskConfig +from mlora.model.args import LinearInfo, MLoRAData, ModelData +from mlora.model.llm import LLMModel +from mlora.model.llm.model_llama import precompute_mask +from mlora.model.tokenizer import Tokenizer + +from .dispatcher import DISPATCHER_CLASS, PipeDispatcher +from .executor import Executor +from .pipeline.function import RecvOperator, SendOperator +from .pipeline.messages import PipeMessage, PipeMessageType +from .pipeline.queue import DeviceSwapQueue +from .pipeline.rpc_transport import RpcTransport +from .pipeline.stream import CudaStream +from .task import Task + + +class WorkerRole(Enum): + HEAD = auto() + MID = auto() + TAIL = auto() + + +class PipeExecutor(Executor): + role_: WorkerRole + device_: str + + rank_: int + world_size_: int + balance_: List[int] + + # info about model + partial_model_: torch.nn.Sequential + heads_: int + model_name_: str + recompute_: bool + + input_queue_: DeviceSwapQueue + transport_: RpcTransport + + # cache some tensor + backward_cache_: Dict[int, torch.Tensor] + input_cache_: Dict[int, MLoRAData] + + dispatcher_: PipeDispatcher + + def __init__( + self, + model: LLMModel, + tokenizer: Tokenizer, + config: MLoRAConfig, + device: str, + rank: int, + balance: List[int], + recompute: bool = False, + ) -> None: + self.model_ = model + self.tokenizer_ = tokenizer + self.heads_ = self.model_.n_heads_ + self.model_name_ = self.model_.name_or_path_ + + self.device_ = device + self.rank_ = rank + self.balance_ = balance + self.world_size_ = len(balance) + + self.backward_cache_ = {} + self.input_cache_ = {} + + self.recompute_ = recompute + + self.__init_worker() + self.__init_partition() + + # record the default stream to sync + self.default_stream_ = CudaStream(torch.cuda.default_stream(self.device_)) + # init the rpc and wait the cluster node ready + self.transport_ = RpcTransport( + self.rank_, self.world_size_, torch.device(self.device_) + ) + + self.dispatcher_: PipeDispatcher = cast( + PipeDispatcher, DISPATCHER_CLASS["pipe"](config.dispatcher_) + ) + + hook_func = { + "init": self.__task_init_hook, + "running": self.__task_to_running_hook, + "ready": self.__task_to_ready_hook, + "done": self.__task_to_done_hook, + "terminate": self.__task_to_terminate_hook, + } + + for hook, cb in hook_func.items(): + self.dispatcher_.register_hook(hook, cb) + + def __init_worker(self): + # init the different worker + if self.rank_ == 0: + self.role_ = WorkerRole.HEAD + self.input_queue_ = DeviceSwapQueue( + torch.device("cpu"), torch.device(self.device_), 4, "input_data_queue" + ) + self.input_queue_.start() + elif self.rank_ == self.world_size_ - 1: + self.role_ = WorkerRole.TAIL + else: + self.role_ = WorkerRole.MID + + def __init_partition(self) -> None: + balance = self.balance_[self.rank_] + start_module_idx = sum(self.balance_[: self.rank_]) + end_module_idx = start_module_idx + balance + logging.info( + f"RANK-{self.rank_} in device {self.device_} to load module layers " + f"from {start_module_idx} to {end_module_idx}." + ) + + seq_model: torch.nn.Sequential = self.model_.sequential() + assert sum(self.balance_) == len(seq_model) + + self.partial_model_ = torch.nn.Sequential() + + for idx in range(start_module_idx, end_module_idx): + self.partial_model_.append(seq_model[idx]) + + assert len(self.partial_model_) == balance + + del seq_model[:start_module_idx] + del seq_model[balance:] + del self.model_ + + torch.cuda.empty_cache() + + def __head_worker_run(self): + while True: + # we get the model's output, and calc the loss + self.__process_comm() + self.__process_backward() + self.__process_output() + self.__process_input() + time.sleep(1 / 100000) + + def __not_head_worker_run(self): + while True: + self.__process_comm() + self.__process_backward() + self.__process_forward() + time.sleep(1 / 100000) + + def __head_process_step(self, message: PipeMessage): + assert message.model_data_ is not None + train_data: MLoRAData = self.input_cache_[message.model_data_.random_id_] + + # like dpo one task have two data config + task_names = set() + for item in train_data.data_config_: + task_names.add(item.task_name_) + + for task_name in task_names: + self.dispatcher_.task_step(task_name) + self.dispatcher_.unlock_task(task_name) + + assert message.model_data_ is not None + del self.input_cache_[message.model_data_.random_id_] + + def __process_backward(self): + message = self.transport_.recv_message(PipeMessageType.GRADIENTS, block=False) + if message is None: + return + + logging.debug( + f"Recv the gradients - {str(message.msg_id_)[:8]} from {message.src_}." + ) + + msg_id = message.msg_id_ + + assert msg_id in self.backward_cache_ + + phony: torch.Tensor = self.backward_cache_[msg_id] + phony.grad_fn.grad_from_next_worker = message.tensor_data_ # type: ignore + phony.backward() + + del self.backward_cache_[msg_id] + + if self.role_ == WorkerRole.HEAD: + self.__head_process_step(message) + else: + assert message.model_data_ is not None + for task_name in message.model_data_.task_name_: + self.dispatcher_.dispatch_task_to_step(task_name) + + def __process_forward(self): + assert self.role_ != WorkerRole.HEAD + + # recv the tensors from prev-worker + message = self.transport_.recv_message(PipeMessageType.ACTIVATIONS, block=False) + if message is None: + return + + logging.debug( + f"Recv the activations - {str(message.msg_id_)[:8]} from {message.src_}." + ) + + data = RecvOperator.apply( + torch.tensor(1.0, requires_grad=True), self.transport_, message + ) + # we need to wait the default stream calcuate all tensor + # and then send it, so we hook the pre stage fn to poll the stream + data.grad_fn.pre_stage_fn = self.default_stream_.poll # type: ignore + assert message.model_data_ is not None + data = self.__forward(data, message.model_data_) + + self.default_stream_.poll() + assert message.model_data_ is not None + return self.__send_activations(data, message.model_data_) + + def __process_comm(self): + try: + msg: PipeMessage = self.transport_.recv_comm( + PipeMessageType.COMM, block=False + ) + comm_data = msg.comm_data_ + except Exception: + return + + if comm_data["comm"] == "task_add": + self.add_task(comm_data["data"]) + elif comm_data["comm"] == "task_running": + self.dispatcher_.dispatch_task_to_run(comm_data["data"]) + elif comm_data["comm"] == "task_ready": + self.dispatcher_.dispatch_task_to_ready(comm_data["data"]) + elif comm_data["comm"] == "task_done": + self.dispatcher_.dispatch_task_to_done(comm_data["data"]) + elif comm_data["comm"] == "task_terminal": + self.dispatcher_.dispatch_task_to_terminal(comm_data["data"]) + else: + raise NotImplementedError + + def __process_output(self): + assert self.role_ == WorkerRole.HEAD + + # recv the tensors from prev-worker + message = self.transport_.recv_message(PipeMessageType.ACTIVATIONS, block=False) + if message is None: + return + + logging.debug( + f"Recv the activations - {str(message.msg_id_)[:8]} from {message.src_}." + ) + + output: torch.Tensor = RecvOperator.apply( + torch.tensor(1.0, requires_grad=True), self.transport_, message + ) + # we need to wait the default stream calcuate all tensor + # and then send it, so we hook the pre stage fn to poll the stream + output.grad_fn.pre_stage_fn = self.default_stream_.poll # type: ignore + + assert message.model_data_ is not None + train_data: MLoRAData = self.input_cache_[message.model_data_.random_id_] + labels = torch.tensor(train_data.batch_tokens_, dtype=torch.long) + masks = torch.tensor(train_data.batch_mask_) + + total_loss: torch.Tensor | None = None + + for config in train_data.data_config_: + loss = config.loss_fn_(output, labels, masks) + if loss is None: + continue + total_loss = loss if total_loss is None else total_loss + loss + + if total_loss is not None: + total_loss.backward() + + def __process_input(self): + train_data: MLoRAData | None = self.dispatcher_.data() + if train_data is None: + return + # step1. get the model data and execute the forward + tensor_data = torch.tensor( + train_data.batch_tokens_, + dtype=torch.long, + device=self.device_, + requires_grad=False, + ) + + hidden_data = self.__forward(tensor_data, train_data.model_data()) + + # step2. then send the hidden state to next worker + self.default_stream_.poll() + self.__send_activations(hidden_data, train_data.model_data()) + + # step3. cache the input, we need it to calc the loss + self.input_cache_[train_data.model_data().random_id_] = train_data + + def __send_activations(self, tensor_data: torch.Tensor, batch_data: ModelData): + assert isinstance(tensor_data, torch.Tensor) + assert batch_data is None or isinstance(batch_data, ModelData) + + msg_id = uuid.uuid4().int + assert msg_id not in self.backward_cache_ + + phony: torch.Tensor = SendOperator.apply( + torch.tensor(1.0, requires_grad=True), + tensor_data, + self.transport_, + msg_id, + batch_data, + ) + + self.backward_cache_[msg_id] = phony + + def __send_comm(self, data: Any): + self.transport_.send_comm(PipeMessageType.COMM, data) + + def __forward(self, tensor_data: torch.Tensor, batch_data: ModelData): + mask = precompute_mask( + tensor_data, self.heads_, self.device_, batch_data.batch_mask_ + ) + data = (tensor_data, mask, batch_data, self.recompute_) + + for seq in self.partial_model_: + data = seq.forward(data) + + return data[0] + + def execute(self) -> None: + if self.role_ == WorkerRole.HEAD: + self.__head_worker_run() + elif self.role_ == WorkerRole.MID or self.role_ == WorkerRole.TAIL: + self.__not_head_worker_run() + else: + raise NotImplementedError + + def add_task(self, config: TaskConfig): + if self.role_ != WorkerRole.TAIL: + self.__send_comm({"comm": "task_add", "data": config}) + if self.role_ != WorkerRole.HEAD: + # only the head worker need to load dataset + config.dataset_ = None + self.dispatcher_.add_task(config, self.model_name_) + + def __task_init_hook(self, task: Task): + logging.info( + f"Init {task.task_type()} : {task.task_name()} " + + f"task with adapters: {task.adapter_name()}" + ) + task.prepare(self.__linears_info(), self.tokenizer_) + + def __task_to_running_hook(self, task: Task): + logging.info(f"Task to running, need to load adapters: {task.adapter_name()}") + if self.role_ != WorkerRole.TAIL: + self.__send_comm({"comm": "task_running", "data": task.task_name()}) + + task.switch_device(self.device_) + for adapter_model in task.adapter_model(): + for partial_layer in self.partial_model_: + if partial_layer.name() != "Decoder": + continue + partial_layer.wrapper_module_.load_adapter(adapter_model) + + def __task_to_ready_hook(self, task: Task): + logging.info(f"Base model offload adapters: {task.adapter_name()}") + if self.role_ != WorkerRole.TAIL: + self.__send_comm({"comm": "task_ready", "data": task.task_name()}) + + for adapter_name in task.adapter_name(): + for partial_layer in self.partial_model_: + if partial_layer.name() != "Decoder": + continue + partial_layer.wrapper_module_.offload_adapter(adapter_name) + task.switch_device("cpu") + + def __task_to_done_hook(self, task: Task): + logging.info(f"Finish and base model offload adapter - {task.adapter_name()}") + if self.role_ != WorkerRole.TAIL: + self.__send_comm({"comm": "task_done", "data": task.task_name()}) + + task.switch_device("cpu") + for adapter_name in task.adapter_name(): + for partial_layer in self.partial_model_: + if partial_layer.name() != "Decoder": + continue + partial_layer.wrapper_module_.offload_adapter(adapter_name) + task.done() + + def __task_to_terminate_hook(self, task: Task): + logging.info(f"Task - {task.task_name()} terminate.") + if self.role_ != WorkerRole.TAIL: + self.__send_comm({"comm": "task_terminal", "data": task.task_name()}) + + task.switch_device("cpu") + for adapter_name in task.adapter_name(): + for partial_layer in self.partial_model_: + if partial_layer.name() != "Decoder": + continue + partial_layer.wrapper_module_.offload_adapter(adapter_name) + task.terminate() + + def __linears_info(self) -> OrderedDict[str, LinearInfo]: + ret_val = OrderedDict() + for module in self.partial_model_: + if module.name() != "Decoder": + continue + ret_val.update(module.wrapper_module_.linears_info()) + return ret_val diff --git a/mlora/executor/pipeline/__init__.py b/mlora/executor/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mlora/executor/pipeline/function.py b/mlora/executor/pipeline/function.py new file mode 100644 index 00000000..80b20d83 --- /dev/null +++ b/mlora/executor/pipeline/function.py @@ -0,0 +1,78 @@ +import logging + +import torch + +from mlora.model.args import ModelData + +from .messages import PipeMessage, PipeMessageType +from .transport import Transport + + +class SendOperator(torch.autograd.Function): + # helper to reduce the activation memory + @staticmethod + def forward( + ctx, + phony: torch.Tensor, + tensor_data: torch.Tensor, + transport: Transport, + msg_id: int, + input_args: ModelData, + ): + assert isinstance(tensor_data, torch.Tensor) + + msg = PipeMessage( + src_=transport.worker_name, + dst_=transport.next_worker_name, + msg_type_=PipeMessageType.ACTIVATIONS, + msg_id_=msg_id, + tensor_data_=tensor_data, + model_data_=input_args, + comm_data_=None, + ) + transport.send_message(msg, False) + + return phony + + @staticmethod + def backward(ctx, grad_output): + assert ctx.grad_from_next_worker is not None + + return (None, ctx.grad_from_next_worker, None, None, None) + + +class RecvOperator(torch.autograd.Function): + # backward will auto send the grad to pre worker + @staticmethod + def forward( + ctx, phony: torch.Tensor, transport: Transport, msg: PipeMessage + ) -> torch.Tensor: + assert msg.msg_type_ == PipeMessageType.ACTIVATIONS + assert isinstance(msg.tensor_data_, torch.Tensor) + + ctx.msg_id_ = msg.msg_id_ + ctx.transport_ = transport + ctx.model_data_ = msg.model_data_ + + return msg.tensor_data_ * phony + + @staticmethod + def backward(ctx, *grad_outputs: torch.Tensor): + transport: Transport = ctx.transport_ + if hasattr(ctx, "pre_stage_fn") and ctx.pre_stage_fn is not None: + ctx.pre_stage_fn() + + logging.debug(f"Send the gradients to {transport.prev_worker_name}") + transport.send_message( + PipeMessage( + src_=transport.worker_name, + dst_=transport.prev_worker_name, + msg_type_=PipeMessageType.GRADIENTS, + msg_id_=ctx.msg_id_, + tensor_data_=grad_outputs[0], + model_data_=ctx.model_data_, + comm_data_=None, + ) + ) + + return (None, None, None) diff --git a/mlora/executor/pipeline/messages.py b/mlora/executor/pipeline/messages.py new file mode 100644 index 00000000..0f98267e --- /dev/null +++ b/mlora/executor/pipeline/messages.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any + +import torch + +from mlora.model.args import ModelData + + +class PipeMessageType(Enum): + ACTIVATIONS = "ACTIVATIONS" + GRADIENTS = "GRADIENTS" + COMM = "COMM" + + +@dataclass() +class PipeMessage: + src_: str + dst_: str + + msg_type_: PipeMessageType + msg_id_: int + + tensor_data_: torch.Tensor | None + model_data_: ModelData | None + + comm_data_: Any diff --git a/mlora/executor/pipeline/queue.py b/mlora/executor/pipeline/queue.py new file mode 100644 index 00000000..f0ceefcb --- /dev/null +++ b/mlora/executor/pipeline/queue.py @@ -0,0 +1,109 @@ +import logging +from queue import Queue +from threading import Thread +from typing import Optional + +import torch + +from .messages import PipeMessage +from .stream import CudaStream + + +class DeviceSwapQueue: + copy_stream_: CudaStream + + def __init__( + self, + source_device: torch.device, + target_device: torch.device, + target_size: int = 0, + queue_name: str = "default", + ) -> None: + source_device_is_cpu: bool = ( + True if source_device == torch.device("cpu") else False + ) + target_device_is_cpu: bool = ( + True if target_device == torch.device("cpu") else False + ) + + assert source_device_is_cpu ^ target_device_is_cpu + + if source_device_is_cpu: + self.copy_stream_: CudaStream = CudaStream(torch.cuda.Stream(target_device)) + else: + self.copy_stream_: CudaStream = CudaStream(torch.cuda.Stream(source_device)) + + self.target_device_: torch.device = target_device + self.source_device_: torch.device = source_device + # TODO: change the size by the size of avaliable gpu memory + self.src_queue_: Queue = Queue() + self.dst_queue_: Queue = Queue(target_size) + + self.queue_name_: str = queue_name + + self.stop_: bool = False + + def swap_thread_loop(self): + try: + msg: PipeMessage = self.src_queue_.get(block=True, timeout=0.001) + except Exception: + return + logging.debug( + f"{self.queue_name_} swap the message - {str(msg.msg_id_)[:8]} start." + ) + + # must ensure the msg.tensor_data_ sync done + with torch.cuda.stream(self.copy_stream_.stream_): + # do not use the pined_memory maybe can speedup + # need more test + assert msg.tensor_data_ is not None + copy_tensor = ( + torch.zeros_like(msg.tensor_data_, device=self.target_device_) + .copy_(msg.tensor_data_, non_blocking=True) + .detach() + ) + msg.tensor_data_ = copy_tensor + # msg.tensor_data_ = msg.tensor_data_.to( + # self.target_device_, non_blocking=True).detach() + + self.copy_stream_.poll() + + logging.debug( + f"{self.queue_name_} swap the message - {str(msg.msg_id_)[:8]} device end." + ) + self.dst_queue_.put(msg, block=True) + + def swap_thread(self): + logging.info(f"DeviceSwapQueue - {self.queue_name_} start.") + while not self.stop_ or not self.src_queue_.empty(): + self.swap_thread_loop() + logging.info(f"DeviceSwapQueue - {self.queue_name_} stop.") + + def start(self): + self.swap_thread_ = Thread(target=self.swap_thread) + self.swap_thread_.start() + + def stop(self): + self.stop_ = True + self.swap_thread_.join() + + def get(self) -> PipeMessage: + return self.dst_queue_.get(block=True) + + def get_waitime(self, timeout: int = 10) -> Optional[PipeMessage]: + try: + return self.dst_queue_.get(block=True, timeout=timeout) + except Exception: + return None + + def get_nowait(self) -> Optional[PipeMessage]: + try: + return self.dst_queue_.get_nowait() + except Exception: + return None + + def put(self, msg: PipeMessage): + self.src_queue_.put(msg) + + def empty(self) -> bool: + return self.src_queue_.empty() and self.dst_queue_.empty() diff --git a/mlora/executor/pipeline/rpc_transport.py b/mlora/executor/pipeline/rpc_transport.py new file mode 100644 index 00000000..63743b4a --- /dev/null +++ b/mlora/executor/pipeline/rpc_transport.py @@ -0,0 +1,247 @@ +import logging +import os +import queue +import uuid +from threading import Thread +from typing import Any, Dict, override + +import torch +import torch.distributed.rpc + +from .messages import PipeMessage, PipeMessageType +from .queue import DeviceSwapQueue +from .transport import Transport + +# save by different message type +# recv/send queue will automatically change the tensors' device +RPCMessageRecvQueues: Dict[PipeMessageType, DeviceSwapQueue] = {} + +RPCMessageSendQueues: Dict[PipeMessageType, DeviceSwapQueue] = {} + +RPCCOMMMessageRecvQueues: Dict[PipeMessageType, queue.Queue] = {} + +RPCCOMMMessageSendQueues: Dict[PipeMessageType, queue.Queue] = {} + + +def rpc_push_device_swap_queue(msg: PipeMessage) -> None: + global RPCMessageRecvQueues + + assert ( + msg.msg_type_ in RPCMessageRecvQueues + ), f"No this message type: {msg.msg_type_.value}" + assert RPCMessageRecvQueues[msg.msg_type_] is not None + + logging.debug(f"RpcTransport async recv the message: {str(msg.msg_id_)[:8]}.") + RPCMessageRecvQueues[msg.msg_type_].put(msg) + + +def rpc_push_comm_queue(msg: PipeMessage) -> None: + global RPCCOMMMessageRecvQueues + + assert ( + msg.msg_type_ in RPCCOMMMessageRecvQueues + ), f"No this comm message type: {msg.msg_type_.value}" + assert RPCCOMMMessageRecvQueues[msg.msg_type_] is not None + + logging.debug(f"RpcTransport async recv the comm message: {str(msg.msg_id_)[:8]}.") + RPCCOMMMessageRecvQueues[msg.msg_type_].put(msg) + + +# rpc transport thread +class RpcTransport(Transport): + rank_: int + world_size_: int + worker_device_: torch.device + + stop_: bool + activations_send_thread_: Thread + gradients_send_thread_: Thread + comm_send_thread_: Thread + + def __init__(self, rank: int, world_size: int, worker_device: torch.device) -> None: + super().__init__(rank, world_size, worker_device) + + self.stop_: bool = False + + self.__init_device_swap_queue() + self.__init_comm_queue() + self.__init_background_thread() + self.__init_rpc() + + def __init_rpc(self) -> None: + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "12355" + + assert self.rank_ > -1 + assert self.world_size_ > -1 + assert self.worker_device_ is not None + + # will be block when all world size's gpu join the group + torch.distributed.rpc.init_rpc( + f"worker-{self.rank_}", rank=self.rank_, world_size=self.world_size_ + ) + + logging.info(f"Init rpc with rank {self.rank_} world_size: {self.world_size_}") + + def __init_device_swap_queue(self): + cpu_device = torch.device("cpu") + + global RPCMessageSendQueues + for key in [PipeMessageType.ACTIVATIONS, PipeMessageType.GRADIENTS]: + RPCMessageSendQueues[key] = DeviceSwapQueue( + self.worker_device_, cpu_device, queue_name=f"{key.value}_send" + ) + RPCMessageSendQueues[key].start() + + global RPCMessageRecvQueues + for key in [PipeMessageType.ACTIVATIONS, PipeMessageType.GRADIENTS]: + RPCMessageRecvQueues[key] = DeviceSwapQueue( + cpu_device, self.worker_device_, queue_name=f"{key.value}_recv" + ) + RPCMessageRecvQueues[key].start() + + def __init_comm_queue(self): + global RPCCOMMMessageSendQueues + for key in [PipeMessageType.COMM]: + RPCCOMMMessageSendQueues[key] = queue.Queue() + + global RPCCOMMMessageRecvQueues + for key in [PipeMessageType.COMM]: + RPCCOMMMessageRecvQueues[key] = queue.Queue() + + def __init_background_thread(self): + self.gradients_send_thread_ = Thread( + target=self.__send_loop, args=(PipeMessageType.GRADIENTS,) + ) + self.activations_send_thread_ = Thread( + target=self.__send_loop, args=(PipeMessageType.ACTIVATIONS,) + ) + self.comm_send_thread_ = Thread( + target=self.__comm_send_loop, args=(PipeMessageType.COMM,) + ) + + self.gradients_send_thread_.start() + self.activations_send_thread_.start() + self.comm_send_thread_.start() + + def __send_loop(self, msg_type: PipeMessageType): + global RPCMessageSendQueues + send_queue: DeviceSwapQueue = RPCMessageSendQueues[msg_type] + assert send_queue is not None + + while not self.stop_ or not send_queue.empty(): + msg = send_queue.get_waitime() + if msg is None: + continue + assert msg.tensor_data_ is not None + assert msg.tensor_data_.device == torch.device("cpu") + logging.debug( + f"RpcTransport async send the message: {str(msg.msg_id_)[:8]} " + f"to {msg.dst_}." + ) + torch.distributed.rpc.rpc_async( + msg.dst_, rpc_push_device_swap_queue, args=(msg,) + ) + + def __comm_send_loop(self, msg_type: PipeMessageType): + global RPCCOMMMessageSendQueues + send_queue: queue.Queue = RPCCOMMMessageSendQueues[msg_type] + assert send_queue is not None + + while not self.stop_ or not send_queue.empty(): + try: + msg = send_queue.get(block=True, timeout=10) + except Exception: + continue + + logging.debug( + f"RpcTransport async send the message: {str(msg.msg_id_)[:8]}" + f" to {msg.dst_}." + ) + torch.distributed.rpc.rpc_async(msg.dst_, rpc_push_comm_queue, args=(msg,)) + + def __stop_send_loop(self): + global RPCMessageRecvQueues + global RPCMessageSendQueues + + # first should stop the recv queue + for key in RPCMessageRecvQueues: + RPCMessageRecvQueues[key].stop() + + # then stop the send queue + for key in RPCMessageSendQueues: + RPCMessageSendQueues[key].stop() + + self.stop_ = True + self.activations_send_thread_.join() + self.gradients_send_thread_.join() + self.comm_send_thread_.join() + + def __stop_rpc(self): + torch.distributed.rpc.shutdown() + + def stop(self): + self.__stop_send_loop() + self.__stop_rpc() + + @override + def recv_message( + self, msg_type: PipeMessageType, block: bool = False + ) -> PipeMessage | None: + global RPCMessageRecvQueues + + assert msg_type in RPCMessageRecvQueues + recv_queue: DeviceSwapQueue = RPCMessageRecvQueues[msg_type] + + if block: + return recv_queue.get() + else: + return recv_queue.get_nowait() + + @override + def send_message(self, msg: PipeMessage, sync: bool = False) -> None: + assert not sync, "RPC transport do not suppose sync == true!" + + global RPCMessageSendQueues + assert msg.msg_type_ in RPCMessageSendQueues + send_queue: DeviceSwapQueue = RPCMessageSendQueues[msg.msg_type_] + send_queue.put(msg) + + @override + def recv_comm(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage: + global RPCCOMMMessageRecvQueues + + assert msg_type in RPCCOMMMessageRecvQueues + recv_queue: queue.Queue = RPCCOMMMessageRecvQueues[msg_type] + + if block: + return recv_queue.get() + else: + return recv_queue.get_nowait() + + @override + def send_comm( + self, msg_type: PipeMessageType, data: Any, sync: bool = False + ) -> None: + pass + assert not sync, "RPC transport do not suppose sync == true!" + + msg_id = uuid.uuid4().int + + msg = PipeMessage( + src_=self.worker_name, + dst_=self.next_worker_name, + msg_type_=msg_type, + msg_id_=msg_id, + tensor_data_=None, + model_data_=None, + comm_data_=data, + ) + + global RPCCOMMMessageSendQueues + assert msg.msg_type_ in RPCCOMMMessageSendQueues + + send_queue: queue.Queue = RPCCOMMMessageSendQueues[msg.msg_type_] + send_queue.put(msg) diff --git a/mlora/executor/pipeline/stream.py b/mlora/executor/pipeline/stream.py new file mode 100644 index 00000000..e1e7217a --- /dev/null +++ b/mlora/executor/pipeline/stream.py @@ -0,0 +1,17 @@ +import time + +import torch + + +class CudaStream: + stream_: torch.cuda.Stream + event_: torch.cuda.Event + + def __init__(self, stream: torch.cuda.Stream) -> None: + self.stream_ = stream + self.event_ = torch.cuda.Event() + + def poll(self) -> None: + self.event_.record(stream=self.stream_) + while not self.event_.query(): + time.sleep(1 / 1000) diff --git a/mlora/executor/pipeline/transport.py b/mlora/executor/pipeline/transport.py new file mode 100644 index 00000000..6e6b99ac --- /dev/null +++ b/mlora/executor/pipeline/transport.py @@ -0,0 +1,50 @@ +from abc import ABC, abstractmethod +from typing import Any + +import torch +import torch.distributed.rpc + +from .messages import PipeMessage, PipeMessageType + + +class Transport(ABC): + rank_: int + world_size_: int + worker_device_: torch.device + + @property + def next_worker_name(self) -> str: + return f"worker-{(self.rank_ + 1) % self.world_size_}" + + @property + def prev_worker_name(self) -> str: + return f"worker-{(self.rank_ - 1) % self.world_size_}" + + @property + def worker_name(self) -> str: + return f"worker-{self.rank_}" + + @abstractmethod + def recv_message( + self, msg_type: PipeMessageType, block: bool = False + ) -> PipeMessage | None: + pass + + @abstractmethod + def send_message(self, msg: PipeMessage, sync: bool = False) -> None: + pass + + @abstractmethod + def recv_comm(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage: + pass + + @abstractmethod + def send_comm( + self, msg_type: PipeMessageType, data: Any, sync: bool = False + ) -> None: + pass + + def __init__(self, rank: int, world_size: int, worker_device: torch.device) -> None: + self.rank_ = rank + self.world_size_ = world_size + self.worker_device_ = worker_device diff --git a/mlora/executor/task/cit_task.py b/mlora/executor/task/cit_task.py index 603db94b..600218b7 100644 --- a/mlora/executor/task/cit_task.py +++ b/mlora/executor/task/cit_task.py @@ -108,6 +108,7 @@ def loss_fn( end_idx, self._expand_batch_tokens, loss_fn, + self.task_name(), ) return ret_tokens, [data_config] diff --git a/mlora/executor/task/cpo_task.py b/mlora/executor/task/cpo_task.py index 5281cc36..b5a82e97 100644 --- a/mlora/executor/task/cpo_task.py +++ b/mlora/executor/task/cpo_task.py @@ -100,7 +100,9 @@ def loss_fn( self.context_.path_ + "_loss", loss.item(), self.now_step_ ) mlora.profiler.metric_log( - self.context_.path_ + "_loss_prefer", loss_prefer.item(), self.now_step_ + self.context_.path_ + "_loss_prefer", + loss_prefer.mean().item(), + self.now_step_, ) logging.info(f"Adapter {self.context_.name_} loss: {loss}") return loss @@ -112,6 +114,7 @@ def loss_fn( end_idx, self._expand_batch_tokens, loss_fn, + self.task_name(), ) return ret_tokens, [data_config] diff --git a/mlora/executor/task/dpo_task.py b/mlora/executor/task/dpo_task.py index e8066809..0dbfc1f2 100644 --- a/mlora/executor/task/dpo_task.py +++ b/mlora/executor/task/dpo_task.py @@ -185,6 +185,7 @@ def loss_fn( ref_end_idx, self._expand_batch_tokens, lambda *_: None, + self.task_name(), ) policy_model_config = MLoRADataConfig( @@ -194,6 +195,7 @@ def loss_fn( policy_end_idx, self._expand_batch_tokens, loss_fn, + self.task_name(), ) return ret_tokens, [ref_model_config, policy_model_config] diff --git a/mlora/executor/task/task.py b/mlora/executor/task/task.py index 5e1b2f74..69f9995a 100644 --- a/mlora/executor/task/task.py +++ b/mlora/executor/task/task.py @@ -33,8 +33,6 @@ class Task: def __init__(self, config: TaskConfig, llm_name: str) -> None: self.config_ = config - self.prompter_ = PrompterFactory.create(config.dataset_) - self.data_ = [] self.now_data_idx_ = 0 self.now_step_ = 1 @@ -79,6 +77,14 @@ def _pre_dataset(self): "sort": lambda data: data.sort(), } + if self.config_.dataset_ is None: + logging.info( + "Task dataset is empty, maybe in pipeline we do not load dataset." + ) + return + + self.prompter_ = PrompterFactory.create(self.config_.dataset_) + logging.info(f"Task load data from {self.config_.dataset_.data_path_}") data = load_dataset( "json", data_files={"data_points": self.config_.dataset_.data_path_} diff --git a/mlora/executor/task/train_task.py b/mlora/executor/task/train_task.py index a68381c0..f1307348 100644 --- a/mlora/executor/task/train_task.py +++ b/mlora/executor/task/train_task.py @@ -153,6 +153,7 @@ def loss_fn( end_idx, self._expand_batch_tokens, loss_fn, + self.task_name(), ) return ret_tokens, [data_config] diff --git a/mlora/model/args.py b/mlora/model/args.py index a3747cd9..004e8d82 100644 --- a/mlora/model/args.py +++ b/mlora/model/args.py @@ -1,4 +1,5 @@ import logging +import uuid from dataclasses import dataclass from typing import Callable, List, Optional, Tuple @@ -87,6 +88,10 @@ class ModelData: enable_checkpoint_: bool + # the flag for serialize + random_id_: int + task_name_: List[str] + class MLoRADataConfig: adapter_name_: str @@ -102,6 +107,8 @@ class MLoRADataConfig: [torch.Tensor, torch.Tensor, torch.Tensor], Optional[torch.Tensor] ] + task_name_: str + def __init__( self, adapter_name: str, @@ -110,6 +117,7 @@ def __init__( end_idx: int, expand_fn: Callable, loss_fn: Callable, + task_name: str, ) -> None: self.adapter_name_ = adapter_name self.adapter_type_ = adapter_type @@ -119,6 +127,8 @@ def __init__( self.expand_fn_ = expand_fn self.loss_fn_ = loss_fn + self.task_name_ = task_name + def model_data_config(self) -> ModelDataConfig: return ModelDataConfig( adapter_name_=self.adapter_name_, @@ -134,6 +144,9 @@ class MLoRAData: batch_mask_: List[Masks] data_config_: List[MLoRADataConfig] + # the flag for serialize + random_id_: int + def __init__( self, batch_tokens: List[Tokens], @@ -143,6 +156,7 @@ def __init__( self.batch_tokens_ = batch_tokens self.batch_mask_ = batch_mask self.data_config_ = data_config + self.random_id_ = uuid.uuid4().int def model_data(self) -> ModelData: return ModelData( @@ -150,6 +164,8 @@ def model_data(self) -> ModelData: batch_mask_=self.batch_mask_, data_config_=[config.model_data_config() for config in self.data_config_], enable_checkpoint_=True, + task_name_=[config.task_name_ for config in self.data_config_], + random_id_=self.random_id_, ) def batch_size(self) -> int: diff --git a/mlora/model/llm/model_llama.py b/mlora/model/llm/model_llama.py index 947b8e32..1a464940 100644 --- a/mlora/model/llm/model_llama.py +++ b/mlora/model/llm/model_llama.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, override import torch +from torch.nn.modules import Sequential from transformers import AutoConfig, AutoModelForCausalLM from mlora.model.args import LinearInfo, LLMModelArgs, Masks, ModelData @@ -327,3 +328,7 @@ def linears_info(self) -> OrderedDict[str, LinearInfo]: continue ret_val.update(module.wrapper_module_.linears_info()) return ret_val + + @override + def sequential(self) -> Sequential: + return self.seq_module_ diff --git a/mlora/model/llm/model_llm.py b/mlora/model/llm/model_llm.py index 26c14f68..95f2980a 100644 --- a/mlora/model/llm/model_llm.py +++ b/mlora/model/llm/model_llm.py @@ -2,6 +2,8 @@ from collections import OrderedDict from typing import List, Optional +import torch + from mlora.model.args import LinearInfo, ModelData from mlora.model.modules import AdapterModel @@ -33,3 +35,6 @@ def offload_adapter(self, adapter_name: str): ... @abstractmethod def linears_info(self) -> OrderedDict[str, LinearInfo]: ... + + @abstractmethod + def sequential(self) -> torch.nn.Sequential: ... diff --git a/mlora/utils/cmd.py b/mlora/utils/cmd.py index 9354340b..9cf45de1 100644 --- a/mlora/utils/cmd.py +++ b/mlora/utils/cmd.py @@ -34,6 +34,9 @@ def _add_base_cmd(parser): ) parser.add_argument("--rank", type=int, default=-1, help="The device's rank number") parser.add_argument("--balance", type=int, nargs="+", help="The model's balance") + parser.add_argument( + "--recompute", type=bool, default=True, help="Enable recompute to save memory" + ) # configuration about log parser.add_argument( "--log_level", type=str, default="INFO", help="Set the log level." diff --git a/mlora_pp_train.py b/mlora_pp_train.py new file mode 100644 index 00000000..922a1aa9 --- /dev/null +++ b/mlora_pp_train.py @@ -0,0 +1,49 @@ +# m-LoRA: Efficient Multi-LoRA Fine Tuning with Shared-Based Model +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (C) 2024 All Rights Reserved. +# +# Github: https://github.com/TUDB-Labs/mLoRA + +import mlora.model +import mlora.utils +import mlora.executor +import mlora.config + +if __name__ == "__main__": + args = mlora.utils.get_cmd_args() + + mlora.utils.setup_seed(args.seed) + mlora.utils.setup_logging(args.log_level, args.log_file) + mlora.utils.setup_cuda_check() + mlora.utils.setup_metric_logger(args.metric_file) + + # enable the trace mode for profiling performance + if args.trace: + mlora.utils.setup_trace_mode() + + tokenizer, model = mlora.model.load_model(args) + config = mlora.config.MLoRAConfig(args.config) + + # init all task from config file + executor = mlora.executor.PipeExecutor( + model, tokenizer, config, args.device, args.rank, args.balance, args.recompute + ) + + # only the header node can add task + if args.rank == 0: + for item in config.tasks_: + executor.add_task(item) + + executor.execute() diff --git a/mlora_server.py b/mlora_server.py index b60bbb58..4a726246 100644 --- a/mlora_server.py +++ b/mlora_server.py @@ -74,8 +74,7 @@ def backend_server_run_fn(args): root_dir_list = mlora.server.root_dir_list() root_dir_list = dict( - map(lambda kv: (kv[0], os.path.join( - args.root, kv[1])), root_dir_list.items()) + map(lambda kv: (kv[0], os.path.join(args.root, kv[1])), root_dir_list.items()) ) mlora.server.set_root_dir_list(root_dir_list) @@ -171,7 +170,18 @@ def task_terminate_callback_fn(task: mlora.executor.task.Task): config = mlora.config.MLoRAServerConfig( {"name": "backend", "concurrency_num": args.concurrency_num} ) - executor = mlora.executor.Executor(model, tokenizer, config) + if args.pipeline: + executor = mlora.executor.PipeExecutor( + model, + tokenizer, + config, + args.device, + args.rank, + args.balance, + args.recompute, + ) + else: + executor = mlora.executor.Executor(model, tokenizer, config) executor.register_hook("done", task_done_callback_fn) executor.register_hook("step", task_step_callback_fn) executor.register_hook("terminate", task_terminate_callback_fn) diff --git a/tests/lora_op_test.py b/tests/lora_op_test.py index 6ad97775..c9eb308e 100644 --- a/tests/lora_op_test.py +++ b/tests/lora_op_test.py @@ -43,6 +43,8 @@ def lora_mlora(self): batch_mask_=[], data_config_=[ModelDataConfig("", "", 0, 2)], enable_checkpoint_=False, + random_id_=0, + task_name_=[""], ) weight = LoRAFunction.apply(