-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enhance Netty pipeline with WebSocketOnlyHandler to enforce WebSocket…
…-only traffic by rejecting non-WebSocket requests and MqttOverWSHandler to dynamically add MQTT handlers post-WebSocket handshake, ensuring protocol compliance and efficient resource management.
- Loading branch information
Showing
8 changed files
with
409 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
...ifromq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/ws/MqttOverWSHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and limitations under the License. | ||
*/ | ||
|
||
package com.baidu.bifromq.mqtt.handler.ws; | ||
|
||
import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler; | ||
import com.baidu.bifromq.mqtt.handler.MQTTMessageDebounceHandler; | ||
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler; | ||
import com.baidu.bifromq.mqtt.handler.condition.DirectMemPressureCondition; | ||
import com.baidu.bifromq.mqtt.handler.condition.HeapMemPressureCondition; | ||
import com.baidu.bifromq.plugin.eventcollector.IEventCollector; | ||
import com.google.common.collect.Sets; | ||
import io.netty.channel.ChannelHandlerContext; | ||
import io.netty.channel.ChannelInboundHandlerAdapter; | ||
import io.netty.channel.ChannelPipeline; | ||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; | ||
import io.netty.handler.codec.mqtt.MqttDecoder; | ||
import io.netty.handler.codec.mqtt.MqttEncoder; | ||
|
||
/** | ||
* A handler that adds MQTT handlers to the pipeline after the WebSocket handshake is complete. | ||
*/ | ||
public class MqttOverWSHandler extends ChannelInboundHandlerAdapter { | ||
private final int maxMQTTConnectPacketSize; | ||
private final int connectTimeoutSeconds; | ||
private final IEventCollector eventCollector; | ||
|
||
public MqttOverWSHandler(int maxMQTTConnectPacketSize, int connectTimeoutSeconds, IEventCollector eventCollector) { | ||
this.maxMQTTConnectPacketSize = maxMQTTConnectPacketSize; | ||
this.connectTimeoutSeconds = connectTimeoutSeconds; | ||
this.eventCollector = eventCollector; | ||
} | ||
|
||
@Override | ||
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { | ||
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { | ||
ChannelPipeline pipeline = ctx.pipeline(); | ||
// Handshake complete, add MQTT handlers. | ||
pipeline.addLast("ws2bytebufDecoder", new WebSocketFrameToByteBufDecoder()); | ||
pipeline.addLast("bytebuf2wsEncoder", new ByteBufToWebSocketFrameEncoder()); | ||
pipeline.addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE); | ||
// insert PacketFilter between Encoder | ||
pipeline.addLast(MqttDecoder.class.getName(), new MqttDecoder(maxMQTTConnectPacketSize)); | ||
pipeline.addLast(MQTTMessageDebounceHandler.NAME, new MQTTMessageDebounceHandler()); | ||
pipeline.addLast(ConditionalRejectHandler.NAME, | ||
new ConditionalRejectHandler(Sets.newHashSet(DirectMemPressureCondition.INSTANCE, | ||
HeapMemPressureCondition.INSTANCE), eventCollector)); | ||
pipeline.addLast(MQTTPreludeHandler.NAME, new MQTTPreludeHandler(connectTimeoutSeconds)); | ||
// Remove the handshake listener after adding MQTT handlers. | ||
ctx.pipeline().remove(this); | ||
} else { | ||
super.userEventTriggered(ctx, evt); | ||
} | ||
} | ||
} |
50 changes: 50 additions & 0 deletions
50
...omq-mqtt-server/src/main/java/com/baidu/bifromq/mqtt/handler/ws/WebSocketOnlyHandler.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and limitations under the License. | ||
*/ | ||
|
||
package com.baidu.bifromq.mqtt.handler.ws; | ||
|
||
import io.netty.channel.ChannelFutureListener; | ||
import io.netty.channel.ChannelHandlerContext; | ||
import io.netty.channel.SimpleChannelInboundHandler; | ||
import io.netty.handler.codec.http.DefaultFullHttpResponse; | ||
import io.netty.handler.codec.http.FullHttpRequest; | ||
import io.netty.handler.codec.http.FullHttpResponse; | ||
import io.netty.handler.codec.http.HttpHeaderNames; | ||
import io.netty.handler.codec.http.HttpResponseStatus; | ||
|
||
/** | ||
* A simple handler that rejects all requests that are not WebSocket upgrade requests. | ||
*/ | ||
public class WebSocketOnlyHandler extends SimpleChannelInboundHandler<FullHttpRequest> { | ||
private final String websocketPath; | ||
|
||
public WebSocketOnlyHandler(String websocketPath) { | ||
super(false); | ||
this.websocketPath = websocketPath; | ||
} | ||
|
||
@Override | ||
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) { | ||
if (!req.uri().equals(websocketPath) | ||
|| | ||
!req.headers().get(HttpHeaderNames.UPGRADE, "").equalsIgnoreCase("websocket")) { | ||
FullHttpResponse response = | ||
new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.BAD_REQUEST); | ||
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); | ||
} else { | ||
// Proceed with the pipeline setup for WebSocket. | ||
ctx.pipeline().remove(this); // Remove the validator after it's used. | ||
ctx.fireChannelRead(req); // Pass the request further if it's valid. | ||
} | ||
} | ||
} |
60 changes: 60 additions & 0 deletions
60
...r/src/test/java/com/baidu/bifromq/mqtt/handler/ws/ByteBufToWebSocketFrameEncoderTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and limitations under the License. | ||
*/ | ||
|
||
package com.baidu.bifromq.mqtt.handler.ws; | ||
|
||
import static org.testng.Assert.assertEquals; | ||
import static org.testng.Assert.assertFalse; | ||
import static org.testng.Assert.assertNotNull; | ||
import static org.testng.Assert.assertTrue; | ||
|
||
import io.netty.buffer.ByteBuf; | ||
import io.netty.buffer.ByteBufUtil; | ||
import io.netty.buffer.Unpooled; | ||
import io.netty.channel.embedded.EmbeddedChannel; | ||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; | ||
import org.testng.annotations.BeforeMethod; | ||
import org.testng.annotations.Test; | ||
|
||
public class ByteBufToWebSocketFrameEncoderTest { | ||
private EmbeddedChannel channel; | ||
|
||
@BeforeMethod | ||
public void setUp() { | ||
// Initialize channel with the encoder before each test | ||
channel = new EmbeddedChannel(new ByteBufToWebSocketFrameEncoder()); | ||
} | ||
|
||
@Test | ||
public void testEncode() { | ||
// Creating a test ByteBuf with sample data | ||
ByteBuf input = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5}); | ||
|
||
// Write the ByteBuf to the channel | ||
input.retain(); | ||
assertTrue(channel.writeOutbound(input.duplicate())); | ||
|
||
// Read the encoded output from the channel | ||
BinaryWebSocketFrame frame = channel.readOutbound(); | ||
|
||
assertNotNull(frame); | ||
assertEquals(input.readerIndex(), frame.content().readerIndex()); | ||
assertEquals(input.writerIndex(), frame.content().writerIndex()); | ||
assertTrue(ByteBufUtil.equals(input, frame.content())); | ||
|
||
// Cleanup | ||
frame.release(); | ||
|
||
assertFalse(channel.finish()); | ||
} | ||
} |
76 changes: 76 additions & 0 deletions
76
...mq-mqtt-server/src/test/java/com/baidu/bifromq/mqtt/handler/ws/MqttOverWSHandlerTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and limitations under the License. | ||
*/ | ||
|
||
package com.baidu.bifromq.mqtt.handler.ws; | ||
|
||
import static org.mockito.Mockito.mock; | ||
import static org.testng.Assert.assertNotNull; | ||
import static org.testng.Assert.assertNull; | ||
|
||
import com.baidu.bifromq.mqtt.handler.ChannelAttrs; | ||
import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler; | ||
import com.baidu.bifromq.mqtt.handler.MQTTMessageDebounceHandler; | ||
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler; | ||
import com.baidu.bifromq.mqtt.session.MQTTSessionContext; | ||
import com.baidu.bifromq.plugin.eventcollector.IEventCollector; | ||
import io.netty.channel.Channel; | ||
import io.netty.channel.ChannelInitializer; | ||
import io.netty.channel.embedded.EmbeddedChannel; | ||
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler; | ||
import io.netty.handler.codec.mqtt.MqttDecoder; | ||
import io.netty.handler.codec.mqtt.MqttEncoder; | ||
import java.net.InetSocketAddress; | ||
import org.testng.annotations.BeforeMethod; | ||
import org.testng.annotations.Test; | ||
|
||
public class MqttOverWSHandlerTest { | ||
private EmbeddedChannel channel; | ||
private MQTTSessionContext sessionContext; | ||
private IEventCollector eventCollector; | ||
|
||
@BeforeMethod | ||
public void setUp() { | ||
eventCollector = mock(IEventCollector.class); | ||
// Initialize channel with the MqttOverWSHandler | ||
sessionContext = MQTTSessionContext.builder() | ||
.eventCollector(eventCollector) | ||
.build(); | ||
channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() { | ||
@Override | ||
protected void initChannel(Channel ch) { | ||
ch.attr(ChannelAttrs.MQTT_SESSION_CTX).set(sessionContext); | ||
ch.attr(ChannelAttrs.PEER_ADDR).set(new InetSocketAddress("127.0.0.1", 8080)); | ||
ch.pipeline().addLast(new MqttOverWSHandler(65536, 30, eventCollector)); | ||
} | ||
}); | ||
} | ||
|
||
@Test | ||
public void testMqttHandlerAdditionAfterHandshakeComplete() { | ||
// Simulate a WebSocket handshake completion event | ||
channel.pipeline() | ||
.fireUserEventTriggered(new WebSocketServerProtocolHandler.HandshakeComplete(null, null, null)); | ||
|
||
// Check if all handlers are added | ||
assertNotNull(channel.pipeline().get(WebSocketFrameToByteBufDecoder.class)); | ||
assertNotNull(channel.pipeline().get(ByteBufToWebSocketFrameEncoder.class)); | ||
assertNotNull(channel.pipeline().get(MqttEncoder.class)); | ||
assertNotNull(channel.pipeline().get(MqttDecoder.class)); | ||
assertNotNull(channel.pipeline().get(MQTTMessageDebounceHandler.class)); | ||
assertNotNull(channel.pipeline().get(ConditionalRejectHandler.class)); | ||
assertNotNull(channel.pipeline().get(MQTTPreludeHandler.class)); | ||
|
||
// Check that the MqttOverWSHandler itself has been removed from the pipeline | ||
assertNull(channel.pipeline().get(MqttOverWSHandler.class)); | ||
} | ||
} |
60 changes: 60 additions & 0 deletions
60
...r/src/test/java/com/baidu/bifromq/mqtt/handler/ws/WebSocketFrameToByteBufDecoderTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and limitations under the License. | ||
*/ | ||
|
||
package com.baidu.bifromq.mqtt.handler.ws; | ||
|
||
import static org.testng.Assert.assertEquals; | ||
import static org.testng.Assert.assertFalse; | ||
import static org.testng.Assert.assertNotNull; | ||
import static org.testng.Assert.assertTrue; | ||
|
||
import io.netty.buffer.ByteBuf; | ||
import io.netty.buffer.ByteBufUtil; | ||
import io.netty.buffer.Unpooled; | ||
import io.netty.channel.embedded.EmbeddedChannel; | ||
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; | ||
import org.testng.annotations.BeforeMethod; | ||
import org.testng.annotations.Test; | ||
|
||
public class WebSocketFrameToByteBufDecoderTest { | ||
private EmbeddedChannel channel; | ||
|
||
@BeforeMethod | ||
public void setUp() { | ||
// Initialize channel with the decoder before each test | ||
channel = new EmbeddedChannel(new WebSocketFrameToByteBufDecoder()); | ||
} | ||
|
||
@Test | ||
public void testDecode() { | ||
// Creating a BinaryWebSocketFrame with sample data | ||
ByteBuf originalContent = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5}); | ||
BinaryWebSocketFrame frame = new BinaryWebSocketFrame(originalContent); | ||
|
||
// Write the frame to the channel | ||
assertTrue(channel.writeInbound(frame)); | ||
|
||
// Read the decoded output from the channel | ||
ByteBuf decoded = channel.readInbound(); | ||
|
||
assertNotNull(decoded); | ||
assertEquals(originalContent.readerIndex(), decoded.readerIndex()); | ||
assertEquals(originalContent.writerIndex(), decoded.writerIndex()); | ||
assertTrue(ByteBufUtil.equals(originalContent, decoded)); | ||
|
||
// Cleanup | ||
decoded.release(); | ||
|
||
assertFalse(channel.finish()); | ||
} | ||
} |
Oops, something went wrong.