Skip to content

Commit

Permalink
support Pipeline parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 committed Jan 25, 2024
1 parent 04bf63d commit 453b176
Show file tree
Hide file tree
Showing 6 changed files with 413 additions and 11 deletions.
32 changes: 23 additions & 9 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

# Command Line Arguments
parser = argparse.ArgumentParser(description='m-LoRA main program')
# the argument about model and tokenizer
parser.add_argument('--base_model', type=str,
help='Path to or name of base model')
parser.add_argument('--tokenizer', type=str,
Expand All @@ -34,25 +35,31 @@
help='The model type, support: llama, chatglm')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')

parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')

parser.add_argument('--inference', action="store_true",
help='The inference mode (just for test)')

# the argument about lora
parser.add_argument('--load_lora', action="store_true",
help="Load lora from file instead of init randomly")
parser.add_argument('--disable_lora', action="store_true",
help="Disable the lora modules")

# the argument about pipeline
parser.add_argument('--pipeline', action="store_true",
help="Train the LoRA model use the pipeline parallelism")
parser.add_argument('--rank', type=int, default=-1,
help="The device's rank number")
parser.add_argument('--balance', type=int, default=-1,
help="The device's rank number")
# the argument about quant
parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')
# the argument about config
parser.add_argument('--config', type=str,
help='Path to finetune configuration')
parser.add_argument('--seed', type=int, default=42,
help='Random seed in integer, default is 42')

