Skip to content

Commit

Permalink
Functionalize distributed_fused_lamb kernel (#53896)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update HostAlloc

* update param name

* update cpu kernel

* remove kernel header

* update

* update
  • Loading branch information
huangjiyi authored May 23, 2023
1 parent 6e0cf61 commit 5f8e7d8
Show file tree
Hide file tree
Showing 5 changed files with 1,135 additions and 1,037 deletions.
68 changes: 62 additions & 6 deletions paddle/fluid/operators/optimizers/distributed_fused_lamb_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/optimizers/distributed_fused_lamb_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/kernel_registry.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -170,8 +171,63 @@ REGISTER_OP_WITHOUT_GRADIENT(distributed_fused_lamb,
ops::DistributedFusedLambOp,
ops::DistributedFusedLambOpMaker);

PD_REGISTER_STRUCT_KERNEL(distributed_fused_lamb,
CPU,
ALL_LAYOUT,
ops::DistributedFusedLambOpKernel,
float) {}
namespace phi {
namespace fusion {

template <typename T, typename Context>
void DistributedFusedLambKernel(const Context &dev_ctx,
const std::vector<const DenseTensor *> &param,
const std::vector<const DenseTensor *> &grad,
const paddle::optional<DenseTensor> &fp32_param,
const paddle::optional<DenseTensor> &fp32_grad,
const paddle::optional<DenseTensor> &fp16_param,
const paddle::optional<DenseTensor> &fp16_grad,
const DenseTensor &moment1,
const DenseTensor &moment2,
const DenseTensor &beta1_pow,
const DenseTensor &beta2_pow,
const DenseTensor &param_offsets,
const DenseTensor &fp32_partial_offsets,
const DenseTensor &fp16_partial_offsets,
const DenseTensor &param_info,
const DenseTensor &param_order,
const DenseTensor &learning_rate,
const DenseTensor &global_scale,
int acc_steps,
float beta1,
float beta2,
float epsilon,
float max_global_grad_norm,
float weight_decay,
bool clip_after_allreduce,
bool use_master_param_norm,
bool use_master_acc_grad,
bool is_grad_scaled_by_nranks,
bool use_hierarchical_allreduce,
int64_t nranks,
const std::vector<int> &ring_ids,
DenseTensor *fp32_param_out,
DenseTensor *fp16_param_out,
DenseTensor *fp32_acc_grad,
DenseTensor *fp16_acc_grad,
DenseTensor *moment1_out,
DenseTensor *moment2_out,
DenseTensor *beta1_pow_out,
DenseTensor *beta2_pow_out,
DenseTensor *param_out,
DenseTensor *found_inf,
DenseTensor *acc_step,
DenseTensor *stop_update,
DenseTensor *step) {
PADDLE_THROW(phi::errors::Unimplemented(
"The distributed_fused_lamb operator does not support CPU yet."));
}

} // namespace fusion
} // namespace phi

PD_REGISTER_KERNEL(distributed_fused_lamb,
CPU,
ALL_LAYOUT,
phi::fusion::DistributedFusedLambKernel,
float) {}
Loading

0 comments on commit 5f8e7d8

Please sign in to comment.