diff --git a/example/mxnet/train_gluon_imagenet_byteps_gc.py b/example/mxnet/train_gluon_imagenet_byteps_gc.py index ce5c76d86..049e99fdb 100644 --- a/example/mxnet/train_gluon_imagenet_byteps_gc.py +++ b/example/mxnet/train_gluon_imagenet_byteps_gc.py @@ -402,7 +402,7 @@ def train(ctx): setattr( param, "byteps_compressor_onebit_enable_scale", opt.onebit_scaling) if opt.compress_momentum: - setattr(param, "byteps_momentum_type", "vanilla") + setattr(param, "byteps_momentum_type", "nesterov") setattr(param, "byteps_momentum_mu", opt.momentum) if opt.compress_momentum: