-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
04bf63d
commit 9f95457
Showing
6 changed files
with
333 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from mlora.pipeline.messages import Transport, PipeMessage, PipeMessageType | ||
from mlora.model import MultiLoraBatchData | ||
|
||
import torch | ||
from typing import Tuple | ||
|
||
|
||
class SendOperator(torch.autograd.Function): | ||
# helper to reduce the activation memory | ||
@staticmethod | ||
def forward(ctx, | ||
transport: Transport, | ||
tensors: Tuple[torch.Tensor, ...], | ||
msg_id: int, | ||
input_args: MultiLoraBatchData): | ||
msg = PipeMessage(src_=transport.worker_name, | ||
dst_=transport.next_worker_name, | ||
msg_type_=PipeMessageType.ACTIVATIONS, | ||
msg_id_=msg_id, | ||
tensors_=tensors, | ||
batch_data_=input_args) | ||
transport.send_message(msg, False) | ||
return torch.tensor(1.0) | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
assert ctx.grad_from_next_worker is not None | ||
return ctx.grad_from_next_worker | ||
|
||
|
||
class RecvOperator(torch.autograd.Function): | ||
# backward will auto send the grad to pre worker | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from mlora.model import MultiLoraBatchData | ||
|
||
import torch | ||
import torch.distributed.rpc | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from queue import Queue | ||
from typing import Dict, Tuple | ||
from enum import Enum | ||
|
||
|
||
class PipeMessageType(Enum): | ||
ACTIVATIONS = "ACTIVATIONS" | ||
GRADIENTS = "GRADIENTS" | ||
|
||
|
||
@dataclass() | ||
class PipeMessage: | ||
src_: str | ||
dst_: str | ||
|
||
msg_type_: PipeMessageType | ||
msg_id_: int | ||
|
||
tensors_: Tuple[torch.Tensor, ...] | ||
batch_data_: MultiLoraBatchData | ||
|
||
|
||
# save by different message type | ||
RPCMessageQueues: Dict[str, Queue] = { | ||
PipeMessageType.ACTIVATIONS.value: Queue(), | ||
PipeMessageType.GRADIENTS.value: Queue() | ||
} | ||
|
||
|
||
def rpc_push_queue(msg: PipeMessage) -> None: | ||
print(RPCMessageQueues) | ||
assert msg.msg_type_.value in globals( | ||
)["RPCMessageQueues"], f"No this message type: {msg.msg_type_.value}" | ||
|
||
globals()["RPCMessageQueues"][msg.msg_type_.value].put(msg) | ||
|
||
|
||
class Transport(ABC): | ||
rank_: int | ||
device_: torch.device | ||
|
||
@property | ||
def next_worker_name(self) -> str: | ||
return f"worker-{self.rank_ + 1}" | ||
|
||
@property | ||
def prev_worker_name(self) -> str: | ||
return f"worker-{self.rank_ - 1}" | ||
|
||
@property | ||
def worker_name(self) -> str: | ||
return f"worker-{self.rank_}" | ||
|
||
@abstractmethod | ||
def recv_message(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage: | ||
pass | ||
|
||
@abstractmethod | ||
def send_message(self, msg: PipeMessage, sync: bool = False) -> None: | ||
pass | ||
|
||
|
||
class RpcTransport(Transport): | ||
def __init__(self, rank: int, worker_device: torch.device) -> None: | ||
super().__init__() | ||
self.rank_ = rank | ||
self.device_ = worker_device | ||
|
||
# recv message will be blocked by convert the cpu memory to gpu memory | ||
def recv_message(self, msg_type: PipeMessageType, block: bool = True) -> PipeMessage: | ||
assert msg_type.value in globals( | ||
)["RPCMessageQueues"], f"No this message type: {msg_type.value}" | ||
|
||
queue = globals()["RPCMessageQueues"][msg_type.value] | ||
if queue.empty(): | ||
return None | ||
|
||
result = queue.get(block=block) | ||
result.tensors_ = tuple(t.to(self.device_) for t in result.tensors_) | ||
return result | ||
|
||
# send message will be blocked by convert the gpu memory to cpu memory | ||
def send_message(self, msg: PipeMessage, sync: bool = False) -> None: | ||
msg.tensors_ = tuple(t.to("cpu") for t in msg.tensors_) | ||
if sync: | ||
send_fn = torch.distributed.rpc.rpc_sync | ||
else: | ||
send_fn = torch.distributed.rpc.rpc_async | ||
|
||
send_fn(msg.dst_, rpc_push_queue, args=(msg,)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from mlora.modelargs import MultiLoraBatchData, LoraBatchDataConfig | ||
from mlora.model import LLMModel, precompute_mask | ||
from mlora.pipeline.messages import Transport, RpcTransport | ||
from mlora.pipeline.messages import PipeMessage, PipeMessageType | ||
from mlora.pipeline.function import SendOperator, RecvOperator | ||
|
||
import os | ||
import torch | ||
import time | ||
import logging | ||
|
||
import torch.distributed.rpc | ||
from typing import List, Dict | ||
from enum import Enum, auto | ||
from dataclasses import dataclass | ||
|
||
|
||
class WorkerRole(Enum): | ||
HEAD = auto() | ||
MID = auto() | ||
TAIL = auto() | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PipeModelLocInfo(): | ||
n_heads_: int = -1 | ||
rope_angle_: torch.Tensor = None | ||
|
||
|
||
class Pipe(): | ||
world_size_: int = -1 | ||
balance_: List[int] = None | ||
|
||
rank_: int = -1 | ||
device_: torch.device = None | ||
|
||
model_partition_: torch.nn.Sequential = torch.nn.Sequential() | ||
model_info_: PipeModelLocInfo = None | ||
role_: WorkerRole = None | ||
|
||
transport_: Transport = None | ||
|
||
backward_cache_: Dict[int, torch.Tensor] = {} | ||
|
||
def __init__(self, | ||
model: LLMModel, | ||
device: torch.device, | ||
rank: int, | ||
balance: List[int]) -> None: | ||
|
||
world_size = torch.cuda.device_count() | ||
assert world_size == len(balance) | ||
|
||
self.rank_ = rank | ||
self.world_size_ = world_size | ||
self.balance_ = balance | ||
self.device_ = device | ||
self.model_info_ = PipeModelLocInfo( | ||
n_heads_=model.n_heads_, | ||
rope_angle_=model.rope_angle_ | ||
) | ||
|
||
if rank == 0: | ||
self.role_ = WorkerRole.HEAD | ||
elif rank == len(balance) - 1: | ||
self.role_ = WorkerRole.TAIL | ||
else: | ||
self.role_ = WorkerRole.MID | ||
|
||
self.transport_ = RpcTransport(self.rank_, self.device_) | ||
self.init_partition(model) | ||
self.init_rpc() | ||
|
||
def run(self): | ||
while True: | ||
self.process_backward() | ||
self.process_forward() | ||
time.sleep(5) | ||
|
||
def process_input(self) -> None: | ||
if self.role_ != WorkerRole.HEAD: | ||
return | ||
logging.debug("Train input data.") | ||
import random | ||
test_data_len = 96 | ||
tokens_list = [[random.random() for _ in range(test_data_len)]] | ||
batch_data = MultiLoraBatchData( | ||
batch_tokens_=tokens_list, | ||
additional_mask_=None, | ||
lora_batch_data_config_=[LoraBatchDataConfig( | ||
adapter_name_="lora_0", | ||
batch_start_idx_=0, | ||
batch_end_idx_=1 | ||
)], | ||
inference_model_=False | ||
) | ||
|
||
tokens_batch = torch.tensor( | ||
batch_data.batch_tokens_, dtype=torch.int64).to(self.device_) | ||
mask = precompute_mask( | ||
tokens_batch, self.model_info_.n_heads_, self.device_) | ||
|
||
dummy_data = (tokens_batch, mask, | ||
self.model_info_.rope_angle_, batch_data, True) | ||
|
||
for seq_model in self.model_partition_: | ||
dummy_data = seq_model.forward(dummy_data) | ||
|
||
msg_id = id(dummy_data[0]) | ||
phony: torch.Tensor = SendOperator.apply( | ||
self.transport_, (dummy_data[0], ), msg_id, batch_data) | ||
self.backward_cache_[msg_id] = phony | ||
|
||
def process_forward(self) -> None: | ||
if self.role_ == WorkerRole.HEAD: | ||
return self.process_input() | ||
logging.debug("To recv the activations message.") | ||
message = self.transport_.recv_message( | ||
PipeMessageType.ACTIVATIONS, block=False) | ||
if message is None: | ||
logging.debug("No activations to process.") | ||
return | ||
logging.debug(f"Have already recv the activations from {message.src_}") | ||
|
||
tokens_batch = message.tensors_[0] | ||
mask = precompute_mask( | ||
tokens_batch, self.model_info_.n_heads_, self.device_) | ||
batch_data = message.batch_data_ | ||
|
||
dummy_data = (tokens_batch, mask, | ||
self.model_info_.rope_angle_, batch_data, True) | ||
|
||
for seq_model in self.model_partition_: | ||
dummy_data = seq_model.forward(dummy_data) | ||
|
||
def process_backward(self) -> None: | ||
logging.debug("To recv the gradients message.") | ||
message = self.transport_.recv_message( | ||
PipeMessageType.GRADIENTS, block=False) | ||
if message is None: | ||
logging.debug("No gradients to process.") | ||
return | ||
logging.debug(f"Have already recv the gradients from {message.src_}") | ||
|
||
def init_partition(self, model: LLMModel) -> None: | ||
balance = self.balance_[self.rank_] | ||
start_module_idx = sum(self.balance_[:self.rank_]) | ||
logging.info( | ||
f"RANK-{self.rank_} in device {self.device_} to load module layers from {start_module_idx} to {start_module_idx + balance}.") | ||
|
||
seq_model = model.sequential_module() | ||
del seq_model[:start_module_idx] | ||
del seq_model[balance:] | ||
for idx in range(0, len(seq_model)): | ||
self.model_partition_.append(seq_model[idx]) | ||
|
||
assert len(self.model_partition_) == balance | ||
|
||
del model | ||
|
||
torch.cuda.empty_cache() | ||
|
||
def init_rpc(self) -> None: | ||
if "MASTER_ADDR" not in os.environ: | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
if "MASTER_PORT" not in os.environ: | ||
os.environ["MASTER_PORT"] = "12355" | ||
|
||
# will be block when all world size's gpu join the group | ||
torch.distributed.rpc.init_rpc( | ||
f"worker-{self.rank_}", rank=self.rank_, world_size=self.world_size_) | ||
|
||
logging.info( | ||
f"Init rpc with rank {self.rank_} world_size: {self.world_size_}") |