diff --git a/server/fishtest/rundb.py b/server/fishtest/rundb.py index 3eb4c4499..28790d062 100644 --- a/server/fishtest/rundb.py +++ b/server/fishtest/rundb.py @@ -1505,8 +1505,8 @@ def purge_run(self, run, p=0.001, res=7.0, iters=1): self.buffer(run, True) return message - def spsa_param_clip(self, param, increment): - return min(max(param["theta"] + increment, param["min"]), param["max"]) + def spsa_param_clip(self, param, increment, r): + return min(max(param["theta"] + increment + r, param["min"]), param["max"]) # Store SPSA parameters for each worker spsa_params = {} @@ -1556,15 +1556,15 @@ def generate_spsa(self, run): r = random.uniform(0, 1) flip = 1 if random.getrandbits(1) else -1 # Stochastic rounding and probability for float N.p: (N, 1-p); (N+1, p) - w_value = math.floor(self.spsa_param_clip(param, c * flip) + r) - b_value = math.floor(self.spsa_param_clip(param, -c * flip) + r) + w_value = math.floor(self.spsa_param_clip(param, c * flip, r)) + b_value = math.floor(self.spsa_param_clip(param, -c * flip, r)) result["w_params"].append( { "name": param["name"], "value": w_value, "R": param["a"] / (spsa["A"] + iter_local) ** spsa["alpha"] / c**2, - #Set c to the real delta after stochastic rounding is applied - "c": abs(w_value - b_value), + # Set c to the real delta after stochastic rounding is applied + "c": abs(w_value - b_value) / 2, "flip": flip, } ) @@ -1600,7 +1600,7 @@ def update_spsa(self, worker, run, spsa_results): R = w_params[idx]["R"] c = w_params[idx]["c"] flip = w_params[idx]["flip"] - param["theta"] = self.spsa_param_clip(param, R * c * result * flip) + param["theta"] = self.spsa_param_clip(param, R * c * result * flip, 0) if grow_summary: summary.append({"theta": param["theta"], "R": R, "c": c})