diff --git a/packages/sdk-socket-server-next/src/config.ts b/packages/sdk-socket-server-next/src/config.ts index 7acb5dad1..5ad8fdb40 100644 --- a/packages/sdk-socket-server-next/src/config.ts +++ b/packages/sdk-socket-server-next/src/config.ts @@ -17,6 +17,7 @@ export const withAdminUI: boolean = process.env.ADMIN_UI === 'true'; const HOUR_IN_SECONDS = 60 * 60; const THIRTY_DAYS_IN_SECONDS = 30 * 24 * 60 * 60; // expiration time of entries in Redis export const MAX_CLIENTS_PER_ROOM = 2; +export const MAX_MESSAGE_LENGTH = 1_000_000; // 1MB limit export const config = { msgExpiry: HOUR_IN_SECONDS, diff --git a/packages/sdk-socket-server-next/src/protocol/handleMessage.ts b/packages/sdk-socket-server-next/src/protocol/handleMessage.ts index da96c1989..db4b353e2 100644 --- a/packages/sdk-socket-server-next/src/protocol/handleMessage.ts +++ b/packages/sdk-socket-server-next/src/protocol/handleMessage.ts @@ -1,7 +1,7 @@ import { Server, Socket } from 'socket.io'; import { v4 as uuidv4 } from 'uuid'; import { pubClient } from '../analytics-api'; -import { config, isDevelopment } from '../config'; +import { config, isDevelopment, MAX_MESSAGE_LENGTH } from '../config'; import { getLogger } from '../logger'; import { increaseRateLimits, @@ -58,6 +58,21 @@ export const handleMessage = async ({ let ready = false; // Determines if the keys have been exchanged and both side support the full protocol try { + // Add message size validation + const messageSize = typeof message === 'string' + ? message.length + : JSON.stringify(message).length; + + if (messageSize > MAX_MESSAGE_LENGTH) { + logger.warn(`[handleMessage] Message size ${messageSize} exceeds limit of ${MAX_MESSAGE_LENGTH} bytes`, { + channelId, + socketId, + clientIp, + }); + callback?.(`Message size ${messageSize} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`); + return; + } + if (clientType) { // new protocol, get channelConfig const channelConfigKey = `channel_config:${channelId}`; diff --git a/packages/sdk/src/config.ts b/packages/sdk/src/config.ts index e55357fa6..d0a62dc37 100644 --- a/packages/sdk/src/config.ts +++ b/packages/sdk/src/config.ts @@ -73,3 +73,5 @@ export const EXTENSION_EVENTS = { CONNECT: 'connect', CONNECTED: 'connected', }; + +export const MAX_MESSAGE_LENGTH = 1_000_000; // 1MB limit diff --git a/packages/sdk/src/services/MobilePortStream/write.test.ts b/packages/sdk/src/services/MobilePortStream/write.test.ts index ee09552f3..97676c8a5 100644 --- a/packages/sdk/src/services/MobilePortStream/write.test.ts +++ b/packages/sdk/src/services/MobilePortStream/write.test.ts @@ -1,4 +1,5 @@ import { Buffer } from 'buffer'; +import { MAX_MESSAGE_LENGTH } from '../../config'; import { write } from './write'; describe('write function', () => { @@ -77,4 +78,50 @@ describe('write function', () => { new Error('MobilePortStream - disconnected'), ); }); + + describe('Message Size Validation', () => { + beforeEach(() => { + jest.clearAllMocks(); + global.window = { + location: { href: 'http://example.com' }, + ReactNativeWebView: { postMessage: mockPostMessage }, + } as any; + }); + + it('should reject messages exceeding MAX_MESSAGE_LENGTH', () => { + const largeData = { + data: { + jsonrpc: '2.0', + method: 'test_method', + params: ['x'.repeat(MAX_MESSAGE_LENGTH)], + }, + }; + + write(largeData, 'utf-8', cb); + + expect(cb).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringMatching( + /Message size \d+ exceeds maximum allowed size of \d+ bytes/u, + ), + }), + ); + expect(mockPostMessage).not.toHaveBeenCalled(); + }); + + it('should accept messages within MAX_MESSAGE_LENGTH', () => { + const validData = { + data: { + jsonrpc: '2.0', + method: 'test_method', + params: ['x'.repeat(100)], + }, + }; + + write(validData, 'utf-8', cb); + + expect(cb).toHaveBeenCalledWith(); + expect(mockPostMessage).toHaveBeenCalled(); + }); + }); }); diff --git a/packages/sdk/src/services/MobilePortStream/write.ts b/packages/sdk/src/services/MobilePortStream/write.ts index b4da3c9a0..b11d9ff36 100644 --- a/packages/sdk/src/services/MobilePortStream/write.ts +++ b/packages/sdk/src/services/MobilePortStream/write.ts @@ -1,4 +1,5 @@ import { Buffer } from 'buffer'; +import { MAX_MESSAGE_LENGTH } from '../../config'; /** * Handles communication between the in-app browser and MetaMask mobile application. @@ -15,6 +16,7 @@ export function write( cb: (error?: Error | null) => void, ) { try { + let stringifiedData: string; if (Buffer.isBuffer(chunk)) { const data: { type: 'Buffer'; @@ -23,18 +25,30 @@ export function write( } = chunk.toJSON(); data._isBuffer = true; - window.ReactNativeWebView?.postMessage( - JSON.stringify({ ...data, origin: window.location.href }), - ); + stringifiedData = JSON.stringify({ + ...data, + origin: window.location.href, + }); } else { if (chunk.data) { chunk.data.toNative = true; } - window.ReactNativeWebView?.postMessage( - JSON.stringify({ ...chunk, origin: window.location.href }), + stringifiedData = JSON.stringify({ + ...chunk, + origin: window.location.href, + }); + } + + if (stringifiedData.length > MAX_MESSAGE_LENGTH) { + return cb( + new Error( + `Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`, + ), ); } + + window.ReactNativeWebView?.postMessage(stringifiedData); } catch (err) { return cb(new Error('MobilePortStream - disconnected')); } diff --git a/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.test.ts b/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.test.ts index 5170b97e6..c5a096726 100644 --- a/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.test.ts +++ b/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.test.ts @@ -1,9 +1,9 @@ -import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream'; // Adjust the import based on your project structure -import { METHODS_TO_REDIRECT } from '../../config'; +import { MAX_MESSAGE_LENGTH, METHODS_TO_REDIRECT } from '../../config'; import * as loggerModule from '../../utils/logger'; // Adjust the import based on your project structure -import { write } from './write'; // Adjust the import based on your project structure +import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure import { extractMethod } from './extractMethod'; +import { write } from './write'; // Adjust the import based on your project structure jest.mock('./extractMethod'); jest.mock('../Ethereum'); @@ -162,11 +162,22 @@ describe('write function', () => { mockIsMobileWeb.mockReturnValue(false); mockIsSecure.mockReturnValue(true); mockGetChannelId.mockReturnValue('some_channel_id'); + mockIsMetaMaskInstalled.mockReturnValue(true); + mockGetKeyInfo.mockReturnValue({ ecies: { public: 'test_public_key' } }); + mockHasDeeplinkProtocol.mockReturnValue(false); }); it('should redirect if method exists in METHODS_TO_REDIRECT', async () => { mockExtractMethod.mockReturnValue({ method: Object.keys(METHODS_TO_REDIRECT)[0], + data: { + data: { + jsonrpc: '2.0', + method: Object.keys(METHODS_TO_REDIRECT)[0], + params: [], + }, + }, + triggeredInstaller: false, }); await write( @@ -239,4 +250,71 @@ describe('write function', () => { expect(spyLogger).toHaveBeenCalled(); }); }); + + describe('Message Size Validation', () => { + it('should reject messages exceeding MAX_MESSAGE_LENGTH', async () => { + mockGetChannelId.mockReturnValue('some_channel_id'); + mockIsReady.mockReturnValue(true); + mockIsConnected.mockReturnValue(true); + + // Mock extractMethod to return large data + const largeData = { + jsonrpc: '2.0', + method: 'eth_call', + params: ['x'.repeat(MAX_MESSAGE_LENGTH + 1)], + }; + + mockExtractMethod.mockReturnValue({ + method: 'eth_call', + data: { + data: largeData, + }, + }); + + await write( + mockRemoteCommunicationPostMessageStream, + { jsonrpc: '2.0', method: 'eth_call' }, + 'utf8', + callback, + ); + + // Don't test for exact error message, just verify it contains the key parts + expect(callback).toHaveBeenCalledWith( + expect.objectContaining({ + message: expect.stringMatching( + /Message size \d+ exceeds maximum allowed size of \d+ bytes/u, + ), + }), + ); + expect(mockSendMessage).not.toHaveBeenCalled(); + }); + + it('should accept messages within MAX_MESSAGE_LENGTH', async () => { + mockGetChannelId.mockReturnValue('some_channel_id'); + mockIsReady.mockReturnValue(true); + mockIsConnected.mockReturnValue(true); + + // Mock extractMethod to return valid-sized data + mockExtractMethod.mockReturnValue({ + method: 'eth_call', + data: { + data: { + jsonrpc: '2.0', + method: 'eth_call', + params: ['x'.repeat(100)], + }, + }, + }); + + await write( + mockRemoteCommunicationPostMessageStream, + { jsonrpc: '2.0', method: 'eth_call' }, + 'utf8', + callback, + ); + + expect(callback).toHaveBeenCalledWith(); + expect(mockSendMessage).toHaveBeenCalled(); + }); + }); }); diff --git a/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.ts b/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.ts index ffc92a600..0ca239ec8 100644 --- a/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.ts +++ b/packages/sdk/src/services/RemoteCommunicationPostMessageStream/write.ts @@ -1,5 +1,9 @@ import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream'; -import { METHODS_TO_REDIRECT, RPC_METHODS } from '../../config'; +import { + METHODS_TO_REDIRECT, + RPC_METHODS, + MAX_MESSAGE_LENGTH, +} from '../../config'; import { METAMASK_CONNECT_BASE_URL, METAMASK_DEEPLINK_BASE, @@ -57,11 +61,17 @@ export async function write( deeplinkProtocolAvailable && mobileWeb && authorized; try { - console.warn( - `[RCPMS: _write()] triggeredInstaller=${triggeredInstaller} activeDeeplinkProtocol=${activeDeeplinkProtocol}`, - ); - if (!triggeredInstaller) { + // Check message size before sending + const stringifiedData = JSON.stringify(data?.data); + if (stringifiedData.length > MAX_MESSAGE_LENGTH) { + return callback( + new Error( + `Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`, + ), + ); + } + // The only reason not to send via network is because the rpc call will be sent in the deeplink instance.state.remote ?.sendMessage(data?.data)