Skip to content

Commit

Permalink
support Pipeline parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 committed Jan 23, 2024
1 parent 04bf63d commit 9f95457
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 10 deletions.
32 changes: 23 additions & 9 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# Command Line Arguments
parser = argparse.ArgumentParser(description='m-LoRA main program')
# the argument about model and tokenizer
parser.add_argument('--base_model', type=str,
help='Path to or name of base model')
parser.add_argument('--tokenizer', type=str,
Expand All @@ -34,25 +35,31 @@
help='The model type, support: llama, chatglm')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')

parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')

parser.add_argument('--inference', action="store_true",
help='The inference mode (just for test)')

# the argument about lora
parser.add_argument('--load_lora', action="store_true",
help="Load lora from file instead of init randomly")
parser.add_argument('--disable_lora', action="store_true",
help="Disable the lora modules")

# the argument about pipeline
parser.add_argument('--pipeline', action="store_true",
help="Train the LoRA model use the pipeline parallelism")
parser.add_argument('--rank', type=int, default=-1,
help="The device's rank number")
parser.add_argument('--balance', type=int, default=-1,
help="The device's rank number")
# the argument about quant
parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')
# the argument about config
parser.add_argument('--config', type=str,
help='Path to finetune configuration')
parser.add_argument('--seed', type=int, default=42,
help='Random seed in integer, default is 42')