# the argument about log
parser.add_argument('--log_level', type=str, default="INFO",
help="Set the log level.")
parser.add_argument('--log_file', type=str,
Expand Down Expand Up @@ -293,6 +300,13 @@ def inference(config: Dict[str, any],
tokenizer, model = load_base_model(config)
init_lora_model(config, model)

if args.pipeline:
pipe = mlora.Pipe(model,
device=torch.device(args.device),
rank=args.rank,
balance=[9, 8, 8, 10])
exit(pipe.run())

torch.cuda.empty_cache()

if args.inference:
Expand Down
4 changes: 3 additions & 1 deletion mlora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mlora.model_chatglm import ChatGLMModel
from mlora.modelargs import LLMModelArgs, MultiLoraBatchData, LoraBatchDataConfig
from mlora.dispatcher import TrainTask, Dispatcher
from mlora.pipeline.pipe import Pipe

__all__ = [
"Tokenizer",
Expand All @@ -17,5 +18,6 @@
"convert_hf_to_pth",
"save_lora_model",
"TrainTask",
"Dispatcher"
"Dispatcher",
"Pipe"
]
11 changes: 10 additions & 1 deletion mlora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


# input_tokens shape is: batch_size * seq_len
# or batch_size * seq_len * dim
# default: upper triangular matrix like below, i.e. diagonal = 1
# 0 -inf -inf
# 0 0 -inf
Expand All @@ -24,7 +25,12 @@ def precompute_mask(input_tokens: torch.Tensor,
additional_mask: List[Masks] = None,
diagonal: int = 1,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
batch_size, seq_len = input_tokens.shape
if len(input_tokens.shape) == 2:
batch_size, seq_len = input_tokens.shape
elif len(input_tokens.shape) == 3:
batch_size, seq_len, _ = input_tokens.shape
else:
raise Exception(f"{input_tokens.shape} is not support.")

mask = torch.full((batch_size, n_heads, seq_len, seq_len),
float("-inf"), device=device, dtype=dtype)
Expand Down Expand Up @@ -112,6 +118,9 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:


class LLMModel(metaclass=ABCMeta):
n_heads_: int = -1
rope_angle_: torch.tensor = None

@abstractclassmethod
def forward(self, input: MultiLoraBatchData):
pass
Expand Down
67 changes: 67 additions & 0 deletions mlora/pipeline/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from mlora.pipeline.messages import Transport, PipeMessage, PipeMessageType
from mlora.model import MultiLoraBatchData

import logging
import torch


class SendOperator(torch.autograd.Function):
# helper to reduce the activation memory
@staticmethod
def forward(ctx,
phony: torch.Tensor,
tensor_data: torch.Tensor,
transport: Transport,
msg_id: int,
input_args: MultiLoraBatchData):
assert isinstance(tensor_data, torch.Tensor)

msg = PipeMessage(src_=transport.worker_name,
dst_=transport.next_worker_name,
msg_type_=PipeMessageType.ACTIVATIONS,
msg_id_=msg_id,
tensor_data_=tensor_data,
batch_data_=input_args)
transport.send_message(msg, False)

return phony

@staticmethod
def backward(ctx, grad_output):
assert ctx.grad_from_next_worker is not None

return (None, ctx.grad_from_next_worker, None, None, None)


class RecvOperator(torch.autograd.Function):
# backward will auto send the grad to pre worker
@staticmethod
def forward(ctx,
phony: torch.Tensor,
transport: Transport,
msg: PipeMessage) -> torch.Tensor:
assert msg.msg_type_ == PipeMessageType.ACTIVATIONS
assert isinstance(msg.tensor_data_, torch.Tensor)

ctx.msg_id_ = msg.msg_id_
ctx.transport_ = transport

return msg.tensor_data_ * phony

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# now only signle grad can be support
assert isinstance(grad_output, torch.Tensor)

transport: Transport = ctx.transport_
logging.debug(f"Send the gradients to {transport.prev_worker_name}")
transport.send_message(PipeMessage(
src_=transport.worker_name,
dst_=transport.prev_worker_name,
msg_type_=PipeMessageType.GRADIENTS,
msg_id_=ctx.msg_id_,
tensor_data_=grad_output,
batch_data_=None,
))

return (None, None, None)
96 changes: 96 additions & 0 deletions mlora/pipeline/messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from mlora.model import MultiLoraBatchData

import torch
import torch.distributed.rpc

from abc import ABC, abstractmethod
from dataclasses import dataclass
from queue import Queue
from typing import Dict, Tuple
from enum import Enum


class PipeMessageType(Enum):
ACTIVATIONS = "ACTIVATIONS"
GRADIENTS = "GRADIENTS"


@dataclass()
class PipeMessage:
src_: str
dst_: str

msg_type_: PipeMessageType
msg_id_: int

tensor_data_: torch.Tensor
batch_data_: MultiLoraBatchData


# save by different message type
RPCMessageQueues: Dict[str, Queue] = {
PipeMessageType.ACTIVATIONS: Queue(),
PipeMessageType.GRADIENTS: Queue()
}


def rpc_push_queue(msg: PipeMessage) -> None:
assert msg.msg_type_ in globals(
)["RPCMessageQueues"], f"No this message type: {msg.msg_type_.value}"

globals()["RPCMessageQueues"][msg.msg_type_].put(msg)


class Transport(ABC):
rank_: int
device_: torch.device

@property
def next_worker_name(self) -> str:
return f"worker-{self.rank_ + 1}"

@property
def prev_worker_name(self) -> str:
return f"worker-{self.rank_ - 1}"

@property
def worker_name(self) -> str:
return f"worker-{self.rank_}"

@abstractmethod
def recv_message(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage:
pass

@abstractmethod
def send_message(self, msg: PipeMessage, sync: bool = False) -> None:
pass


class RpcTransport(Transport):
def __init__(self, rank: int, worker_device: torch.device) -> None:
super().__init__()
self.rank_ = rank
self.device_ = worker_device

# recv message will be blocked by convert the cpu memory to gpu memory
def recv_message(self, msg_type: PipeMessageType, block: bool = True) -> PipeMessage:
assert msg_type in globals(
)["RPCMessageQueues"], f"No this message type: {msg_type.value}"

queue = globals()["RPCMessageQueues"][msg_type]
if queue.empty():
return None

result: PipeMessage = queue.get(block=block)
result.tensor_data_ = result.tensor_data_.to(self.device_)
return result

# send message will be blocked by convert the gpu memory to cpu memory
def send_message(self, msg: PipeMessage, sync: bool = False) -> None:
msg.tensor_data_ = msg.tensor_data_.to("cpu")
if sync:
send_fn = torch.distributed.rpc.rpc_sync
else:
send_fn = torch.distributed.rpc.rpc_async

send_fn(msg.dst_, rpc_push_queue, args=(msg,))
Loading

0 comments on commit 453b176

Please sign in to comment.