Skip to content

Commit

Permalink
Enhance Netty pipeline with WebSocketOnlyHandler to enforce WebSocket…
Browse files Browse the repository at this point in the history
…-only traffic by rejecting non-WebSocket requests and MqttOverWSHandler to dynamically add MQTT handlers post-WebSocket handshake, ensuring protocol compliance and efficient resource management.
  • Loading branch information
popduke committed Jun 25, 2024
1 parent 376ecbf commit 7255969
Show file tree
Hide file tree
Showing 8 changed files with 409 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import com.baidu.bifromq.baseenv.EnvProvider;
import com.baidu.bifromq.baserpc.utils.NettyUtil;
import com.baidu.bifromq.mqtt.handler.ByteBufToWebSocketFrameEncoder;
import com.baidu.bifromq.mqtt.handler.ChannelAttrs;
import com.baidu.bifromq.mqtt.handler.ClientAddrHandler;
import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler;
Expand All @@ -24,7 +23,8 @@
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler;
import com.baidu.bifromq.mqtt.handler.condition.DirectMemPressureCondition;
import com.baidu.bifromq.mqtt.handler.condition.HeapMemPressureCondition;
import com.baidu.bifromq.mqtt.handler.ws.WebSocketFrameToByteBufDecoder;
import com.baidu.bifromq.mqtt.handler.ws.MqttOverWSHandler;
import com.baidu.bifromq.mqtt.handler.ws.WebSocketOnlyHandler;
import com.baidu.bifromq.mqtt.session.MQTTSessionContext;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.RateLimiter;
Expand Down Expand Up @@ -216,19 +216,11 @@ protected void initChannel(SocketChannel ch) {
pipeline.addLast("httpDecoder", new HttpRequestDecoder());
pipeline.addLast("remoteAddr", remoteAddrHandler);
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
pipeline.addLast("webSocketOnly", new WebSocketOnlyHandler(connBuilder.path()));
pipeline.addLast("webSocketHandler", new WebSocketServerProtocolHandler(connBuilder.path(),
MQTT_SUBPROTOCOL_CSV_LIST));
pipeline.addLast("ws2bytebufDecoder", new WebSocketFrameToByteBufDecoder());
pipeline.addLast("bytebuf2wsEncoder", new ByteBufToWebSocketFrameEncoder());
pipeline.addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
// insert PacketFilter here
pipeline.addLast(MqttDecoder.class.getName(), new MqttDecoder(builder.maxBytesInMessage));
pipeline.addLast(MQTTMessageDebounceHandler.NAME, new MQTTMessageDebounceHandler());
pipeline.addLast(ConditionalRejectHandler.NAME,
new ConditionalRejectHandler(
Sets.newHashSet(DirectMemPressureCondition.INSTANCE, HeapMemPressureCondition.INSTANCE),
sessionContext.eventCollector));
pipeline.addLast(MQTTPreludeHandler.NAME, new MQTTPreludeHandler(builder.connectTimeoutSeconds));
pipeline.addLast("webSocketHandshakeListener", new MqttOverWSHandler(
builder.maxBytesInMessage, builder.connectTimeoutSeconds, sessionContext.eventCollector));
}));
}
});
Expand All @@ -247,19 +239,11 @@ protected void initChannel(SocketChannel ch) {
pipeline.addLast("httpDecoder", new HttpRequestDecoder());
pipeline.addLast("remoteAddr", remoteAddrHandler);
pipeline.addLast("aggregator", new HttpObjectAggregator(65536));
pipeline.addLast("webSocketOnly", new WebSocketOnlyHandler(connBuilder.path()));
pipeline.addLast("webSocketHandler", new WebSocketServerProtocolHandler(connBuilder.path(),
MQTT_SUBPROTOCOL_CSV_LIST));
pipeline.addLast("ws2bytebufDecoder", new WebSocketFrameToByteBufDecoder());
pipeline.addLast("bytebuf2wsEncoder", new ByteBufToWebSocketFrameEncoder());
pipeline.addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
// insert PacketFilter between Encoder
pipeline.addLast(MqttDecoder.class.getName(), new MqttDecoder(builder.maxBytesInMessage));
pipeline.addLast(MQTTMessageDebounceHandler.NAME, new MQTTMessageDebounceHandler());
pipeline.addLast(ConditionalRejectHandler.NAME,
new ConditionalRejectHandler(
Sets.newHashSet(DirectMemPressureCondition.INSTANCE, HeapMemPressureCondition.INSTANCE),
sessionContext.eventCollector));
pipeline.addLast(MQTTPreludeHandler.NAME, new MQTTPreludeHandler(builder.connectTimeoutSeconds));
pipeline.addLast("webSocketHandshakeListener", new MqttOverWSHandler(
builder.maxBytesInMessage, builder.connectTimeoutSeconds, sessionContext.eventCollector));
}));
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.bifromq.mqtt.handler;
package com.baidu.bifromq.mqtt.handler.ws;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.bifromq.mqtt.handler.ws;

