-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
39c62de
commit a50cbb5
Showing
6 changed files
with
283 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
|
||
|
||
class AutogradWithoutActivations(torch.autograd.Function): | ||
# helper to reduce the activation memory | ||
@staticmethod | ||
def forward(ctx, *x): | ||
return torch.tensor(1, 0) | ||
|
||
@staticmethod | ||
def backward(ctx, grad): | ||
assert ctx.grad_from_next_worker is not None | ||
return ctx.grad_from_next_worker | ||
|
||
|
||
class RecvOperator(torch.autograd.Function): | ||
# backward will auto send the grad to pre worker | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import torch | ||
import torch.distributed.rpc | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from queue import Queue | ||
from typing import Dict | ||
from enum import Enum | ||
|
||
|
||
class PipeMessageType(Enum): | ||
ACTIVATIONS = "ACTIVATIONS" | ||
GRADIENTS = "GRADIENTS" | ||
|
||
|
||
@dataclass() | ||
class PipeMessage: | ||
src_: str | ||
dst_: str | ||
|
||
msg_type_: PipeMessageType | ||
tensors_: torch.Tensor | ||
|
||
|
||
# save by different message type | ||
RPCMessageQueues: Dict[str, Queue] = { | ||
PipeMessageType.ACTIVATIONS.value: Queue(), | ||
PipeMessageType.GRADIENTS.value: Queue() | ||
} | ||
|
||
|
||
def rpc_push_queue(msg: PipeMessage) -> None: | ||
assert msg.msg_type_ in globals( | ||
)["RPCMessageQueues"], f"No this message type: {msg.msg_type_.value}" | ||
|
||
globals()["RPCMessageQueues"][msg.msg_type_.value].put(msg) | ||
|
||
|
||
class Transport(ABC): | ||
device_: torch.device | ||
|
||
@abstractmethod | ||
def recv_message(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage: | ||
pass | ||
|
||
@abstractmethod | ||
def send_message(self, msg: PipeMessage, sync: bool = False) -> None: | ||
pass | ||
|
||
|
||
class RpcTransport(Transport): | ||
def __init__(self, worker_device: torch.device) -> None: | ||
super().__init__() | ||
self.device_ = worker_device | ||
|
||
# recv message will be blocked by convert the cpu memory to gpu memory | ||
def recv_message(self, msg_type: PipeMessageType, block: bool = True) -> PipeMessage: | ||
assert msg_type.value in globals( | ||
)["RPCMessageQueues"], f"No this message type: {msg_type.value}" | ||
|
||
queue = globals()["RPCMessageQueues"][msg_type.value] | ||
if queue.empty(): | ||
return None | ||
|
||
result = queue.get(block=block) | ||
result.tensors_ = tuple(t.to(self.device_) for t in result.tensors_) | ||
return result | ||
|
||
# send message will be blocked by convert the gpu memory to cpu memory | ||
def send_message(self, msg: PipeMessage, sync: bool = False) -> None: | ||
msg.tensors_ = tuple(t.to("cpu") for t in msg.tensors_) | ||
if sync: | ||
send_fn = torch.distributed.rpc.rpc_sync | ||
else: | ||
send_fn = torch.distributed.rpc.rpc_async | ||
|
||
send_fn(msg.dst_, rpc_push_queue, args=(msg,)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from mlora.modelargs import MultiLoraBatchData, LoraBatchDataConfig | ||
from mlora.model import LLMModel, precompute_mask | ||
from mlora.pipeline.messages import Transport, RpcTransport | ||
from mlora.pipeline.messages import PipeMessage, PipeMessageType | ||
|
||
import os | ||
import torch | ||
import time | ||
import logging | ||
|
||
import torch.distributed.rpc | ||
from typing import List | ||
from enum import Enum, auto | ||
from dataclasses import dataclass | ||
|
||
|
||
class WorkerRole(Enum): | ||
HEAD = auto() | ||
MID = auto() | ||
TAIL = auto() | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PipeModelLocInfo(): | ||
n_heads_: int = -1 | ||
rope_angle_: torch.Tensor = None | ||
|
||
|
||
class Pipe(): | ||
world_size_: int = -1 | ||
balance_: List[int] = None | ||
|
||
rank_: int = -1 | ||
device_: torch.device = None | ||
|
||
model_partition_: torch.nn.Sequential = torch.nn.Sequential() | ||
model_info_: PipeModelLocInfo = None | ||
role_: WorkerRole = None | ||
|
||
transport_: Transport = None | ||
|
||
def __init__(self, | ||
model: LLMModel, | ||
device: torch.device, | ||
rank: int, | ||
balance: List[int]) -> None: | ||
|
||
world_size = torch.cuda.device_count() | ||
assert world_size == len(balance) | ||
|
||
self.rank_ = rank | ||
self.world_size_ = world_size | ||
self.balance_ = balance | ||
self.device_ = device | ||
self.model_info_ = PipeModelLocInfo( | ||
n_heads_=model.n_heads_, | ||
rope_angle_=model.rope_angle_ | ||
) | ||
|
||
if rank == 0: | ||
self.role_ = WorkerRole.HEAD | ||
elif rank == len(balance) - 1: | ||
self.role_ = WorkerRole.TAIL | ||
else: | ||
self.role_ = WorkerRole.MID | ||
|
||
self.transport_ = RpcTransport(self.device_) | ||
self.init_partition(model) | ||
self.init_rpc() | ||
|
||
@property | ||
def next_worker_name(self) -> str: | ||
return f"worker-{self.rank_ + 1}" | ||
|
||
@property | ||
def prev_worker_name(self) -> str: | ||
return f"worker-{self.rank_ - 1}" | ||
|
||
# process the backward | ||
# process the forward | ||
def run(self): | ||
while True: | ||
self.process_backward() | ||
self.process_forward() | ||
self.process_input() | ||
time.sleep(5) | ||
|
||
def process_input(self) -> None: | ||
if self.role_ != WorkerRole.HEAD: | ||
return | ||
logging.debug("Train input data.") | ||
import random | ||
test_data_len = 96 | ||
tokens_list = [[random.random() for _ in range(test_data_len)]] | ||
input = MultiLoraBatchData( | ||
batch_tokens_=tokens_list, | ||
additional_mask_=None, | ||
lora_batch_data_config_=[LoraBatchDataConfig( | ||
adapter_name_="lora_0", | ||
batch_start_idx_=0, | ||
batch_end_idx_=1 | ||
)], | ||
inference_model_=False | ||
) | ||
|
||
tokens_batch = torch.tensor( | ||
input.batch_tokens_, dtype=torch.int64).to(self.device_) | ||
mask = precompute_mask( | ||
tokens_batch, self.model_info_.n_heads_, self.device_) | ||
|
||
dummy_data = (tokens_batch, mask, | ||
self.model_info_.rope_angle_, input, True) | ||
|
||
for seq_model in self.model_partition_: | ||
dummy_data = seq_model.forward(dummy_data) | ||
|
||
def process_forward(self) -> None: | ||
if self.role_ == WorkerRole.HEAD: | ||
return | ||
|
||
def process_backward(self) -> None: | ||
logging.debug("To recv the gradients message.") | ||
message = self.transport_.recv_message( | ||
PipeMessageType.GRADIENTS, block=False) | ||
if message is None: | ||
logging.debug("No gradients to deal.") | ||
return | ||
logging.debug(f"Have already recv the gradients from {message.src_}") | ||
|
||
def init_partition(self, model: LLMModel) -> None: | ||
balance = self.balance_[self.rank_] | ||
start_module_idx = sum(self.balance_[:self.rank_]) | ||
logging.info( | ||
f"RANK-{self.rank_} in device {self.device_} to load module layers from {start_module_idx} to {start_module_idx + balance}.") | ||
|
||
seq_model = model.sequential_module() | ||
del seq_model[:start_module_idx] | ||
del seq_model[balance:] | ||
for idx in range(0, len(seq_model)): | ||
self.model_partition_.append(seq_model[idx]) | ||
|
||
assert len(self.model_partition_) == balance | ||
|
||
del model | ||
|
||
torch.cuda.empty_cache() | ||
|
||
def init_rpc(self) -> None: | ||
if "MASTER_ADDR" not in os.environ: | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
if "MASTER_PORT" not in os.environ: | ||
os.environ["MASTER_PORT"] = "12355" | ||
|
||
# will be block when all world size's gpu join the group | ||
torch.distributed.rpc.init_rpc( | ||
f"worker-{self.rank_}", rank=self.rank_, world_size=self.world_size_) | ||
|
||
logging.info( | ||
f"Init rpc with rank {self.rank_} world_size: {self.world_size_}") |