From 470082086d971a5d8eead3fa4253fba4c8ede138 Mon Sep 17 00:00:00 2001 From: Yonny Hao Date: Wed, 4 Dec 2024 15:08:46 +0800 Subject: [PATCH] 1. enhance IAuthProvider to allow check permission for connection explicitly, so that auth provider impl could fine-grain control the auth workflow 2. Tenant Setting "BypassPermCheckError" now will turn CheckResult(Error) into CheckResult(Granted) when enabled. 3. change the 'cause' field type of AccessControlError event from Throwable to String --- .../mqtt/handler/MQTTConnectHandler.java | 10 +++ .../mqtt/handler/v3/MQTT3ConnectHandler.java | 46 +++++++++-- .../mqtt/handler/v5/MQTT5ConnectHandler.java | 68 ++++++++++++---- .../baidu/bifromq/mqtt/utils/AuthUtil.java | 14 +++- .../mqtt/handler/MQTTConnectHandlerTest.java | 30 +++++++ .../bifromq/mqtt/handler/v3/BaseMQTTTest.java | 6 ++ .../handler/v3/MQTT3ConnectHandlerTest.java | 13 ++++ .../mqtt/handler/v5/EnhancedAuthTest.java | 17 ++++ .../handler/v5/MQTT5ConnectHandlerTest.java | 24 ++++++ .../mqtt/integration/v3/MQTTConnectTest.java | 78 +++++++++++++++++-- .../integration/v3/MQTTDisconnectTest.java | 10 +++ .../mqtt/integration/v3/MQTTKickTest.java | 6 ++ .../authprovider/AuthProviderManager.java | 12 ++- .../authprovider/AuthProviderManagerTest.java | 23 +++++- .../src/main/proto/mqtt_actions.proto | 5 ++ .../accessctrl/AccessControlError.java | 2 +- 16 files changed, 332 insertions(+), 32 deletions(-) diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandler.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandler.java index e75dababe..70f7df96a 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandler.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandler.java @@ -127,6 +127,13 @@ public final void channelRead(ChannelHandlerContext ctx, Object msg) { } long reqId = System.nanoTime(); cancellableTasks.track(authenticate(connMsg)) + .thenComposeAsync(okOrGoAway -> { + if (okOrGoAway.goAway != null) { + return CompletableFuture.completedFuture(okOrGoAway); + } + // check conn permission + return checkConnectPermission(connMsg, okOrGoAway.clientInfo); + }, ctx.executor()) .thenComposeAsync(okOrGoAway -> { if (okOrGoAway.goAway != null) { handleGoAway(okOrGoAway.goAway); @@ -312,6 +319,9 @@ public final void channelRead(ChannelHandlerContext ctx, Object msg) { protected abstract CompletableFuture authenticate(MqttConnectMessage message); + protected abstract CompletableFuture checkConnectPermission(MqttConnectMessage message, + ClientInfo clientInfo); + protected abstract void handleMqttMessage(MqttMessage message); protected abstract GoAway onNoEnoughResources(MqttConnectMessage message, TenantResourceType resourceType, diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandler.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandler.java index 87d1fd71c..1c7e6f6b4 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandler.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandler.java @@ -17,6 +17,7 @@ import static com.baidu.bifromq.mqtt.handler.MQTTConnectHandler.AuthResult.goAway; import static com.baidu.bifromq.mqtt.handler.MQTTConnectHandler.AuthResult.ok; import static com.baidu.bifromq.mqtt.handler.condition.ORCondition.or; +import static com.baidu.bifromq.mqtt.utils.AuthUtil.buildConnAction; import static com.baidu.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal; import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CHANNEL_ID_KEY; import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CLIENT_ADDRESS_KEY; @@ -66,6 +67,7 @@ import com.baidu.bifromq.type.ClientInfo; import com.baidu.bifromq.type.Message; import com.baidu.bifromq.type.QoS; +import com.baidu.bifromq.type.UserProperties; import com.baidu.bifromq.util.TopicUtil; import com.baidu.bifromq.util.UTF8Util; import com.bifromq.plugin.resourcethrottler.TenantResourceType; @@ -124,12 +126,12 @@ protected GoAway sanityCheck(MqttConnectMessage message) { .build(), getLocal(IdentifierRejected.class).peerAddress(clientAddress)); } - if (message.variableHeader().hasUserName() && - !UTF8Util.isWellFormed(message.payload().userName(), SANITY_CHECK)) { + if (message.variableHeader().hasUserName() + && !UTF8Util.isWellFormed(message.payload().userName(), SANITY_CHECK)) { return new GoAway(getLocal(MalformedUserName.class).peerAddress(clientAddress)); } - if (message.variableHeader().isWillFlag() && - !UTF8Util.isWellFormed(message.payload().willTopic(), SANITY_CHECK)) { + if (message.variableHeader().isWillFlag() + && !UTF8Util.isWellFormed(message.payload().willTopic(), SANITY_CHECK)) { return new GoAway(getLocal(MalformedWillTopic.class).peerAddress(clientAddress)); } return null; @@ -148,8 +150,8 @@ protected CompletableFuture authenticate(MqttConnectMessage message) .setTenantId(ok.getTenantId()) .setType(MQTT_TYPE_VALUE) .putAllMetadata(ok.getAttrsMap()) // custom attrs - .putMetadata(MQTT_PROTOCOL_VER_KEY, message.variableHeader().version() == 3 ? - MQTT_PROTOCOL_VER_3_1_VALUE : MQTT_PROTOCOL_VER_3_1_1_VALUE) + .putMetadata(MQTT_PROTOCOL_VER_KEY, message.variableHeader().version() == 3 + ? MQTT_PROTOCOL_VER_3_1_VALUE : MQTT_PROTOCOL_VER_3_1_1_VALUE) .putMetadata(MQTT_USER_ID_KEY, ok.getUserId()) .putMetadata(MQTT_CLIENT_ID_KEY, message.payload().clientIdentifier()) .putMetadata(MQTT_CHANNEL_ID_KEY, ctx.channel().id().asLongText()) @@ -212,6 +214,38 @@ protected CompletableFuture authenticate(MqttConnectMessage message) }, ctx.executor()); } + @Override + protected CompletableFuture checkConnectPermission(MqttConnectMessage message, ClientInfo clientInfo) { + return authProvider.checkPermission(clientInfo, buildConnAction(UserProperties.getDefaultInstance())) + .thenApply(checkResult -> { + switch (checkResult.getTypeCase()) { + case GRANTED -> { + return AuthResult.ok(clientInfo); + } + case DENIED -> { + return goAway(MqttMessageBuilders + .connAck() + .returnCode(CONNECTION_REFUSED_NOT_AUTHORIZED) + .build(), + getLocal(NotAuthorizedClient.class) + .tenantId(clientInfo.getTenantId()) + .userId(clientInfo.getMetadataOrDefault(MQTT_USER_ID_KEY, "")) + .clientId(clientInfo.getMetadataOrDefault(MQTT_CLIENT_ID_KEY, "")) + .peerAddress(ChannelAttrs.socketAddress(ctx.channel()))); + } + default -> { + return goAway(MqttMessageBuilders + .connAck() + .returnCode(CONNECTION_REFUSED_SERVER_UNAVAILABLE) + .build(), + getLocal(AuthError.class) + .cause("Failed to check connect permission") + .peerAddress(ChannelAttrs.socketAddress(ctx.channel()))); + } + } + }); + } + @Override protected void handleMqttMessage(MqttMessage message) { // never happen in MQTT3 diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandler.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandler.java index 79b9bbacb..9257b019c 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandler.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandler.java @@ -23,8 +23,10 @@ import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.maximumPacketSize; import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.requestProblemInformation; import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.requestResponseInformation; +import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.toUserProperties; import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.toWillMessage; import static com.baidu.bifromq.mqtt.handler.v5.MQTT5MessageUtils.topicAliasMaximum; +import static com.baidu.bifromq.mqtt.utils.AuthUtil.buildConnAction; import static com.baidu.bifromq.mqtt.utils.MQTT5MessageSizer.MIN_CONTROL_PACKET_SIZE; import static com.baidu.bifromq.plugin.eventcollector.ThreadLocalEventPool.getLocal; import static com.baidu.bifromq.type.MQTTClientInfoConstants.MQTT_CHANNEL_ID_KEY; @@ -72,6 +74,7 @@ import com.baidu.bifromq.plugin.authprovider.type.Failed; import com.baidu.bifromq.plugin.authprovider.type.MQTT5AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT5ExtendedAuthData; +import com.baidu.bifromq.plugin.authprovider.type.MQTTAction; import com.baidu.bifromq.plugin.authprovider.type.Success; import com.baidu.bifromq.plugin.clientbalancer.IClientBalancer; import com.baidu.bifromq.plugin.clientbalancer.Redirection; @@ -165,8 +168,8 @@ protected GoAway sanityCheck(MqttConnectMessage connMsg) { .build(), getLocal(MalformedClientIdentifier.class).peerAddress(clientAddress)); } - if (connMsg.variableHeader().hasUserName() && - !UTF8Util.isWellFormed(connMsg.payload().userName(), SANITY_CHECK)) { + if (connMsg.variableHeader().hasUserName() + && !UTF8Util.isWellFormed(connMsg.payload().userName(), SANITY_CHECK)) { return new GoAway(MqttMessageBuilders .connAck() .properties(MQTT5MessageBuilders.connAckProperties() @@ -176,8 +179,8 @@ protected GoAway sanityCheck(MqttConnectMessage connMsg) { .build(), getLocal(MalformedUserName.class).peerAddress(clientAddress)); } - if (authMethod(connMsg.variableHeader().properties()).isEmpty() && - authData(connMsg.variableHeader().properties()).isPresent()) { + if (authMethod(connMsg.variableHeader().properties()).isEmpty() + && authData(connMsg.variableHeader().properties()).isPresent()) { return new GoAway(MqttMessageBuilders .connAck() .properties(MQTT5MessageBuilders.connAckProperties() @@ -288,6 +291,45 @@ protected CompletableFuture authenticate(MqttConnectMessage message) } } + @Override + protected CompletableFuture checkConnectPermission(MqttConnectMessage message, ClientInfo clientInfo) { + MQTTAction connAction = buildConnAction(toUserProperties(message.variableHeader().properties())); + return authProvider.checkPermission(clientInfo, connAction) + .thenApply(checkResult -> { + switch (checkResult.getTypeCase()) { + case GRANTED -> { + return AuthResult.ok(clientInfo); + } + case DENIED -> { + return goAway(MqttMessageBuilders + .connAck() + .properties(MQTT5MessageBuilders.connAckProperties() + .reasonString("Not authorized") + .build()) + .returnCode(CONNECTION_REFUSED_NOT_AUTHORIZED_5) + .build(), + getLocal(NotAuthorizedClient.class) + .tenantId(clientInfo.getTenantId()) + .userId(clientInfo.getMetadataOrDefault(MQTT_USER_ID_KEY, "")) + .clientId(connMsg.payload().clientIdentifier()) + .peerAddress(ChannelAttrs.socketAddress(ctx.channel()))); + } + default -> { + return goAway(MqttMessageBuilders + .connAck() + .properties(MQTT5MessageBuilders.connAckProperties() + .reasonString("Failed to check connect permission") + .build()) + .returnCode(CONNECTION_REFUSED_UNSPECIFIED_ERROR) + .build(), + getLocal(AuthError.class) + .cause("Failed to check connect permission") + .peerAddress(ChannelAttrs.socketAddress(ctx.channel()))); + } + } + }); + } + private void extendedAuth(MQTT5ExtendedAuthData authData) { this.isAuthing = true; authProvider.extendedAuth(authData) @@ -385,8 +427,8 @@ protected void handleMqttMessage(MqttMessage message) { .peerAddress(ChannelAttrs.socketAddress(ctx.channel())))); case DISCONNECT -> handleGoAway(GoAway.now(getLocal(EnhancedAuthAbortByClient.class))); default -> handleGoAway(GoAway.now(getLocal(ProtocolError.class) - .statement("Unexpected control packet during enhanced auth: " + - message.fixedHeader().messageType()) + .statement("Unexpected control packet during enhanced auth: " + + message.fixedHeader().messageType()) .peerAddress(ChannelAttrs.socketAddress(ctx.channel())))); } } else { @@ -411,8 +453,8 @@ protected void handleMqttMessage(MqttMessage message) { } case DISCONNECT -> handleGoAway(GoAway.now(getLocal(EnhancedAuthAbortByClient.class))); default -> handleGoAway(GoAway.now(getLocal(ProtocolError.class) - .statement("Unexpected control packet during enhanced auth: " + - message.fixedHeader().messageType()) + .statement("Unexpected control packet during enhanced auth: " + + message.fixedHeader().messageType()) .peerAddress(ChannelAttrs.socketAddress(ctx.channel())))); } } @@ -538,9 +580,9 @@ protected GoAway validate(MqttConnectMessage message, TenantSettings settings, C .statement("Will QoS not supported") .clientInfo(clientInfo)); } - if (settings.payloadFormatValidationEnabled && - isUTF8Payload(connMsg.payload().willProperties()) && - !UTF8Util.isValidUTF8Payload(connMsg.payload().willMessageInBytes())) { + if (settings.payloadFormatValidationEnabled + && isUTF8Payload(connMsg.payload().willProperties()) + && !UTF8Util.isValidUTF8Payload(connMsg.payload().willMessageInBytes())) { return new GoAway(MqttMessageBuilders .connAck() .properties(MQTT5MessageBuilders.connAckProperties() @@ -693,8 +735,8 @@ protected MqttConnAckMessage onConnected(MqttConnectMessage connMsg, connPropsBuilder.maximumPacketSize(settings.maxPacketSize); connPropsBuilder.topicAliasMaximum(settings.maxTopicAlias); connPropsBuilder.receiveMaximum(settings.receiveMaximum); - if (requestResponseInformation(connMsg.variableHeader().properties()) && - clientInfo.containsMetadata(MQTT_RESPONSE_INFO)) { + if (requestResponseInformation(connMsg.variableHeader().properties()) + && clientInfo.containsMetadata(MQTT_RESPONSE_INFO)) { // include response information only when client requested it connPropsBuilder.responseInformation(clientInfo.getMetadataOrDefault(MQTT_RESPONSE_INFO, "")); } diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/AuthUtil.java b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/AuthUtil.java index 5acd3026d..18ade470e 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/AuthUtil.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/utils/AuthUtil.java @@ -20,6 +20,7 @@ import static com.google.protobuf.UnsafeByteOperations.unsafeWrap; import com.baidu.bifromq.mqtt.handler.ChannelAttrs; +import com.baidu.bifromq.plugin.authprovider.type.ConnAction; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT5AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT5ExtendedAuthData; @@ -39,7 +40,6 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.security.cert.X509Certificate; -import java.util.Base64; import java.util.Optional; import lombok.SneakyThrows; @@ -53,7 +53,7 @@ public static MQTT3AuthData buildMQTT3AuthData(Channel channel, MqttConnectMessa } X509Certificate cert = ChannelAttrs.clientCertificate(channel); if (cert != null) { - authData.setCert(unsafeWrap(Base64.getEncoder().encode(cert.getEncoded()))); + authData.setCert(unsafeWrap(cert.getEncoded())); } if (msg.variableHeader().hasUserName()) { authData.setUsername(msg.payload().userName()); @@ -82,7 +82,7 @@ public static MQTT5AuthData buildMQTT5AuthData(Channel channel, MqttConnectMessa MQTT5AuthData.Builder authData = MQTT5AuthData.newBuilder(); X509Certificate cert = ChannelAttrs.clientCertificate(channel); if (cert != null) { - authData.setCert(unsafeWrap(Base64.getEncoder().encode(cert.getEncoded()))); + authData.setCert(unsafeWrap(cert.getEncoded())); } if (msg.variableHeader().hasUserName()) { authData.setUsername(msg.payload().userName()); @@ -132,6 +132,14 @@ public static MQTT5ExtendedAuthData buildMQTT5ExtendedAuthData(MqttMessage authM .build(); } + public static MQTTAction buildConnAction(UserProperties userProps) { + return MQTTAction.newBuilder() + .setConn(ConnAction.newBuilder() + .setUserProps(userProps) + .build()) + .build(); + } + public static MQTTAction buildPubAction(String topic, QoS qos, boolean retained) { return MQTTAction.newBuilder() .setPub(PubAction.newBuilder() diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandlerTest.java index 0a335d20c..ecd9b60ef 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/MQTTConnectHandlerTest.java @@ -26,6 +26,8 @@ import com.baidu.bifromq.mqtt.MockableTest; import com.baidu.bifromq.mqtt.handler.record.GoAway; import com.baidu.bifromq.mqtt.session.MQTTSessionContext; +import com.baidu.bifromq.plugin.authprovider.type.CheckResult; +import com.baidu.bifromq.plugin.authprovider.type.Denied; import com.baidu.bifromq.plugin.eventcollector.IEventCollector; import com.baidu.bifromq.plugin.settingprovider.ISettingProvider; import com.baidu.bifromq.plugin.settingprovider.Setting; @@ -115,6 +117,24 @@ public void authenticateFailed() { assertFalse(channel.isOpen()); } + @Test + public void checkConnPermissionFailed() { + MqttConnectMessage connMsg = MqttMessageBuilders.connect() + .clientId("client") + .protocolVersion(MqttVersion.MQTT_3_1_1) + .build(); + ClientInfo clientInfo = ClientInfo.newBuilder().build(); + when(connectHandler.sanityCheck(connMsg)).thenReturn(null); + when(connectHandler.authenticate(connMsg)).thenReturn( + CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); + when(connectHandler.checkConnectPermission(eq(connMsg), eq(clientInfo))).thenReturn( + CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.goAway(null))); + channel.writeInbound(connMsg); + channel.advanceTimeBy(6, TimeUnit.SECONDS); + channel.runScheduledPendingTasks(); + assertFalse(channel.isOpen()); + } + @Test public void noTotalConnectionResource() { MqttConnectMessage connMsg = MqttMessageBuilders.connect() @@ -127,6 +147,8 @@ public void noTotalConnectionResource() { when(connectHandler.sanityCheck(connMsg)).thenReturn(null); when(connectHandler.authenticate(connMsg)).thenReturn( CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); + when(connectHandler.checkConnectPermission(eq(connMsg), eq(clientInfo))).thenReturn( + CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); when(connectHandler.onNoEnoughResources(connMsg, TotalConnections, clientInfo)).thenReturn(new GoAway()); when(resourceThrottler.hasResource(tenantId, TotalConnections)).thenReturn(false); channel.writeInbound(connMsg); @@ -147,6 +169,8 @@ public void noTotalSessionMemoryBytesResource() { when(connectHandler.sanityCheck(connMsg)).thenReturn(null); when(connectHandler.authenticate(connMsg)).thenReturn( CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); + when(connectHandler.checkConnectPermission(eq(connMsg), eq(clientInfo))).thenReturn( + CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); when(connectHandler.onNoEnoughResources(connMsg, TotalSessionMemoryBytes, clientInfo)).thenReturn(new GoAway()); when(resourceThrottler.hasResource(tenantId, TotalSessionMemoryBytes)).thenReturn(false); channel.writeInbound(connMsg); @@ -167,6 +191,8 @@ public void noTotalConnectPerSecondResource() { when(connectHandler.sanityCheck(connMsg)).thenReturn(null); when(connectHandler.authenticate(connMsg)).thenReturn( CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); + when(connectHandler.checkConnectPermission(eq(connMsg), eq(clientInfo))).thenReturn( + CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); when(connectHandler.onNoEnoughResources(connMsg, TotalConnectPerSecond, clientInfo)).thenReturn(new GoAway()); when(resourceThrottler.hasResource(tenantId, TotalConnectPerSecond)).thenReturn(false); channel.writeInbound(connMsg); @@ -186,6 +212,8 @@ public void validationFailed() { when(connectHandler.sanityCheck(connMsg)).thenReturn(null); when(connectHandler.authenticate(connMsg)).thenReturn( CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); + when(connectHandler.checkConnectPermission(eq(connMsg), eq(clientInfo))).thenReturn( + CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); when(connectHandler.validate(eq(connMsg), any(), eq(clientInfo))).thenReturn(new GoAway()); channel.writeInbound(connMsg); channel.advanceTimeBy(6, TimeUnit.SECONDS); @@ -204,6 +232,8 @@ public void needRedirect() { when(connectHandler.sanityCheck(connMsg)).thenReturn(null); when(connectHandler.authenticate(connMsg)).thenReturn( CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); + when(connectHandler.checkConnectPermission(eq(connMsg), eq(clientInfo))).thenReturn( + CompletableFuture.completedFuture(MQTTConnectHandler.AuthResult.ok(clientInfo))); when(connectHandler.needRedirect(clientInfo)).thenReturn(new GoAway()); channel.writeInbound(connMsg); channel.advanceTimeBy(6, TimeUnit.SECONDS); diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/BaseMQTTTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/BaseMQTTTest.java index 14143b6cd..646a67287 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/BaseMQTTTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/BaseMQTTTest.java @@ -31,6 +31,7 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.lenient; @@ -75,6 +76,7 @@ import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult; +import com.baidu.bifromq.plugin.authprovider.type.MQTTAction; import com.baidu.bifromq.plugin.authprovider.type.Ok; import com.baidu.bifromq.plugin.authprovider.type.Reject; import com.baidu.bifromq.plugin.clientbalancer.IClientBalancer; @@ -247,6 +249,10 @@ protected void mockAuthPass(String... attrsKeyValues) { .putAllAttrs(attrsMap) .build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); } protected void mockAuthReject(Reject.Code code, String reason) { diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandlerTest.java index 5177fdd9c..2ad493ca8 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v3/MQTT3ConnectHandlerTest.java @@ -17,6 +17,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; @@ -27,8 +28,11 @@ import com.baidu.bifromq.mqtt.handler.ChannelAttrs; import com.baidu.bifromq.mqtt.session.MQTTSessionContext; import com.baidu.bifromq.plugin.authprovider.IAuthProvider; +import com.baidu.bifromq.plugin.authprovider.type.CheckResult; +import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult; +import com.baidu.bifromq.plugin.authprovider.type.MQTTAction; import com.baidu.bifromq.plugin.authprovider.type.Ok; import com.baidu.bifromq.plugin.clientbalancer.IClientBalancer; import com.baidu.bifromq.plugin.clientbalancer.Redirection; @@ -37,6 +41,7 @@ import com.baidu.bifromq.plugin.eventcollector.mqttbroker.clientdisconnect.Redirect; import com.baidu.bifromq.plugin.settingprovider.ISettingProvider; import com.baidu.bifromq.plugin.settingprovider.Setting; +import com.baidu.bifromq.type.ClientInfo; import com.bifromq.plugin.resourcethrottler.IResourceThrottler; import io.netty.channel.Channel; import io.netty.channel.ChannelInitializer; @@ -114,6 +119,10 @@ public void needMove() { .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() .setOk(Ok.newBuilder().setTenantId("tenantId").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); when(clientBalancer.needRedirect(any())).thenReturn( Optional.of(new Redirection(true, Optional.of("server1")))); channel.writeInbound(connMsg); @@ -137,6 +146,10 @@ public void needUseAnotherServer() { .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() .setOk(Ok.newBuilder().setTenantId("tenantId").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); when(clientBalancer.needRedirect(any())).thenReturn( Optional.of(new Redirection(false, Optional.empty()))); channel.writeInbound(connMsg); diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/EnhancedAuthTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/EnhancedAuthTest.java index fca137a4a..f0bd91f25 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/EnhancedAuthTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/EnhancedAuthTest.java @@ -15,6 +15,7 @@ import static com.baidu.bifromq.mqtt.handler.condition.ORCondition.or; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.when; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -33,16 +34,20 @@ import com.baidu.bifromq.mqtt.service.LocalSessionRegistry; import com.baidu.bifromq.mqtt.session.MQTTSessionContext; import com.baidu.bifromq.plugin.authprovider.IAuthProvider; +import com.baidu.bifromq.plugin.authprovider.type.CheckResult; import com.baidu.bifromq.plugin.authprovider.type.Continue; import com.baidu.bifromq.plugin.authprovider.type.Failed; +import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT5ExtendedAuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT5ExtendedAuthResult; +import com.baidu.bifromq.plugin.authprovider.type.MQTTAction; import com.baidu.bifromq.plugin.authprovider.type.Success; import com.baidu.bifromq.plugin.clientbalancer.IClientBalancer; import com.baidu.bifromq.plugin.eventcollector.IEventCollector; import com.baidu.bifromq.plugin.settingprovider.ISettingProvider; import com.baidu.bifromq.plugin.settingprovider.Setting; import com.baidu.bifromq.sessiondict.client.ISessionDictClient; +import com.baidu.bifromq.type.ClientInfo; import com.bifromq.plugin.resourcethrottler.IResourceThrottler; import com.google.common.util.concurrent.RateLimiter; import com.google.protobuf.ByteString; @@ -179,6 +184,10 @@ public void testAuthSuccess() { MQTT5ExtendedAuthResult.newBuilder() .setSuccess(Success.newBuilder().build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); channel.writeInbound(connect); MqttConnAckMessage connAckMessage = channel.readOutbound(); assertEquals(connAckMessage.variableHeader().connectReturnCode(), MqttConnectReturnCode.CONNECTION_ACCEPTED); @@ -219,6 +228,10 @@ public void testAuthSuccess2() { MQTT5ExtendedAuthResult.newBuilder() .setSuccess(Success.newBuilder().setTenantId("tenant").setUserId("user").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); channel.writeInbound(MqttMessageBuilders.auth() .reasonCode(MQTT5AuthReasonCode.Continue.value()) .properties(MQTT5MessageUtils.mqttProps() @@ -338,6 +351,10 @@ public void testReAuth() { MQTT5ExtendedAuthResult.newBuilder() .setSuccess(Success.newBuilder().build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); channel.writeInbound(connect); MqttConnAckMessage connAckMessage = channel.readOutbound(); assertEquals(connAckMessage.variableHeader().connectReturnCode(), MqttConnectReturnCode.CONNECTION_ACCEPTED); diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandlerTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandlerTest.java index 37155d1d7..f6fecd28c 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandlerTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/v5/MQTT5ConnectHandlerTest.java @@ -41,9 +41,12 @@ import com.baidu.bifromq.mqtt.handler.ChannelAttrs; import com.baidu.bifromq.mqtt.session.MQTTSessionContext; import com.baidu.bifromq.plugin.authprovider.IAuthProvider; +import com.baidu.bifromq.plugin.authprovider.type.CheckResult; +import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT5AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT5AuthResult; import com.baidu.bifromq.plugin.authprovider.type.MQTT5ExtendedAuthData; +import com.baidu.bifromq.plugin.authprovider.type.MQTTAction; import com.baidu.bifromq.plugin.authprovider.type.Success; import com.baidu.bifromq.plugin.clientbalancer.IClientBalancer; import com.baidu.bifromq.plugin.clientbalancer.Redirection; @@ -54,6 +57,7 @@ import com.baidu.bifromq.plugin.eventcollector.mqttbroker.clientdisconnect.ResourceThrottled; import com.baidu.bifromq.plugin.settingprovider.ISettingProvider; import com.baidu.bifromq.plugin.settingprovider.Setting; +import com.baidu.bifromq.type.ClientInfo; import com.bifromq.plugin.resourcethrottler.IResourceThrottler; import com.bifromq.plugin.resourcethrottler.TenantResourceType; import com.google.protobuf.ByteString; @@ -256,6 +260,10 @@ public void noTotalConnectionResource() { .thenReturn(CompletableFuture.completedFuture(MQTT5AuthResult.newBuilder() .setSuccess(Success.newBuilder().setTenantId("tenantId").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); channel.writeInbound(connMsg); channel.advanceTimeBy(6, TimeUnit.SECONDS); channel.runScheduledPendingTasks(); @@ -280,6 +288,10 @@ public void noSessionMemoryResource() { .thenReturn(CompletableFuture.completedFuture(MQTT5AuthResult.newBuilder() .setSuccess(Success.newBuilder().setTenantId("tenantId").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); channel.writeInbound(connMsg); channel.advanceTimeBy(6, TimeUnit.SECONDS); channel.runScheduledPendingTasks(); @@ -304,6 +316,10 @@ public void noTotalConnectPerSecondResource() { .thenReturn(CompletableFuture.completedFuture(MQTT5AuthResult.newBuilder() .setSuccess(Success.newBuilder().setTenantId("tenantId").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); channel.writeInbound(connMsg); channel.advanceTimeBy(6, TimeUnit.SECONDS); channel.runScheduledPendingTasks(); @@ -327,6 +343,10 @@ public void needMove() { .thenReturn(CompletableFuture.completedFuture(MQTT5AuthResult.newBuilder() .setSuccess(Success.newBuilder().setTenantId("tenantId").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); when(clientBalancer.needRedirect(any())).thenReturn( Optional.of(new Redirection(true, Optional.of("server1")))); channel.writeInbound(connMsg); @@ -352,6 +372,10 @@ public void needUseAnotherServer() { .thenReturn(CompletableFuture.completedFuture(MQTT5AuthResult.newBuilder() .setSuccess(Success.newBuilder().setTenantId("tenantId").build()) .build())); + when(authProvider.checkPermission(any(ClientInfo.class), argThat(MQTTAction::hasConn))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); when(clientBalancer.needRedirect(any())).thenReturn( Optional.of(new Redirection(false, Optional.empty()))); channel.writeInbound(connMsg); diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTConnectTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTConnectTest.java index fc3323509..1f0ce0821 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTConnectTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTConnectTest.java @@ -18,6 +18,7 @@ import static org.eclipse.paho.client.mqttv3.MqttException.REASON_CODE_NOT_AUTHORIZED; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -28,10 +29,15 @@ import com.baidu.bifromq.mqtt.TestUtils; import com.baidu.bifromq.mqtt.integration.MQTTTest; import com.baidu.bifromq.mqtt.integration.v3.client.MqttTestClient; +import com.baidu.bifromq.plugin.authprovider.type.CheckResult; +import com.baidu.bifromq.plugin.authprovider.type.Denied; +import com.baidu.bifromq.plugin.authprovider.type.Error; +import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult; import com.baidu.bifromq.plugin.authprovider.type.Ok; import com.baidu.bifromq.plugin.authprovider.type.Reject; +import com.baidu.bifromq.plugin.eventcollector.Event; import com.baidu.bifromq.plugin.eventcollector.mqttbroker.channelclosed.AuthError; import com.baidu.bifromq.plugin.eventcollector.mqttbroker.channelclosed.NotAuthorizedClient; import com.baidu.bifromq.plugin.eventcollector.mqttbroker.channelclosed.UnauthenticatedClient; @@ -42,6 +48,7 @@ import lombok.extern.slf4j.Slf4j; import org.eclipse.paho.client.mqttv3.MqttConnectOptions; import org.eclipse.paho.client.mqttv3.MqttException; +import org.mockito.ArgumentCaptor; import org.testng.annotations.Test; @Slf4j @@ -68,6 +75,10 @@ public void connectWithCleanSessionTrue() { .setTenantId(tenantId) .setUserId("testUser") .build()).build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); MqttConnectOptions connOpts = new MqttConnectOptions(); connOpts.setCleanSession(true); @@ -86,6 +97,10 @@ public void connectWithCleanSessionFalse() { .setTenantId(tenantId) .setUserId("testUser") .build()).build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); MqttConnectOptions connOpts = new MqttConnectOptions(); connOpts.setCleanSession(false); @@ -97,13 +112,17 @@ public void connectWithCleanSessionFalse() { } @Test(groups = "integration") - public void testBadWillTopic() { + public void badWillTopic() { when(authProvider.auth(any(MQTT3AuthData.class))) .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() .setOk(Ok.newBuilder() .setTenantId(tenantId) .setUserId("testUser") .build()).build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); MqttConnectOptions connOpts = new MqttConnectOptions(); connOpts.setWill("$share/badwilltopic", new byte[] {}, 0, false); @@ -120,8 +139,52 @@ public void testBadWillTopic() { .report(argThat(event -> event instanceof InvalidTopic)); } + @Test + public void connectPermissionDenied() { + when(authProvider.auth(any(MQTT3AuthData.class))) + .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() + .setOk(Ok.newBuilder() + .setTenantId(tenantId) + .setUserId("testUser") + .build()).build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setDenied(Denied.getDefaultInstance()) + .build())); + + MqttConnectOptions connOpts = new MqttConnectOptions(); + connOpts.setMqttVersion(4); + connOpts.setCleanSession(true); + connOpts.setUserName("abcdef/testClient"); + MqttException e = TestUtils.expectThrow(() -> mqttClient.connect(connOpts)); + assertEquals(e.getReasonCode(), REASON_CODE_NOT_AUTHORIZED); + verify(eventCollector, atLeast(1)).report(argThat(event -> event instanceof NotAuthorizedClient)); + } + + @Test + public void connectPermissionError() { + when(authProvider.auth(any(MQTT3AuthData.class))) + .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() + .setOk(Ok.newBuilder() + .setTenantId(tenantId) + .setUserId("testUser") + .build()).build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setError(Error.getDefaultInstance()) + .build())); + + MqttConnectOptions connOpts = new MqttConnectOptions(); + connOpts.setMqttVersion(4); + connOpts.setCleanSession(true); + connOpts.setUserName("abcdef/testClient"); + MqttException e = TestUtils.expectThrow(() -> mqttClient.connect(connOpts)); + assertEquals(e.getReasonCode(), REASON_CODE_BROKER_UNAVAILABLE); + verify(eventCollector, atLeast(1)).report(argThat(event -> event instanceof AuthError)); + } + @Test(groups = "integration") - public void testUnauthenticated() { + public void unauthenticated() { when(authProvider.auth(any(MQTT3AuthData.class))) .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() .setReject(Reject.newBuilder() @@ -140,7 +203,7 @@ public void testUnauthenticated() { } @Test(groups = "integration") - public void testBanned() { + public void banned() { when(authProvider.auth(any(MQTT3AuthData.class))) .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() .setReject(Reject.newBuilder() @@ -155,11 +218,11 @@ public void testBanned() { MqttException e = TestUtils.expectThrow(() -> mqttClient.connect(connOpts)); assertEquals(e.getReasonCode(), REASON_CODE_NOT_AUTHORIZED); - verify(eventCollector).report(argThat(event -> event instanceof NotAuthorizedClient)); + verify(eventCollector, atLeast(1)).report(argThat(event -> event instanceof NotAuthorizedClient)); } @Test(groups = "integration") - public void testAuthError() { + public void authError() { when(authProvider.auth(any(MQTT3AuthData.class))) .thenReturn(CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder() .setReject(Reject.newBuilder() @@ -175,6 +238,9 @@ public void testAuthError() { MqttException e = TestUtils.expectThrow(() -> mqttClient.connect(connOpts)); assertEquals(e.getReasonCode(), REASON_CODE_BROKER_UNAVAILABLE); - verify(eventCollector).report(argThat(event -> event instanceof AuthError)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Event.class); + verify(eventCollector, atLeast(1)).report(argThat(event -> event instanceof AuthError)); +// verify(eventCollector, atLeast(1)).report(argumentCaptor.capture()); +// argumentCaptor.getAllValues(); } } diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTDisconnectTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTDisconnectTest.java index 3fad7305b..9985ebbb7 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTDisconnectTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTDisconnectTest.java @@ -22,6 +22,8 @@ import com.baidu.bifromq.mqtt.integration.MQTTTest; import com.baidu.bifromq.mqtt.integration.v3.client.MqttTestClient; +import com.baidu.bifromq.plugin.authprovider.type.CheckResult; +import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult; import com.baidu.bifromq.plugin.authprovider.type.Ok; @@ -42,6 +44,10 @@ public void disconnectDirectly() { .setTenantId(tenantId) .setUserId("testUser") .build()).build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); MqttConnectOptions connOpts = new MqttConnectOptions(); connOpts.setCleanSession(true); @@ -67,6 +73,10 @@ public void disconnect() { .setTenantId(tenantId) .setUserId("testUser") .build()).build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); MqttConnectOptions connOpts = new MqttConnectOptions(); connOpts.setCleanSession(true); diff --git a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTKickTest.java b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTKickTest.java index 74855363c..3f2d13f0b 100644 --- a/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTKickTest.java +++ b/bifromq-mqtt/bifromq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/integration/v3/MQTTKickTest.java @@ -22,6 +22,8 @@ import com.baidu.bifromq.mqtt.integration.MQTTTest; import com.baidu.bifromq.mqtt.integration.v3.client.MqttTestClient; +import com.baidu.bifromq.plugin.authprovider.type.CheckResult; +import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthResult; import com.baidu.bifromq.plugin.authprovider.type.Ok; @@ -44,6 +46,10 @@ public void testKick() { .setUserId(deviceKey) .build()) .build())); + when(authProvider.checkPermission(any(), any())) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build())); MqttConnectOptions connOpts = new MqttConnectOptions(); connOpts.setMqttVersion(4); diff --git a/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/main/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManager.java b/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/main/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManager.java index 54ef7999c..f67307f4f 100644 --- a/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/main/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManager.java +++ b/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/main/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManager.java @@ -204,11 +204,19 @@ public CompletableFuture checkPermission(ClientInfo client, MQTTAct return delegate.checkPermission(client, action) .thenApply(v -> { start.stop(metricMgr.checkCallTimer); + if (v.getTypeCase() == CheckResult.TypeCase.ERROR + && (boolean) settingProvider.provide(ByPassPermCheckError, client.getTenantId())) { + eventCollector.report( + getLocal(AccessControlError.class).clientInfo(client).cause(v.getError().getReason())); + return CheckResult.newBuilder() + .setGranted(Granted.getDefaultInstance()) + .build(); + } return v; }) .exceptionally(e -> { metricMgr.checkCallErrorCounter.increment(); - eventCollector.report(getLocal(AccessControlError.class).clientInfo(client).cause(e)); + eventCollector.report(getLocal(AccessControlError.class).clientInfo(client).cause(e.getMessage())); boolean byPass = settingProvider.provide(ByPassPermCheckError, client.getTenantId()); if (byPass) { return CheckResult.newBuilder() @@ -225,7 +233,7 @@ public CompletableFuture checkPermission(ClientInfo client, MQTTAct }); } catch (Throwable e) { metricMgr.checkCallErrorCounter.increment(); - eventCollector.report(getLocal(AccessControlError.class).clientInfo(client).cause(e)); + eventCollector.report(getLocal(AccessControlError.class).clientInfo(client).cause(e.getMessage())); boolean byPass = settingProvider.provide(ByPassPermCheckError, client.getTenantId()); if (byPass) { return CompletableFuture.completedFuture(CheckResult.newBuilder() diff --git a/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/test/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManagerTest.java b/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/test/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManagerTest.java index a74b96c67..97cc1eb96 100644 --- a/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/test/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManagerTest.java +++ b/bifromq-plugin/bifromq-plugin-auth-provider-helper/src/test/java/com/baidu/bifromq/plugin/authprovider/AuthProviderManagerTest.java @@ -25,6 +25,7 @@ import static org.testng.Assert.assertTrue; import com.baidu.bifromq.plugin.authprovider.type.CheckResult; +import com.baidu.bifromq.plugin.authprovider.type.Error; import com.baidu.bifromq.plugin.authprovider.type.Failed; import com.baidu.bifromq.plugin.authprovider.type.Granted; import com.baidu.bifromq.plugin.authprovider.type.MQTT3AuthData; @@ -352,7 +353,7 @@ public void checkPermissionReturnErrorAndNoPass() { ArgumentCaptor eventArgumentCaptor = ArgumentCaptor.forClass(AccessControlError.class); verify(eventCollector).report(eventArgumentCaptor.capture()); assertEquals(eventArgumentCaptor.getValue().type(), EventType.ACCESS_CONTROL_ERROR); - assertTrue(eventArgumentCaptor.getValue().cause().getMessage().contains("Intend Error")); + assertTrue(eventArgumentCaptor.getValue().cause().contains("Intend Error")); assertEquals(meterRegistry.find(CALL_TIMER).tag(TAG_METHOD, "AuthProvider/check").timer().count(), 0); @@ -377,4 +378,24 @@ public void checkPermissionThrowsExceptionAndPass() { assertEquals(meterRegistry.find(CALL_FAIL_COUNTER).tag(TAG_METHOD, "AuthProvider/check").counter().count(), 1); } + + @Test + public void byPassCheckResultError() { + when(settingProvider.provide(ByPassPermCheckError, clientInfo.getTenantId())).thenReturn(true); + manager = + new AuthProviderManager(mockProvider.getClass().getName(), pluginManager, settingProvider, eventCollector); + when(mockProvider.checkPermission(any(ClientInfo.class), any(MQTTAction.class))) + .thenReturn(CompletableFuture.completedFuture(CheckResult.newBuilder() + .setError(Error.newBuilder().build()) + .build())); + assertTrue(manager.checkPermission(clientInfo, mockActionInfo).join().hasGranted()); + ArgumentCaptor eventArgumentCaptor = ArgumentCaptor.forClass(AccessControlError.class); + verify(eventCollector).report(eventArgumentCaptor.capture()); + assertEquals(eventArgumentCaptor.getValue().type(), EventType.ACCESS_CONTROL_ERROR); + + assertEquals(meterRegistry.find(CALL_TIMER).tag(TAG_METHOD, "AuthProvider/check").timer().count(), + 1); + assertEquals(meterRegistry.find(CALL_FAIL_COUNTER).tag(TAG_METHOD, "AuthProvider/check").counter().count(), + 0); + } } diff --git a/bifromq-plugin/bifromq-plugin-auth-provider/src/main/proto/mqtt_actions.proto b/bifromq-plugin/bifromq-plugin-auth-provider/src/main/proto/mqtt_actions.proto index 85440ee31..3ec213bc9 100644 --- a/bifromq-plugin/bifromq-plugin-auth-provider/src/main/proto/mqtt_actions.proto +++ b/bifromq-plugin/bifromq-plugin-auth-provider/src/main/proto/mqtt_actions.proto @@ -27,11 +27,16 @@ message UnsubAction { commontype.UserProperties userProps = 2; } +message ConnAction { + commontype.UserProperties userProps = 1; +} + message MQTTAction { oneof Type{ PubAction pub = 1; SubAction sub = 2; UnsubAction unsub = 3; + ConnAction conn = 4; } } diff --git a/bifromq-plugin/bifromq-plugin-event-collector/src/main/java/com/baidu/bifromq/plugin/eventcollector/mqttbroker/accessctrl/AccessControlError.java b/bifromq-plugin/bifromq-plugin-event-collector/src/main/java/com/baidu/bifromq/plugin/eventcollector/mqttbroker/accessctrl/AccessControlError.java index c70f80463..5cd42fae4 100644 --- a/bifromq-plugin/bifromq-plugin-event-collector/src/main/java/com/baidu/bifromq/plugin/eventcollector/mqttbroker/accessctrl/AccessControlError.java +++ b/bifromq-plugin/bifromq-plugin-event-collector/src/main/java/com/baidu/bifromq/plugin/eventcollector/mqttbroker/accessctrl/AccessControlError.java @@ -25,7 +25,7 @@ @Accessors(fluent = true, chain = true) @ToString(callSuper = true) public final class AccessControlError extends ClientEvent { - private Throwable cause; + private String cause; @Override public EventType type() {