Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functionalize distributed_fused_lamb kernel #53896

Merged
merged 21 commits into from
May 23, 2023
68 changes: 62 additions & 6 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/kernel_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -170,8 +171,63 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
ops::DistributedFusedLambOp,
ops::DistributedFusedLambOpMaker);

PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb,
huangjiyi marked this conversation as resolved.
Show resolved Hide resolved
CPU,
ALL_LAYOUT,
ops::DistributedFusedLambOpKernel,
float) {}
namespace phi {
namespace fusion {

template <typename T, typename Context>
void DistributedFusedLambKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out,
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
PADDLE_THROW(phi::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}

} // namespace fusion
} // namespace phi

PD_REGISTER_KERNEL(distributed_fused_lamb,
CPU,
ALL_LAYOUT,
phi::fusion::DistributedFusedLambKernel,
float) {}
Loading