Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] support pipeline parallelism #242

Merged
merged 1 commit into from
Jul 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -58,6 +58,43 @@ For further detailed usage information, please use `--help` option:
python mlora_train.py --help
```

## Deployment using pipeline parallelism
Similar to Quickstart, the command to start in a two-node environment is as follows:

NOTE1: Use environment variables `MASTER_ADDR/MASTER_PORT` to set the master node.

NOTE2: Set balance, indicating the number of decoder layers allocated to each rank.


```bash
# in the first node
export MASTER_ADDR=master.svc.cluster.local
export MASTER_PORT=12355
python mlora_pp_train.py \
--base_model TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
--config demo/lora/lora_case_1.yaml \
--pipeline \
--device "cuda:0" \
--rank 0 \
--balance 12 13 \
--recompute False \
--precision fp32

# in the second node
export MASTER_ADDR=master.svc.cluster.local
export MASTER_PORT=12355
python mlora_pp_train.py \
--base_model TinyLlama/TinyLlama-1.1B-Chat-v0.4 \
--config demo/lora/lora_case_1.yaml \
--pipeline \
--device "cuda:1" \
--rank 1 \
--balance 12 13 \
--recompute False \
--precision fp32
```


## Quickstart with Docker
mLoRA offers an official Docker image for quick start and development, The image is available on Dockerhub Packages registry.

2 changes: 1 addition & 1 deletion mlora/config/task.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ def __init__(
self.init(self.__params_map, config)

self.adapter_ = adapters[config["adapter"]]
self.dataset_ = datasets[config["dataset"]]
self.dataset_: DatasetConfig | None = datasets[config["dataset"]]


class TrainTaskConfig(TaskConfig):
3 changes: 2 additions & 1 deletion mlora/executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .executor import Executor
from .pipe_executor import PipeExecutor

__all__ = ["Executor"]
__all__ = ["Executor", "PipeExecutor"]
11 changes: 9 additions & 2 deletions mlora/executor/dispatcher/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Dict, Type

from .backend_dispatcher import BackendDispatcher
from .dispatcher import Dispatcher
from .pipe_dispatcher import PipeDispatcher

DISPATCHER_CLASS = {"default": Dispatcher, "backend": BackendDispatcher}
DISPATCHER_CLASS: Dict[str, Type[Dispatcher]] = {
"default": Dispatcher,
"backend": BackendDispatcher,
"pipe": PipeDispatcher,
}

__all__ = ["Dispatcher", "BackendDispatcher", "DISPATCHER_CLASS"]
__all__ = ["Dispatcher", "BackendDispatcher", "PipeDispatcher", "DISPATCHER_CLASS"]
2 changes: 1 addition & 1 deletion mlora/executor/dispatcher/dispatcher.py
Original file line number Diff line number Diff line change
@@ -131,7 +131,7 @@ def _align_batch_tokens(

return batch_tokens, batch_masks

def data(self) -> MLoRAData:
def data(self) -> MLoRAData | None:
self._dispatch_task_in()

batch_tokens: List[Tokens] = []
120 changes: 120 additions & 0 deletions mlora/executor/dispatcher/pipe_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import List, Set, override

from mlora.config.dispatcher import DispatcherConfig
from mlora.executor.task import Task
from mlora.model.args import Masks, MLoRAData, MLoRADataConfig, Tokens

from .backend_dispatcher import BackendDispatcher


class PipeDispatcher(BackendDispatcher):
lock_set_: Set[str]

def __init__(self, config: DispatcherConfig) -> None:
super().__init__(config)
self.lock_set_ = set()

@override
def _dispatch_task_in(self):
# ready task to terminate
terminate_task = [task for task in self.ready_ if task.is_terminate()]
self.ready_ = [task for task in self.ready_ if not task.is_terminate()]

for task in terminate_task:
self.terminate_event_.notify(task)

# pipeline only have one running task
while len(self.running_) <= self.concurrency_num_ and len(self.ready_) > 0:
task = self.ready_.pop(0)
self.running_.append(task)
self.running_event_.notify(task)

def find_the_task(self, task_name: str) -> Task:
# the worker do not really dispather the task
# so we just find it in the read
for task in self.ready_:
if task.task_name() != task_name:
continue
return task
raise Exception(f"No this task {task.task_name()}")

# if not the head worker, we need to manully dispatch the task
def dispatch_task_to_run(self, task_name: str):
task = self.find_the_task(task_name)
self.running_event_.notify(task)

def dispatch_task_to_ready(self, task_name: str):
task = self.find_the_task(task_name)
self.ready_event_.notify(task)

def dispatch_task_to_done(self, task_name: str):
task = self.find_the_task(task_name)
self.done_event_.notify(task)

def dispatch_task_to_terminal(self, task_name: str):
task = self.find_the_task(task_name)
self.terminate_event_.notify(task)

def dispatch_task_to_step(self, task_name: str):
task = self.find_the_task(task_name)
task.step()
self.step_event_.notify(task)

def lock_task(self, name: str):
self.lock_set_.add(name)

def unlock_task(self, name: str):
if name not in self.lock_set_:
return
self.lock_set_.remove(name)

def is_lock(self, name: str):
return name in self.lock_set_

@override
def data(self) -> MLoRAData | None:
self._dispatch_task_in()

batch_tokens: List[Tokens] = []
batch_masks: List[Masks] = []
data_configs: List[MLoRADataConfig] = []

can_run_task = list(
filter(lambda task: not self.is_lock(task.task_name()), self.running_)
)

if len(can_run_task) == 0:
return None

# get all train data
start_idx: int = 0
# pipe dispatcher just run one task
task = can_run_task[0]

data, data_config = task.data(start_idx)

# for unlock the task
for item in data_config:
item.task_name_ = task.task_name()

data_configs.extend(data_config)
batch_tokens.extend(data)
start_idx = start_idx + len(data)
self.lock_task(task.task_name())

# post process this batch data
batch_tokens, batch_masks = self._align_batch_tokens(batch_tokens, data_configs)

return MLoRAData(
batch_tokens=batch_tokens, batch_mask=batch_masks, data_config=data_configs
)

def task_step(self, task_name: str):
# in head worker the task must in running
for task in self.running_:
if task.task_name() != task_name:
continue
task.step()
self.step_event_.notify(task)

self._dispatch_task_out()
3 changes: 2 additions & 1 deletion mlora/executor/executor.py
Original file line number Diff line number Diff line change
@@ -99,7 +99,8 @@ def execute(self) -> None:
mm_collect_step = 0

while not self.dispatcher_.is_done():
data: MLoRAData = self.dispatcher_.data()
data: MLoRAData | None = self.dispatcher_.data()
assert data is not None

torch.cuda.reset_peak_memory_stats(device=self.model_.device_)

Loading