Skip to content

Commit

Permalink
1. enhance IAuthProvider to allow check permission for connection exp…
Browse files Browse the repository at this point in the history
…licitly, so that auth provider impl could fine-grain control the auth workflow

2. Tenant Setting "BypassPermCheckError" now will turn CheckResult(Error) into CheckResult(Granted) when enabled.
3. change the 'cause' field type of AccessControlError event from Throwable to String
  • Loading branch information
popduke committed Dec 4, 2024
1 parent 7b56c60 commit 4700820
Show file tree
Hide file tree
Showing 16 changed files with 332 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ public final void channelRead(ChannelHandlerContext ctx, Object msg) {
}
long reqId = System.nanoTime();
cancellableTasks.track(authenticate(connMsg))
.thenComposeAsync(okOrGoAway -> {
if (okOrGoAway.goAway != null) {
return CompletableFuture.completedFuture(okOrGoAway);
}
// check conn permission
return checkConnectPermission(connMsg, okOrGoAway.clientInfo);
}, ctx.executor())
.thenComposeAsync(okOrGoAway -> {
if (okOrGoAway.goAway != null) {
handleGoAway(okOrGoAway.goAway);
Expand Down Expand Up @@ -312,6 +319,9 @@ public final void channelRead(ChannelHandlerContext ctx, Object msg) {

protected abstract CompletableFuture<AuthResult> authenticate(MqttConnectMessage message);

protected abstract CompletableFuture<AuthResult> checkConnectPermission(MqttConnectMessage message,
ClientInfo clientInfo);

protected abstract void handleMqttMessage(MqttMessage message);

protected abstract GoAway onNoEnoughResources(MqttConnectMessage message, TenantResourceType resourceType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static com.baidu.bifromq.mqtt.handler.MQTTConnectHandler.AuthResult.goAway;
import static com.baidu.bifromq.mqtt.handler.MQTTConnectHandler.AuthResult.ok;
import static com.baidu.bifromq.mqtt.handler.condition.ORCondition.or;
import static com.baidu.bifromq.mqtt.utils.AuthUtil.buildConnAction;
import static com.baidu.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal;
import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CHANNEL_ID_KEY;
import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CLIENT_ADDRESS_KEY;
Expand Down Expand Up @@ -66,6 +67,7 @@
import com.baidu.bifromq.type.ClientInfo;
import com.baidu.bifromq.type.Message;
import com.baidu.bifromq.type.QoS;
import com.baidu.bifromq.type.UserProperties;
import com.baidu.bifromq.util.TopicUtil;
import com.baidu.bifromq.util.UTF8Util;
import com.bifromq.plugin.resourcethrottler.TenantResourceType;
Expand Down Expand Up @@ -124,12 +126,12 @@ protected GoAway sanityCheck(MqttConnectMessage message) {
.build(),
getLocal(IdentifierRejected.class).peerAddress(clientAddress));
}
if (message.variableHeader().hasUserName() &&
!UTF8Util.isWellFormed(message.payload().userName(), SANITY_CHECK)) {
if (message.variableHeader().hasUserName()
&& !UTF8Util.isWellFormed(message.payload().userName(), SANITY_CHECK)) {
return new GoAway(getLocal(MalformedUserName.class).peerAddress(clientAddress));
}
if (message.variableHeader().isWillFlag() &&
!UTF8Util.isWellFormed(message.payload().willTopic(), SANITY_CHECK)) {
if (message.variableHeader().isWillFlag()
&& !UTF8Util.isWellFormed(message.payload().willTopic(), SANITY_CHECK)) {
return new GoAway(getLocal(MalformedWillTopic.class).peerAddress(clientAddress));
}
return null;
Expand All @@ -148,8 +150,8 @@ protected CompletableFuture<AuthResult> authenticate(MqttConnectMessage message)
.setTenantId(ok.getTenantId())
.setType(MQTT_TYPE_VALUE)
.putAllMetadata(ok.getAttrsMap()) // custom attrs
.putMetadata(MQTT_PROTOCOL_VER_KEY, message.variableHeader().version() == 3 ?
MQTT_PROTOCOL_VER_3_1_VALUE : MQTT_PROTOCOL_VER_3_1_1_VALUE)
.putMetadata(MQTT_PROTOCOL_VER_KEY, message.variableHeader().version() == 3
? MQTT_PROTOCOL_VER_3_1_VALUE : MQTT_PROTOCOL_VER_3_1_1_VALUE)
.putMetadata(MQTT_USER_ID_KEY, ok.getUserId())
.putMetadata(MQTT_CLIENT_ID_KEY, message.payload().clientIdentifier())
.putMetadata(MQTT_CHANNEL_ID_KEY, ctx.channel().id().asLongText())
Expand Down Expand Up @@ -212,6 +214,38 @@ protected CompletableFuture<AuthResult> authenticate(MqttConnectMessage message)
}, ctx.executor());
}

@Override
protected CompletableFuture<AuthResult> checkConnectPermission(MqttConnectMessage message, ClientInfo clientInfo) {
return authProvider.checkPermission(clientInfo, buildConnAction(UserProperties.getDefaultInstance()))
.thenApply(checkResult -> {
switch (checkResult.getTypeCase()) {
case GRANTED -> {
return AuthResult.ok(clientInfo);
}
case DENIED -> {
return goAway(MqttMessageBuilders
.connAck()
.returnCode(CONNECTION_REFUSED_NOT_AUTHORIZED)
.build(),
getLocal(NotAuthorizedClient.class)
.tenantId(clientInfo.getTenantId())
.userId(clientInfo.getMetadataOrDefault(MQTT_USER_ID_KEY, ""))
.clientId(clientInfo.getMetadataOrDefault(MQTT_CLIENT_ID_KEY, ""))
.peerAddress(ChannelAttrs.socketAddress(ctx.channel())));
}
default -> {
return goAway(MqttMessageBuilders
.connAck()
.returnCode(CONNECTION_REFUSED_SERVER_UNAVAILABLE)
.build(),
getLocal(AuthError.class)
.cause("Failed to check connect permission")
.peerAddress(ChannelAttrs.socketAddress(ctx.channel())));
}
}
});
}

@Override
protected void handleMqttMessage(MqttMessage message) {
// never happen in MQTT3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.maximumPacketSize;
import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.requestProblemInformation;
import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.requestResponseInformation;
import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.toUserProperties;
import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.toWillMessage;
import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.topicAliasMaximum;
import static com.baidu.bifromq.mqtt.utils.AuthUtil.buildConnAction;
import static com.baidu.bifromq.mqtt.utils.MQTT5MessageSizer.MIN_CONTROL_PACKET_SIZE;
import static com.baidu.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal;
import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CHANNEL_ID_KEY;
Expand Down Expand Up @@ -72,6 +74,7 @@
import com.baidu.bifromq.plugin.authprovider.type.Failed;
import com.baidu.bifromq.plugin.authprovider.type.MQTT5AuthData;
import com.baidu.bifromq.plugin.authprovider.type.MQTT5ExtendedAuthData;
import com.baidu.bifromq.plugin.authprovider.type.MQTTAction;
import com.baidu.bifromq.plugin.authprovider.type.Success;
import com.baidu.bifromq.plugin.clientbalancer.IClientBalancer;
import com.baidu.bifromq.plugin.clientbalancer.Redirection;
Expand Down Expand Up @@ -165,8 +168,8 @@ protected GoAway sanityCheck(MqttConnectMessage connMsg) {
.build(),
getLocal(MalformedClientIdentifier.class).peerAddress(clientAddress));
}
if (connMsg.variableHeader().hasUserName() &&
!UTF8Util.isWellFormed(connMsg.payload().userName(), SANITY_CHECK)) {
if (connMsg.variableHeader().hasUserName()
&& !UTF8Util.isWellFormed(connMsg.payload().userName(), SANITY_CHECK)) {
return new GoAway(MqttMessageBuilders
.connAck()
.properties(MQTT5MessageBuilders.connAckProperties()
Expand All @@ -176,8 +179,8 @@ protected GoAway sanityCheck(MqttConnectMessage connMsg) {
.build(),
getLocal(MalformedUserName.class).peerAddress(clientAddress));
}
if (authMethod(connMsg.variableHeader().properties()).isEmpty() &&
authData(connMsg.variableHeader().properties()).isPresent()) {
if (authMethod(connMsg.variableHeader().properties()).isEmpty()
&& authData(connMsg.variableHeader().properties()).isPresent()) {
return new GoAway(MqttMessageBuilders
.connAck()
.properties(MQTT5MessageBuilders.connAckProperties()
Expand Down Expand Up @@ -288,6 +291,45 @@ protected CompletableFuture<AuthResult> authenticate(MqttConnectMessage message)
}
}

@Override
protected CompletableFuture<AuthResult> checkConnectPermission(MqttConnectMessage message, ClientInfo clientInfo) {
MQTTAction connAction = buildConnAction(toUserProperties(message.variableHeader().properties()));
return authProvider.checkPermission(clientInfo, connAction)
.thenApply(checkResult -> {
switch (checkResult.getTypeCase()) {
case GRANTED -> {
return AuthResult.ok(clientInfo);
}
case DENIED -> {
return goAway(MqttMessageBuilders
.connAck()
.properties(MQTT5MessageBuilders.connAckProperties()
.reasonString("Not authorized")
.build())
.returnCode(CONNECTION_REFUSED_NOT_AUTHORIZED_5)
.build(),
getLocal(NotAuthorizedClient.class)
.tenantId(clientInfo.getTenantId())
.userId(clientInfo.getMetadataOrDefault(MQTT_USER_ID_KEY, ""))
.clientId(connMsg.payload().clientIdentifier())
.peerAddress(ChannelAttrs.socketAddress(ctx.channel())));
}
default -> {
return goAway(MqttMessageBuilders
.connAck()
.properties(MQTT5MessageBuilders.connAckProperties()
.reasonString("Failed to check connect permission")
.build())
.returnCode(CONNECTION_REFUSED_UNSPECIFIED_ERROR)
.build(),
getLocal(AuthError.class)
.cause("Failed to check connect permission")
.peerAddress(ChannelAttrs.socketAddress(ctx.channel())));
}
}
});
}

private void extendedAuth(MQTT5ExtendedAuthData authData) {
this.isAuthing = true;
authProvider.extendedAuth(authData)
Expand Down Expand Up @@ -385,8 +427,8 @@ protected void handleMqttMessage(MqttMessage message) {
.peerAddress(ChannelAttrs.socketAddress(ctx.channel()))));
case DISCONNECT -> handleGoAway(GoAway.now(getLocal(EnhancedAuthAbortByClient.class)));
default -> handleGoAway(GoAway.now(getLocal(ProtocolError.class)
.statement("Unexpected control packet during enhanced auth: " +
message.fixedHeader().messageType())
.statement("Unexpected control packet during enhanced auth: "
+ message.fixedHeader().messageType())
.peerAddress(ChannelAttrs.socketAddress(ctx.channel()))));
}
} else {
Expand All @@ -411,8 +453,8 @@ protected void handleMqttMessage(MqttMessage message) {
}
case DISCONNECT -> handleGoAway(GoAway.now(getLocal(EnhancedAuthAbortByClient.class)));
default -> handleGoAway(GoAway.now(getLocal(ProtocolError.class)
.statement("Unexpected control packet during enhanced auth: " +
message.fixedHeader().messageType())
.statement("Unexpected control packet during enhanced auth: "
+ message.fixedHeader().messageType())
.peerAddress(ChannelAttrs.socketAddress(ctx.channel()))));
}
}
Expand Down Expand Up @@ -538,9 +580,9 @@ protected GoAway validate(MqttConnectMessage message, TenantSettings settings, C
.statement("Will QoS not supported")
.clientInfo(clientInfo));
}
if (settings.payloadFormatValidationEnabled &&
isUTF8Payload(connMsg.payload().willProperties()) &&
!UTF8Util.isValidUTF8Payload(connMsg.payload().willMessageInBytes())) {
if (settings.payloadFormatValidationEnabled
&& isUTF8Payload(connMsg.payload().willProperties())
&& !UTF8Util.isValidUTF8Payload(connMsg.payload().willMessageInBytes())) {
return new GoAway(MqttMessageBuilders
.connAck()
.properties(MQTT5MessageBuilders.connAckProperties()
Expand Down Expand Up @@ -693,8 +735,8 @@ protected MqttConnAckMessage onConnected(MqttConnectMessage connMsg,
connPropsBuilder.maximumPacketSize(settings.maxPacketSize);
connPropsBuilder.topicAliasMaximum(settings.maxTopicAlias);
connPropsBuilder.receiveMaximum(settings.receiveMaximum);
if (requestResponseInformation(connMsg.variableHeader().properties()) &&
clientInfo.containsMetadata(MQTT_RESPONSE_INFO)) {
if (requestResponseInformation(connMsg.variableHeader().properties())
&& clientInfo.containsMetadata(MQTT_RESPONSE_INFO)) {
// include response information only when client requested it
connPropsBuilder.responseInformation(clientInfo.getMetadataOrDefault(MQTT_RESPONSE_INFO, ""));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static com.google.protobuf.UnsafeByteOperations.unsafeWrap;

import com.baidu.bifromq.mqtt.handler.ChannelAttrs;
import com.baidu.bifromq.plugin.authprovider.type.ConnAction;
import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData;
import com.baidu.bifromq.plugin.authprovider.type.MQTT5AuthData;
import com.baidu.bifromq.plugin.authprovider.type.MQTT5ExtendedAuthData;
Expand All @@ -39,7 +40,6 @@
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Optional;
import lombok.SneakyThrows;

Expand All @@ -53,7 +53,7 @@ public static MQTT3AuthData buildMQTT3AuthData(Channel channel, MqttConnectMessa
}
X509Certificate cert = ChannelAttrs.clientCertificate(channel);
if (cert != null) {
authData.setCert(unsafeWrap(Base64.getEncoder().encode(cert.getEncoded())));
authData.setCert(unsafeWrap(cert.getEncoded()));
}
if (msg.variableHeader().hasUserName()) {
authData.setUsername(msg.payload().userName());
Expand Down Expand Up @@ -82,7 +82,7 @@ public static MQTT5AuthData buildMQTT5AuthData(Channel channel, MqttConnectMessa
MQTT5AuthData.Builder authData = MQTT5AuthData.newBuilder();
X509Certificate cert = ChannelAttrs.clientCertificate(channel);
if (cert != null) {
authData.setCert(unsafeWrap(Base64.getEncoder().encode(cert.getEncoded())));
authData.setCert(unsafeWrap(cert.getEncoded()));
}
if (msg.variableHeader().hasUserName()) {
authData.setUsername(msg.payload().userName());
Expand Down Expand Up @@ -132,6 +132,14 @@ public static MQTT5ExtendedAuthData buildMQTT5ExtendedAuthData(MqttMessage authM
.build();
}

public static MQTTAction buildConnAction(UserProperties userProps) {
return MQTTAction.newBuilder()
.setConn(ConnAction.newBuilder()
.setUserProps(userProps)
.build())
.build();
}

public static MQTTAction buildPubAction(String topic, QoS qos, boolean retained) {
return MQTTAction.newBuilder()
.setPub(PubAction.newBuilder()
Expand Down
Loading

0 comments on commit 4700820

Please sign in to comment.