Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add websocket component stack #710

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion config/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ import {
getOraCompressionIcav2PipelineTableStackProps,
} from './stacks/oraCompressionPipelineManager';
import { getOraDecompressionManagerStackProps } from './stacks/oraDecompressionPipelineManager';
import { getWebSocketApiStackProps } from './stacks/clientWebsocketApi';
import { getPgDDProps } from './stacks/pgDD';

interface EnvironmentConfig {
name: string;
region: string;
Expand Down Expand Up @@ -131,6 +131,7 @@ export const getEnvironmentConfig = (stage: AppStage): EnvironmentConfig | null
workflowManagerStackProps: getWorkflowManagerStackProps(stage),
stackyMcStackFaceProps: getGlueStackProps(stage),
fmAnnotatorProps: getFmAnnotatorProps(),
websocketApiStackProps: getWebSocketApiStackProps(stage),
pgDDProps: getPgDDProps(stage),
},
};
Expand Down
20 changes: 20 additions & 0 deletions config/stacks/clientWebsocketApi.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { WebSocketApiStackProps } from '../../lib/workload/stateless/stacks/client-websocket-conn/deploy';
import { AppStage, vpcProps, region, cognitoUserPoolIdParameterName } from '../constants';

export const getWebSocketApiStackProps = (stage: AppStage): WebSocketApiStackProps => {
return {
connectionTableName: 'OrcaBusClientWebsocketApiConnectionTable',
messageHistoryTableName: 'OrcaBusClientWebsocketApiMessageHistoryTable',
websocketApigatewayName: `OrcaBusClientWebsocketApi${stage}`,
lambdaSecurityGroupName: 'OrcaBusClientWebsocketApiSecurityGroup',
connectionFunctionName: 'websocketApiConnect',
disconnectFunctionName: 'websocketApiDisconnect',
messageFunctionName: 'websocketApiMessage',

vpcProps: vpcProps,
websocketApiEndpointParameterName: `/orcabus/client-websocket-api-endpoint`,
websocketStageName: stage,
cognitoRegion: region,
cognitoUserPoolIdParameterName: cognitoUserPoolIdParameterName,
};
};
30 changes: 30 additions & 0 deletions lib/workload/stateless/stacks/client-websocket-conn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# WebSocket API Stack

A serverless WebSocket API implementation using AWS CDK, API Gateway WebSocket APIs, Lambda, and DynamoDB for real-time communication.

## Architecture

![Architecture Diagram](./websocket-api-arch.png)

### Components

- **API Gateway WebSocket API**: Handles WebSocket connections
- **Lambda Functions**: Process WebSocket events
- **DynamoDB**: Stores connection information

## Features

- Real-time bidirectional communication
- Connection management
- Message broadcasting
- Secure VPC deployment
- Automatic scaling
- Connection cleanup

## Prerequisites

