Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
NikolayBlagoev committed Mar 6, 2024
1 parent 67be3c0 commit 8b304a2
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion deccom/protocols/peerdiscovery/abstractpeerdiscovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ def get_peers(self) -> dict[bytes, Peer]:
return self.peers

async def find_peer(self, id: bytes) -> Peer:
return None
return self.get_peer(id)
2 changes: 1 addition & 1 deletion deccom/protocols/peerdiscovery/kademliadiscovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ async def _find_peer(self, fut, id):
return

async def find_peer(self, id) -> Peer:
if self.peers.get(id) == None:
if self.get_peer(id) == None:
if self.peer_crawls.get(id) == None:
loop = asyncio.get_running_loop()
fut = loop.create_future()
Expand Down
2 changes: 1 addition & 1 deletion deccom/protocols/streamprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def open_connection(self, remote_ip, remote_port, node_id: bytes, duplicat
print("OPENING TO SELF???")
return
if remote_port == None:
return None
return None
if self.connections.get(node_id) != None:
# print("duplicate connection OPENED")

Expand Down
2 changes: 1 addition & 1 deletion trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(self, x):

optimizer = optim.SGD(net.parameters(), lr=learning_rate,
momentum=momentum)
training = TrainingProtocol(6,3,int(argv[1]),net,optimizer,train_loader)
training = TrainingProtocol(3,3,int(argv[1]),net,optimizer,train_loader)
training.set_lower(stream)
me = TrainingNode(training,"127.0.0.1", 10015 if argv[1] == "0" else None)
print( "TCP", me.tcp_port)
Expand Down
17 changes: 12 additions & 5 deletions trainingprotocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from deccom.protocols.streamprotocol import StreamProtocol
from deccom.protocols.wrappers import *
import torch.nn as nn
from torch import zeros_like
import torch.nn.functional as F
import torch.optim as optim
from torch import tensor, mean, stack, cat, split
Expand Down Expand Up @@ -91,6 +92,7 @@ async def start(self):
batch_idx, ret = next(self.dataloader)
data = ret['text']
target = ret['text']
print(data.shape, target.shape)
except StopIteration :
print("TRAINING COMPLETE")
return
Expand Down Expand Up @@ -118,15 +120,16 @@ def _apply_grad(self):
for i, param in enumerate(self.net.parameters()):
param.data = param.data - 0.01*tmp[i].view(self.sizes[i])
self.aggregation = []
@bindfrom("st")
@bindfrom("stream_callback")
def process_data(self, data:bytes, nodeid, addr):
seq_id = bytes(data[0:8])

data=pickle.loads(data[8:])
peer: Peer = self._lower_get_peer(nodeid)
if nodeid == self.prev:
if self.pipeline_rank == 0:
loss = self.net.task_layer(data,self.buffer_in.get(seq_id)[0])
loss = self.net.task_layer(data,self.buffer_in.get(seq_id))

loss.backward()
if self.iter % 100 == 0:
print(loss.item())
Expand All @@ -150,7 +153,11 @@ def process_data(self, data:bytes, nodeid, addr):
tmp = []
self.len_sizes = []
for param in self.net.parameters():
tmp.append(param.grad.view(-1))
if param.grad == None:
tmp.append(zeros_like(param.view(-1)))
else:
tmp.append(param.grad.view(-1))

self.len_sizes.append(len(tmp[-1]))
loop = asyncio.get_running_loop()
self.prev_grad = cat(tmp)
Expand All @@ -160,10 +167,10 @@ def process_data(self, data:bytes, nodeid, addr):
loop.create_task(self.send_stream(peer,pickle.dumps(self.prev_grad),seqdata=seq_id))

self.aggregation.append(self.prev_grad)
# print("calculating\n\n\n\n",len(self.dp_group))
print("calculating\n\n\n\n",len(self.dp_group))
if len(self.aggregation) == len(self.dp_group):
self._apply_grad()
# print("\n\n\n\ncalculated")
print("\n\n\n\ncalculated")
try:
batch_idx, ret = next(self.dataloader)
data = ret['text']
Expand Down

0 comments on commit 8b304a2

Please sign in to comment.