forked from OakleyTan/FedSSP
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
102 lines (85 loc) · 3.87 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pandas as pd
from client import collate_pyg_to_dgl
import torch
import numpy as np
def proscess_loader(loader, device):
preprocessed_batches = []
for batch in loader:
batch.to(device)
e, u, g, length, valid_indices = collate_pyg_to_dgl(batch)
valid_labels = batch.y[valid_indices].to(device)
preprocessed_batches.append((e.to(device), u.to(device), g.to(device), length.to(device), valid_labels, len(valid_indices)))
return preprocessed_batches
def run_fedSSP(args, clients, server, COMMUNICATION_ROUNDS, local_epoch, samp=None, frac=1.0, summary_writer=None):
device = torch.device('cuda:0')
if samp is None:
sampling_fn = server.randomSample_clients
frac = 1.0
for client in clients:
dataloaders = client.dataLoader
train_loader, val_loader, test_loader = dataloaders['train'], dataloaders['val'], dataloaders['test']
client.train_preprocessed_batches = proscess_loader(train_loader, device)
client.test_preprocessed_batches = proscess_loader(test_loader, device)
client.val_preprocessed_batches = proscess_loader(val_loader, device)
server.clients = clients
server.selected_clients = clients
client.train_samples = len(train_loader)
for c_round in range(1, COMMUNICATION_ROUNDS + 1):
if (c_round) % 50 == 0:
print(f" > round {c_round}")
if c_round == 1:
selected_clients = clients
else:
selected_clients = sampling_fn(clients, frac)
server.selected_clients = selected_clients#新增
server.clients = clients
for client in selected_clients:
client.local_train(local_epoch)
if c_round != 1:
for i, w in enumerate(server.uploaded_weights):
w = 1 / len(server.selected_clients)
server.uploaded_weights[i] = w
global_consensus = 0
for cid, w in zip(server.uploaded_ids, server.uploaded_weights):
global_consensus += server.clients[cid].current_mean * w
for client in server.selected_clients:
client.global_consensus = global_consensus.data.clone()
server.receive_models_SSP()
server.aggregate_parameters_SSP()
server.send_models_SSP()
else:
tot_samples = 0
for client in server.selected_clients:
tot_samples += client.train_samples
server.uploaded_ids.append(client.id)
server.uploaded_weights.append(client.train_samples)
for i, w in enumerate(server.uploaded_weights):
w = w / tot_samples
server.uploaded_weights[i] = w
global_consensus = 0
for cid, w in zip(server.uploaded_ids, server.uploaded_weights):
w = 1 / len(server.selected_clients)
global_consensus += server.clients[cid].current_mean * w
for client in server.selected_clients:
client.global_consensus = global_consensus.data.clone()
if c_round % 1 == 0:
accs = []
losses = []
for idx in range(len(clients)):
loss, acc = clients[idx].evaluate()
accs.append(acc)
losses.append(loss)
mean_acc = np.mean(accs)
std_acc = np.std(accs)
summary_writer.add_scalar(f'Test/Acc/Mean_{args.alg}', mean_acc, c_round)
summary_writer.add_scalar(f'Test/Acc/Std_{args.alg}', std_acc, c_round)
frame = pd.DataFrame()
for client in clients:
loss, acc = client.evaluate()
frame.loc[client.name, 'test_acc'] = acc
def highlight_max(s):
is_max = s == s.max()
return ['background-color: yellow' if v else '' for v in is_max]
fs = frame.style.apply(highlight_max).data
print(fs)
return frame