Skip to content

Commit

Permalink
Merge pull request #630 from fantes/decoupled_wd_schedule_fix
Browse files Browse the repository at this point in the history
add API params for decoupled weight decay scheduler
  • Loading branch information
beniz authored Aug 30, 2019
2 parents f477e89 + ccb72b1 commit 19989f7
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1128,17 +1128,48 @@ namespace dd
{
solver_param.set_solver_type(caffe::SolverParameter_SolverType_ADAM);
solver_param.set_regularization_type("decoupled");
int periods = 4;
int mult = 2;
if (ad_solver.has("decoupled_wd_mult"))
solver_param.set_decoupled_wd_t0(ad_solver.get("decoupled_wd_mult").get<double>());
if (ad_solver.has("decoupled_wd_periods"))
periods = ad_solver.get("decoupled_wd_periods").get<int>();
if (max_iter > 0)
solver_param.set_decoupled_wd_t0(max_iter/ (pow(mult,periods)-1));
else
throw MLLibBadParamException("decoupled weight decay requires iterations to be set");

}
else if (strcasecmp(solver_type.c_str(),"SGDW") == 0)
{
solver_param.set_solver_type(caffe::SolverParameter_SolverType_SGD);
solver_param.set_regularization_type("decoupled");
int periods = 4;
int mult = 2;
if (ad_solver.has("decoupled_wd_mult"))
solver_param.set_decoupled_wd_t0(ad_solver.get("decoupled_wd_mult").get<double>());
if (ad_solver.has("decoupled_wd_periods"))
periods = ad_solver.get("decoupled_wd_periods").get<int>();
if (max_iter > 0)
solver_param.set_decoupled_wd_t0(max_iter / (pow(mult,periods)-1));
else
throw MLLibBadParamException("decoupled weight decay requires iterations to be set");
}
else if (strcasecmp(solver_type.c_str(),"AMSGRADW") == 0)
{
solver_param.set_solver_type(caffe::SolverParameter_SolverType_ADAM);
solver_param.set_amsgrad(true);
solver_param.set_regularization_type("decoupled");
int periods = 4;
int mult = 2;
if (ad_solver.has("decoupled_wd_mult"))
solver_param.set_decoupled_wd_t0(ad_solver.get("decoupled_wd_mult").get<double>());
if (ad_solver.has("decoupled_wd_periods"))
periods = ad_solver.get("decoupled_wd_periods").get<int>();
if (max_iter > 0)
solver_param.set_decoupled_wd_t0(max_iter/ (pow(mult,periods)-1));
else
throw MLLibBadParamException("decoupled weight decay requires iterations to be set");
}
if (ad_solver.has("rectified") && ad_solver.get("rectified").get<bool>())
solver_param.set_rectified(true);
Expand Down

0 comments on commit 19989f7

Please sign in to comment.