diff --git a/config/config.ts b/config/config.ts index f71c5a715..5ac42533d 100644 --- a/config/config.ts +++ b/config/config.ts @@ -63,8 +63,8 @@ import { getOraCompressionIcav2PipelineTableStackProps, } from './stacks/oraCompressionPipelineManager'; import { getOraDecompressionManagerStackProps } from './stacks/oraDecompressionPipelineManager'; +import { getWebSocketApiStackProps } from './stacks/clientWebsocketApi'; import { getPgDDProps } from './stacks/pgDD'; - interface EnvironmentConfig { name: string; region: string; @@ -131,6 +131,7 @@ export const getEnvironmentConfig = (stage: AppStage): EnvironmentConfig | null workflowManagerStackProps: getWorkflowManagerStackProps(stage), stackyMcStackFaceProps: getGlueStackProps(stage), fmAnnotatorProps: getFmAnnotatorProps(), + websocketApiStackProps: getWebSocketApiStackProps(stage), pgDDProps: getPgDDProps(stage), }, }; diff --git a/config/stacks/clientWebsocketApi.ts b/config/stacks/clientWebsocketApi.ts new file mode 100644 index 000000000..33b55bb1b --- /dev/null +++ b/config/stacks/clientWebsocketApi.ts @@ -0,0 +1,20 @@ +import { WebSocketApiStackProps } from '../../lib/workload/stateless/stacks/client-websocket-conn/deploy'; +import { AppStage, vpcProps, region, cognitoUserPoolIdParameterName } from '../constants'; + +export const getWebSocketApiStackProps = (stage: AppStage): WebSocketApiStackProps => { + return { + connectionTableName: 'OrcaBusClientWebsocketApiConnectionTable', + messageHistoryTableName: 'OrcaBusClientWebsocketApiMessageHistoryTable', + websocketApigatewayName: `OrcaBusClientWebsocketApi${stage}`, + lambdaSecurityGroupName: 'OrcaBusClientWebsocketApiSecurityGroup', + connectionFunctionName: 'websocketApiConnect', + disconnectFunctionName: 'websocketApiDisconnect', + messageFunctionName: 'websocketApiMessage', + + vpcProps: vpcProps, + websocketApiEndpointParameterName: `/orcabus/client-websocket-api-endpoint`, + websocketStageName: stage, + cognitoRegion: region, + cognitoUserPoolIdParameterName: cognitoUserPoolIdParameterName, + }; +}; diff --git a/lib/workload/stateless/stacks/client-websocket-conn/README.md b/lib/workload/stateless/stacks/client-websocket-conn/README.md new file mode 100644 index 000000000..0b6ec31a3 --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/README.md @@ -0,0 +1,30 @@ +# WebSocket API Stack + +A serverless WebSocket API implementation using AWS CDK, API Gateway WebSocket APIs, Lambda, and DynamoDB for real-time communication. + +## Architecture + +![Architecture Diagram](./websocket-api-arch.png) + +### Components + +- **API Gateway WebSocket API**: Handles WebSocket connections +- **Lambda Functions**: Process WebSocket events +- **DynamoDB**: Stores connection information + +## Features + +- Real-time bidirectional communication +- Connection management +- Message broadcasting +- Secure VPC deployment +- Automatic scaling +- Connection cleanup + +## Prerequisites + +- AWS CDK CLI +- Node.js & npm +- Python 3.12 +- AWS Account and configured credentials +- VPC with private subnets diff --git a/lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts b/lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts new file mode 100644 index 000000000..439e0a816 --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts @@ -0,0 +1,210 @@ +import { Stack, RemovalPolicy, StackProps, Duration } from 'aws-cdk-lib'; +import { Table, AttributeType } from 'aws-cdk-lib/aws-dynamodb'; +import { Vpc, SecurityGroup, VpcLookupOptions, IVpc, ISecurityGroup } from 'aws-cdk-lib/aws-ec2'; +import { WebSocketApi, WebSocketStage } from 'aws-cdk-lib/aws-apigatewayv2'; +import { WebSocketLambdaIntegration } from 'aws-cdk-lib/aws-apigatewayv2-integrations'; +import { PolicyStatement } from 'aws-cdk-lib/aws-iam'; +import { PythonFunction, PythonLayerVersion } from '@aws-cdk/aws-lambda-python-alpha'; +import { Runtime, Architecture, LayerVersion } from 'aws-cdk-lib/aws-lambda'; +import { Construct } from 'constructs'; +import * as path from 'path'; +import { WebSocketLambdaAuthorizer } from 'aws-cdk-lib/aws-apigatewayv2-authorizers'; +import { StringParameter } from 'aws-cdk-lib/aws-ssm'; + +export interface WebSocketApiStackProps extends StackProps { + // DynamoDB and Lambda configuration + connectionTableName: string; + websocketApigatewayName: string; + connectionFunctionName: string; + disconnectFunctionName: string; + messageFunctionName: string; + messageHistoryTableName: string; + + // Parameter name for the WebSocket API endpoint + websocketApiEndpointParameterName: string; + websocketStageName: string; + + // Cognito configuration for the authorizer + cognitoRegion: string; + cognitoUserPoolIdParameterName: string; + lambdaSecurityGroupName: string; + vpcProps: VpcLookupOptions; +} + +export class WebSocketApiStack extends Stack { + private readonly lambdaRuntimePythonVersion = Runtime.PYTHON_3_12; + private readonly props: WebSocketApiStackProps; + private vpc: IVpc; + private lambdaSG: ISecurityGroup; + + constructor(scope: Construct, id: string, props: WebSocketApiStackProps) { + super(scope, id, props); + + this.props = props; + + this.vpc = Vpc.fromLookup(this, 'MainVpc', props.vpcProps); + this.lambdaSG = SecurityGroup.fromLookupByName( + this, + 'LambdaSecurityGroup', + props.lambdaSecurityGroupName, + this.vpc + ); + + // DynamoDB Table for storing connection IDs + const connectionTable = new Table(this, 'WebSocketApiConnections', { + tableName: props.connectionTableName, + partitionKey: { + name: 'connectionId', + type: AttributeType.STRING, + }, + removalPolicy: RemovalPolicy.DESTROY, // For demo purposes, not recommended for production + }); + + //DynamoDB Table for message history + const messageHistoryTable = new Table(this, 'WebSocketApiMessageHistory', { + tableName: props.messageHistoryTableName, + partitionKey: { + name: 'messageId', + type: AttributeType.STRING, + }, + timeToLiveAttribute: 'ttl', // Enable TTL + removalPolicy: RemovalPolicy.DESTROY, + }); + + // Lambda function for $connect + const connectHandler = this.createPythonFunction(props.connectionFunctionName, { + index: 'connect.py', + handler: 'lambda_handler', + timeout: Duration.minutes(2), + }); + + // Lambda function for $disconnect + const disconnectHandler = this.createPythonFunction(props.disconnectFunctionName, { + index: 'disconnect.py', + handler: 'lambda_handler', + timeout: Duration.minutes(2), + }); + + // Lambda function for $default (broadcast messages) + const messageHandler = this.createPythonFunction(props.messageFunctionName, { + index: 'message.py', + handler: 'lambda_handler', + timeout: Duration.minutes(2), + }); + + // build layer from deps + const authLayer = new PythonLayerVersion(this, 'BaseLayer', { + entry: path.join(__dirname, '../deps'), + compatibleRuntimes: [this.lambdaRuntimePythonVersion], + compatibleArchitectures: [Architecture.ARM_64], + }); + + const userPoolId = StringParameter.fromStringParameterName( + this, + 'CognitoUserPoolIdParameter', + props.cognitoUserPoolIdParameterName + ).stringValue; + + // authorizer function to check the client token based on the JWT token + const connectAuthorizer = this.createPythonFunction('AuthHandler', { + index: 'auth.py', + handler: 'lambda_handler', + timeout: Duration.minutes(2), + environment: { + COGNITO_REGION: props.cognitoRegion, + COGNITO_USER_POOL_ID: userPoolId, + }, + layers: [authLayer], + }); + + // Grant permissions to Lambda functions + connectionTable.grantReadWriteData(connectHandler); + connectionTable.grantReadWriteData(disconnectHandler); + connectionTable.grantReadWriteData(messageHandler); + // messageHistoryTable.grantReadData(connectHandler); + messageHistoryTable.grantReadWriteData(messageHandler); + + // WebSocket API + const api = new WebSocketApi(this, props.websocketApigatewayName, { + apiName: props.websocketApigatewayName, + description: 'WebSocket API for the app notifications', + connectRouteOptions: { + integration: new WebSocketLambdaIntegration('ConnectIntegration', connectHandler), + // FIXME: uncomment this when auth is implemented + // authorizer: new WebSocketLambdaAuthorizer( + // "ConnectAuthorizer", + // connectAuthorizer, + // { + // authorizerName: "ConnectAuthorizer", + // identitySource: [ + // "route.request.header.Authorization", + // "route.request.querystring.Authorization", + // ], + // } + // ), + }, + disconnectRouteOptions: { + integration: new WebSocketLambdaIntegration('DisconnectIntegration', disconnectHandler), + }, + defaultRouteOptions: { + integration: new WebSocketLambdaIntegration('DefaultIntegration', messageHandler), + }, + }); + + api.addRoute('sendMessage', { + integration: new WebSocketLambdaIntegration('SendMessageIntegration', messageHandler), + }); + + // Deploy WebSocket API to a stage + const stage = new WebSocketStage(this, 'WebSocketStage', { + webSocketApi: api, + stageName: props.websocketStageName, + autoDeploy: true, + }); + + // Create the WebSocket API endpoint URL + const webSocketApiEndpoint = `${api.apiEndpoint}/${stage.stageName}`; + + // save this url into the parameter store for the client to use + new StringParameter(this, 'WebSocketApiEndpoint', { + parameterName: props.websocketApiEndpointParameterName, + description: 'The endpoint URL for the WebSocket API', + stringValue: webSocketApiEndpoint, + }); + + const commonEnvironment = { + CONNECTION_TABLE: connectionTable.tableName, + MESSAGE_HISTORY_TABLE: messageHistoryTable.tableName, + WEBSOCKET_API_ENDPOINT: webSocketApiEndpoint, + }; + + // Add environment variables individually + for (const [key, value] of Object.entries(commonEnvironment)) { + connectHandler.addEnvironment(key, value); + disconnectHandler.addEnvironment(key, value); + messageHandler.addEnvironment(key, value); + } + + // Grant permissions to the message handler + messageHandler.addToRolePolicy( + new PolicyStatement({ + actions: ['execute-api:ManageConnections'], + resources: [ + `arn:aws:execute-api:${this.region}:${this.account}:${api.apiId}/dev/POST/@connections/*`, + ], + }) + ); + } + + private createPythonFunction(name: string, props: object): PythonFunction { + return new PythonFunction(this, name, { + entry: path.join(__dirname, '../lambda'), + runtime: this.lambdaRuntimePythonVersion, + securityGroups: [this.lambdaSG], + vpc: this.vpc, + vpcSubnets: { subnets: this.vpc.privateSubnets }, + architecture: Architecture.ARM_64, + ...props, + }); + } +} diff --git a/lib/workload/stateless/stacks/client-websocket-conn/deps/requirements.txt b/lib/workload/stateless/stacks/client-websocket-conn/deps/requirements.txt new file mode 100644 index 000000000..f93f1f67b --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/deps/requirements.txt @@ -0,0 +1,2 @@ +PyJWT==2.8.0 +requests==2.31.0 \ No newline at end of file diff --git a/lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py b/lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py new file mode 100644 index 000000000..2a8f40fb8 --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py @@ -0,0 +1,71 @@ +import os +import logging +import jwt +import requests +from typing import Dict, Any + +logger = logging.getLogger() +logger.setLevel(logging.INFO) + +def generate_policy(principal_id, effect, resource): + return { + 'principalId': principal_id, + 'policyDocument': { + 'Version': '2012-10-17', + 'Statement': [{ + 'Action': 'execute-api:Invoke', + 'Effect': effect, + 'Resource': resource + }] + } + } + +def get_public_key(): + """Get Cognito public key for JWT verification""" + url = f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}/.well-known/jwks.json' + try: + response = requests.get(url) + return response.json()['keys'][0] # Get the first key + except Exception as e: + logger.error(f"Error getting public key: {str(e)}") + raise + +def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """Simple Lambda authorizer for WebSocket""" + logger.info("WebSocket authorization request") + # Get environment variables + assert 'COGNITO_USER_POOL_ID' in os.environ, "COGNITO_USER_POOL_ID is not set" + assert 'COGNITO_REGION' in os.environ, "COGNITO_REGION is not set" + + COGNITO_USER_POOL_ID = os.environ['COGNITO_USER_POOL_ID'] + COGNITO_REGION = os.environ.get('COGNITO_REGION', 'ap-southeast-2') + try: + # Get token from headers + # Check both header and querystring + auth_token = None + if event.get('headers', {}).get('Authorization'): + auth_token = event['headers']['Authorization'] + elif event.get('queryStringParameters', {}).get('Authorization'): + auth_token = event['queryStringParameters']['Authorization'] + + if not auth_token: + return generate_policy('user', 'Deny', event['methodArn']) + + + # Get public key + public_key = get_public_key() + + # Verify token + decoded = jwt.decode( + auth_token, + public_key, + algorithms=['RS256'], + issuer=f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}' + ) + + # Generate allow policy + return generate_policy(decoded['sub'], 'Allow', event['methodArn']) + except Exception as e: + logger.error(f"Authorization failed: {str(e)}") + # Return deny policy + return generate_policy('unauthorized', 'Deny', event['methodArn']) diff --git a/lib/workload/stateless/stacks/client-websocket-conn/lambda/connect.py b/lib/workload/stateless/stacks/client-websocket-conn/lambda/connect.py new file mode 100644 index 000000000..ee094bbd5 --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/lambda/connect.py @@ -0,0 +1,22 @@ +import boto3 +import os + +def lambda_handler(event, context): + # Get table names from environment variables + assert 'CONNECTION_TABLE' in os.environ, "CONNECTION_TABLE environment variable is not set" + connections_table_name = os.environ['CONNECTION_TABLE'] + + dynamodb = boto3.resource('dynamodb') + connections_table = dynamodb.Table(connections_table_name) + + connection_id = event['requestContext']['connectionId'] + + try: + # Store connection + connections_table.put_item( + Item={'connectionId': connection_id} + ) + return {'statusCode': 200} + except Exception as e: + print(f"Error storing connection: {e}") + return {'statusCode': 500, 'body': str(e)} diff --git a/lib/workload/stateless/stacks/client-websocket-conn/lambda/disconnect.py b/lib/workload/stateless/stacks/client-websocket-conn/lambda/disconnect.py new file mode 100644 index 000000000..c21c6cdfd --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/lambda/disconnect.py @@ -0,0 +1,19 @@ +import boto3 +import os + +def lambda_handler(event, context): + # Get table name from environment variable + assert 'CONNECTION_TABLE' in os.environ, "CONNECTION_TABLE environment variable is not set" + connections_table_name = os.environ['CONNECTION_TABLE'] + + dynamodb = boto3.resource('dynamodb') + table = dynamodb.Table(connections_table_name) + + connection_id = event['requestContext']['connectionId'] + + try: + table.delete_item(Key={'ConnectionId': connection_id}) + return {'statusCode': 200} + except Exception as e: + print(f"Error deleting connection: {e}") + return {'statusCode': 500, 'body': str(e)} \ No newline at end of file diff --git a/lib/workload/stateless/stacks/client-websocket-conn/lambda/message.py b/lib/workload/stateless/stacks/client-websocket-conn/lambda/message.py new file mode 100644 index 000000000..94d4204f5 --- /dev/null +++ b/lib/workload/stateless/stacks/client-websocket-conn/lambda/message.py @@ -0,0 +1,94 @@ +import boto3 +import json +import os +import uuid +from datetime import datetime, timedelta +import time + +def lambda_handler(event, context): + + assert os.environ['CONNECTION_TABLE'] is not None, "CONNECTION_TABLE environment variable is not set" + assert os.environ['WEBSOCKET_API_ENDPOINT'] is not None, "WEBSOCKET_API_ENDPOINT environment variable is not set" + assert os.environ['MESSAGE_HISTORY_TABLE'] is not None, "MESSAGE_HISTORY_TABLE environment variable is not set" + + # Get environment variables + connections_table_name = os.environ['CONNECTION_TABLE'] + message_history_table_name = os.environ['MESSAGE_HISTORY_TABLE'] + + # connections URL with replace wss:// header to https + websocket_endpoint = os.environ['WEBSOCKET_API_ENDPOINT'].replace('wss://', 'https://') + + dynamodb = boto3.resource('dynamodb') + connections_table = dynamodb.Table(connections_table_name) + message_table = dynamodb.Table(message_history_table_name) + # Initialize API Gateway client + apigw_client = boto3.client('apigatewaymanagementapi', + endpoint_url=websocket_endpoint) + + print(f"Received event: {event}, websocket endpoint: {websocket_endpoint}") + + try: + # Initialize response data + data = event + response_data = { + 'type': data.get('type', ''), + 'message': data.get('message', '') + } + + + # save message to dynamodb + message_id = str(uuid.uuid4()) + timestamp = datetime.now().isoformat() + ttl_time = datetime.now() + timedelta(days=2) + ttl_timestamp = int(time.mktime(ttl_time.timetuple())) + message_data = { + 'messageId': message_id, + 'data': response_data, + 'timestamp': timestamp, + 'ttl': ttl_timestamp + } + try: + message_table.put_item(Item=message_data) + except Exception as e: + print(f"Error saving message to dynamodb: {e}") + + # Broadcast to all connections + connections = connections_table.scan()['Items'] + + for connection in connections: + connection_id = connection['connectionId'] + try: + apigw_client.post_to_connection( + ConnectionId=connection_id, + Data=json.dumps(response_data) + ) + except apigw_client.exceptions.GoneException: + # Remove stale connection + connections_table.delete_item(Key={'connectionId': connection_id}) + except Exception as e: + print(f"Failed to post message to {connection_id}: {e}") + + return {'statusCode': 200} + + except json.JSONDecodeError: + return { + 'statusCode': 400, + 'body': json.dumps({'error': 'Invalid JSON in request body'}) + } + except KeyError as e: + return { + 'statusCode': 400, + 'body': json.dumps({'error': f'Missing required field: {str(e)}'}) + } + except Exception as e: + print(f"Error: {e}") + return { + 'statusCode': 500, + 'body': json.dumps({'error': 'Internal server error'}) + } + + +# test case +# curl -X POST https://.execute-api..amazonaws.com/Prod/message -H "Content-Type: application/json" -d '{"type": "test", "message": "Hello, world!"}' +# invoke lambda function from aws console, cmd: aws lambda invoke --function-name --payload '{"type": "test", "message": "Hello, world!"}' response.json +# check cloudwatch logs for response \ No newline at end of file diff --git a/lib/workload/stateless/stacks/client-websocket-conn/websocket-api-arch.png b/lib/workload/stateless/stacks/client-websocket-conn/websocket-api-arch.png new file mode 100644 index 000000000..a630e7438 Binary files /dev/null and b/lib/workload/stateless/stacks/client-websocket-conn/websocket-api-arch.png differ diff --git a/lib/workload/stateless/statelessStackCollectionClass.ts b/lib/workload/stateless/statelessStackCollectionClass.ts index d10167501..a89721676 100644 --- a/lib/workload/stateless/statelessStackCollectionClass.ts +++ b/lib/workload/stateless/statelessStackCollectionClass.ts @@ -81,6 +81,8 @@ import { } from './stacks/ora-decompression-manager/deploy'; import { PgDDStack, PgDDStackProps } from './stacks/pg-dd/deploy/stack'; +import { WebSocketApiStackProps, WebSocketApiStack } from './stacks/client-websocket-conn/deploy'; + export interface StatelessStackCollectionProps { metadataManagerStackProps: MetadataManagerStackProps; sequenceRunManagerStackProps: SequenceRunManagerStackProps; @@ -105,6 +107,7 @@ export interface StatelessStackCollectionProps { workflowManagerStackProps: WorkflowManagerStackProps; stackyMcStackFaceProps: GlueStackProps; fmAnnotatorProps: FMAnnotatorConfigurableProps; + websocketApiStackProps: WebSocketApiStackProps; pgDDProps?: PgDDStackProps; } @@ -133,6 +136,7 @@ export class StatelessStackCollection { readonly workflowManagerStack: Stack; readonly stackyMcStackFaceStack: Stack; readonly fmAnnotator: Stack; + readonly websocketApiStack: Stack; readonly pgDDStack: Stack; constructor( @@ -313,6 +317,10 @@ export class StatelessStackCollection { domainName: fileManagerStack.domainName, }); + this.websocketApiStack = new WebSocketApiStack(scope, 'WebSocketApiStack', { + ...this.createTemplateProps(env, 'WebSocketApiStack'), + ...statelessConfiguration.websocketApiStackProps, + }); if (statelessConfiguration.pgDDProps) { this.pgDDStack = new PgDDStack(scope, 'PgDDStack', { ...this.createTemplateProps(env, 'PgDDStack'),