Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:add Bidistreaming Api #890

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions jraft-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,11 @@
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-stub</artifactId>
<version>1.17.0</version>
<scope>compile</scope>
</dependency>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么 jraft-core 需要依赖 grpc? 应该是 rpc_grpc-impl 依赖吧?

</dependencies>
</project>
30 changes: 30 additions & 0 deletions jraft-core/src/main/java/com/alipay/sofa/jraft/rpc/RpcClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@
import com.alipay.sofa.jraft.error.RemotingException;
import com.alipay.sofa.jraft.option.RpcOptions;
import com.alipay.sofa.jraft.util.Endpoint;
import com.google.protobuf.Message;
import io.grpc.stub.StreamObserver;

/**
*
* @author jiachun.fjc
* @author HH
*/
public interface RpcClient extends Lifecycle<RpcOptions> {

Expand Down Expand Up @@ -107,4 +110,31 @@ default void invokeAsync(final Endpoint endpoint, final Object request, final In
*/
void invokeAsync(final Endpoint endpoint, final Object request, final InvokeContext ctx, final InvokeCallback callback,
final long timeoutMs) throws InterruptedException, RemotingException;

/**
* Streaming invocation with a callback.
*
* @param endpoint target address
* @param request request object
* @param callback invoke callback
* @param timeoutMs timeout millisecond
* @return request stream observer.
*/
default StreamObserver<Message> invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeCallback callback,
final long timeoutMs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一行好像格式没对齐

return invokeBidiStreaming(endpoint, request, null, callback, timeoutMs);
}

/**
* BidiStreaming invocation.
*
* @param endpoint target address
* @param request request object
* @param callback invoke callback
* @param ctx invoke context
* @param timeoutMs timeout millisecond
* @return request stream observer.
*/
StreamObserver<Message> invokeBidiStreaming(final Endpoint endpoint, final Object request, final InvokeContext ctx, final InvokeCallback callback,
final long timeoutMs);
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
/**
*
* @author jiachun.fjc
* @author HH
*/
public interface RpcServer extends Lifecycle<Void> {

Expand All @@ -39,6 +40,13 @@ public interface RpcServer extends Lifecycle<Void> {
*/
void registerProcessor(final RpcProcessor<?> processor);

/**
* Register bidiStreaming user processor.
*
* @param processor the user processor which has a interest
*/
void registerBidiStreamingProcessor(final RpcProcessor<?> processor);

/**
*
* @return bound port
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@
import com.alipay.sofa.jraft.rpc.impl.core.ClientServiceConnectionEventProcessor;
import com.alipay.sofa.jraft.util.Endpoint;
import com.alipay.sofa.jraft.util.Requires;
import com.google.protobuf.Message;
import io.grpc.stub.StreamObserver;

/**
* Bolt rpc client impl.
*
* @author jiachun.fjc
* @author HH
*/
public class BoltRpcClient implements RpcClient {

Expand Down Expand Up @@ -119,6 +122,12 @@ public void invokeAsync(final Endpoint endpoint, final Object request, final Inv
}
}

@Override
public StreamObserver<Message> invokeBidiStreaming(Endpoint endpoint, Object request, InvokeContext ctx,
InvokeCallback callback, long timeoutMs) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议参考原有代码规范,添加 final 关键字

return null;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以 throw 一个 UnsupportedException 之类的 error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要抛出异常,不能吞掉

}

public com.alipay.remoting.rpc.RpcClient getRpcClient() {
return rpcClient;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
* Bolt RPC server impl.
*
* @author jiachun.fjc
* @author HH
*/
public class BoltRpcServer implements RpcServer {

Expand Down Expand Up @@ -87,6 +88,11 @@ public void close() {
});
}

@Override
public void registerBidiStreamingProcessor(RpcProcessor<?> processor) {
//no-op
}

@Override
public int boundPort() {
return this.rpcServer.port();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
*
* @author nicholas.jxf
* @author jiachun.fjc
* @author HH
*/
public class GrpcClient implements RpcClient {

Expand Down Expand Up @@ -111,63 +112,112 @@ public void registerConnectEventListener(final ReplicatorGroup replicatorGroup)
}

@Override
public Object invokeSync(final Endpoint endpoint, final Object request, final InvokeContext ctx,
final long timeoutMs) throws RemotingException {
final CompletableFuture<Object> future = new CompletableFuture<>();

invokeAsync(endpoint, request, ctx, (result, err) -> {
if (err == null) {
future.complete(result);
} else {
future.completeExceptionally(err);
}
}, timeoutMs);

try {
return future.get(timeoutMs, TimeUnit.MILLISECONDS);
} catch (final TimeoutException e) {
future.cancel(true);
throw new InvokeTimeoutException(e);
} catch (final Throwable t) {
future.cancel(true);
throw new RemotingException(t);
}
}
public Object invokeSync(final Endpoint endpoint, final Object request, final InvokeContext ctx,
final long timeoutMs) throws RemotingException {
final CompletableFuture<Object> future = new CompletableFuture<>();

invokeAsync(endpoint, request, ctx, (result, err) -> {
if (err == null) {
future.complete(result);
}
else {
future.completeExceptionally(err);
}
}, timeoutMs);

try {
return future.get(timeoutMs, TimeUnit.MILLISECONDS);
}
catch (final TimeoutException e) {
future.cancel(true);
throw new InvokeTimeoutException(e);
}
catch (final Throwable t) {
future.cancel(true);
throw new RemotingException(t);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个格式需要手动复原一下

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个格式是由mvn clean compile格式化的,需要复原吗?


@Override
public void invokeAsync(final Endpoint endpoint, final Object request, final InvokeContext ctx,
final InvokeCallback callback, final long timeoutMs) {
Requires.requireNonNull(endpoint, "endpoint");
Requires.requireNonNull(request, "request");
public void invokeAsync(final Endpoint endpoint, final Object request, final InvokeContext ctx,
final InvokeCallback callback, final long timeoutMs) {
nonNullCheck(endpoint, request);

final Executor executor = callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE;
final Executor executor = getExecutor(callback);

final Channel ch = getCheckedChannel(endpoint);
if (ch == null) {
executor.execute(() -> callback.complete(null, new RemotingException("Fail to connect: " + endpoint)));
return;
}
final Channel ch = getCheckedChannel(endpoint);
if (ch == null) {
executor.execute(() -> callback.complete(null, new RemotingException("Fail to connect: " + endpoint)));
return;
}

final MethodDescriptor<Message, Message> method = getCallMethod(request);
final CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
final MethodDescriptor<Message, Message> method = getCallMethod(request);
final CallOptions callOpts = getCallOpts(timeoutMs);

ClientCalls.asyncUnaryCall(ch.newCall(method, callOpts), (Message) request, new StreamObserver<Message>() {
ClientCalls.asyncUnaryCall(ch.newCall(method, callOpts), (Message) request, new StreamObserver<Message>() {

@Override
public void onNext(final Message value) {
executor.execute(() -> callback.complete(value, null));
}
@Override
public void onNext(final Message value) {
executor.execute(() -> callback.complete(value, null));
}

@Override
public void onError(final Throwable throwable) {
executor.execute(() -> callback.complete(null, throwable));
}
@Override
public void onError(final Throwable throwable) {
executor.execute(() -> callback.complete(null, throwable));
}

@Override
public void onCompleted() {
// NO-OP
}
});
@Override
public void onCompleted() {
// NO-OP
}
});
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面的代码需要手动复原格式


@Override
public StreamObserver<Message> invokeBidiStreaming(Endpoint endpoint, Object request, InvokeContext ctx, InvokeCallback callback, long timeoutMs) {
nonNullCheck(endpoint, request);

final Executor executor = getExecutor(callback);

final Channel ch = getCheckedChannel(endpoint);
if (ch == null) {
executor.execute(() -> callback.complete(null, new RemotingException("Fail to connect: " + endpoint)));
return null;
}

final MethodDescriptor<Message, Message> method = getCallMethod(request);
final CallOptions callOpts = getCallOpts(timeoutMs);

return ClientCalls
.asyncBidiStreamingCall(ch.newCall(method, callOpts), new StreamObserver<Message>() {
@Override
public void onNext(final Message value) {
executor.execute(() -> callback.complete(value, null));
}

@Override
public void onError(final Throwable throwable) {
executor.execute(() -> callback.complete(null, throwable));
}

@Override
public void onCompleted() {
// NO-OP
}
});
}

private void nonNullCheck(final Endpoint endpoint, final Object request) {
Requires.requireNonNull(endpoint, "endpoint");
Requires.requireNonNull(request, "request");
}

private Executor getExecutor(final InvokeCallback callback) {
return callback.executor() != null ? callback.executor() : DirectExecutor.INSTANCE;
}

private CallOptions getCallOpts(final long timeoutMs) {
return CallOptions.DEFAULT.withDeadlineAfter(timeoutMs, TimeUnit.MILLISECONDS);
}

private MethodDescriptor<Message, Message> getCallMethod(final Object request) {
Expand Down Expand Up @@ -195,12 +245,13 @@ private ManagedChannel getCheckedChannel(final Endpoint endpoint) {
}

private ManagedChannel getChannel(final Endpoint endpoint, final boolean createIfAbsent) {
if (createIfAbsent) {
return this.managedChannelPool.computeIfAbsent(endpoint, this::newChannel);
} else {
return this.managedChannelPool.get(endpoint);
}
}
if (createIfAbsent) {
return this.managedChannelPool.computeIfAbsent(endpoint, this::newChannel);
}
else {
return this.managedChannelPool.get(endpoint);
}
}

private ManagedChannel newChannel(final Endpoint endpoint) {
final ManagedChannel ch = ManagedChannelBuilder.forAddress(endpoint.getIp(), endpoint.getPort()) //
Expand All @@ -222,8 +273,8 @@ private ManagedChannel removeChannel(final Endpoint endpoint) {
}

private void notifyWhenStateChanged(final ConnectivityState state, final Endpoint endpoint, final ManagedChannel ch) {
ch.notifyWhenStateChanged(state, () -> onStateChanged(endpoint, ch));
}
ch.notifyWhenStateChanged(state, () -> onStateChanged(endpoint, ch));
}

private void onStateChanged(final Endpoint endpoint, final ManagedChannel ch) {
final ConnectivityState state = ch.getState(false);
Expand Down Expand Up @@ -252,27 +303,29 @@ private void onStateChanged(final Endpoint endpoint, final ManagedChannel ch) {
}

private void notifyReady(final Endpoint endpoint) {
LOG.info("The channel {} has successfully established.", endpoint);

clearConnFailuresCount(endpoint);

final ReplicatorGroup rpGroup = this.replicatorGroup;
if (rpGroup != null) {
try {
RpcUtils.runInThread(() -> {
final PeerId peer = new PeerId();
if (peer.parse(endpoint.toString())) {
LOG.info("Peer {} is connected.", peer);
rpGroup.checkReplicator(peer, true);
} else {
LOG.error("Fail to parse peer: {}.", endpoint);
}
});
} catch (final Throwable t) {
LOG.error("Fail to check replicator {}.", endpoint, t);
}
}
}
LOG.info("The channel {} has successfully established.", endpoint);

clearConnFailuresCount(endpoint);

final ReplicatorGroup rpGroup = this.replicatorGroup;
if (rpGroup != null) {
try {
RpcUtils.runInThread(() -> {
final PeerId peer = new PeerId();
if (peer.parse(endpoint.toString())) {
LOG.info("Peer {} is connected.", peer);
rpGroup.checkReplicator(peer, true);
}
else {
LOG.error("Fail to parse peer: {}.", endpoint);
}
});
}
catch (final Throwable t) {
LOG.error("Fail to check replicator {}.", endpoint, t);
}
}
}

private void notifyFailure(final Endpoint endpoint) {
LOG.warn("There has been some transient failure on this channel {}.", endpoint);
Expand Down Expand Up @@ -310,8 +363,8 @@ private boolean checkChannel(final Endpoint endpoint, final boolean createIfAbse
}

private int incConnFailuresCount(final Endpoint endpoint) {
return this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger()).incrementAndGet();
}
return this.transientFailures.computeIfAbsent(endpoint, ep -> new AtomicInteger()).incrementAndGet();
}

private void clearConnFailuresCount(final Endpoint endpoint) {
this.transientFailures.remove(endpoint);
Expand Down
Loading