import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler;
import com.baidu.bifromq.mqtt.handler.MQTTMessageDebounceHandler;
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler;
import com.baidu.bifromq.mqtt.handler.condition.DirectMemPressureCondition;
import com.baidu.bifromq.mqtt.handler.condition.HeapMemPressureCondition;
import com.baidu.bifromq.plugin.eventcollector.IEventCollector;
import com.google.common.collect.Sets;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder;

/**
* A handler that adds MQTT handlers to the pipeline after the WebSocket handshake is complete.
*/
public class MqttOverWSHandler extends ChannelInboundHandlerAdapter {
private final int maxMQTTConnectPacketSize;
private final int connectTimeoutSeconds;
private final IEventCollector eventCollector;

public MqttOverWSHandler(int maxMQTTConnectPacketSize, int connectTimeoutSeconds, IEventCollector eventCollector) {
this.maxMQTTConnectPacketSize = maxMQTTConnectPacketSize;
this.connectTimeoutSeconds = connectTimeoutSeconds;
this.eventCollector = eventCollector;
}

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
ChannelPipeline pipeline = ctx.pipeline();
// Handshake complete, add MQTT handlers.
pipeline.addLast("ws2bytebufDecoder", new WebSocketFrameToByteBufDecoder());
pipeline.addLast("bytebuf2wsEncoder", new ByteBufToWebSocketFrameEncoder());
pipeline.addLast(MqttEncoder.class.getName(), MqttEncoder.INSTANCE);
// insert PacketFilter between Encoder
pipeline.addLast(MqttDecoder.class.getName(), new MqttDecoder(maxMQTTConnectPacketSize));
pipeline.addLast(MQTTMessageDebounceHandler.NAME, new MQTTMessageDebounceHandler());
pipeline.addLast(ConditionalRejectHandler.NAME,
new ConditionalRejectHandler(Sets.newHashSet(DirectMemPressureCondition.INSTANCE,
HeapMemPressureCondition.INSTANCE), eventCollector));
pipeline.addLast(MQTTPreludeHandler.NAME, new MQTTPreludeHandler(connectTimeoutSeconds));
// Remove the handshake listener after adding MQTT handlers.
ctx.pipeline().remove(this);
} else {
super.userEventTriggered(ctx, evt);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.bifromq.mqtt.handler.ws;

import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpRequest;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpResponseStatus;

/**
* A simple handler that rejects all requests that are not WebSocket upgrade requests.
*/
public class WebSocketOnlyHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
private final String websocketPath;

public WebSocketOnlyHandler(String websocketPath) {
super(false);
this.websocketPath = websocketPath;
}

@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest req) {
if (!req.uri().equals(websocketPath)
||
!req.headers().get(HttpHeaderNames.UPGRADE, "").equalsIgnoreCase("websocket")) {
FullHttpResponse response =
new DefaultFullHttpResponse(req.protocolVersion(), HttpResponseStatus.BAD_REQUEST);
ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
} else {
// Proceed with the pipeline setup for WebSocket.
ctx.pipeline().remove(this); // Remove the validator after it's used.
ctx.fireChannelRead(req); // Pass the request further if it's valid.
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.bifromq.mqtt.handler.ws;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

public class ByteBufToWebSocketFrameEncoderTest {
private EmbeddedChannel channel;

@BeforeMethod
public void setUp() {
// Initialize channel with the encoder before each test
channel = new EmbeddedChannel(new ByteBufToWebSocketFrameEncoder());
}

@Test
public void testEncode() {
// Creating a test ByteBuf with sample data
ByteBuf input = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5});

// Write the ByteBuf to the channel
input.retain();
assertTrue(channel.writeOutbound(input.duplicate()));

// Read the encoded output from the channel
BinaryWebSocketFrame frame = channel.readOutbound();

assertNotNull(frame);
assertEquals(input.readerIndex(), frame.content().readerIndex());
assertEquals(input.writerIndex(), frame.content().writerIndex());
assertTrue(ByteBufUtil.equals(input, frame.content()));

// Cleanup
frame.release();

assertFalse(channel.finish());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.bifromq.mqtt.handler.ws;

import static org.mockito.Mockito.mock;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;

import com.baidu.bifromq.mqtt.handler.ChannelAttrs;
import com.baidu.bifromq.mqtt.handler.ConditionalRejectHandler;
import com.baidu.bifromq.mqtt.handler.MQTTMessageDebounceHandler;
import com.baidu.bifromq.mqtt.handler.MQTTPreludeHandler;
import com.baidu.bifromq.mqtt.session.MQTTSessionContext;
import com.baidu.bifromq.plugin.eventcollector.IEventCollector;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder;
import java.net.InetSocketAddress;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

public class MqttOverWSHandlerTest {
private EmbeddedChannel channel;
private MQTTSessionContext sessionContext;
private IEventCollector eventCollector;

@BeforeMethod
public void setUp() {
eventCollector = mock(IEventCollector.class);
// Initialize channel with the MqttOverWSHandler
sessionContext = MQTTSessionContext.builder()
.eventCollector(eventCollector)
.build();
channel = new EmbeddedChannel(true, true, new ChannelInitializer<>() {
@Override
protected void initChannel(Channel ch) {
ch.attr(ChannelAttrs.MQTT_SESSION_CTX).set(sessionContext);
ch.attr(ChannelAttrs.PEER_ADDR).set(new InetSocketAddress("127.0.0.1", 8080));
ch.pipeline().addLast(new MqttOverWSHandler(65536, 30, eventCollector));
}
});
}

@Test
public void testMqttHandlerAdditionAfterHandshakeComplete() {
// Simulate a WebSocket handshake completion event
channel.pipeline()
.fireUserEventTriggered(new WebSocketServerProtocolHandler.HandshakeComplete(null, null, null));

// Check if all handlers are added
assertNotNull(channel.pipeline().get(WebSocketFrameToByteBufDecoder.class));
assertNotNull(channel.pipeline().get(ByteBufToWebSocketFrameEncoder.class));
assertNotNull(channel.pipeline().get(MqttEncoder.class));
assertNotNull(channel.pipeline().get(MqttDecoder.class));
assertNotNull(channel.pipeline().get(MQTTMessageDebounceHandler.class));
assertNotNull(channel.pipeline().get(ConditionalRejectHandler.class));
assertNotNull(channel.pipeline().get(MQTTPreludeHandler.class));

// Check that the MqttOverWSHandler itself has been removed from the pipeline
assertNull(channel.pipeline().get(MqttOverWSHandler.class));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) 2024. The BifroMQ Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and limitations under the License.
*/

package com.baidu.bifromq.mqtt.handler.ws;

import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.Unpooled;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

public class WebSocketFrameToByteBufDecoderTest {
private EmbeddedChannel channel;

@BeforeMethod
public void setUp() {
// Initialize channel with the decoder before each test
channel = new EmbeddedChannel(new WebSocketFrameToByteBufDecoder());
}

@Test
public void testDecode() {
// Creating a BinaryWebSocketFrame with sample data
ByteBuf originalContent = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4, 5});
BinaryWebSocketFrame frame = new BinaryWebSocketFrame(originalContent);

// Write the frame to the channel
assertTrue(channel.writeInbound(frame));

// Read the decoded output from the channel
ByteBuf decoded = channel.readInbound();

assertNotNull(decoded);
assertEquals(originalContent.readerIndex(), decoded.readerIndex());
assertEquals(originalContent.writerIndex(), decoded.writerIndex());
assertTrue(ByteBufUtil.equals(originalContent, decoded));

// Cleanup
decoded.release();

assertFalse(channel.finish());
}
}
Loading

0 comments on commit 7255969

Please sign in to comment.