Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement message size validation to prevent excessive payloads #1197

Merged
merged 7 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/sdk-socket-server-next/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ export const withAdminUI: boolean = process.env.ADMIN_UI === 'true';
const HOUR_IN_SECONDS = 60 * 60;
const THIRTY_DAYS_IN_SECONDS = 30 * 24 * 60 * 60; // expiration time of entries in Redis
export const MAX_CLIENTS_PER_ROOM = 2;
export const MAX_MESSAGE_LENGTH = 1_000_000; // 1MB limit

export const config = {
msgExpiry: HOUR_IN_SECONDS,
Expand Down
17 changes: 16 additions & 1 deletion packages/sdk-socket-server-next/src/protocol/handleMessage.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { Server, Socket } from 'socket.io';
import { v4 as uuidv4 } from 'uuid';
import { pubClient } from '../analytics-api';
import { config, isDevelopment } from '../config';
import { config, isDevelopment, MAX_MESSAGE_LENGTH } from '../config';
import { getLogger } from '../logger';
import {
increaseRateLimits,
Expand Down Expand Up @@ -58,6 +58,21 @@ export const handleMessage = async ({
let ready = false; // Determines if the keys have been exchanged and both side support the full protocol

try {
// Add message size validation
const messageSize = typeof message === 'string'
? message.length
: JSON.stringify(message).length;

if (messageSize > MAX_MESSAGE_LENGTH) {
logger.warn(`[handleMessage] Message size ${messageSize} exceeds limit of ${MAX_MESSAGE_LENGTH} bytes`, {
channelId,
socketId,
clientIp,
});
callback?.(`Message size ${messageSize} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`);
return;
}

if (clientType) {
// new protocol, get channelConfig
const channelConfigKey = `channel_config:${channelId}`;
Expand Down
2 changes: 2 additions & 0 deletions packages/sdk/src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@ export const EXTENSION_EVENTS = {
CONNECT: 'connect',
CONNECTED: 'connected',
};

export const MAX_MESSAGE_LENGTH = 1_000_000; // 1MB limit
abretonc7s marked this conversation as resolved.
Show resolved Hide resolved
47 changes: 47 additions & 0 deletions packages/sdk/src/services/MobilePortStream/write.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Buffer } from 'buffer';
import { MAX_MESSAGE_LENGTH } from '../../config';
import { write } from './write';

describe('write function', () => {
Expand Down Expand Up @@ -77,4 +78,50 @@ describe('write function', () => {
new Error('MobilePortStream - disconnected'),
);
});

describe('Message Size Validation', () => {
beforeEach(() => {
jest.clearAllMocks();
global.window = {
location: { href: 'http://example.com' },
ReactNativeWebView: { postMessage: mockPostMessage },
} as any;
});

it('should reject messages exceeding MAX_MESSAGE_LENGTH', () => {
const largeData = {
data: {
jsonrpc: '2.0',
method: 'test_method',
params: ['x'.repeat(MAX_MESSAGE_LENGTH)],
},
};

write(largeData, 'utf-8', cb);

expect(cb).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringMatching(
/Message size \d+ exceeds maximum allowed size of \d+ bytes/u,
),
}),
);
expect(mockPostMessage).not.toHaveBeenCalled();
});

it('should accept messages within MAX_MESSAGE_LENGTH', () => {
const validData = {
data: {
jsonrpc: '2.0',
method: 'test_method',
params: ['x'.repeat(100)],
},
};

write(validData, 'utf-8', cb);

expect(cb).toHaveBeenCalledWith();
expect(mockPostMessage).toHaveBeenCalled();
});
});
});
24 changes: 19 additions & 5 deletions packages/sdk/src/services/MobilePortStream/write.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Buffer } from 'buffer';
import { MAX_MESSAGE_LENGTH } from '../../config';

