diff --git a/jraft-core/pom.xml b/jraft-core/pom.xml
index a815c069e..06d98b6e6 100644
--- a/jraft-core/pom.xml
+++ b/jraft-core/pom.xml
@@ -127,5 +127,9 @@
org.openjdk.jmh
jmh-generator-annprocess
+
+ com.alipay.sofa
+ rpc-grpc-impl
+
diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java
index f6ae21856..5f99d3ea3 100644
--- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java
+++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java
@@ -21,10 +21,13 @@
import com.alipay.sofa.jraft.error.RemotingException;
import com.alipay.sofa.jraft.option.RpcOptions;
import com.alipay.sofa.jraft.util.Endpoint;
+import com.google.protobuf.Message;
+import io.grpc.stub.StreamObserver;
/**
*
* @author jiachun.fjc
+ * @author HH
*/
public interface RpcClient extends Lifecycle {
@@ -107,4 +110,31 @@ default void invokeAsync(final Endpoint endpoint, final Object request, final In
*/
void invokeAsync(final Endpoint endpoint, final Object request, final InvokeContext ctx, final InvokeCallback callback,
final long timeoutMs) throws InterruptedException, RemotingException;
+
+ /**
+ * Streaming invocation with a callback.
+ *
+ * @param endpoint target address
+ * @param request request object
+ * @param callback invoke callback
+ * @param timeoutMs timeout millisecond
+ * @return request stream observer.
+ */
+ default StreamObserver invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeCallback callback,
+ final long timeoutMs) {
+ return invokeBidiStreaming(endpoint, request, null, callback, timeoutMs);
+ }
+
+ /**
+ * BidiStreaming invocation.
+ *
+ * @param endpoint target address
+ * @param request request object
+ * @param callback invoke callback
+ * @param ctx invoke context
+ * @param timeoutMs timeout millisecond
+ * @return request stream observer.
+ */
+ StreamObserver invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeContext ctx, final InvokeCallback callback,
+ final long timeoutMs);
}
diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java
index 98407776b..8fa816a18 100644
--- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java
+++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java
@@ -22,6 +22,7 @@
/**
*
* @author jiachun.fjc
+ * @author HH
*/
public interface RpcServer extends Lifecycle {
@@ -39,6 +40,13 @@ public interface RpcServer extends Lifecycle {
*/
void registerProcessor(final RpcProcessor> processor);
+ /**
+ * Register bidiStreaming user processor.
+ *
+ * @param processor the user processor which has a interest
+ */
+ void registerBidiStreamingProcessor(final RpcProcessor> processor);
+
/**
*
* @return bound port
diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java
index 2ba5c6a7c..d1b670d41 100644
--- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java
+++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java
@@ -32,11 +32,14 @@
import com.alipay.sofa.jraft.rpc.impl.core.ClientServiceConnectionEventProcessor;
import com.alipay.sofa.jraft.util.Endpoint;
import com.alipay.sofa.jraft.util.Requires;
+import com.google.protobuf.Message;
+import io.grpc.stub.StreamObserver;
/**
* Bolt rpc client impl.
*
* @author jiachun.fjc
+ * @author HH
*/
public class BoltRpcClient implements RpcClient {
@@ -119,6 +122,12 @@ public void invokeAsync(final Endpoint endpoint, final Object request, final Inv
}
}
+ @Override
+ public StreamObserver invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeContext ctx,
+ final InvokeCallback callback, final long timeoutMs) {
+ throw new UnsupportedOperationException();
+ }
+
public com.alipay.remoting.rpc.RpcClient getRpcClient() {
return rpcClient;
}
diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java
index ae008644d..cf27daf0f 100644
--- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java
+++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java
@@ -22,7 +22,6 @@
import com.alipay.remoting.BizContext;
import com.alipay.remoting.ConnectionEventType;
import com.alipay.remoting.config.BoltClientOption;
-import com.alipay.remoting.config.switches.GlobalSwitch;
import com.alipay.remoting.rpc.protocol.AsyncUserProcessor;
import com.alipay.sofa.jraft.rpc.Connection;
import com.alipay.sofa.jraft.rpc.RpcContext;
@@ -34,6 +33,7 @@
* Bolt RPC server impl.
*
* @author jiachun.fjc
+ * @author HH
*/
public class BoltRpcServer implements RpcServer {
@@ -87,6 +87,11 @@ public void close() {
});
}
+ @Override
+ public void registerBidiStreamingProcessor(final RpcProcessor> processor) {
+ throw new UnsupportedOperationException();
+ }
+
@Override
public int boundPort() {
return this.rpcServer.port();
diff --git a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java
index 256611a79..a5a9f081a 100644
--- a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java
+++ b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java
@@ -57,6 +57,7 @@
*
* @author nicholas.jxf
* @author jiachun.fjc
+ * @author HH
*/
public class GrpcClient implements RpcClient {
@@ -137,10 +138,9 @@ public Object invokeSync(final Endpoint endpoint, final Object request, final In
@Override
public void invokeAsync(final Endpoint endpoint, final Object request, final InvokeContext ctx,
final InvokeCallback callback, final long timeoutMs) {
- Requires.requireNonNull(endpoint, "endpoint");
- Requires.requireNonNull(request, "request");
+ nonNullCheck(endpoint, request);
- final Executor executor = callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE;
+ final Executor executor = getExecutor(callback);
final Channel ch = getCheckedChannel(endpoint);
if (ch == null) {
@@ -149,7 +149,7 @@ public void invokeAsync(final Endpoint endpoint, final Object request, final Inv
}
final MethodDescriptor method = getCallMethod(request);
- final CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
+ final CallOptions callOpts = getCallOpts(timeoutMs);
ClientCalls.asyncUnaryCall(ch.newCall(method, callOpts), (Message) request, new StreamObserver() {
@@ -170,6 +170,53 @@ public void onCompleted() {
});
}
+ @Override
+ public StreamObserver invokeBidiStreaming(Endpoint endpoint, Object request, InvokeContext ctx, InvokeCallback callback, long timeoutMs) {
+ nonNullCheck(endpoint, request);
+
+ final Executor executor = getExecutor(callback);
+
+ final Channel ch = getCheckedChannel(endpoint);
+ if (ch == null) {
+ executor.execute(() -> callback.complete(null, new RemotingException("Fail to connect: " + endpoint)));
+ return null;
+ }
+
+ final MethodDescriptor method = getCallMethod(request);
+ final CallOptions callOpts = getCallOpts(timeoutMs);
+
+ return ClientCalls
+ .asyncBidiStreamingCall(ch.newCall(method, callOpts), new StreamObserver() {
+ @Override
+ public void onNext(final Message value) {
+ executor.execute(() -> callback.complete(value, null));
+ }
+
+ @Override
+ public void onError(final Throwable throwable) {
+ executor.execute(() -> callback.complete(null, throwable));
+ }
+
+ @Override
+ public void onCompleted() {
+ // NO-OP
+ }
+ });
+ }
+
+ private void nonNullCheck(final Endpoint endpoint, final Object request) {
+ Requires.requireNonNull(endpoint, "endpoint");
+ Requires.requireNonNull(request, "request");
+ }
+
+ private Executor getExecutor(final InvokeCallback callback) {
+ return callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE;
+ }
+
+ private CallOptions getCallOpts(final long timeoutMs) {
+ return CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
+ }
+
private MethodDescriptor getCallMethod(final Object request) {
final String interest = request.getClass().getName();
final Message reqIns = Requires.requireNonNull(this.parserClasses.get(interest), "null default instance: "
@@ -195,12 +242,13 @@ private ManagedChannel getCheckedChannel(final Endpoint endpoint) {
}
private ManagedChannel getChannel(final Endpoint endpoint, final boolean createIfAbsent) {
- if (createIfAbsent) {
- return this.managedChannelPool.computeIfAbsent(endpoint, this::newChannel);
- } else {
- return this.managedChannelPool.get(endpoint);
- }
- }
+ if (createIfAbsent) {
+ return this.managedChannelPool.computeIfAbsent(endpoint, this::newChannel);
+ }
+ else {
+ return this.managedChannelPool.get(endpoint);
+ }
+ }
private ManagedChannel newChannel(final Endpoint endpoint) {
final ManagedChannel ch = ManagedChannelBuilder.forAddress(endpoint.getIp(), endpoint.getPort()) //
@@ -222,8 +270,8 @@ private ManagedChannel removeChannel(final Endpoint endpoint) {
}
private void notifyWhenStateChanged(final ConnectivityState state, final Endpoint endpoint, final ManagedChannel ch) {
- ch.notifyWhenStateChanged(state, () -> onStateChanged(endpoint, ch));
- }
+ ch.notifyWhenStateChanged(state, () -> onStateChanged(endpoint, ch));
+ }
private void onStateChanged(final Endpoint endpoint, final ManagedChannel ch) {
final ConnectivityState state = ch.getState(false);
@@ -252,27 +300,29 @@ private void onStateChanged(final Endpoint endpoint, final ManagedChannel ch) {
}
private void notifyReady(final Endpoint endpoint) {
- LOG.info("The channel {} has successfully established.", endpoint);
-
- clearConnFailuresCount(endpoint);
-
- final ReplicatorGroup rpGroup = this.replicatorGroup;
- if (rpGroup != null) {
- try {
- RpcUtils.runInThread(() -> {
- final PeerId peer = new PeerId();
- if (peer.parse(endpoint.toString())) {
- LOG.info("Peer {} is connected.", peer);
- rpGroup.checkReplicator(peer, true);
- } else {
- LOG.error("Fail to parse peer: {}.", endpoint);
- }
- });
- } catch (final Throwable t) {
- LOG.error("Fail to check replicator {}.", endpoint, t);
- }
- }
- }
+ LOG.info("The channel {} has successfully established.", endpoint);
+
+ clearConnFailuresCount(endpoint);
+
+ final ReplicatorGroup rpGroup = this.replicatorGroup;
+ if (rpGroup != null) {
+ try {
+ RpcUtils.runInThread(() -> {
+ final PeerId peer = new PeerId();
+ if (peer.parse(endpoint.toString())) {
+ LOG.info("Peer {} is connected.", peer);
+ rpGroup.checkReplicator(peer, true);
+ }
+ else {
+ LOG.error("Fail to parse peer: {}.", endpoint);
+ }
+ });
+ }
+ catch (final Throwable t) {
+ LOG.error("Fail to check replicator {}.", endpoint, t);
+ }
+ }
+ }
private void notifyFailure(final Endpoint endpoint) {
LOG.warn("There has been some transient failure on this channel {}.", endpoint);
@@ -310,8 +360,8 @@ private boolean checkChannel(final Endpoint endpoint, final boolean createIfAbse
}
private int incConnFailuresCount(final Endpoint endpoint) {
- return this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger()).incrementAndGet();
- }
+ return this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger()).incrementAndGet();
+ }
private void clearConnFailuresCount(final Endpoint endpoint) {
this.transientFailures.remove(endpoint);
diff --git a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java
index 45cf7fd77..93a8423c6 100644
--- a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java
+++ b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java
@@ -35,6 +35,7 @@
import io.grpc.ServerServiceDefinition;
import io.grpc.protobuf.ProtoUtils;
import io.grpc.stub.ServerCalls;
+import io.grpc.stub.StreamObserver;
import io.grpc.util.MutableHandlerRegistry;
import org.slf4j.Logger;
@@ -57,6 +58,7 @@
*
* @author nicholas.jxf
* @author jiachun.fjc
+ * @author HH
*/
public class GrpcServer implements RpcServer {
@@ -84,33 +86,34 @@ public GrpcServer(Server server, MutableHandlerRegistry handlerRegistry, Map()) //
- .threadFactory(new NamedThreadFactory(EXECUTOR_NAME + "-", true)) //
- .rejectedHandler((r, executor) -> {
- throw new RejectedExecutionException("[" + EXECUTOR_NAME + "], task " + r.toString() +
- " rejected from " +
- executor.toString());
- })
- .build();
-
- try {
- this.server.start();
- } catch (final IOException e) {
- ThrowUtil.throwException(e);
- }
- return true;
- }
+ public boolean init(final Void opts) {
+ if (!this.started.compareAndSet(false, true)) {
+ throw new IllegalStateException("grpc server has started");
+ }
+
+ this.defaultExecutor = ThreadPoolUtil.newBuilder() //
+ .poolName(EXECUTOR_NAME) //
+ .enableMetric(true) //
+ .coreThreads(Math.min(20, GrpcRaftRpcFactory.RPC_SERVER_PROCESSOR_POOL_SIZE / 5)) //
+ .maximumThreads(GrpcRaftRpcFactory.RPC_SERVER_PROCESSOR_POOL_SIZE) //
+ .keepAliveSeconds(60L) //
+ .workQueue(new SynchronousQueue<>()) //
+ .threadFactory(new NamedThreadFactory(EXECUTOR_NAME + "-", true)) //
+ .rejectedHandler((r, executor) -> {
+ throw new RejectedExecutionException("[" + EXECUTOR_NAME + "], task " + r.toString() +
+ " rejected from " +
+ executor.toString());
+ })
+ .build();
+
+ try {
+ this.server.start();
+ }
+ catch (final IOException e) {
+ ThrowUtil.throwException(e);
+ }
+ return true;
+ }
@Override
public void shutdown() {
@@ -126,84 +129,128 @@ public void registerConnectionClosedEventListener(final ConnectionClosedEventLis
this.closedEventListeners.add(listener);
}
- @SuppressWarnings("unchecked")
@Override
- public void registerProcessor(final RpcProcessor processor) {
- final String interest = processor.interest();
- final Message reqIns = Requires.requireNonNull(this.parserClasses.get(interest), "null default instance: " + interest);
- final MethodDescriptor method = MethodDescriptor //
- .newBuilder() //
- .setType(MethodDescriptor.MethodType.UNARY) //
- .setFullMethodName(
- MethodDescriptor.generateFullMethodName(processor.interest(), GrpcRaftRpcFactory.FIXED_METHOD_NAME)) //
- .setRequestMarshaller(ProtoUtils.marshaller(reqIns)) //
- .setResponseMarshaller(ProtoUtils.marshaller(this.marshallerRegistry.findResponseInstanceByRequest(interest))) //
- .build();
-
- final ServerCallHandler handler = ServerCalls.asyncUnaryCall(
- (request, responseObserver) -> {
- final SocketAddress remoteAddress = RemoteAddressInterceptor.getRemoteAddress();
- final Connection conn = ConnectionInterceptor.getCurrentConnection(this.closedEventListeners);
-
- final RpcContext rpcCtx = new RpcContext() {
-
- @Override
- public void sendResponse(final Object responseObj) {
- try {
- responseObserver.onNext((Message) responseObj);
- responseObserver.onCompleted();
- } catch (final Throwable t) {
- LOG.warn("[GRPC] failed to send response.", t);
- }
- }
-
- @Override
- public Connection getConnection() {
- if (conn == null) {
- throw new IllegalStateException("fail to get connection");
- }
- return conn;
- }
-
- @Override
- public String getRemoteAddress() {
- // Rely on GRPC's capabilities, not magic (netty channel)
- return remoteAddress != null ? remoteAddress.toString() : null;
- }
- };
-
- final RpcProcessor.ExecutorSelector selector = processor.executorSelector();
- Executor executor;
- if (selector != null && request instanceof RpcRequests.AppendEntriesRequest) {
- final RpcRequests.AppendEntriesRequest req = (RpcRequests.AppendEntriesRequest) request;
- final RpcRequests.AppendEntriesRequestHeader.Builder header = RpcRequests.AppendEntriesRequestHeader //
- .newBuilder() //
- .setGroupId(req.getGroupId()) //
- .setPeerId(req.getPeerId()) //
- .setServerId(req.getServerId());
- executor = selector.select(interest, header.build());
- } else {
- executor = processor.executor();
- }
-
- if (executor == null) {
- executor = this.defaultExecutor;
- }
-
- if (executor != null) {
- executor.execute(() -> processor.handleRequest(rpcCtx, request));
- } else {
- processor.handleRequest(rpcCtx, request);
- }
- });
+ public void registerBidiStreamingProcessor(final RpcProcessor processor) {
+ final String interest = processor.interest();
+ final MethodDescriptor method = buildMethodDescriptor(interest);
+
+ final ServerCallHandler handler = ServerCalls.asyncBidiStreamingCall(responseObserver ->
+ new StreamObserver() {
+ @Override
+ public void onNext(final Message request) {
+ handleRequest(processor, request, responseObserver);
+ }
+
+ @Override
+ public void onError(final Throwable throwable) {
+ LOG.error("[Grpc] error", throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+ responseObserver.onCompleted();
+ }
+ });
+
+ serviceRegistry(interest, method, handler);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Override
+ public void registerProcessor(final RpcProcessor processor) {
+ final String interest = processor.interest();
+ final MethodDescriptor method = buildMethodDescriptor(interest);
+
+ final ServerCallHandler handler = ServerCalls.asyncUnaryCall(
+ (request, responseObserver) -> {
+ handleRequest(processor, request, responseObserver);
+ });
+
+ serviceRegistry(interest, method, handler);
+ }
+
+ private MethodDescriptor buildMethodDescriptor(final String interest) {
+ final Message reqIns = Requires.requireNonNull(this.parserClasses.get(interest), "null default instance: "
+ + interest);
+ return MethodDescriptor //
+ . newBuilder() //
+ .setType(MethodDescriptor.MethodType.UNARY) //
+ .setFullMethodName(MethodDescriptor.generateFullMethodName(interest, GrpcRaftRpcFactory.FIXED_METHOD_NAME)) //
+ .setRequestMarshaller(ProtoUtils.marshaller(reqIns)) //
+ .setResponseMarshaller(
+ ProtoUtils.marshaller(this.marshallerRegistry.findResponseInstanceByRequest(interest))) //
+ .build();
+ }
+ private void handleRequest(final RpcProcessor processor, final Message request,
+ final StreamObserver responseObserver) {
+ final SocketAddress remoteAddress = RemoteAddressInterceptor.getRemoteAddress();
+ final Connection conn = ConnectionInterceptor.getCurrentConnection(this.closedEventListeners);
+
+ final RpcContext rpcCtx = new RpcContext() {
+
+ @Override
+ public void sendResponse(final Object responseObj) {
+ try {
+ responseObserver.onNext((Message) responseObj);
+ responseObserver.onCompleted();
+ }
+ catch (final Throwable t) {
+ LOG.warn("[GRPC] failed to send response.", t);
+ }
+ }
+
+ @Override
+ public Connection getConnection() {
+ if (conn == null) {
+ throw new IllegalStateException("fail to get connection");
+ }
+ return conn;
+ }
+
+ @Override
+ public String getRemoteAddress() {
+ // Rely on GRPC's capabilities, not magic (netty channel)
+ return remoteAddress != null ? remoteAddress.toString() : null;
+ }
+ };
+
+ final RpcProcessor.ExecutorSelector selector = processor.executorSelector();
+ Executor executor;
+ if (selector != null && request instanceof RpcRequests.AppendEntriesRequest) {
+ final RpcRequests.AppendEntriesRequest req = (RpcRequests.AppendEntriesRequest) request;
+ final RpcRequests.AppendEntriesRequestHeader.Builder header = RpcRequests.AppendEntriesRequestHeader //
+ .newBuilder() //
+ .setGroupId(req.getGroupId()) //
+ .setPeerId(req.getPeerId()) //
+ .setServerId(req.getServerId());
+ executor = selector.select(processor.interest(), header.build());
+ }
+ else {
+ executor = processor.executor();
+ }
+
+ if (executor == null) {
+ executor = this.defaultExecutor;
+ }
+
+ if (executor != null) {
+ executor.execute(() -> processor.handleRequest(rpcCtx, request));
+ }
+ else {
+ processor.handleRequest(rpcCtx, request);
+ }
+ }
+
+ private void serviceRegistry(final String interest, final MethodDescriptor method,
+ final ServerCallHandler handler) {
final ServerServiceDefinition serviceDef = ServerServiceDefinition //
- .builder(interest) //
- .addMethod(method, handler) //
- .build();
+ .builder(interest) //
+ .addMethod(method, handler) //
+ .build();
- this.handlerRegistry
- .addService(ServerInterceptors.intercept(serviceDef, this.serverInterceptors.toArray(new ServerInterceptor[0])));
+ this.handlerRegistry.addService(ServerInterceptors.intercept(serviceDef,
+ this.serverInterceptors.toArray(new ServerInterceptor[0])));
}
@Override