Skip to content

Commit

Permalink
support Pipeline parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 committed Jan 22, 2024
1 parent 39c62de commit 3d6f506
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 10 deletions.
32 changes: 23 additions & 9 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# Command Line Arguments
parser = argparse.ArgumentParser(description='m-LoRA main program')
# the argument about model and tokenizer
parser.add_argument('--base_model', type=str,
help='Path to or name of base model')
parser.add_argument('--tokenizer', type=str,
Expand All @@ -34,25 +35,31 @@
help='The model type, support: llama, chatglm')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')

parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')

parser.add_argument('--inference', action="store_true",
help='The inference mode (just for test)')

# the argument about lora
parser.add_argument('--load_lora', action="store_true",
help="Load lora from file instead of init randomly")
parser.add_argument('--disable_lora', action="store_true",
help="Disable the lora modules")

# the argument about pipeline
parser.add_argument('--pipeline', action="store_true",
help="Train the LoRA model use the pipeline parallelism")
parser.add_argument('--rank', type=int, default=-1,
help="The device's rank number")
parser.add_argument('--balance', type=int, default=-1,
help="The device's rank number")
# the argument about quant
parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')
# the argument about config
parser.add_argument('--config', type=str,
help='Path to finetune configuration')
parser.add_argument('--seed', type=int, default=42,
help='Random seed in integer, default is 42')

