Skip to content

Commit

Permalink
Merge pull request #627 from fantes/decoupled_wd
Browse files Browse the repository at this point in the history
dede API for Decoupled wd : adamW and sgdW and amsgradW
  • Loading branch information
beniz authored Aug 24, 2019
2 parents 5e20dda + d9d28ff commit 9731159
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,22 @@ namespace dd
solver_param.set_solver_type(caffe::SolverParameter_SolverType_ADAM);
solver_param.set_amsgrad(true);
}
else if (strcasecmp(solver_type.c_str(),"ADAMW") == 0)
{
solver_param.set_solver_type(caffe::SolverParameter_SolverType_ADAM);
solver_param.set_regularization_type("decoupled");
}
else if (strcasecmp(solver_type.c_str(),"SGDW") == 0)
{
solver_param.set_solver_type(caffe::SolverParameter_SolverType_SGD);
solver_param.set_regularization_type("decoupled");
}
else if (strcasecmp(solver_type.c_str(),"AMSGRADW") == 0)
{
solver_param.set_solver_type(caffe::SolverParameter_SolverType_ADAM);
solver_param.set_amsgrad(true);
solver_param.set_regularization_type("decoupled");
}
caffe::UpgradeSolverType(&solver_param);
}
if (ad_solver.has("test_interval"))
Expand Down Expand Up @@ -1192,6 +1208,10 @@ namespace dd
{
solver.reset(caffe::SolverRegistry<float>::CreateSolver(solver_param));
this->_logger->info("selected solver: " + solver_param.type());
if (solver_param.amsgrad())
this->_logger->info("solver flavor : AMSGRAD ");
if (solver_param.regularization_type() == "decoupled")
this->_logger->info("solver flavor: decoupled weight decay ");
}
catch(...)
{
Expand Down

0 comments on commit 9731159

Please sign in to comment.