Skip to content

Commit

Permalink
swarm updated
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolayBlagoev committed Mar 6, 2024
1 parent 43fc3ad commit 61a8210
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
32 changes: 23 additions & 9 deletions swarmprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from deccom.cryptofuncs.hash import SHA256
from deccom.peers.peer import Peer
from deccom.protocols.abstractprotocol import AbstractProtocol
from deccom.protocols.streamprotocol import StreamProtocol

import torch.nn as nn
import torch.nn.functional as F
from torch import tensor, mean, stack, cat, split, zeros_like
import pickle
from deccom.protocols.wrappers import *
from datetime import datetime

import asyncio
Expand All @@ -33,7 +34,7 @@ class SwarmProtocol(AbstractProtocol):
})
required_lower = AbstractProtocol.required_lower + \
["find_peer", "set_stream_callback",
"open_connection", "send_stream", "get_peer", "get_peers", "set_connected_callback","set_disconnect_callback"]
"open_connection", "send_stream", "get_peer", "get_peers", "connected_callback","disconnect_callback"]
INTRODUCTION = int.from_bytes(b'\xd1', byteorder="big")
COMPLETE = int.from_bytes(b'\xd8', byteorder="big")

Expand All @@ -44,11 +45,7 @@ def __init__(self, rank, net, optimizer, dataloader=None, submodule=None, callba

self.rank = rank
self.dataloader = dataloader
self._lower_find_peer = lambda: ...
self._lower_open_connection = lambda: ...
self._lower_send_stream = lambda: ...
self._lower_get_peer = lambda: ...
self._lower_get_peers = lambda: ...


self.net: nn.Module = net
self.sizes = []
Expand All @@ -70,6 +67,21 @@ def __init__(self, rank, net, optimizer, dataloader=None, submodule=None, callba
self.same_stage = []
self.outstanding: dict[bytes,asyncio.TimerHandle] = dict()
self.forward_start = None
@bindto("find_peer")
async def _lower_find_peer(self, p: Peer):
return None
@bindto("open_connection")
async def _lower_open_connection(self):
return
@bindto("send_stream")
async def send_stream(self, node_id, data):
return

@bindto("get_peer")
def _lower_get_peer(self, p: Peer):
return None

@bindfrom("connected_callback")
def peer_connected(self, nodeid):
# print("NEW PEER")
loop = asyncio.get_running_loop()
Expand Down Expand Up @@ -137,8 +149,8 @@ def send_back(self, seq_id, data):
return
async def start(self, p: Peer):
await super().start(p)
peers = self._lower_get_peers()
for _,p in peers.items():

for _,p in self.bootstrap_peers:
# print("introducing to ",p.addr)
msg = bytearray([SwarmProtocol.INTRODUCTION])
msg = msg + int(self.rank).to_bytes(4,byteorder="big") + self.peer.id_node
Expand Down Expand Up @@ -166,6 +178,7 @@ async def start(self, p: Peer):
ret = SwarmProtocol.train(self.net,self.optimizer, data, rank = 0, stage=1)
ret.retain_grad()
await self.send_forward(new_seq_id,target, ret)
@bindfrom("disconnect_callback")
def peer_disconnected(self,addr, node_id):
if node_id in self.same_stage:
self.same_stage.remove(node_id)
Expand Down Expand Up @@ -226,6 +239,7 @@ def check_for_back(self):
loop = asyncio.get_running_loop()
loop.create_task(
self.send_forward(new_seq_id,target, ret))
@bindfrom("stream_callback")
def process_data(self, data:bytes, nodeid, addr):
seq_id = bytes(data[0:8])
stage = int.from_bytes(data[8:12],byteorder="big")
Expand Down
1 change: 0 additions & 1 deletion swarmtrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from deccom.cryptofuncs.hash import SHA256
from deccom.nodes import StreamNode
from deccom.protocols.defaultprotocol import DefaultProtocol
from deccom.protocols.peerdiscovery.gossipdiscovery import GossipProtocol
from deccom.peers import Peer
from deccom.protocols.streamprotocol import StreamProtocol
from trainingnode import TrainingNode
Expand Down

0 comments on commit 61a8210

Please sign in to comment.