- AWS CDK CLI
- Node.js & npm
- Python 3.12
- AWS Account and configured credentials
- VPC with private subnets
210 changes: 210 additions & 0 deletions lib/workload/stateless/stacks/client-websocket-conn/deploy/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import { Stack, RemovalPolicy, StackProps, Duration } from 'aws-cdk-lib';
import { Table, AttributeType } from 'aws-cdk-lib/aws-dynamodb';
import { Vpc, SecurityGroup, VpcLookupOptions, IVpc, ISecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { WebSocketApi, WebSocketStage } from 'aws-cdk-lib/aws-apigatewayv2';
import { WebSocketLambdaIntegration } from 'aws-cdk-lib/aws-apigatewayv2-integrations';
import { PolicyStatement } from 'aws-cdk-lib/aws-iam';
import { PythonFunction, PythonLayerVersion } from '@aws-cdk/aws-lambda-python-alpha';
import { Runtime, Architecture, LayerVersion } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs';
import * as path from 'path';
import { WebSocketLambdaAuthorizer } from 'aws-cdk-lib/aws-apigatewayv2-authorizers';
import { StringParameter } from 'aws-cdk-lib/aws-ssm';

export interface WebSocketApiStackProps extends StackProps {
// DynamoDB and Lambda configuration
connectionTableName: string;
websocketApigatewayName: string;
connectionFunctionName: string;
disconnectFunctionName: string;
messageFunctionName: string;
messageHistoryTableName: string;

// Parameter name for the WebSocket API endpoint
websocketApiEndpointParameterName: string;
websocketStageName: string;

// Cognito configuration for the authorizer
cognitoRegion: string;
cognitoUserPoolIdParameterName: string;
lambdaSecurityGroupName: string;
vpcProps: VpcLookupOptions;
}

export class WebSocketApiStack extends Stack {
private readonly lambdaRuntimePythonVersion = Runtime.PYTHON_3_12;
private readonly props: WebSocketApiStackProps;
private vpc: IVpc;
private lambdaSG: ISecurityGroup;

constructor(scope: Construct, id: string, props: WebSocketApiStackProps) {
super(scope, id, props);

this.props = props;

this.vpc = Vpc.fromLookup(this, 'MainVpc', props.vpcProps);
this.lambdaSG = SecurityGroup.fromLookupByName(
this,
'LambdaSecurityGroup',
props.lambdaSecurityGroupName,
this.vpc
);

// DynamoDB Table for storing connection IDs
const connectionTable = new Table(this, 'WebSocketApiConnections', {
tableName: props.connectionTableName,
partitionKey: {
name: 'connectionId',
type: AttributeType.STRING,
},
removalPolicy: RemovalPolicy.DESTROY, // For demo purposes, not recommended for production
});

//DynamoDB Table for message history
const messageHistoryTable = new Table(this, 'WebSocketApiMessageHistory', {
tableName: props.messageHistoryTableName,
partitionKey: {
name: 'messageId',
type: AttributeType.STRING,
},
timeToLiveAttribute: 'ttl', // Enable TTL
removalPolicy: RemovalPolicy.DESTROY,
});

// Lambda function for $connect
const connectHandler = this.createPythonFunction(props.connectionFunctionName, {
index: 'connect.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
});

// Lambda function for $disconnect
const disconnectHandler = this.createPythonFunction(props.disconnectFunctionName, {
index: 'disconnect.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
});

// Lambda function for $default (broadcast messages)
const messageHandler = this.createPythonFunction(props.messageFunctionName, {
index: 'message.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
});

// build layer from deps
const authLayer = new PythonLayerVersion(this, 'BaseLayer', {
entry: path.join(__dirname, '../deps'),
compatibleRuntimes: [this.lambdaRuntimePythonVersion],
compatibleArchitectures: [Architecture.ARM_64],
});

const userPoolId = StringParameter.fromStringParameterName(
this,
'CognitoUserPoolIdParameter',
props.cognitoUserPoolIdParameterName
).stringValue;

// authorizer function to check the client token based on the JWT token
const connectAuthorizer = this.createPythonFunction('AuthHandler', {
index: 'auth.py',
handler: 'lambda_handler',
timeout: Duration.minutes(2),
environment: {
COGNITO_REGION: props.cognitoRegion,
COGNITO_USER_POOL_ID: userPoolId,
},
layers: [authLayer],
});

// Grant permissions to Lambda functions
connectionTable.grantReadWriteData(connectHandler);
connectionTable.grantReadWriteData(disconnectHandler);
connectionTable.grantReadWriteData(messageHandler);
// messageHistoryTable.grantReadData(connectHandler);
messageHistoryTable.grantReadWriteData(messageHandler);

// WebSocket API
const api = new WebSocketApi(this, props.websocketApigatewayName, {
apiName: props.websocketApigatewayName,
description: 'WebSocket API for the app notifications',
connectRouteOptions: {
integration: new WebSocketLambdaIntegration('ConnectIntegration', connectHandler),
// FIXME: uncomment this when auth is implemented
// authorizer: new WebSocketLambdaAuthorizer(
// "ConnectAuthorizer",
// connectAuthorizer,
// {
// authorizerName: "ConnectAuthorizer",
// identitySource: [
// "route.request.header.Authorization",
// "route.request.querystring.Authorization",
// ],
// }
// ),
},
disconnectRouteOptions: {
integration: new WebSocketLambdaIntegration('DisconnectIntegration', disconnectHandler),
},
defaultRouteOptions: {
integration: new WebSocketLambdaIntegration('DefaultIntegration', messageHandler),
},
});

api.addRoute('sendMessage', {
integration: new WebSocketLambdaIntegration('SendMessageIntegration', messageHandler),
});

// Deploy WebSocket API to a stage
const stage = new WebSocketStage(this, 'WebSocketStage', {
webSocketApi: api,
stageName: props.websocketStageName,
autoDeploy: true,
});

// Create the WebSocket API endpoint URL
const webSocketApiEndpoint = `${api.apiEndpoint}/${stage.stageName}`;

// save this url into the parameter store for the client to use
new StringParameter(this, 'WebSocketApiEndpoint', {
parameterName: props.websocketApiEndpointParameterName,
description: 'The endpoint URL for the WebSocket API',
stringValue: webSocketApiEndpoint,
});

const commonEnvironment = {
CONNECTION_TABLE: connectionTable.tableName,
MESSAGE_HISTORY_TABLE: messageHistoryTable.tableName,
WEBSOCKET_API_ENDPOINT: webSocketApiEndpoint,
};

// Add environment variables individually
for (const [key, value] of Object.entries(commonEnvironment)) {
connectHandler.addEnvironment(key, value);
disconnectHandler.addEnvironment(key, value);
messageHandler.addEnvironment(key, value);
}

// Grant permissions to the message handler
messageHandler.addToRolePolicy(
new PolicyStatement({
actions: ['execute-api:ManageConnections'],
resources: [
`arn:aws:execute-api:${this.region}:${this.account}:${api.apiId}/dev/POST/@connections/*`,
],
})
);
}

private createPythonFunction(name: string, props: object): PythonFunction {
return new PythonFunction(this, name, {
entry: path.join(__dirname, '../lambda'),
runtime: this.lambdaRuntimePythonVersion,
securityGroups: [this.lambdaSG],
vpc: this.vpc,
vpcSubnets: { subnets: this.vpc.privateSubnets },
architecture: Architecture.ARM_64,
...props,
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
PyJWT==2.8.0
requests==2.31.0
71 changes: 71 additions & 0 deletions lib/workload/stateless/stacks/client-websocket-conn/lambda/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import logging
import jwt
import requests
from typing import Dict, Any

logger = logging.getLogger()
logger.setLevel(logging.INFO)

def generate_policy(principal_id, effect, resource):
return {
'principalId': principal_id,
'policyDocument': {
'Version': '2012-10-17',
'Statement': [{
'Action': 'execute-api:Invoke',
'Effect': effect,
'Resource': resource
}]
}
}

def get_public_key():
"""Get Cognito public key for JWT verification"""
url = f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}/.well-known/jwks.json'
try:
response = requests.get(url)
return response.json()['keys'][0] # Get the first key
except Exception as e:
logger.error(f"Error getting public key: {str(e)}")
raise

def lambda_handler(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
"""Simple Lambda authorizer for WebSocket"""
logger.info("WebSocket authorization request")
# Get environment variables
assert 'COGNITO_USER_POOL_ID' in os.environ, "COGNITO_USER_POOL_ID is not set"
assert 'COGNITO_REGION' in os.environ, "COGNITO_REGION is not set"

COGNITO_USER_POOL_ID = os.environ['COGNITO_USER_POOL_ID']
COGNITO_REGION = os.environ.get('COGNITO_REGION', 'ap-southeast-2')
try:
# Get token from headers
# Check both header and querystring
auth_token = None
if event.get('headers', {}).get('Authorization'):
auth_token = event['headers']['Authorization']
elif event.get('queryStringParameters', {}).get('Authorization'):
auth_token = event['queryStringParameters']['Authorization']

if not auth_token:
return generate_policy('user', 'Deny', event['methodArn'])


# Get public key
public_key = get_public_key()

# Verify token
decoded = jwt.decode(
auth_token,
public_key,
algorithms=['RS256'],
issuer=f'https://cognito-idp.{COGNITO_REGION}.amazonaws.com/{COGNITO_USER_POOL_ID}'
)

# Generate allow policy
return generate_policy(decoded['sub'], 'Allow', event['methodArn'])
except Exception as e:
logger.error(f"Authorization failed: {str(e)}")
# Return deny policy
return generate_policy('unauthorized', 'Deny', event['methodArn'])
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import boto3
import os

def lambda_handler(event, context):
# Get table names from environment variables
assert 'CONNECTION_TABLE' in os.environ, "CONNECTION_TABLE environment variable is not set"
connections_table_name = os.environ['CONNECTION_TABLE']

dynamodb = boto3.resource('dynamodb')
connections_table = dynamodb.Table(connections_table_name)

connection_id = event['requestContext']['connectionId']

try:
# Store connection
connections_table.put_item(
Item={'connectionId': connection_id}
)
return {'statusCode': 200}
except Exception as e:
print(f"Error storing connection: {e}")
return {'statusCode': 500, 'body': str(e)}
Loading