/**
* Handles communication between the in-app browser and MetaMask mobile application.
Expand All @@ -15,6 +16,7 @@ export function write(
cb: (error?: Error | null) => void,
) {
try {
let stringifiedData: string;
if (Buffer.isBuffer(chunk)) {
const data: {
type: 'Buffer';
Expand All @@ -23,18 +25,30 @@ export function write(
} = chunk.toJSON();

data._isBuffer = true;
window.ReactNativeWebView?.postMessage(
JSON.stringify({ ...data, origin: window.location.href }),
);
stringifiedData = JSON.stringify({
...data,
origin: window.location.href,
});
} else {
if (chunk.data) {
chunk.data.toNative = true;
}

window.ReactNativeWebView?.postMessage(
JSON.stringify({ ...chunk, origin: window.location.href }),
stringifiedData = JSON.stringify({
...chunk,
origin: window.location.href,
});
}

if (stringifiedData.length > MAX_MESSAGE_LENGTH) {
return cb(
new Error(
`Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`,
),
);
}

window.ReactNativeWebView?.postMessage(stringifiedData);
} catch (err) {
return cb(new Error('MobilePortStream - disconnected'));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure
import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream'; // Adjust the import based on your project structure
import { METHODS_TO_REDIRECT } from '../../config';
import { MAX_MESSAGE_LENGTH, METHODS_TO_REDIRECT } from '../../config';
import * as loggerModule from '../../utils/logger'; // Adjust the import based on your project structure
import { write } from './write'; // Adjust the import based on your project structure
import { Ethereum } from '../Ethereum'; // Adjust the import based on your project structure
import { extractMethod } from './extractMethod';
import { write } from './write'; // Adjust the import based on your project structure

jest.mock('./extractMethod');
jest.mock('../Ethereum');
Expand Down Expand Up @@ -162,11 +162,22 @@ describe('write function', () => {
mockIsMobileWeb.mockReturnValue(false);
mockIsSecure.mockReturnValue(true);
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsMetaMaskInstalled.mockReturnValue(true);
mockGetKeyInfo.mockReturnValue({ ecies: { public: 'test_public_key' } });
mockHasDeeplinkProtocol.mockReturnValue(false);
});

it('should redirect if method exists in METHODS_TO_REDIRECT', async () => {
mockExtractMethod.mockReturnValue({
method: Object.keys(METHODS_TO_REDIRECT)[0],
data: {
data: {
jsonrpc: '2.0',
method: Object.keys(METHODS_TO_REDIRECT)[0],
params: [],
},
},
triggeredInstaller: false,
});

await write(
Expand Down Expand Up @@ -239,4 +250,71 @@ describe('write function', () => {
expect(spyLogger).toHaveBeenCalled();
});
});

describe('Message Size Validation', () => {
it('should reject messages exceeding MAX_MESSAGE_LENGTH', async () => {
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsReady.mockReturnValue(true);
mockIsConnected.mockReturnValue(true);

// Mock extractMethod to return large data
const largeData = {
jsonrpc: '2.0',
method: 'eth_call',
params: ['x'.repeat(MAX_MESSAGE_LENGTH + 1)],
};

mockExtractMethod.mockReturnValue({
method: 'eth_call',
data: {
data: largeData,
},
});

await write(
mockRemoteCommunicationPostMessageStream,
{ jsonrpc: '2.0', method: 'eth_call' },
'utf8',
callback,
);

// Don't test for exact error message, just verify it contains the key parts
expect(callback).toHaveBeenCalledWith(
expect.objectContaining({
message: expect.stringMatching(
/Message size \d+ exceeds maximum allowed size of \d+ bytes/u,
),
}),
);
expect(mockSendMessage).not.toHaveBeenCalled();
});

it('should accept messages within MAX_MESSAGE_LENGTH', async () => {
mockGetChannelId.mockReturnValue('some_channel_id');
mockIsReady.mockReturnValue(true);
mockIsConnected.mockReturnValue(true);

// Mock extractMethod to return valid-sized data
mockExtractMethod.mockReturnValue({
method: 'eth_call',
data: {
data: {
jsonrpc: '2.0',
method: 'eth_call',
params: ['x'.repeat(100)],
},
},
});

await write(
mockRemoteCommunicationPostMessageStream,
{ jsonrpc: '2.0', method: 'eth_call' },
'utf8',
callback,
);

expect(callback).toHaveBeenCalledWith();
expect(mockSendMessage).toHaveBeenCalled();
});
});
});
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { RemoteCommunicationPostMessageStream } from '../../PostMessageStream/RemoteCommunicationPostMessageStream';
import { METHODS_TO_REDIRECT, RPC_METHODS } from '../../config';
import {
METHODS_TO_REDIRECT,
RPC_METHODS,
MAX_MESSAGE_LENGTH,
} from '../../config';
import {
METAMASK_CONNECT_BASE_URL,
METAMASK_DEEPLINK_BASE,
Expand Down Expand Up @@ -57,11 +61,17 @@ export async function write(
deeplinkProtocolAvailable && mobileWeb && authorized;

try {
console.warn(
`[RCPMS: _write()] triggeredInstaller=${triggeredInstaller} activeDeeplinkProtocol=${activeDeeplinkProtocol}`,
);

if (!triggeredInstaller) {
// Check message size before sending
const stringifiedData = JSON.stringify(data?.data);
if (stringifiedData.length > MAX_MESSAGE_LENGTH) {
return callback(
new Error(
`Message size ${stringifiedData.length} exceeds maximum allowed size of ${MAX_MESSAGE_LENGTH} bytes`,
),
);
}

// The only reason not to send via network is because the rpc call will be sent in the deeplink
instance.state.remote
?.sendMessage(data?.data)
Expand Down
Loading