-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
04bf63d
commit 453b176
Showing
6 changed files
with
413 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from mlora.pipeline.messages import Transport, PipeMessage, PipeMessageType | ||
from mlora.model import MultiLoraBatchData | ||
|
||
import logging | ||
import torch | ||
|
||
|
||
class SendOperator(torch.autograd.Function): | ||
# helper to reduce the activation memory | ||
@staticmethod | ||
def forward(ctx, | ||
phony: torch.Tensor, | ||
tensor_data: torch.Tensor, | ||
transport: Transport, | ||
msg_id: int, | ||
input_args: MultiLoraBatchData): | ||
assert isinstance(tensor_data, torch.Tensor) | ||
|
||
msg = PipeMessage(src_=transport.worker_name, | ||
dst_=transport.next_worker_name, | ||
msg_type_=PipeMessageType.ACTIVATIONS, | ||
msg_id_=msg_id, | ||
tensor_data_=tensor_data, | ||
batch_data_=input_args) | ||
transport.send_message(msg, False) | ||
|
||
return phony | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
assert ctx.grad_from_next_worker is not None | ||
|
||
return (None, ctx.grad_from_next_worker, None, None, None) | ||
|
||
|
||
class RecvOperator(torch.autograd.Function): | ||
# backward will auto send the grad to pre worker | ||
@staticmethod | ||
def forward(ctx, | ||
phony: torch.Tensor, | ||
transport: Transport, | ||
msg: PipeMessage) -> torch.Tensor: | ||
assert msg.msg_type_ == PipeMessageType.ACTIVATIONS | ||
assert isinstance(msg.tensor_data_, torch.Tensor) | ||
|
||
ctx.msg_id_ = msg.msg_id_ | ||
ctx.transport_ = transport | ||
|
||
return msg.tensor_data_ * phony | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output: torch.Tensor): | ||
# now only signle grad can be support | ||
assert isinstance(grad_output, torch.Tensor) | ||
|
||
transport: Transport = ctx.transport_ | ||
logging.debug(f"Send the gradients to {transport.prev_worker_name}") | ||
transport.send_message(PipeMessage( | ||
src_=transport.worker_name, | ||
dst_=transport.prev_worker_name, | ||
msg_type_=PipeMessageType.GRADIENTS, | ||
msg_id_=ctx.msg_id_, | ||
tensor_data_=grad_output, | ||
batch_data_=None, | ||
)) | ||
|
||
return (None, None, None) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from mlora.model import MultiLoraBatchData | ||
|
||
import torch | ||
import torch.distributed.rpc | ||
|
||
from abc import ABC, abstractmethod | ||
from dataclasses import dataclass | ||
from queue import Queue | ||
from typing import Dict, Tuple | ||
from enum import Enum | ||
|
||
|
||
class PipeMessageType(Enum): | ||
ACTIVATIONS = "ACTIVATIONS" | ||
GRADIENTS = "GRADIENTS" | ||
|
||
|
||
@dataclass() | ||
class PipeMessage: | ||
src_: str | ||
dst_: str | ||
|
||
msg_type_: PipeMessageType | ||
msg_id_: int | ||
|
||
tensor_data_: torch.Tensor | ||
batch_data_: MultiLoraBatchData | ||
|
||
|
||
# save by different message type | ||
RPCMessageQueues: Dict[str, Queue] = { | ||
PipeMessageType.ACTIVATIONS: Queue(), | ||
PipeMessageType.GRADIENTS: Queue() | ||
} | ||
|
||
|
||
def rpc_push_queue(msg: PipeMessage) -> None: | ||
assert msg.msg_type_ in globals( | ||
)["RPCMessageQueues"], f"No this message type: {msg.msg_type_.value}" | ||
|
||
globals()["RPCMessageQueues"][msg.msg_type_].put(msg) | ||
|
||
|
||
class Transport(ABC): | ||
rank_: int | ||
device_: torch.device | ||
|
||
@property | ||
def next_worker_name(self) -> str: | ||
return f"worker-{self.rank_ + 1}" | ||
|
||
@property | ||
def prev_worker_name(self) -> str: | ||
return f"worker-{self.rank_ - 1}" | ||
|
||
@property | ||
def worker_name(self) -> str: | ||
return f"worker-{self.rank_}" | ||
|
||
@abstractmethod | ||
def recv_message(self, msg_type: PipeMessageType, block: bool = False) -> PipeMessage: | ||
pass | ||
|
||
@abstractmethod | ||
def send_message(self, msg: PipeMessage, sync: bool = False) -> None: | ||
pass | ||
|
||
|
||
class RpcTransport(Transport): | ||
def __init__(self, rank: int, worker_device: torch.device) -> None: | ||
super().__init__() | ||
self.rank_ = rank | ||
self.device_ = worker_device | ||
|
||
# recv message will be blocked by convert the cpu memory to gpu memory | ||
def recv_message(self, msg_type: PipeMessageType, block: bool = True) -> PipeMessage: | ||
assert msg_type in globals( | ||
)["RPCMessageQueues"], f"No this message type: {msg_type.value}" | ||
|
||
queue = globals()["RPCMessageQueues"][msg_type] | ||
if queue.empty(): | ||
return None | ||
|
||
result: PipeMessage = queue.get(block=block) | ||
result.tensor_data_ = result.tensor_data_.to(self.device_) | ||
return result | ||
|
||
# send message will be blocked by convert the gpu memory to cpu memory | ||
def send_message(self, msg: PipeMessage, sync: bool = False) -> None: | ||
msg.tensor_data_ = msg.tensor_data_.to("cpu") | ||
if sync: | ||
send_fn = torch.distributed.rpc.rpc_sync | ||
else: | ||
send_fn = torch.distributed.rpc.rpc_async | ||
|
||
send_fn(msg.dst_, rpc_push_queue, args=(msg,)) |
Oops, something went wrong.