Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add const in ParameterUpdater init #963

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/parameter/ParameterUpdaterBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */

namespace paddle {

void ParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void ParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;
for (ParameterType type : getParameterTypes()) {
for (auto& para : parameters) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/parameter/ParameterUpdaterBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ParameterUpdater {
parameterTypes_.push_back(type);
}

virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

// called by Trainer when starting a new pass
virtual void startPass() {}
Expand Down Expand Up @@ -105,7 +105,7 @@ class ParameterUpdaterComposite : public ParameterUpdater {
ParameterUpdaterComposite() {}
virtual ~ParameterUpdaterComposite() {}

virtual void init(std::vector<ParameterPtr>& parameters) = 0;
virtual void init(const std::vector<ParameterPtr>& parameters) = 0;

virtual void startPass() {
syncThreadPool_->execPlusOwner(
Expand Down
3 changes: 2 additions & 1 deletion paddle/trainer/ParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ SgdUpdaterWithCpuAverager::SgdUpdaterWithCpuAverager(
updateWorker_.addJob([]() { hl_set_device(FLAGS_gpu_id); });
}

void SgdUpdaterWithCpuAverager::init(std::vector<ParameterPtr>& parameters) {
void SgdUpdaterWithCpuAverager::init(
const std::vector<ParameterPtr>& parameters) {
SgdLocalUpdater::init(parameters);
averager_->init(parameters_.size(), nullptr);
copyEvents_.resize(parameters_.size());
Expand Down
4 changes: 2 additions & 2 deletions paddle/trainer/ParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class SgdLocalUpdater : public ParameterUpdater {
* be initialized.
* @param parameters The parameter need to be initialized.
*/
virtual void init(std::vector<ParameterPtr>& parameters) {
virtual void init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);
optimizer_->init(parameters_.size(), nullptr);
// check no L1 decay in parameter configs
Expand Down Expand Up @@ -208,7 +208,7 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
* @brief init. Initialize cpu parameters, model average optimizer.
* @param parameters
*/
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

virtual PassType startBatch(int64_t batchSize) {
averager_->startBatch(-1UL);
Expand Down
7 changes: 4 additions & 3 deletions paddle/trainer/RemoteParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ RemoteParameterUpdater::RemoteParameterUpdater(
addParameterType(PARAMETER_MOMENTUM);
}

void RemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void RemoteParameterUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

if (localUpdater_) {
Expand Down Expand Up @@ -595,7 +595,8 @@ SparseRemoteParameterUpdater::SparseRemoteParameterUpdater(
testing_(testing),
useApplyInPserver_(false) {}

void SparseRemoteParameterUpdater::init(std::vector<ParameterPtr>& parameters) {
void SparseRemoteParameterUpdater::init(
const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

parameterClient_.reset(new ParameterClient2(
Expand Down Expand Up @@ -809,7 +810,7 @@ void SparseRemoteParameterUpdater::saveParametersRemote(
}

void SparseRemoteParameterUpdaterComposite::init(
std::vector<ParameterPtr>& parameters) {
const std::vector<ParameterPtr>& parameters) {
parameters_ = parameters;

std::vector<ParameterPtr> parametersArray[NUMBER_UPDATERS];
Expand Down
6 changes: 3 additions & 3 deletions paddle/trainer/RemoteParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
/**
* initialize the internal parameter client and itself.
*/
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
/**
* @brief start batch
*
Expand Down Expand Up @@ -274,7 +274,7 @@ class SparseRemoteParameterUpdater : public ParameterUpdater {
}

/// initialization
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);

/// stateful batch control
virtual PassType startBatch(int64_t batchSize);
Expand Down Expand Up @@ -360,7 +360,7 @@ class SparseRemoteParameterUpdaterComposite : public ParameterUpdaterComposite {
}

/// initialization of dense and sparse updaters
virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
};

class ParameterUpdaterCreators {
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/ThreadParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SgdThreadUpdater::SgdThreadUpdater(const OptimizationConfig& optConfig)
}
}

void SgdThreadUpdater::init(std::vector<ParameterPtr>& parameters) {
void SgdThreadUpdater::init(const std::vector<ParameterPtr>& parameters) {
ParameterUpdater::init(parameters);

// calc max parameter id
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/ThreadParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class SgdThreadUpdater : public ParameterUpdater {
// Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost);

virtual void init(std::vector<ParameterPtr>& parameters);
virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize);
// Call finishBatch for each optimizer.
virtual void finishBatch(real cost);
Expand Down