-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathclient_funct.py
149 lines (122 loc) · 5.04 KB
/
client_funct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
import torch
import torch.nn.functional as F
from utils import validate, model_parameter_vector
import copy
from nodes import Node
##############################################################################
# General client function
##############################################################################
def receive_server_model(args, client_nodes, central_node):
for idx in range(len(client_nodes)):
if 'fedlaw' in args.server_method:
client_nodes[idx].model.load_param(copy.deepcopy(central_node.model.get_param(clone = True)))
else:
client_nodes[idx].model.load_state_dict(copy.deepcopy(central_node.model.state_dict()))
return client_nodes
def Client_update(args, client_nodes, central_node):
'''
client update functions
'''
# clients receive the server model
client_nodes = receive_server_model(args, client_nodes, central_node)
# update the global model
if args.client_method == 'local_train':
client_losses = []
for i in range(len(client_nodes)):
epoch_losses = []
for epoch in range(args.E):
loss = client_localTrain(args, client_nodes[i])
epoch_losses.append(loss)
client_losses.append(sum(epoch_losses)/len(epoch_losses))
train_loss = sum(client_losses)/len(client_losses)
elif args.client_method == 'fedprox':
global_model_param = copy.deepcopy(list(central_node.model.parameters()))
client_losses = []
for i in range(len(client_nodes)):
epoch_losses = []
for epoch in range(args.E):
loss = client_fedprox(global_model_param, args, client_nodes[i])
epoch_losses.append(loss)
client_losses.append(sum(epoch_losses)/len(epoch_losses))
train_loss = sum(client_losses)/len(client_losses)
elif args.client_method == 'feddyn':
global_model_vector = copy.deepcopy(model_parameter_vector(args, central_node.model).detach().clone())
client_losses = []
for i in range(len(client_nodes)):
epoch_losses = []
for epoch in range(args.E):
loss = client_feddyn(global_model_vector, args, client_nodes[i])
epoch_losses.append(loss)
client_losses.append(sum(epoch_losses)/len(epoch_losses))
train_loss = sum(client_losses)/len(client_losses)
# update old grad
v1 = model_parameter_vector(args, client_nodes[i].model).detach()
client_nodes[i].old_grad = client_nodes[i].old_grad - args.mu * (v1 - global_model_vector)
else:
raise ValueError('Undefined server method...')
return client_nodes, train_loss
def Client_validate(args, client_nodes):
'''
client validation functions, for testing local personalization
'''
client_acc = []
for idx in range(len(client_nodes)):
acc = validate(args, client_nodes[idx])
# print('client ', idx, ', after training, acc is', acc)
client_acc.append(acc)
avg_client_acc = sum(client_acc) / len(client_acc)
return avg_client_acc
# Vanilla local training
def client_localTrain(args, node, loss = 0.0):
node.model.train()
loss = 0.0
train_loader = node.local_data # iid
for idx, (data, target) in enumerate(train_loader):
# zero_grad
node.optimizer.zero_grad()
# train model
data, target = data.cuda(), target.cuda()
output_local = node.model(data)
loss_local = F.cross_entropy(output_local, target)
loss_local.backward()
loss = loss + loss_local.item()
node.optimizer.step()
return loss/len(train_loader)
# FedProx
def client_fedprox(global_model_param, args, node, loss = 0.0):
node.model.train()
loss = 0.0
train_loader = node.local_data # iid
for idx, (data, target) in enumerate(train_loader):
# zero_grad
node.optimizer.zero_grad()
# train model
data, target = data.cuda(), target.cuda()
output_local = node.model(data)
loss_local = F.cross_entropy(output_local, target)
loss_local.backward()
loss = loss + loss_local.item()
# fedprox update
node.optimizer.step(global_model_param)
return loss/len(train_loader)
#FedDyn
def client_feddyn(global_model_vector, args, node, loss = 0.0):
node.model.train()
loss = 0.0
train_loader = node.local_data # iid
for idx, (data, target) in enumerate(train_loader):
# zero_grad
node.optimizer.zero_grad()
# train model
data, target = data.cuda(), target.cuda()
output_local = node.model(data)
loss_local = F.cross_entropy(output_local, target)
loss = loss + loss_local.item()
# feddyn update
v1 = model_parameter_vector(args, node.model)
loss_local += args.mu/2 * torch.norm(v1 - global_model_vector, 2)
loss_local -= torch.dot(v1, node.old_grad)
loss_local.backward()
node.optimizer.step()
return loss/len(train_loader)