diff --git a/jraft-core/pom.xml b/jraft-core/pom.xml index a815c069e..06d98b6e6 100644 --- a/jraft-core/pom.xml +++ b/jraft-core/pom.xml @@ -127,5 +127,9 @@ org.openjdk.jmh jmh-generator-annprocess + + com.alipay.sofa + rpc-grpc-impl + diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java index f6ae21856..5f99d3ea3 100644 --- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java +++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java @@ -21,10 +21,13 @@ import com.alipay.sofa.jraft.error.RemotingException; import com.alipay.sofa.jraft.option.RpcOptions; import com.alipay.sofa.jraft.util.Endpoint; +import com.google.protobuf.Message; +import io.grpc.stub.StreamObserver; /** * * @author jiachun.fjc + * @author HH */ public interface RpcClient extends Lifecycle { @@ -107,4 +110,31 @@ default void invokeAsync(final Endpoint endpoint, final Object request, final In */ void invokeAsync(final Endpoint endpoint, final Object request, final InvokeContext ctx, final InvokeCallback callback, final long timeoutMs) throws InterruptedException, RemotingException; + + /** + * Streaming invocation with a callback. + * + * @param endpoint target address + * @param request request object + * @param callback invoke callback + * @param timeoutMs timeout millisecond + * @return request stream observer. + */ + default StreamObserver invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeCallback callback, + final long timeoutMs) { + return invokeBidiStreaming(endpoint, request, null, callback, timeoutMs); + } + + /** + * BidiStreaming invocation. + * + * @param endpoint target address + * @param request request object + * @param callback invoke callback + * @param ctx invoke context + * @param timeoutMs timeout millisecond + * @return request stream observer. + */ + StreamObserver invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeContext ctx, final InvokeCallback callback, + final long timeoutMs); } diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java index 98407776b..8fa816a18 100644 --- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java +++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcServer.java @@ -22,6 +22,7 @@ /** * * @author jiachun.fjc + * @author HH */ public interface RpcServer extends Lifecycle { @@ -39,6 +40,13 @@ public interface RpcServer extends Lifecycle { */ void registerProcessor(final RpcProcessor processor); + /** + * Register bidiStreaming user processor. + * + * @param processor the user processor which has a interest + */ + void registerBidiStreamingProcessor(final RpcProcessor processor); + /** * * @return bound port diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java index 2ba5c6a7c..d1b670d41 100644 --- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java +++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcClient.java @@ -32,11 +32,14 @@ import com.alipay.sofa.jraft.rpc.impl.core.ClientServiceConnectionEventProcessor; import com.alipay.sofa.jraft.util.Endpoint; import com.alipay.sofa.jraft.util.Requires; +import com.google.protobuf.Message; +import io.grpc.stub.StreamObserver; /** * Bolt rpc client impl. * * @author jiachun.fjc + * @author HH */ public class BoltRpcClient implements RpcClient { @@ -119,6 +122,12 @@ public void invokeAsync(final Endpoint endpoint, final Object request, final Inv } } + @Override + public StreamObserver invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeContext ctx, + final InvokeCallback callback, final long timeoutMs) { + throw new UnsupportedOperationException(); + } + public com.alipay.remoting.rpc.RpcClient getRpcClient() { return rpcClient; } diff --git a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java index ae008644d..cf27daf0f 100644 --- a/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java +++ b/jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/impl/BoltRpcServer.java @@ -22,7 +22,6 @@ import com.alipay.remoting.BizContext; import com.alipay.remoting.ConnectionEventType; import com.alipay.remoting.config.BoltClientOption; -import com.alipay.remoting.config.switches.GlobalSwitch; import com.alipay.remoting.rpc.protocol.AsyncUserProcessor; import com.alipay.sofa.jraft.rpc.Connection; import com.alipay.sofa.jraft.rpc.RpcContext; @@ -34,6 +33,7 @@ * Bolt RPC server impl. * * @author jiachun.fjc + * @author HH */ public class BoltRpcServer implements RpcServer { @@ -87,6 +87,11 @@ public void close() { }); } + @Override + public void registerBidiStreamingProcessor(final RpcProcessor processor) { + throw new UnsupportedOperationException(); + } + @Override public int boundPort() { return this.rpcServer.port(); diff --git a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java index 256611a79..a5a9f081a 100644 --- a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java +++ b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcClient.java @@ -57,6 +57,7 @@ * * @author nicholas.jxf * @author jiachun.fjc + * @author HH */ public class GrpcClient implements RpcClient { @@ -137,10 +138,9 @@ public Object invokeSync(final Endpoint endpoint, final Object request, final In @Override public void invokeAsync(final Endpoint endpoint, final Object request, final InvokeContext ctx, final InvokeCallback callback, final long timeoutMs) { - Requires.requireNonNull(endpoint, "endpoint"); - Requires.requireNonNull(request, "request"); + nonNullCheck(endpoint, request); - final Executor executor = callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE; + final Executor executor = getExecutor(callback); final Channel ch = getCheckedChannel(endpoint); if (ch == null) { @@ -149,7 +149,7 @@ public void invokeAsync(final Endpoint endpoint, final Object request, final Inv } final MethodDescriptor method = getCallMethod(request); - final CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS); + final CallOptions callOpts = getCallOpts(timeoutMs); ClientCalls.asyncUnaryCall(ch.newCall(method, callOpts), (Message) request, new StreamObserver() { @@ -170,6 +170,53 @@ public void onCompleted() { }); } + @Override + public StreamObserver invokeBidiStreaming(Endpoint endpoint, Object request, InvokeContext ctx, InvokeCallback callback, long timeoutMs) { + nonNullCheck(endpoint, request); + + final Executor executor = getExecutor(callback); + + final Channel ch = getCheckedChannel(endpoint); + if (ch == null) { + executor.execute(() -> callback.complete(null, new RemotingException("Fail to connect: " + endpoint))); + return null; + } + + final MethodDescriptor method = getCallMethod(request); + final CallOptions callOpts = getCallOpts(timeoutMs); + + return ClientCalls + .asyncBidiStreamingCall(ch.newCall(method, callOpts), new StreamObserver() { + @Override + public void onNext(final Message value) { + executor.execute(() -> callback.complete(value, null)); + } + + @Override + public void onError(final Throwable throwable) { + executor.execute(() -> callback.complete(null, throwable)); + } + + @Override + public void onCompleted() { + // NO-OP + } + }); + } + + private void nonNullCheck(final Endpoint endpoint, final Object request) { + Requires.requireNonNull(endpoint, "endpoint"); + Requires.requireNonNull(request, "request"); + } + + private Executor getExecutor(final InvokeCallback callback) { + return callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE; + } + + private CallOptions getCallOpts(final long timeoutMs) { + return CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS); + } + private MethodDescriptor getCallMethod(final Object request) { final String interest = request.getClass().getName(); final Message reqIns = Requires.requireNonNull(this.parserClasses.get(interest), "null default instance: " @@ -195,12 +242,13 @@ private ManagedChannel getCheckedChannel(final Endpoint endpoint) { } private ManagedChannel getChannel(final Endpoint endpoint, final boolean createIfAbsent) { - if (createIfAbsent) { - return this.managedChannelPool.computeIfAbsent(endpoint, this::newChannel); - } else { - return this.managedChannelPool.get(endpoint); - } - } + if (createIfAbsent) { + return this.managedChannelPool.computeIfAbsent(endpoint, this::newChannel); + } + else { + return this.managedChannelPool.get(endpoint); + } + } private ManagedChannel newChannel(final Endpoint endpoint) { final ManagedChannel ch = ManagedChannelBuilder.forAddress(endpoint.getIp(), endpoint.getPort()) // @@ -222,8 +270,8 @@ private ManagedChannel removeChannel(final Endpoint endpoint) { } private void notifyWhenStateChanged(final ConnectivityState state, final Endpoint endpoint, final ManagedChannel ch) { - ch.notifyWhenStateChanged(state, () -> onStateChanged(endpoint, ch)); - } + ch.notifyWhenStateChanged(state, () -> onStateChanged(endpoint, ch)); + } private void onStateChanged(final Endpoint endpoint, final ManagedChannel ch) { final ConnectivityState state = ch.getState(false); @@ -252,27 +300,29 @@ private void onStateChanged(final Endpoint endpoint, final ManagedChannel ch) { } private void notifyReady(final Endpoint endpoint) { - LOG.info("The channel {} has successfully established.", endpoint); - - clearConnFailuresCount(endpoint); - - final ReplicatorGroup rpGroup = this.replicatorGroup; - if (rpGroup != null) { - try { - RpcUtils.runInThread(() -> { - final PeerId peer = new PeerId(); - if (peer.parse(endpoint.toString())) { - LOG.info("Peer {} is connected.", peer); - rpGroup.checkReplicator(peer, true); - } else { - LOG.error("Fail to parse peer: {}.", endpoint); - } - }); - } catch (final Throwable t) { - LOG.error("Fail to check replicator {}.", endpoint, t); - } - } - } + LOG.info("The channel {} has successfully established.", endpoint); + + clearConnFailuresCount(endpoint); + + final ReplicatorGroup rpGroup = this.replicatorGroup; + if (rpGroup != null) { + try { + RpcUtils.runInThread(() -> { + final PeerId peer = new PeerId(); + if (peer.parse(endpoint.toString())) { + LOG.info("Peer {} is connected.", peer); + rpGroup.checkReplicator(peer, true); + } + else { + LOG.error("Fail to parse peer: {}.", endpoint); + } + }); + } + catch (final Throwable t) { + LOG.error("Fail to check replicator {}.", endpoint, t); + } + } + } private void notifyFailure(final Endpoint endpoint) { LOG.warn("There has been some transient failure on this channel {}.", endpoint); @@ -310,8 +360,8 @@ private boolean checkChannel(final Endpoint endpoint, final boolean createIfAbse } private int incConnFailuresCount(final Endpoint endpoint) { - return this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger()).incrementAndGet(); - } + return this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger()).incrementAndGet(); + } private void clearConnFailuresCount(final Endpoint endpoint) { this.transientFailures.remove(endpoint); diff --git a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java index 45cf7fd77..93a8423c6 100644 --- a/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java +++ b/jraft-extension/rpc-grpc-impl/src/main/java/com/alipay/sofa/jraft/rpc/impl/GrpcServer.java @@ -35,6 +35,7 @@ import io.grpc.ServerServiceDefinition; import io.grpc.protobuf.ProtoUtils; import io.grpc.stub.ServerCalls; +import io.grpc.stub.StreamObserver; import io.grpc.util.MutableHandlerRegistry; import org.slf4j.Logger; @@ -57,6 +58,7 @@ * * @author nicholas.jxf * @author jiachun.fjc + * @author HH */ public class GrpcServer implements RpcServer { @@ -84,33 +86,34 @@ public GrpcServer(Server server, MutableHandlerRegistry handlerRegistry, Map()) // - .threadFactory(new NamedThreadFactory(EXECUTOR_NAME + "-", true)) // - .rejectedHandler((r, executor) -> { - throw new RejectedExecutionException("[" + EXECUTOR_NAME + "], task " + r.toString() + - " rejected from " + - executor.toString()); - }) - .build(); - - try { - this.server.start(); - } catch (final IOException e) { - ThrowUtil.throwException(e); - } - return true; - } + public boolean init(final Void opts) { + if (!this.started.compareAndSet(false, true)) { + throw new IllegalStateException("grpc server has started"); + } + + this.defaultExecutor = ThreadPoolUtil.newBuilder() // + .poolName(EXECUTOR_NAME) // + .enableMetric(true) // + .coreThreads(Math.min(20, GrpcRaftRpcFactory.RPC_SERVER_PROCESSOR_POOL_SIZE / 5)) // + .maximumThreads(GrpcRaftRpcFactory.RPC_SERVER_PROCESSOR_POOL_SIZE) // + .keepAliveSeconds(60L) // + .workQueue(new SynchronousQueue<>()) // + .threadFactory(new NamedThreadFactory(EXECUTOR_NAME + "-", true)) // + .rejectedHandler((r, executor) -> { + throw new RejectedExecutionException("[" + EXECUTOR_NAME + "], task " + r.toString() + + " rejected from " + + executor.toString()); + }) + .build(); + + try { + this.server.start(); + } + catch (final IOException e) { + ThrowUtil.throwException(e); + } + return true; + } @Override public void shutdown() { @@ -126,84 +129,128 @@ public void registerConnectionClosedEventListener(final ConnectionClosedEventLis this.closedEventListeners.add(listener); } - @SuppressWarnings("unchecked") @Override - public void registerProcessor(final RpcProcessor processor) { - final String interest = processor.interest(); - final Message reqIns = Requires.requireNonNull(this.parserClasses.get(interest), "null default instance: " + interest); - final MethodDescriptor method = MethodDescriptor // - .newBuilder() // - .setType(MethodDescriptor.MethodType.UNARY) // - .setFullMethodName( - MethodDescriptor.generateFullMethodName(processor.interest(), GrpcRaftRpcFactory.FIXED_METHOD_NAME)) // - .setRequestMarshaller(ProtoUtils.marshaller(reqIns)) // - .setResponseMarshaller(ProtoUtils.marshaller(this.marshallerRegistry.findResponseInstanceByRequest(interest))) // - .build(); - - final ServerCallHandler handler = ServerCalls.asyncUnaryCall( - (request, responseObserver) -> { - final SocketAddress remoteAddress = RemoteAddressInterceptor.getRemoteAddress(); - final Connection conn = ConnectionInterceptor.getCurrentConnection(this.closedEventListeners); - - final RpcContext rpcCtx = new RpcContext() { - - @Override - public void sendResponse(final Object responseObj) { - try { - responseObserver.onNext((Message) responseObj); - responseObserver.onCompleted(); - } catch (final Throwable t) { - LOG.warn("[GRPC] failed to send response.", t); - } - } - - @Override - public Connection getConnection() { - if (conn == null) { - throw new IllegalStateException("fail to get connection"); - } - return conn; - } - - @Override - public String getRemoteAddress() { - // Rely on GRPC's capabilities, not magic (netty channel) - return remoteAddress != null ? remoteAddress.toString() : null; - } - }; - - final RpcProcessor.ExecutorSelector selector = processor.executorSelector(); - Executor executor; - if (selector != null && request instanceof RpcRequests.AppendEntriesRequest) { - final RpcRequests.AppendEntriesRequest req = (RpcRequests.AppendEntriesRequest) request; - final RpcRequests.AppendEntriesRequestHeader.Builder header = RpcRequests.AppendEntriesRequestHeader // - .newBuilder() // - .setGroupId(req.getGroupId()) // - .setPeerId(req.getPeerId()) // - .setServerId(req.getServerId()); - executor = selector.select(interest, header.build()); - } else { - executor = processor.executor(); - } - - if (executor == null) { - executor = this.defaultExecutor; - } - - if (executor != null) { - executor.execute(() -> processor.handleRequest(rpcCtx, request)); - } else { - processor.handleRequest(rpcCtx, request); - } - }); + public void registerBidiStreamingProcessor(final RpcProcessor processor) { + final String interest = processor.interest(); + final MethodDescriptor method = buildMethodDescriptor(interest); + + final ServerCallHandler handler = ServerCalls.asyncBidiStreamingCall(responseObserver -> + new StreamObserver() { + @Override + public void onNext(final Message request) { + handleRequest(processor, request, responseObserver); + } + + @Override + public void onError(final Throwable throwable) { + LOG.error("[Grpc] error", throwable); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }); + + serviceRegistry(interest, method, handler); + } + + @SuppressWarnings("unchecked") + @Override + public void registerProcessor(final RpcProcessor processor) { + final String interest = processor.interest(); + final MethodDescriptor method = buildMethodDescriptor(interest); + + final ServerCallHandler handler = ServerCalls.asyncUnaryCall( + (request, responseObserver) -> { + handleRequest(processor, request, responseObserver); + }); + + serviceRegistry(interest, method, handler); + } + + private MethodDescriptor buildMethodDescriptor(final String interest) { + final Message reqIns = Requires.requireNonNull(this.parserClasses.get(interest), "null default instance: " + + interest); + return MethodDescriptor // + . newBuilder() // + .setType(MethodDescriptor.MethodType.UNARY) // + .setFullMethodName(MethodDescriptor.generateFullMethodName(interest, GrpcRaftRpcFactory.FIXED_METHOD_NAME)) // + .setRequestMarshaller(ProtoUtils.marshaller(reqIns)) // + .setResponseMarshaller( + ProtoUtils.marshaller(this.marshallerRegistry.findResponseInstanceByRequest(interest))) // + .build(); + } + private void handleRequest(final RpcProcessor processor, final Message request, + final StreamObserver responseObserver) { + final SocketAddress remoteAddress = RemoteAddressInterceptor.getRemoteAddress(); + final Connection conn = ConnectionInterceptor.getCurrentConnection(this.closedEventListeners); + + final RpcContext rpcCtx = new RpcContext() { + + @Override + public void sendResponse(final Object responseObj) { + try { + responseObserver.onNext((Message) responseObj); + responseObserver.onCompleted(); + } + catch (final Throwable t) { + LOG.warn("[GRPC] failed to send response.", t); + } + } + + @Override + public Connection getConnection() { + if (conn == null) { + throw new IllegalStateException("fail to get connection"); + } + return conn; + } + + @Override + public String getRemoteAddress() { + // Rely on GRPC's capabilities, not magic (netty channel) + return remoteAddress != null ? remoteAddress.toString() : null; + } + }; + + final RpcProcessor.ExecutorSelector selector = processor.executorSelector(); + Executor executor; + if (selector != null && request instanceof RpcRequests.AppendEntriesRequest) { + final RpcRequests.AppendEntriesRequest req = (RpcRequests.AppendEntriesRequest) request; + final RpcRequests.AppendEntriesRequestHeader.Builder header = RpcRequests.AppendEntriesRequestHeader // + .newBuilder() // + .setGroupId(req.getGroupId()) // + .setPeerId(req.getPeerId()) // + .setServerId(req.getServerId()); + executor = selector.select(processor.interest(), header.build()); + } + else { + executor = processor.executor(); + } + + if (executor == null) { + executor = this.defaultExecutor; + } + + if (executor != null) { + executor.execute(() -> processor.handleRequest(rpcCtx, request)); + } + else { + processor.handleRequest(rpcCtx, request); + } + } + + private void serviceRegistry(final String interest, final MethodDescriptor method, + final ServerCallHandler handler) { final ServerServiceDefinition serviceDef = ServerServiceDefinition // - .builder(interest) // - .addMethod(method, handler) // - .build(); + .builder(interest) // + .addMethod(method, handler) // + .build(); - this.handlerRegistry - .addService(ServerInterceptors.intercept(serviceDef, this.serverInterceptors.toArray(new ServerInterceptor[0]))); + this.handlerRegistry.addService(ServerInterceptors.intercept(serviceDef, + this.serverInterceptors.toArray(new ServerInterceptor[0]))); } @Override