From 07c57f7c3470ac417396cabdb8718cf535baf31a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=8A=E9=9C=86?= Date: Wed, 3 Apr 2024 16:43:13 +0800 Subject: [PATCH 1/3] [Distributed] Directly use hvd DistributedOptimizer. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 泊霆 --- tensorflow/python/distribute/hvd_strategy.py | 26 +++++--------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/tensorflow/python/distribute/hvd_strategy.py b/tensorflow/python/distribute/hvd_strategy.py index 8a3ae9c3f43..23d857dc4c6 100644 --- a/tensorflow/python/distribute/hvd_strategy.py +++ b/tensorflow/python/distribute/hvd_strategy.py @@ -388,20 +388,16 @@ def wraps_optimizer(cls): HvdOptimizer ''' class HvdOptimizer(cls, optimizer.Optimizer): - def __init__(self, *args, **kwargs): - kwargs["learning_rate"] = kwargs.get("learning_rate", 0.001) *\ - HvdContext.get().world_size - super(HvdOptimizer, self).__init__(*args, **kwargs) + def __init__(self, learning_rate=0.001, *args, **kwargs): + learning_rate = learning_rate * HvdContext.get().world_size + super(HvdOptimizer, self).__init__(learning_rate, *args, **kwargs) - def compute_gradients(self, loss, **kwargs): - loss = hvd.allreduce(loss, op=hvd.Sum) - return super().compute_gradients(loss, **kwargs) - if isinstance(cls, HvdOptimizer): return cls else: def horovod_optimizer(*args, **kwargs): - return HvdOptimizer(*args, **kwargs) + from horovod.tensorflow import DistributedOptimizer + return DistributedOptimizer(HvdOptimizer(*args, **kwargs)) return horovod_optimizer @@ -478,16 +474,6 @@ def HorovodMonitoredTrainingSession(*args, **kwargs): # pylint: disable=invalid kwargs['config'] = wraps_session_config(kwargs.pop('config', None)) kwargs['is_chief'] = True args = list(args) - if args: - master = args[0] - if not master: - master = '' - args[0] = master - else: - master = kwargs.pop('master', None) - if not master: - master = '' - kwargs['master'] = master prev_monitored_session = _monitored_session.MonitoredSession sess = fn(*args, **kwargs) @@ -1449,4 +1435,4 @@ def export(export_dir_base, as_text=as_text, clear_devices=clear_devices, strip_default_attrs=strip_default_attrs, - modes=[mode]) \ No newline at end of file + modes=[mode]) From e9cba36dbe82caf24d84d4e52c17f553b69152b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=8A=E9=9C=86?= Date: Wed, 3 Apr 2024 17:01:38 +0800 Subject: [PATCH 2/3] [Distributed] Estimator input args sanity check. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 泊霆 --- .../distribute/group_embedding_collective_strategy.py | 2 +- tensorflow/python/distribute/hvd_strategy.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/distribute/group_embedding_collective_strategy.py b/tensorflow/python/distribute/group_embedding_collective_strategy.py index 8fa94ea9f7e..a989aaa889c 100644 --- a/tensorflow/python/distribute/group_embedding_collective_strategy.py +++ b/tensorflow/python/distribute/group_embedding_collective_strategy.py @@ -101,7 +101,7 @@ def estimator(self, model_fn, **kwargs): from tensorflow.python.distribute.hvd_strategy import wraps_estimator _estimator = wraps_estimator(_estimator_lib.Estimator) elif self._hb: - _estimator = hb.estimator.Estimator + _estimator = self._hb.estimator.Estimator return _estimator(model_fn, **kwargs) diff --git a/tensorflow/python/distribute/hvd_strategy.py b/tensorflow/python/distribute/hvd_strategy.py index 23d857dc4c6..6506d84c977 100644 --- a/tensorflow/python/distribute/hvd_strategy.py +++ b/tensorflow/python/distribute/hvd_strategy.py @@ -1060,10 +1060,14 @@ def __init__(self, model_fn, **kwargs): self._eval_drop_remainder = kwargs.pop('eval_drop_remainder', True) self._predict_drop_remainder = kwargs.pop( 'predict_drop_remainder', True) + config = kwargs.get('config', None) + if config is None: + config = run_config_lib.RunConfig() + else: + kwargs.pop('config') super().__init__( - wraps_model_fn(model_fn, model_dir, kwargs['config']), - **kwargs) + wraps_model_fn(model_fn, model_dir, config), **kwargs) def _assert_members_are_not_overridden(self): r'''disable the overridden check here. From f35b1df15d6630de92150908669b6929a916c189 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B3=8A=E9=9C=86?= Date: Sun, 7 Apr 2024 16:31:53 +0800 Subject: [PATCH 3/3] [Distributed] Allow additional horovod DistributedOptimizer args. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 泊霆 --- tensorflow/python/distribute/hvd_strategy.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/hvd_strategy.py b/tensorflow/python/distribute/hvd_strategy.py index 6506d84c977..14d221e02e2 100644 --- a/tensorflow/python/distribute/hvd_strategy.py +++ b/tensorflow/python/distribute/hvd_strategy.py @@ -397,7 +397,15 @@ def __init__(self, learning_rate=0.001, *args, **kwargs): else: def horovod_optimizer(*args, **kwargs): from horovod.tensorflow import DistributedOptimizer - return DistributedOptimizer(HvdOptimizer(*args, **kwargs)) + horovod_args = DistributedOptimizer.__code__.co_varnames + horovod_real_kargs = {} + candidate_keys = list(kwargs.keys()) + for kwarg in candidate_keys: + if kwarg in horovod_args: + value = kwargs[kwarg] + del kwargs[kwarg] + horovod_real_kargs[kwarg] = value + return DistributedOptimizer(HvdOptimizer(*args, **kwargs), **horovod_real_kargs) return horovod_optimizer