# the argument about log
parser.add_argument('--log_level', type=str, default="INFO",
help="Set the log level.")
parser.add_argument('--log_file', type=str,
Expand Down Expand Up @@ -293,6 +300,13 @@ def inference(config: Dict[str, any],
tokenizer, model = load_base_model(config)
init_lora_model(config, model)

if args.pipeline:
pipe = mlora.Pipe(model,
device=torch.device(args.device),
rank=args.rank,
balance=[9, 8, 8, 10])
exit(pipe.run())

torch.cuda.empty_cache()

if args.inference:
Expand Down
4 changes: 3 additions & 1 deletion mlora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mlora.model_chatglm import ChatGLMModel
from mlora.modelargs import LLMModelArgs, MultiLoraBatchData, LoraBatchDataConfig
from mlora.dispatcher import TrainTask, Dispatcher
from mlora.pipeline.pipe import Pipe

__all__ = [
"Tokenizer",
Expand All @@ -17,5 +18,6 @@
"convert_hf_to_pth",
"save_lora_model",
"TrainTask",
"Dispatcher"
"Dispatcher",
"Pipe"
]
3 changes: 3 additions & 0 deletions mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:


class LLMModel(metaclass=ABCMeta):
n_heads_: int = -1
rope_angle_: torch.tensor = None

@abstractclassmethod
def forward(self, input: MultiLoraBatchData):
pass
Expand Down
33 changes: 33 additions & 0 deletions mlora/pipeline/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from mlora.pipeline.messages import Transport, PipeMessage, PipeMessageType
from mlora.model import MultiLoraBatchData

import torch
from typing import Tuple


class SendOperator(torch.autograd.Function):
# helper to reduce the activation memory
@staticmethod
def forward(ctx,
transport: Transport,
tensors: Tuple[torch.Tensor, ...],
msg_id: int,
input_args: MultiLoraBatchData):
msg = PipeMessage(src_=transport.worker_name,
dst_=transport.next_worker_name,
msg_type_=PipeMessageType.ACTIVATIONS,
msg_id_=msg_id,
tensors_=tensors,
batch_data_=input_args)
transport.send_message(msg, False)
return torch.tensor(1.0)

@staticmethod
def backward(ctx, grad):
assert ctx.grad_from_next_worker is not None
return ctx.grad_from_next_worker


class RecvOperator(torch.autograd.Function):
# backward will auto send the grad to pre worker
pass
97 changes: 97 additions & 0 deletions mlora/pipeline/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from mlora.model import MultiLoraBatchData

import torch
import torch.distributed.rpc

from abc import ABC, abstractmethod
from dataclasses import dataclass
from queue import Queue
from typing import Dict, Tuple
from enum import Enum


class PipeMessageType(Enum):
ACTIVATIONS = "ACTIVATIONS"
GRADIENTS = "GRADIENTS"


@dataclass()
class PipeMessage:
src_: str
dst_: str

msg_type_: PipeMessageType
msg_id_: int

tensors_: Tuple[torch.Tensor, ...]
batch_data_: MultiLoraBatchData


# save by different message type
RPCMessageQueues: Dict[str, Queue] = {
PipeMessageType.ACTIVATIONS.value: Queue(),
PipeMessageType.GRADIENTS.value: Queue()
}


def rpc_push_queue(msg: PipeMessage) -> None:
print(RPCMessageQueues)
assert msg.msg_type_.value in globals(
)["RPCMessageQueues"], f"No this message type: {msg.msg_type_.value}"

globals()["RPCMessageQueues"][msg.msg_type_.value].put(msg)


class Transport(ABC):
rank_: int
device_: torch.device

@property
def next_worker_name(self) -> str:
return f"worker-{self.rank_ + 1}"

@property
def prev_worker_name(self) -> str:
return f"worker-{self.rank_ - 1}"

@property
def worker_name(self) -> str:
return f"worker-{self.rank_}"

@abstractmethod
def recv_message(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage:
pass

@abstractmethod
def send_message(self, msg: PipeMessage, sync: bool = False) -> None:
pass


class RpcTransport(Transport):
def __init__(self, rank: int, worker_device: torch.device) -> None:
super().__init__()
self.rank_ = rank
self.device_ = worker_device

# recv message will be blocked by convert the cpu memory to gpu memory
def recv_message(self, msg_type: PipeMessageType, block: bool = True) -> PipeMessage:
assert msg_type.value in globals(
)["RPCMessageQueues"], f"No this message type: {msg_type.value}"

queue = globals()["RPCMessageQueues"][msg_type.value]
if queue.empty():
return None

result = queue.get(block=block)
result.tensors_ = tuple(t.to(self.device_) for t in result.tensors_)
return result

# send message will be blocked by convert the gpu memory to cpu memory
def send_message(self, msg: PipeMessage, sync: bool = False) -> None:
msg.tensors_ = tuple(t.to("cpu") for t in msg.tensors_)
if sync:
send_fn = torch.distributed.rpc.rpc_sync
else:
send_fn = torch.distributed.rpc.rpc_async

send_fn(msg.dst_, rpc_push_queue, args=(msg,))
174 changes: 174 additions & 0 deletions mlora/pipeline/pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from mlora.modelargs import MultiLoraBatchData, LoraBatchDataConfig
from mlora.model import LLMModel, precompute_mask
from mlora.pipeline.messages import Transport, RpcTransport
from mlora.pipeline.messages import PipeMessage, PipeMessageType
from mlora.pipeline.function import SendOperator, RecvOperator

import os
import torch
import time
import logging

import torch.distributed.rpc
from typing import List, Dict
from enum import Enum, auto
from dataclasses import dataclass


class WorkerRole(Enum):
HEAD = auto()
MID = auto()
TAIL = auto()


@dataclass(frozen=True)
class PipeModelLocInfo():
n_heads_: int = -1
rope_angle_: torch.Tensor = None


class Pipe():
world_size_: int = -1
balance_: List[int] = None

rank_: int = -1
device_: torch.device = None

model_partition_: torch.nn.Sequential = torch.nn.Sequential()
model_info_: PipeModelLocInfo = None
role_: WorkerRole = None

transport_: Transport = None

backward_cache_: Dict[int, torch.Tensor] = {}

def __init__(self,
model: LLMModel,
device: torch.device,
rank: int,
balance: List[int]) -> None:

world_size = torch.cuda.device_count()
assert world_size == len(balance)

self.rank_ = rank
self.world_size_ = world_size
self.balance_ = balance
self.device_ = device
self.model_info_ = PipeModelLocInfo(
n_heads_=model.n_heads_,
rope_angle_=model.rope_angle_
)

if rank == 0:
self.role_ = WorkerRole.HEAD
elif rank == len(balance) - 1:
self.role_ = WorkerRole.TAIL
else:
self.role_ = WorkerRole.MID

self.transport_ = RpcTransport(self.rank_, self.device_)
self.init_partition(model)
self.init_rpc()

def run(self):
while True:
self.process_backward()
self.process_forward()
time.sleep(5)

def process_input(self) -> None:
if self.role_ != WorkerRole.HEAD:
return
logging.debug("Train input data.")
import random
test_data_len = 96
tokens_list = [[random.random() for _ in range(test_data_len)]]
batch_data = MultiLoraBatchData(
batch_tokens_=tokens_list,
additional_mask_=None,
lora_batch_data_config_=[LoraBatchDataConfig(
adapter_name_="lora_0",
batch_start_idx_=0,
batch_end_idx_=1
)],
inference_model_=False
)

tokens_batch = torch.tensor(
batch_data.batch_tokens_, dtype=torch.int64).to(self.device_)
mask = precompute_mask(
tokens_batch, self.model_info_.n_heads_, self.device_)

dummy_data = (tokens_batch, mask,
self.model_info_.rope_angle_, batch_data, True)

for seq_model in self.model_partition_:
dummy_data = seq_model.forward(dummy_data)

msg_id = id(dummy_data[0])
phony: torch.Tensor = SendOperator.apply(
self.transport_, (dummy_data[0], ), msg_id, batch_data)
self.backward_cache_[msg_id] = phony

def process_forward(self) -> None:
if self.role_ == WorkerRole.HEAD:
return self.process_input()
logging.debug("To recv the activations message.")
message = self.transport_.recv_message(
PipeMessageType.ACTIVATIONS, block=False)
if message is None:
logging.debug("No activations to process.")
return
logging.debug(f"Have already recv the activations from {message.src_}")

tokens_batch = message.tensors_[0]
mask = precompute_mask(
tokens_batch, self.model_info_.n_heads_, self.device_)
batch_data = message.batch_data_

dummy_data = (tokens_batch, mask,
self.model_info_.rope_angle_, batch_data, True)

for seq_model in self.model_partition_:
dummy_data = seq_model.forward(dummy_data)

def process_backward(self) -> None:
logging.debug("To recv the gradients message.")
message = self.transport_.recv_message(
PipeMessageType.GRADIENTS, block=False)
if message is None:
logging.debug("No gradients to process.")
return
logging.debug(f"Have already recv the gradients from {message.src_}")

def init_partition(self, model: LLMModel) -> None:
balance = self.balance_[self.rank_]
start_module_idx = sum(self.balance_[:self.rank_])
logging.info(
f"RANK-{self.rank_} in device {self.device_} to load module layers from {start_module_idx} to {start_module_idx + balance}.")

seq_model = model.sequential_module()
del seq_model[:start_module_idx]
del seq_model[balance:]
for idx in range(0, len(seq_model)):
self.model_partition_.append(seq_model[idx])

assert len(self.model_partition_) == balance

del model

torch.cuda.empty_cache()

def init_rpc(self) -> None:
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = "localhost"
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "12355"

# will be block when all world size's gpu join the group
torch.distributed.rpc.init_rpc(
f"worker-{self.rank_}", rank=self.rank_, world_size=self.world_size_)

logging.info(
f"Init rpc with rank {self.rank_} world_size: {self.world_size_}")

0 comments on commit 9f95457

Please sign in to comment.