-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDummyPlatform.py
executable file
·110 lines (92 loc) · 2.98 KB
/
DummyPlatform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Dummy federated learning blockchain backend for faster testing.
"""
class DummyPlatform:
modelBytes = None
epoch = 0
dataSize = 0
means = None
stds = None
@staticmethod
def initAccounts(amount: int):
return [DummyPlatform.Account() for i in range(amount)]
class Account:
"""
Wraps accounts with helper functions and some additional data.
"""
def deploy(self, modelBytes):
"""
Deploys the contract with this account and obtain a reference to it.
"""
DummyPlatform.modelBytes = modelBytes
def obtainContract(self):
"""
After the contract has been deployed by one user, the other users will call
this function to obtain a reference to it in self.contract.
"""
pass
def getUpdateEvents(self, receipts):
"""
From a list of receipts get the processed events.
"""
return receipts
def getMeanEvents(self, receipts):
"""
From a list of receipts get the processed mean events.
"""
return receipts
def getStdEvents(self, receipts):
"""
From a list of receipts get the processed std events.
"""
return receipts
def globalUpdate(self, modelBytes):
"""
Update the global model after weight averaging.
Should be called by owner only.
"""
DummyPlatform.modelBytes = modelBytes
DummyPlatform.epoch += 1
DummyPlatform.dataSize = 0
return None
def localUpdate(self, *vargs):
"""
Trigger a local update event.
"""
DummyPlatform.dataSize += vargs[1]
return (vargs[1], vargs[2])
def globalMeans(self, means):
"""
Update the global means.
Should be called by owner only.
"""
DummyPlatform.means = means
return None
def localMeans(self, *vargs):
"""
Trigger a local means event.
"""
return (vargs[0], vargs[1])
def globalStds(self, stds):
"""
Update the global stds.
Should be called by owner only.
"""
DummyPlatform.stds = stds
return None
def localStds(self, *vargs):
"""
Trigger a local stds event.
"""
return (vargs[0], vargs[1])
# The following public accessor functions don't need to use account
def getModel(self):
return DummyPlatform.modelBytes
def getEpoch(self):
return DummyPlatform.epoch
def getDataSize(self):
return DummyPlatform.dataSize
def getMeans(self):
return DummyPlatform.means
def getStds(self):
return DummyPlatform.stds