diff --git a/mlperf_logging/rcp_checker/rcp_checker.py b/mlperf_logging/rcp_checker/rcp_checker.py index b63ebe2..8a828f6 100644 --- a/mlperf_logging/rcp_checker/rcp_checker.py +++ b/mlperf_logging/rcp_checker/rcp_checker.py @@ -232,7 +232,7 @@ def _prune_rcps(self): # Step 1 # Find point with fastest convergence and prune all point with smaller batch size # In that way the min batch size point will have the fastest convergenece - fastest_conv = min(min_epochs, key=lambda rc: rc['Min Epochs']) + fastest_conv = min(min_epochs, key=lambda rc: rc['RCP Mean']) min_epochs = list(filter(lambda rc: rc['BS'] >= fastest_conv['BS'], min_epochs)) # Step 2 @@ -249,7 +249,7 @@ def _prune_rcps(self): rcp_max = min_epochs[i+1] bs = min_epochs[i]['BS'] name, rcp = self._create_interp_rcp(bs, rcp_min, rcp_max) - if min_epochs[i]['Min Epochs'] > rcp['Min Epochs']: + if min_epochs[i]['RCP Mean'] > rcp['RCP Mean']: del min_epochs[i] i = i-1 list_len = list_len - 1