# the argument about log
parser.add_argument('--log_level', type=str, default="INFO",
help="Set the log level.")
parser.add_argument('--log_file', type=str,
Expand Down Expand Up @@ -294,6 +301,13 @@ def inference(config: Dict[str, any],
tokenizer, model = load_base_model(config)
init_lora_model(config, model)

if args.pipeline:
pipe = mlora.Pipe(model,
device=torch.device(args.device),
rank=args.rank,
balance=[9, 8, 8, 10])
exit(pipe.run())

torch.cuda.empty_cache()

if args.inference:
Expand Down
4 changes: 3 additions & 1 deletion mlora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mlora.model_chatglm import ChatGLMModel
from mlora.modelargs import LLMModelArgs, MultiLoraBatchData, LoraBatchDataConfig
from mlora.dispatcher import TrainTask, Dispatcher
from mlora.pipeline.pipe import Pipe

__all__ = [
"Tokenizer",
Expand All @@ -17,5 +18,6 @@
"convert_hf_to_pth",
"save_lora_model",
"TrainTask",
"Dispatcher"
"Dispatcher",
"Pipe"
]
3 changes: 3 additions & 0 deletions mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:


class LLMModel(metaclass=ABCMeta):
n_heads_: int = -1
rope_angle_: torch.tensor = None

@abstractclassmethod
def forward(self, input: MultiLoraBatchData):
pass
Expand Down
Empty file added mlora/pipeline/dispatcher.py
Empty file.
9 changes: 9 additions & 0 deletions mlora/pipeline/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch

# helper to reduce the activation memory
class AutogradWithoutActivations(torch.autograd.Function):
pass

# backward will auto send the grad to pre worker
class RecvOperator(torch.autograd.Function):
pass
77 changes: 77 additions & 0 deletions mlora/pipeline/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import torch.distributed.rpc

from abc import ABC, abstractmethod
from dataclasses import dataclass
from queue import Queue
from typing import Dict
from enum import Enum


class PipeMessageType(Enum):
ACTIVATIONS = "ACTIVATIONS"
GRADIENTS = "GRADIENTS"


@dataclass()
class PipeMessage:
src_: str
dst_: str

msg_type_: PipeMessageType
tensors_: torch.Tensor


# save by different message type
RPCMessageQueues: Dict[str, Queue] = {
PipeMessageType.ACTIVATIONS.value: Queue(),
PipeMessageType.GRADIENTS.value: Queue()
}


def rpc_push_queue(msg: PipeMessage) -> None:
assert msg.msg_type_ in globals(
)["RPCMessageQueues"], f"No this message type: {msg.msg_type_.value}"

globals()["RPCMessageQueues"][msg.msg_type_.value].put(msg)


class Transport(ABC):
device_: torch.device

@abstractmethod
def recv_message(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage:
pass

@abstractmethod
def send_message(self, msg: PipeMessage, sync: bool = False) -> None:
pass


class RpcTransport(Transport):
def __init__(self, worker_device: torch.device) -> None:
super().__init__()
self.device_ = worker_device

# recv message will be blocked by convert the cpu memory to gpu memory
def recv_message(self, msg_type: PipeMessageType, block: bool = True) -> PipeMessage:
assert msg_type.value in globals(
)["RPCMessageQueues"], f"No this message type: {msg_type.value}"

queue = globals()["RPCMessageQueues"][msg_type.value]
if queue.empty():
return None

result = queue.get(block=block)
result.tensors_ = tuple(t.to(self.device_) for t in result.tensors_)
return result

# send message will be blocked by convert the gpu memory to cpu memory
def send_message(self, msg: PipeMessage, sync: bool = False) -> None:
msg.tensors_ = tuple(t.to("cpu") for t in msg.tensors_)
if sync:
send_fn = torch.distributed.rpc.rpc_sync
else:
send_fn = torch.distributed.rpc.rpc_async

send_fn(msg.dst_, rpc_push_queue, args=(msg,))
160 changes: 160 additions & 0 deletions mlora/pipeline/pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from mlora.modelargs import MultiLoraBatchData, LoraBatchDataConfig
from mlora.model import LLMModel, precompute_mask
from mlora.pipeline.messages import Transport, RpcTransport
from mlora.pipeline.messages import PipeMessage, PipeMessageType

import os
import torch
import time
import logging

import torch.distributed.rpc
from typing import List
from enum import Enum, auto
from dataclasses import dataclass


class WorkerRole(Enum):
HEAD = auto()
MID = auto()
TAIL = auto()


@dataclass(frozen=True)
class PipeModelLocInfo():
n_heads_: int = -1
rope_angle_: torch.Tensor = None


class Pipe():
world_size_: int = -1
balance_: List[int] = None

rank_: int = -1
device_: torch.device = None

model_partition_: torch.nn.Sequential = torch.nn.Sequential()
model_info_: PipeModelLocInfo = None
role_: WorkerRole = None

transport_: Transport = None

def __init__(self,
model: LLMModel,
device: torch.device,
rank: int,
balance: List[int]) -> None:

world_size = torch.cuda.device_count()
assert world_size == len(balance)

self.rank_ = rank
self.world_size_ = world_size
self.balance_ = balance
self.device_ = device
self.model_info_ = PipeModelLocInfo(
n_heads_=model.n_heads_,
rope_angle_=model.rope_angle_
)

if rank == 0:
self.role_ = WorkerRole.HEAD
elif rank == len(balance) - 1:
self.role_ = WorkerRole.TAIL
else:
self.role_ = WorkerRole.MID

self.transport_ = RpcTransport(self.device_)
self.init_partition(model)
self.init_rpc()

@property
def next_worker_name(self) -> str:
return f"worker-{self.rank_ + 1}"

@property
def prev_worker_name(self) -> str:
return f"worker-{self.rank_ - 1}"

# process the backward
# process the forward
def run(self):
while True:
self.process_backward()
self.process_forward()
self.process_input()
time.sleep(5)

def process_input(self) -> None:
if self.role_ != WorkerRole.HEAD:
return
logging.debug("Train input data.")
import random
test_data_len = 96
tokens_list = [[random.random() for _ in range(test_data_len)]]
input = MultiLoraBatchData(
batch_tokens_=tokens_list,
additional_mask_=None,
lora_batch_data_config_=[LoraBatchDataConfig(
adapter_name_="lora_0",
batch_start_idx_=0,
batch_end_idx_=1
)],
inference_model_=False
)

tokens_batch = torch.tensor(
input.batch_tokens_, dtype=torch.int64).to(self.device_)
mask = precompute_mask(
tokens_batch, self.model_info_.n_heads_, self.device_)

dummy_data = (tokens_batch, mask,
self.model_info_.rope_angle_, input, True)

for seq_model in self.model_partition_:
dummy_data = seq_model.forward(dummy_data)

def process_forward(self) -> None:
if self.role_ == WorkerRole.HEAD:
return
pass

def process_backward(self) -> None:
logging.debug("To recv the gradients message.")
message = self.transport_.recv_message(
PipeMessageType.GRADIENTS, block=False)
if message is None:
logging.debug("No gradients to deal.")
return
logging.debug(f"Have already recv the gradients from {message.src_}")

def init_partition(self, model: LLMModel) -> None:
balance = self.balance_[self.rank_]
start_module_idx = sum(self.balance_[:self.rank_])
logging.info(
f"RANK-{self.rank_} in device {self.device_} to load module layers from {start_module_idx} to {start_module_idx + balance}.")

seq_model = model.sequential_module()
del seq_model[:start_module_idx]
del seq_model[balance:]
for idx in range(0, len(seq_model)):
self.model_partition_.append(seq_model[idx])

assert len(self.model_partition_) == balance

del model

torch.cuda.empty_cache()

def init_rpc(self) -> None:
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "12355"

# will be block when all world size's gpu join the group
torch.distributed.rpc.init_rpc(
f"worker-{self.rank_}", rank=self.rank_, world_size=self.world_size_)

logging.info(
f"Init rpc with rank {self.rank_} world_size: {self.world_size_}")

0 comments on commit 3d6f506

Please sign in to comment.