From 39fc7a0daca53879ee491cb3bb6ca1b5f97ad799 Mon Sep 17 00:00:00 2001 From: Pablo Gonzalez Date: Mon, 20 May 2024 12:05:51 -0500 Subject: [PATCH] Fix scaling: prune RCPs by mean epochs --- mlperf_logging/rcp_checker/rcp_checker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlperf_logging/rcp_checker/rcp_checker.py b/mlperf_logging/rcp_checker/rcp_checker.py index b63ebe2..8a828f6 100644 --- a/mlperf_logging/rcp_checker/rcp_checker.py +++ b/mlperf_logging/rcp_checker/rcp_checker.py @@ -232,7 +232,7 @@ def _prune_rcps(self): # Step 1 # Find point with fastest convergence and prune all point with smaller batch size # In that way the min batch size point will have the fastest convergenece - fastest_conv = min(min_epochs, key=lambda rc: rc['Min Epochs']) + fastest_conv = min(min_epochs, key=lambda rc: rc['RCP Mean']) min_epochs = list(filter(lambda rc: rc['BS'] >= fastest_conv['BS'], min_epochs)) # Step 2 @@ -249,7 +249,7 @@ def _prune_rcps(self): rcp_max = min_epochs[i+1] bs = min_epochs[i]['BS'] name, rcp = self._create_interp_rcp(bs, rcp_min, rcp_max) - if min_epochs[i]['Min Epochs'] > rcp['Min Epochs']: + if min_epochs[i]['RCP Mean'] > rcp['RCP Mean']: del min_epochs[i] i = i-1 list_len = list_len - 1