Skip to content

Commit

Permalink
feat: adds certificate validation to ws proxy
Browse files Browse the repository at this point in the history
refs #170
  • Loading branch information
stalniy committed Feb 4, 2025
1 parent 29b5b25 commit 273a786
Show file tree
Hide file tree
Showing 7 changed files with 246 additions and 95 deletions.
4 changes: 1 addition & 3 deletions apps/provider-proxy/src/ClientSocketStats.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ export class ClientWebSocketStats {
private closedOn?: Date;
private usage: WebSocketUsage = "Unknown";

private usageStats: {
[key in WebSocketUsage]: { count: number; data: number };
} = {
private usageStats: Record<WebSocketUsage, { count: number; data: number }> = {
StreamLogs: { count: 0, data: 0 },
StreamEvents: { count: 0, data: 0 },
Shell: { count: 0, data: 0 },
Expand Down
3 changes: 2 additions & 1 deletion apps/provider-proxy/src/app.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ export async function startAppServer(port: number): Promise<AppServer> {
}
});
});
const wss = WebsocketServer.from(httpAppServer, container.wsLogger).listen();
const wss = new WebsocketServer(httpAppServer, container.certificateValidator, container.wsLogger);
wss.listen();
});
}

Expand Down
210 changes: 154 additions & 56 deletions apps/provider-proxy/src/services/WebsocketServer.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
import { LoggerService } from "@akashnetwork/logging";
import { SupportedChainNetworks } from "@akashnetwork/net";
import http from "http";
import https from "https";
import { TLSSocket } from "tls";
import { v4 as uuidv4 } from "uuid";
import WebSocket from "ws";

import { ClientWebSocketStats, WebSocketUsage } from "../ClientSocketStats";
import { container } from "../container";
import { CertificateValidator } from "./CertificateValidator";

// @see https://www.rfc-editor.org/rfc/rfc6455.html#page-46
const WS_ERRORS = {
VIOLATED_POLICY: 1008
};

