-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_class.py
135 lines (113 loc) · 3.38 KB
/
train_class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Client side implementation"""
import numpy as np
import torch
from dataprocess import eigen_decomposition
from model import Specformer
from train_func import test, train
class Trainer_General:
"""Representation of a client"""
def __init__(
self,
rank: int,
class_num,
adj,
x,
y,
idx_train,
idx_test,
nlayer,
hidden_dim,
num_heads,
tran_dropout,
feat_dropout,
prop_dropout,
norm,
lr,
weight_decay,
local_step,
):
torch.manual_seed(rank)
self.rank = rank
self.train_losses = []
self.train_accs = []
self.test_losses = []
self.test_accs = []
self.x = x
self.adj = adj
self.e, self.u = eigen_decomposition(self.adj)
self.e = torch.FloatTensor(self.e)
self.u = torch.FloatTensor(self.u)
self.y = y
self.nlayer = nlayer
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.tran_dropout = tran_dropout
self.feat_dropout = feat_dropout
self.prop_dropout = prop_dropout
self.norm = norm
self.lr = lr
self.weight_decay = weight_decay
self.device = torch.device("cpu")
self.idx_train = idx_train
self.idx_test = idx_test
self.model = Specformer(
class_num,
self.x.shape[1],
self.nlayer,
self.hidden_dim,
self.num_heads,
self.tran_dropout,
self.feat_dropout,
self.prop_dropout,
self.norm,
)
self.optimizer = torch.optim.SGD(
self.model.parameters(),
lr=lr,
weight_decay=weight_decay,
)
self.local_step = local_step
self.new_e = []
@torch.no_grad()
def update_params(self, params, current_global_epoch) -> None:
self.model.to("cpu")
for p, mp in zip(params, self.model.parameters()):
mp.data = p
self.model.to(self.device)
def train(self, current_global_round) -> None:
for iteration in range(self.local_step):
self.model.train()
loss_train, acc_train, self.new_e = train(
iteration,
self.model,
self.optimizer,
self.e,
self.u,
self.x,
self.y,
self.idx_train,
)
self.train_losses.append(loss_train)
self.train_accs.append(acc_train)
loss_test, acc_test = self.local_test()
self.test_losses.append(loss_test)
self.test_accs.append(acc_test)
def local_test(self) -> list:
local_test_loss, local_test_acc = test(
self.model, self.e, self.u, self.x, self.y, self.idx_test
)
return [local_test_loss, local_test_acc]
def get_params(self) -> tuple:
self.optimizer.zero_grad(set_to_none=True)
return tuple(self.model.parameters())
def get_e(self):
return self.e, self.new_e
def get_all_loss_accuracy(self) -> list:
return [
np.array(self.train_losses),
np.array(self.train_accs),
np.array(self.test_losses),
np.array(self.test_accs),
]
def get_rank(self):
return self.rank