Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ParameterServer python api #1066

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions demo/quick_start/cluster/pserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from py_paddle import swig_paddle as api
import paddle.proto.ParameterServerConfig_pb2 as ParameterServerConfig


def main():
api.initPaddle()
pServerConfig = ParameterServerConfig.ParameterServerConfig()
pServerConfig.ports_num = 1
pServerConfig.nics = "lo0"
pServerConfig.num_gradient_servers = 1
pServerConfig.port = 7164
pserver = api.ParameterServer.createFromConfigProto(pServerConfig)
pserver.start()
pserver.wait()


if __name__ == '__main__':
main()
3 changes: 2 additions & 1 deletion paddle/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ set(API_SOURCES
SequenceGenerator.cpp
Trainer.cpp
Util.cpp
Vector.cpp)
Vector.cpp
ParameterServer.cpp)
set(API_HEADER
PaddleAPI.h
Internal.h)
Expand Down
3 changes: 3 additions & 0 deletions paddle/api/Paddle.swig
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ namespace std {
%newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater;
%newobject ParameterUpdater::createRemoteUpdater;
%newobject ParameterServer::createByConfigProtoPtr;
%newobject ParameterServer::createByConfigProtoStr;

%feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide
Expand All @@ -197,5 +199,6 @@ namespace std {
%ignore ParameterConfigPrivate;
%ignore OptimizationConfigPrivate;
%ignore ParameterTraverseCallbackPrivate;
%ignore ParameterServerPrivate;
%include "utils/GlobalConstants.h"
%include "api/PaddleAPI.h"
22 changes: 22 additions & 0 deletions paddle/api/PaddleAPI.h
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,28 @@ class ParameterUpdater {
ParameterUpdaterPrivate* m;
};

struct ParameterServerPrivate;
class ParameterServer {
private:
ParameterServer();

public:
static ParameterServer* createByConfigProtoPtr(const void* confPtr);
static ParameterServer* createByConfigProtoStr(const std::string& protoStr);

~ParameterServer();

/**
* @brief initialize Parameter Server.
* @param gm
*/
void start();
void wait();

private:
ParameterServerPrivate* m;
};

struct EvaluatorPrivate;
class Evaluator {
private:
Expand Down
5 changes: 5 additions & 0 deletions paddle/api/PaddleAPIPrivate.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/parameter/ParameterUpdaterBase.h"
#include "paddle/pserver/ParameterServerController.h"
#include "paddle/trainer/TrainerConfigHelper.h"

struct GradientMachinePrivate {
Expand Down Expand Up @@ -72,6 +73,10 @@ struct ParameterUpdaterPrivate {
std::unique_ptr<paddle::ParameterUpdater> updater;
};

struct ParameterServerPrivate {
std::unique_ptr<paddle::ParameterServerController> parameterServerController;
};

struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater,
Expand Down
44 changes: 44 additions & 0 deletions paddle/api/ParameterServer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "PaddleAPI.h"

#include "PaddleAPIPrivate.h"

ParameterServer::ParameterServer() : m(new ParameterServerPrivate()) {}

ParameterServer* ParameterServer::createByConfigProtoPtr(const void* confPtr) {
auto& conf = *(const paddle::ParameterServerConfig*)(confPtr);
auto pServer = new ParameterServer();
pServer->m->parameterServerController.reset(
paddle::ParameterServerController::create(conf));
return pServer;
}

ParameterServer* ParameterServer::createByConfigProtoStr(
const std::string& protoStr) {
paddle::ParameterServerConfig conf;
conf.ParseFromString(protoStr);
if (conf.IsInitialized()) {
return ParameterServer::createByConfigProtoPtr(&conf);
} else {
return nullptr;
}
}

ParameterServer::~ParameterServer() { delete m; }

void ParameterServer::start() { m->parameterServerController->start(); }

void ParameterServer::wait() { m->parameterServerController->wait(); }
32 changes: 26 additions & 6 deletions paddle/py_paddle/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,20 @@
Some Useful method for py_paddle.
"""

import swig_paddle
import os
import paddle.trainer.PyDataProviderWrapper
import paddle.proto.ParameterConfig_pb2
import paddle.proto.ModelConfig_pb2
import paddle.proto.TrainerConfig_pb2
import weakref
import numpy
import struct
import sys
import copy

import swig_paddle
import paddle.trainer.PyDataProviderWrapper
import paddle.proto.ParameterConfig_pb2
import paddle.proto.ModelConfig_pb2
import paddle.proto.TrainerConfig_pb2
import paddle.proto.ParameterServerConfig_pb2


def initializePaddle(*args):
"""
Expand Down Expand Up @@ -558,11 +560,29 @@ def getForwardOutput(self):
swig_paddle.Trainer.getForwardOutput = getForwardOutput


def __monkeypatch_parameter_server__():
def createFromConfigProto(protoObj):
"""
Create Parameter Server From Proto object.
:param protoObj: ParameterServer Config
:type protoObj: proto.ParameterServerConfig_pb2.ParameterServerConfig
:return: paddle.ParameterServer
"""
assert isinstance(
protoObj,
paddle.proto.ParameterServerConfig_pb2.ParameterServerConfig)
return swig_paddle.ParameterServer.createByConfigProtoStr(
protoObj.SerializeToString())

swig_paddle.ParameterServer.createFromConfigProto = \
staticmethod(createFromConfigProto)


def monkeypatches():
patches = [
__monkeypatch_init_paddle__, __monkeypatch_gradient_machine__,
__monkey_patch_protobuf_objects__, __monkey_patch_parameter__,
__monkey_patch_trainer__
__monkey_patch_trainer__, __monkeypatch_parameter_server__
]
for patch in patches:
patch()
18 changes: 9 additions & 9 deletions proto/ParameterServerConfig.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,22 @@ message ParameterClientConfig {
message ParameterServerConfig {
// The ports number for parameter send,
// increment based on default port number
required int32 ports_num = 1 [default = 1];
optional int32 ports_num = 1 [default = 1];
// The ports number for parameter send,
// increment based on default (port + ports_num
required int32 ports_num_for_sparse = 2 [default = 0];
optional int32 ports_num_for_sparse = 2 [default = 0];
// network device name for pservers
required string nics = 3 [default = "xgbe0,xgbe1"];
required string rdma_tcp = 4 [default = "tcp"];
optional string nics = 3 [default = "xgbe0,xgbe1"];
optional string rdma_tcp = 4 [default = "tcp"];
// Listening port for pserver
required int32 port = 5 [default = 20134];
optional int32 port = 5 [default = 20134];
// number of gradient servers
required int32 num_gradient_servers = 6 [default = 1];
optional int32 num_gradient_servers = 6 [default = 1];
// number of threads for sync op exec
required int32 pserver_num_threads = 7 [default = 1];
optional int32 pserver_num_threads = 7 [default = 1];
// control config_.async_lagged_grad_discard_ratio() min value
required double async_lagged_ratio_min = 8 [default = 1.0];
optional double async_lagged_ratio_min = 8 [default = 1.0];
// if async_lagged_grad_discard_ratio is not set in trainer_config.conf
// use it as defalut value
required double async_lagged_ratio_default = 9 [default = 1.5];
optional double async_lagged_ratio_default = 9 [default = 1.5];
}