export class WebsocketServer {
private readonly openProviderSockets: Record<string, WebSocket> = {};
private readonly openProviderSockets: Record<
string,
{
ws: WebSocket;
isVerified: boolean;
}
> = {};
private wss?: WebSocket.Server;

static from(appServer: http.Server, logger?: LoggerService): WebsocketServer {
return new WebsocketServer(appServer, logger);
}

constructor(
private readonly appServer: http.Server,
private readonly certificateValidator: CertificateValidator,
private readonly logger?: LoggerService
) {}

Expand Down Expand Up @@ -59,7 +70,7 @@ export class WebsocketServer {
stats.close();

if (id in this.openProviderSockets) {
this.openProviderSockets[id].terminate();
this.openProviderSockets[id].ws.terminate();
delete this.openProviderSockets[id];
} else {
wsLogger?.debug("Corresponding provider socket not found");
Expand Down Expand Up @@ -103,21 +114,24 @@ export class WebsocketServer {
private proxyMessageToProvider(message: WsMessage, ws: WebSocket, stats: ClientWebSocketStats, logger?: LoggerService): void {
const url = message.url.replace("https://", "wss://");

let providerWs = this.openProviderSockets[stats.id];
if (!providerWs || providerWs?.url !== url) {
providerWs?.terminate();
let socketDetails = this.openProviderSockets[stats.id];
if (
!socketDetails ||
socketDetails.ws.url !== url ||
socketDetails.ws.readyState === WebSocket.CLOSED ||
socketDetails.ws.readyState === WebSocket.CLOSING
) {
socketDetails?.ws.terminate();
logger?.info(`Initializing new provider websocket connection: ${url}`);
providerWs = new WebSocket(url, {
socketDetails = this.createProviderSocket(url, {
wsId: stats.id,
cert: message.certPem,
key: message.keyPem,
agent: new https.Agent({
// create new Agent to ensure TLS resumption is not used for websockets
sessionTimeout: 0,
rejectUnauthorized: false
})
chainNetwork: message.chainNetwork,
providerAddress: message.providerAddress,
logger
});
linkSockets(providerWs, ws, stats, logger);
this.openProviderSockets[stats.id] = providerWs;
this.linkSockets(socketDetails.ws, ws, stats, logger);
}

if (!message.data) {
Expand All @@ -133,56 +147,138 @@ export class WebsocketServer {
});
};

if (providerWs.readyState === WebSocket.OPEN) {
providerWs.send(data, callback);
if (socketDetails.ws.readyState === WebSocket.OPEN && socketDetails.isVerified) {
socketDetails.ws.send(data, callback);
} else {
providerWs.once("open", () => providerWs.send(data, callback));
socketDetails.ws.once("verified", () => socketDetails.ws.send(data, callback));
}
}
}

function linkSockets(providerWs: WebSocket, ws: WebSocket, stats: ClientWebSocketStats, logger?: LoggerService): void {
providerWs.on("open", function open() {
logger?.info(`Connected to provider websocket: ${providerWs.url}`);
});
private createProviderSocket(url: string, options: CreateProviderSocketOptions) {
const pws = new WebSocket(url, {
key: options.key,
cert: options.cert,
agent: new https.Agent({
// do not use TLS session resumption for websocket
sessionTimeout: 0,
rejectUnauthorized: false
})
});

this.openProviderSockets[options.wsId] = { ws: pws, isVerified: false };

pws.on("upgrade", response => {
// Using sync function here to ensure that no data is processed by event handlers until SSL cert validation is finished
const certificate = response.socket && response.socket instanceof TLSSocket ? response.socket.getPeerX509Certificate() : undefined;

if (!certificate) {
// call destroy manually because at this time websocket is not connected with the actual socket
response.socket.destroy();
pws.close(1008, `Server ${url} didn't provide SSL certificate`);
return;
}

if (!options.chainNetwork || !options.providerAddress) {
// temporary certificate validation is optional
pws.once("open", () => {
this.openProviderSockets[options.wsId].isVerified = true;
pws.emit("verified");
});
return;
}

// stop reading data from socket until we validate certificate
response.socket.pause();
this.certificateValidator
.validate(certificate, options.chainNetwork, options.providerAddress)
.catch(error => {
options.logger?.error({
message: "Could not validate SSL certificate",
error,
chainNetwork: options.chainNetwork,
providerAddress: options.providerAddress
});
return {
ok: false,
code: "serverError"
} as const;
})
.then(result => {
if (result.ok === false) {
// ensure that no messages are proxied from untrusted websocket
pws.removeAllListeners("message");

const reason = result.code === "serverError" ? "Could not validate SSL certificate" : `Invalid SSL certificate: ${result.code}`;
pws.close(WS_ERRORS.VIOLATED_POLICY, reason);
} else {
this.openProviderSockets[options.wsId].isVerified = true;
pws.emit("verified");
}

providerWs.on("message", socketMessage => {
if (!socketMessage) return;
const data = JSON.stringify({
type: "websocket",
message: socketMessage
// need to call this in error and success case, otherwise listeners will not be notified about close event
response.socket.resume();
});
});
stats.logDataTransfer(Buffer.from(data).length);
ws.send(data);
});

providerWs.on("error", error => {
logger?.error({
message: "Websocket received an error",
error
return this.openProviderSockets[options.wsId];
}

private linkSockets(providerWs: WebSocket, ws: WebSocket, stats: ClientWebSocketStats, logger?: LoggerService): void {
providerWs.on("open", function open() {
logger?.info(`Connected to provider websocket: ${providerWs.url}`);
});
const data = JSON.stringify({
type: "websocket",
message: error,
error

providerWs.on("message", socketMessage => {
if (!socketMessage) return;
const data = JSON.stringify({
type: "websocket",
message: socketMessage
});
stats.logDataTransfer(Buffer.from(data).length);
ws.send(data);
});
stats.logDataTransfer(Buffer.from(data).length);
ws.send(data);
});

providerWs.on("close", event => {
logger?.info({
message: "Provider websocket was closed",
event

providerWs.on("error", error => {
logger?.error({
message: "Websocket received an error",
error
});
const data = JSON.stringify({
type: "websocket",
message: error,
error
});
stats.logDataTransfer(Buffer.from(data).length);
ws.send(data);
});
const data = JSON.stringify({
type: "websocket",
message: "",
closed: true

providerWs.on("close", (code, reason) => {
delete this.openProviderSockets[stats.id];
logger?.info({
message: "Provider websocket was closed",
code,
reason
});
const data = JSON.stringify({
type: "websocket",
message: "",
closed: true,
code,
reason
});
stats.logDataTransfer(Buffer.from(data).length);
ws.send(data);
});
stats.logDataTransfer(Buffer.from(data).length);
ws.send(data);
});
}
}

interface CreateProviderSocketOptions {
wsId: string;
cert: string;
key: string;
chainNetwork?: SupportedChainNetworks;
providerAddress?: string;
logger?: LoggerService;
}

function getWebSocketUsage(message: any): WebSocketUsage {

Check warning on line 284 in apps/provider-proxy/src/services/WebsocketServer.ts

View workflow job for this annotation

GitHub Actions / build

Unexpected any. Specify a different type
Expand Down Expand Up @@ -222,6 +318,8 @@ interface WsMessage {
url: string;
certPem?: string;
keyPem?: string;
chainNetwork?: SupportedChainNetworks;
providerAddress?: string;
/**
* Currently it's used only for service shell communication
* and stores only buffered representation of string in char codes
Expand Down
29 changes: 11 additions & 18 deletions apps/provider-proxy/test/functional/provider-proxy-http.spec.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import { SupportedChainNetworks } from "@akashnetwork/net";
import { bech32 } from "bech32";
import { setTimeout } from "timers/promises";

import { createX509CertPair } from "../seeders/createX509CertPair";
import { mockOnChainCertificates, stopChainAPIServer } from "../setup/chainApiServer";
import { generateBech32, startChainApiServer, stopChainAPIServer } from "../setup/chainApiServer";
import { startProviderServer, stopProviderServer } from "../setup/providerServer";
import { request } from "../setup/proxyServer";
import { startServer, stopServer } from "../setup/proxyServer";
Expand All @@ -29,7 +28,7 @@ describe("Provider HTTP proxy", () => {
const providerAddress = generateBech32();
const validCertPair = createX509CertPair({ commonName: providerAddress, validFrom: new Date(Date.now() - ONE_HOUR) });

await mockOnChainCertificates([validCertPair.cert]);
await startChainApiServer([validCertPair.cert]);
const providerUrl = await startProviderServer({ certPair: validCertPair });

const response = await request("/", {
Expand All @@ -53,7 +52,7 @@ describe("Provider HTTP proxy", () => {
const providerAddress = generateBech32();
const validCertPair = createX509CertPair({ commonName: providerAddress, validFrom: new Date(Date.now() - ONE_HOUR) });

await mockOnChainCertificates([validCertPair.cert]);
await startChainApiServer([validCertPair.cert]);
const providerUrl = await startProviderServer({ certPair: validCertPair });

const response = await request("/", {
Expand All @@ -78,7 +77,7 @@ describe("Provider HTTP proxy", () => {
const providerAddress = generateBech32();
const validCertPair = createX509CertPair({ commonName: providerAddress, validFrom: new Date(Date.now() - ONE_HOUR) });

const chainServer = await mockOnChainCertificates([validCertPair.cert]);
const chainServer = await startChainApiServer([validCertPair.cert]);
const providerUrl = await startProviderServer({ certPair: validCertPair });

let response = await request("/", {
Expand Down Expand Up @@ -112,7 +111,7 @@ describe("Provider HTTP proxy", () => {
const providerAddress = generateBech32();
const validCertPair = createX509CertPair({ commonName: providerAddress, validFrom: new Date(Date.now() - ONE_HOUR) });

await mockOnChainCertificates([
await startChainApiServer([
createX509CertPair({
commonName: providerAddress,
validFrom: new Date(Date.now() + ONE_HOUR),
Expand Down Expand Up @@ -152,7 +151,7 @@ describe("Provider HTTP proxy", () => {
validTo: new Date(Date.now() - ONE_HOUR)
});

await mockOnChainCertificates([createX509CertPair({ commonName: providerAddress, validFrom: new Date(Date.now() + ONE_HOUR) }).cert, validCertPair.cert]);
await startChainApiServer([createX509CertPair({ commonName: providerAddress, validFrom: new Date(Date.now() + ONE_HOUR) }).cert, validCertPair.cert]);
const providerUrl = await startProviderServer({ certPair: validCertPair });

const requestProvider = () =>
Expand Down Expand Up @@ -199,7 +198,7 @@ describe("Provider HTTP proxy", () => {
})
});
await setTimeout(200);
await mockOnChainCertificates([validCertPair.cert]);
await startChainApiServer([validCertPair.cert]);
const response = await responsePromise;

expect(response.status).toBe(200);
Expand All @@ -214,8 +213,8 @@ describe("Provider HTTP proxy", () => {
});

const providerUrl = await startProviderServer({ certPair: validCertPair });
await mockOnChainCertificates([validCertPair.cert], {
respondWithOnceWith: 502
await startChainApiServer([validCertPair.cert], {
respondOnceWith: 502
});

const response = await request("/", {
Expand All @@ -240,7 +239,7 @@ describe("Provider HTTP proxy", () => {
});

const providerUrl = await startProviderServer({ certPair: validCertPair });
await mockOnChainCertificates([validCertPair.cert]);
await startChainApiServer([validCertPair.cert]);

const response = await request("/", {
method: "POST",
Expand All @@ -264,7 +263,7 @@ describe("Provider HTTP proxy", () => {
});

const providerUrl = await startProviderServer({ certPair: validCertPair });
await mockOnChainCertificates([validCertPair.cert]);
await startChainApiServer([validCertPair.cert]);

const response = await request("/", {
method: "POST",
Expand All @@ -281,10 +280,4 @@ describe("Provider HTTP proxy", () => {
const body = await response.text();
expect(body).toBe("Fast");
});

let index = 0;
function generateBech32() {
const words = bech32.toWords(Buffer.from("foobar2", "utf8"));
return bech32.encode(`test${++index}`, words);
}
});
Loading

0 comments on commit 273a786

Please sign in to comment.