Skip to content

Commit

Permalink
Make Remoting select commandFactory base on Protocol (#330)
Browse files Browse the repository at this point in the history
* make Remoting select commandFactory base on Protocol
  • Loading branch information
OrezzerO authored Jun 27, 2023
1 parent 9b98a47 commit 829be55
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 20 deletions.
58 changes: 41 additions & 17 deletions src/main/java/com/alipay/remoting/BaseRemoting.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ public abstract class BaseRemoting {
.getLogger("CommonDefault");
private final static long ABANDONING_REQUEST_THRESHOLD = 0L;

protected CommandFactory commandFactory;
private CommandFactory defalutCommandFactory;

public BaseRemoting(CommandFactory commandFactory) {
this.commandFactory = commandFactory;
this.defalutCommandFactory = commandFactory;
}

/**
Expand All @@ -69,7 +69,7 @@ protected RemotingCommand invokeSync(final Connection conn, final RemotingComman
request.getId(),
conn.getUrl() != null ? conn.getUrl() : RemotingUtil.parseRemoteAddress(conn
.getChannel()));
return this.commandFactory.createTimeoutResponse(conn.getRemoteAddress());
return this.getCommandFactory(conn).createTimeoutResponse(conn.getRemoteAddress());
}

final InvokeFuture future = createInvokeFuture(request, request.getInvokeContext());
Expand All @@ -86,7 +86,7 @@ protected RemotingCommand invokeSync(final Connection conn, final RemotingComman
public void operationComplete(ChannelFuture f) throws Exception {
if (!f.isSuccess()) {
conn.removeInvokeFuture(requestId);
future.putResponse(commandFactory.createSendFailedResponse(
future.putResponse(getCommandFactory(conn).createSendFailedResponse(
conn.getRemoteAddress(), f.cause()));
LOGGER.error("Invoke send failed, id={}", requestId, f.cause());
}
Expand All @@ -98,7 +98,8 @@ public void operationComplete(ChannelFuture f) throws Exception {
}
} catch (Exception e) {
conn.removeInvokeFuture(requestId);
future.putResponse(commandFactory.createSendFailedResponse(conn.getRemoteAddress(), e));
future.putResponse(getCommandFactory(conn).createSendFailedResponse(
conn.getRemoteAddress(), e));
LOGGER.error("Exception caught when sending invocation, id={}", requestId, e);
}
RemotingCommand response = future.waitResponse(remainingTime);
Expand All @@ -109,7 +110,7 @@ public void operationComplete(ChannelFuture f) throws Exception {

if (response == null) {
conn.removeInvokeFuture(requestId);
response = this.commandFactory.createTimeoutResponse(conn.getRemoteAddress());
response = this.getCommandFactory(conn).createTimeoutResponse(conn.getRemoteAddress());
LOGGER.warn("Wait response, request id={} timeout!", requestId);
}

Expand Down Expand Up @@ -137,7 +138,8 @@ protected void invokeWithCallback(final Connection conn, final RemotingCommand r
request.getId(),
conn.getUrl() != null ? conn.getUrl() : RemotingUtil.parseRemoteAddress(conn
.getChannel()));
future.putResponse(commandFactory.createTimeoutResponse(conn.getRemoteAddress()));
future.putResponse(getCommandFactory(conn).createTimeoutResponse(
conn.getRemoteAddress()));
future.tryAsyncExecuteInvokeCallbackAbnormally();
return;
}
Expand All @@ -149,8 +151,8 @@ protected void invokeWithCallback(final Connection conn, final RemotingCommand r
public void run(Timeout timeout) throws Exception {
InvokeFuture future = conn.removeInvokeFuture(requestId);
if (future != null) {
future.putResponse(commandFactory.createTimeoutResponse(conn
.getRemoteAddress()));
future.putResponse(getCommandFactory(conn).createTimeoutResponse(
conn.getRemoteAddress()));
future.tryAsyncExecuteInvokeCallbackAbnormally();
}
}
Expand All @@ -165,7 +167,7 @@ public void operationComplete(ChannelFuture cf) throws Exception {
InvokeFuture f = conn.removeInvokeFuture(requestId);
if (f != null) {
f.cancelTimeout();
f.putResponse(commandFactory.createSendFailedResponse(
f.putResponse(getCommandFactory(conn).createSendFailedResponse(
conn.getRemoteAddress(), cf.cause()));
f.tryAsyncExecuteInvokeCallbackAbnormally();
}
Expand All @@ -179,7 +181,8 @@ public void operationComplete(ChannelFuture cf) throws Exception {
InvokeFuture f = conn.removeInvokeFuture(requestId);
if (f != null) {
f.cancelTimeout();
f.putResponse(commandFactory.createSendFailedResponse(conn.getRemoteAddress(), e));
f.putResponse(getCommandFactory(conn).createSendFailedResponse(
conn.getRemoteAddress(), e));
f.tryAsyncExecuteInvokeCallbackAbnormally();
}
LOGGER.error("Exception caught when sending invocation. The address is {}",
Expand Down Expand Up @@ -208,7 +211,8 @@ protected InvokeFuture invokeWithFuture(final Connection conn, final RemotingCom
request.getId(),
conn.getUrl() != null ? conn.getUrl() : RemotingUtil.parseRemoteAddress(conn
.getChannel()));
future.putResponse(commandFactory.createTimeoutResponse(conn.getRemoteAddress()));
future.putResponse(getCommandFactory(conn).createTimeoutResponse(
conn.getRemoteAddress()));
return future;
}

Expand All @@ -219,8 +223,8 @@ protected InvokeFuture invokeWithFuture(final Connection conn, final RemotingCom
public void run(Timeout timeout) throws Exception {
InvokeFuture future = conn.removeInvokeFuture(requestId);
if (future != null) {
future.putResponse(commandFactory.createTimeoutResponse(conn
.getRemoteAddress()));
future.putResponse(getCommandFactory(conn).createTimeoutResponse(
conn.getRemoteAddress()));
}
}

Expand All @@ -235,7 +239,7 @@ public void operationComplete(ChannelFuture cf) throws Exception {
InvokeFuture f = conn.removeInvokeFuture(requestId);
if (f != null) {
f.cancelTimeout();
f.putResponse(commandFactory.createSendFailedResponse(
f.putResponse(getCommandFactory(conn).createSendFailedResponse(
conn.getRemoteAddress(), cf.cause()));
}
LOGGER.error("Invoke send failed. The address is {}",
Expand All @@ -248,7 +252,8 @@ public void operationComplete(ChannelFuture cf) throws Exception {
InvokeFuture f = conn.removeInvokeFuture(requestId);
if (f != null) {
f.cancelTimeout();
f.putResponse(commandFactory.createSendFailedResponse(conn.getRemoteAddress(), e));
f.putResponse(getCommandFactory(conn).createSendFailedResponse(
conn.getRemoteAddress(), e));
}
LOGGER.error("Exception caught when sending invocation. The address is {}",
RemotingUtil.parseRemoteAddress(conn.getChannel()), e);
Expand Down Expand Up @@ -323,8 +328,27 @@ protected abstract InvokeFuture createInvokeFuture(final Connection conn,
final InvokeContext invokeContext,
final InvokeCallback invokeCallback);

@Deprecated
protected CommandFactory getCommandFactory() {
return commandFactory;
LOGGER
.warn("The method getCommandFactory() is deprecated. Please use getCommandFactory(ProtocolCode/Connection) instead.");
return defalutCommandFactory;
}

protected CommandFactory getCommandFactory(Connection conn) {
ProtocolCode protocolCode = conn.getChannel().attr(Connection.PROTOCOL).get();
return getCommandFactory(protocolCode);
}

protected CommandFactory getCommandFactory(ProtocolCode protocolCode) {
if (protocolCode == null) {
return getCommandFactory();
}
Protocol protocol = ProtocolManager.getProtocol(protocolCode);
if (protocol == null) {
return getCommandFactory();
}
return protocol.getCommandFactory();
}

private int remainingTime(RemotingCommand request, int timeout) {
Expand Down
6 changes: 3 additions & 3 deletions src/main/java/com/alipay/remoting/rpc/RpcRemoting.java
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ public void invokeWithCallback(final Connection conn, final Object request,
protected RemotingCommand toRemotingCommand(Object request, Connection conn,
InvokeContext invokeContext, int timeoutMillis)
throws SerializationException {
RpcRequestCommand command = this.getCommandFactory().createRequestCommand(request);
RpcRequestCommand command = this.getCommandFactory(conn).createRequestCommand(request);

if (null != invokeContext) {
// set client custom serializer for request command if not null
Expand Down Expand Up @@ -370,7 +370,7 @@ private void logDebugInfo(RemotingCommand requestCommand) {
@Override
protected InvokeFuture createInvokeFuture(RemotingCommand request, InvokeContext invokeContext) {
return new DefaultInvokeFuture(request.getId(), null, null, request.getProtocolCode()
.getFirstByte(), this.getCommandFactory(), invokeContext);
.getFirstByte(), this.getCommandFactory(request.getProtocolCode()), invokeContext);
}

/**
Expand All @@ -382,6 +382,6 @@ protected InvokeFuture createInvokeFuture(Connection conn, RemotingCommand reque
InvokeCallback invokeCallback) {
return new DefaultInvokeFuture(request.getId(), new RpcInvokeCallbackListener(
RemotingUtil.parseRemoteAddress(conn.getChannel())), invokeCallback, request
.getProtocolCode().getFirstByte(), this.getCommandFactory(), invokeContext);
.getProtocolCode().getFirstByte(), this.getCommandFactory(conn), invokeContext);
}
}
130 changes: 130 additions & 0 deletions src/test/java/com/alipay/remoting/BaseRemotingTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alipay.remoting;

import com.alipay.remoting.rpc.RpcCommandFactory;
import io.netty.channel.local.LocalChannel;
import org.junit.Test;

import static org.junit.Assert.assertSame;

public class BaseRemotingTest {

@Test
public void getCommandFactory() {
RpcCommandFactory commandFactory = new RpcCommandFactory();
BaseRemoting baseRemoting = new EmptyRemoting(commandFactory);
assertSame(commandFactory, baseRemoting.getCommandFactory());
}

@Test
public void getCommandFactoryFromProtocolCode() {
RpcCommandFactory defaultCommand = new RpcCommandFactory();
BaseRemoting baseRemoting = new EmptyRemoting(defaultCommand);

// no 3a protocol
CommandFactory commandFactory = baseRemoting.getCommandFactory(ProtocolCode
.fromBytes((byte) 0x3a));
assertSame(defaultCommand, commandFactory);

// register 3a protocol
RpcCommandFactory my3aCommandFactory = new RpcCommandFactory();
ProtocolManager.registerProtocol(new MyProtocol(my3aCommandFactory), (byte) 0x3a);
// get 3a
commandFactory = baseRemoting.getCommandFactory(ProtocolCode.fromBytes((byte) 0x3a));
assertSame(my3aCommandFactory, commandFactory);

ProtocolManager.unRegisterProtocol((byte) 0x3a);
}

@Test
public void getCommandFactoryFromConnection() {
RpcCommandFactory defaultCommand = new RpcCommandFactory();
BaseRemoting baseRemoting = new EmptyRemoting(defaultCommand);

Connection connection = new Connection(new LocalChannel());

// no 3a protocol
CommandFactory commandFactory = baseRemoting.getCommandFactory(connection);
assertSame(defaultCommand, commandFactory);

// register 3a protocol
RpcCommandFactory my3aCommandFactory = new RpcCommandFactory();
ProtocolManager.registerProtocol(new MyProtocol(my3aCommandFactory), (byte) 0x3a);
connection.getChannel().attr(Connection.PROTOCOL).set(ProtocolCode.fromBytes((byte) 0x3a));
// get 3a
commandFactory = baseRemoting.getCommandFactory(connection);
assertSame(my3aCommandFactory, commandFactory);

ProtocolManager.unRegisterProtocol((byte) 0x3a);
}

static class EmptyRemoting extends BaseRemoting {

public EmptyRemoting(CommandFactory commandFactory) {
super(commandFactory);
}

@Override
protected InvokeFuture createInvokeFuture(RemotingCommand request,
InvokeContext invokeContext) {
return null;
}

@Override
protected InvokeFuture createInvokeFuture(Connection conn, RemotingCommand request,
InvokeContext invokeContext,
InvokeCallback invokeCallback) {
return null;
}
}

static class MyProtocol implements Protocol {

private CommandFactory commandFactory;

public MyProtocol(CommandFactory commandFactory) {
this.commandFactory = commandFactory;
}

@Override
public CommandEncoder getEncoder() {
return null;
}

@Override
public CommandDecoder getDecoder() {
return null;
}

@Override
public HeartbeatTrigger getHeartbeatTrigger() {
return null;
}

@Override
public CommandHandler getCommandHandler() {
return null;
}

@Override
public CommandFactory getCommandFactory() {
return commandFactory;
}
}

}

0 comments on commit 829be55

Please sign in to comment.