Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: input component issues when pages opened by firefox #505

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/infrastructure/lib/api/model-management.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ export class ModelApi extends Construct {
modelLambda.addToRolePolicy(this.iamHelper.cfnStatement);
modelLambda.addToRolePolicy(this.iamHelper.stsStatement);
modelLambda.addToRolePolicy(this.iamHelper.cfnStatement);
modelLambda.addToRolePolicy(this.iamHelper.serviceQuotaStatement);
modelLambda.addToRolePolicy(this.iamHelper.sagemakerModelManagementStatement);

// API Gateway Lambda Integration to manage model
const lambdaModelIntegration = new apigw.LambdaIntegration(modelLambda, {
Expand Down
56 changes: 55 additions & 1 deletion source/infrastructure/lib/shared/iam-helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ export class IAMHelper extends Construct {
public secretStatement: PolicyStatement;
public codePipelineStatement: PolicyStatement;
public cfnStatement: PolicyStatement;
public serviceQuotaStatement: PolicyStatement;
public sagemakerModelManagementStatement: PolicyStatement;

public createPolicyStatement(actions: string[], resources: string[]) {
return new PolicyStatement({
Expand Down Expand Up @@ -79,7 +81,7 @@ export class IAMHelper extends Construct {
"sagemaker:CreateEndpoint",
"sagemaker:CreateEndpointConfig",
"sagemaker:InvokeEndpointAsync",
"sagemaker:UpdateEndpointWeightsAndCapacities",
"sagemaker:UpdateEndpointWeightsAndCapacities"
],
[`arn:${Aws.PARTITION}:sagemaker:${Aws.REGION}:${Aws.ACCOUNT_ID}:endpoint/*`],
);
Expand Down Expand Up @@ -191,5 +193,57 @@ export class IAMHelper extends Construct {
],
["*"],
);
this.sagemakerModelManagementStatement = this.createPolicyStatement(
[
"sagemaker:List*",
"sagemaker:ListEndpoints",
"sagemaker:DeleteModel",
"sagemaker:DeleteEndpoint",
"sagemaker:DescribeEndpoint",
"sagemaker:DeleteEndpointConfig",
"sagemaker:DescribeEndpointConfig",
"sagemaker:InvokeEndpoint",
"sagemaker:CreateModel",
"sagemaker:CreateEndpoint",
"sagemaker:CreateEndpointConfig",
"sagemaker:InvokeEndpointAsync",
"sagemaker:UpdateEndpointWeightsAndCapacities"
],
["*"],
);
this.serviceQuotaStatement = this.createPolicyStatement(
[
"autoscaling:DescribeAccountLimits",
"cloudformation:DescribeAccountLimits",
"cloudwatch:DescribeAlarmsForMetric",
"cloudwatch:DescribeAlarms",
"cloudwatch:GetMetricData",
"cloudwatch:GetMetricStatistics",
"dynamodb:DescribeLimits",
"elasticloadbalancing:DescribeAccountLimits",
"iam:GetAccountSummary",
"kinesis:DescribeLimits",
"organizations:DescribeAccount",
"organizations:DescribeOrganization",
"organizations:ListAWSServiceAccessForOrganization",
"rds:DescribeAccountAttributes",
"route53:GetAccountLimit",
"tag:GetTagKeys",
"tag:GetTagValues",
"servicequotas:GetAssociationForServiceQuotaTemplate",
"servicequotas:GetAWSDefaultServiceQuota",
"servicequotas:GetRequestedServiceQuotaChange",
"servicequotas:GetServiceQuota",
"servicequotas:GetServiceQuotaIncreaseRequestFromTemplate",
"servicequotas:ListAWSDefaultServiceQuotas",
"servicequotas:ListRequestedServiceQuotaChangeHistory",
"servicequotas:ListRequestedServiceQuotaChangeHistoryByQuota",
"servicequotas:ListServices",
"servicequotas:ListServiceQuotas",
"servicequotas:ListServiceQuotaIncreaseRequestsInTemplate",
"servicequotas:ListTagsForResource"
],
["*"],
);
}
}
234 changes: 141 additions & 93 deletions source/lambda/etl/sfn_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,124 +2,172 @@
import logging
import os
from datetime import datetime, timezone
from typing import Dict, List, TypedDict

import boto3
from chatbot_management import create_chatbot
from constant import ExecutionStatus, IndexType, UiStatus
from utils.parameter_utils import get_query_parameter

# Initialize AWS resources once
client = boto3.client("stepfunctions")
dynamodb = boto3.resource("dynamodb")
execution_table = dynamodb.Table(os.environ.get("EXECUTION_TABLE_NAME"))
index_table = dynamodb.Table(os.environ.get("INDEX_TABLE_NAME"))
chatbot_table = dynamodb.Table(os.environ.get("CHATBOT_TABLE_NAME"))
model_table = dynamodb.Table(os.environ.get("MODEL_TABLE_NAME"))
embedding_endpoint = os.environ.get("EMBEDDING_ENDPOINT")
index_table = dynamodb.Table(os.environ.get("INDEX_TABLE_NAME"))
sfn_arn = os.environ.get("SFN_ARN")
create_time = str(datetime.now(timezone.utc))


# Consolidate constants at the top
CORS_HEADERS = {
"Content-Type": "application/json",
"Access-Control-Allow-Headers": "Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "*",
}

# Initialize logging at the top level
logger = logging.getLogger()
logger.setLevel(logging.INFO)


def handler(event, context):
# Check the event for possible S3 created event
input_payload = {}
logger.info(event)
resp_header = {
"Content-Type": "application/json",
"Access-Control-Allow-Headers": "Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "*",
def validate_index_type(index_type: str) -> bool:
"""Validate if the provided index type is supported."""
valid_types = [
IndexType.QD.value,
IndexType.QQ.value,
IndexType.INTENTION.value,
]
return index_type in valid_types


def get_etl_info(group_name: str, chatbot_id: str, index_type: str):
"""
Retrieve the index id, model type, and model endpoint for the given chatbot and index type.
These will be further used to perform knowledge ingestion to opensearch.
Returns: Tuple of (index_id, model_type, model_endpoint)
"""

chatbot_item = chatbot_table.get_item(
Key={"groupName": group_name, "chatbotId": chatbot_id}
).get("Item")

model_item = model_table.get_item(
Key={"groupName": group_name, "modelId": f"{chatbot_id}-embedding"}
).get("Item")

if not (chatbot_item and model_item):
raise ValueError("Chatbot or model not found")

model = model_item.get("parameter", {})
specific_type_indices = (
chatbot_item.get("indexIds", {}).get(index_type, {}).get("value", {})
)

if not specific_type_indices:
raise ValueError("No indices found for the given index type")

return (
next(iter(specific_type_indices.values())), # First index ID
model.get("ModelType"),
model.get("ModelEndpoint"),
)


def create_execution_record(
execution_id: str, input_body: Dict, sfn_execution_id: str
) -> None:
"""Create execution record in DynamoDB."""
execution_record = {
**input_body,
"sfnExecutionId": sfn_execution_id,
"executionStatus": ExecutionStatus.IN_PROGRESS.value,
"executionId": execution_id,
"uiStatus": UiStatus.ACTIVE.value,
"createTime": str(datetime.now(timezone.utc)),
}
del execution_record["tableItemId"]
execution_table.put_item(Item=execution_record)


def handler(event: Dict, context) -> Dict:
"""Main Lambda handler for ETL operations."""
logger.info(event)

authorizer_type = event["requestContext"].get("authorizer", {}).get("authorizerType")
if authorizer_type == "lambda_authorizer":
claims = json.loads(event["requestContext"]["authorizer"]["claims"])
try:
# Validate and extract authorization
authorizer = event["requestContext"].get("authorizer", {})
if authorizer.get("authorizerType") != "lambda_authorizer":
raise ValueError("Invalid authorizer type")

claims = json.loads(authorizer.get("claims", {}))
if "use_api_key" in claims:
group_name = get_query_parameter(event, "GroupName", "Admin")
cognito_groups_list = [group_name]
else:
cognito_groups = claims["cognito:groups"]
cognito_groups_list = cognito_groups.split(",")
else:
logger.error("Invalid authorizer type")
cognito_groups_list = claims["cognito:groups"].split(",")

# Process input
input_body = json.loads(event["body"])
index_type = input_body.get("indexType")

if not validate_index_type(index_type):
return {
"statusCode": 400,
"headers": CORS_HEADERS,
"body": f"Invalid indexType, valid values are {', '.join([t.value for t in IndexType])}",
}

group_name = input_body.get("groupName") or (
"Admin"
if "Admin" in cognito_groups_list
else cognito_groups_list[0]
)
chatbot_id = input_body.get("chatbotId", group_name.lower())
index_id, embedding_model_type, embedding_endpoint = get_etl_info(
group_name, chatbot_id, index_type
)

# Update input body with processed values
input_body.update(
{
"chatbotId": chatbot_id,
"groupName": group_name,
"tableItemId": context.aws_request_id,
"indexId": index_id,
"embeddingModelType": embedding_model_type,
"embeddingEndpoint": embedding_endpoint,
}
)

# Start step function and create execution record
sfn_response = client.start_execution(
stateMachineArn=sfn_arn, input=json.dumps(input_body)
)

execution_id = context.aws_request_id
create_execution_record(
execution_id,
input_body,
sfn_response["executionArn"].split(":")[-1],
)

return {
"statusCode": 403,
"headers": resp_header,
"body": json.dumps({"error": "Invalid authorizer type"}),
"statusCode": 200,
"headers": CORS_HEADERS,
"body": json.dumps(
{
"execution_id": execution_id,
"step_function_arn": sfn_response["executionArn"],
"input_payload": input_body,
}
),
}

# Parse the body from the event object
input_body = json.loads(event["body"])
if "indexType" not in input_body or input_body["indexType"] not in [
IndexType.QD.value,
IndexType.QQ.value,
IndexType.INTENTION.value,
]:
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return {
"statusCode": 400,
"headers": resp_header,
"body": (
f"Invalid indexType, valid values are "
f"{IndexType.QD.value}, {IndexType.QQ.value}, "
f"{IndexType.INTENTION.value}"
),
"statusCode": 500,
"headers": CORS_HEADERS,
"body": json.dumps({"error": str(e)}),
}
index_type = input_body["indexType"]
group_name = "Admin" if "Admin" in cognito_groups_list else cognito_groups_list[0]
chatbot_id = input_body.get("chatbotId", group_name.lower())

if "indexId" in input_body:
index_id = input_body["indexId"]
else:
# Use default index id if not specified in the request
index_id = f"{chatbot_id}-qd-default"
if index_type == IndexType.QQ.value:
index_id = f"{chatbot_id}-qq-default"
elif index_type == IndexType.INTENTION.value:
index_id = f"{chatbot_id}-intention-default"

if "tag" in input_body:
tag = input_body["tag"]
else:
tag = index_id

input_body["indexId"] = index_id
input_body["groupName"] = group_name if "groupName" not in input_body else input_body["groupName"]
chatbot_event_body = input_body
chatbot_event_body["group_name"] = group_name
chatbot_event = {"body": json.dumps(chatbot_event_body)}
chatbot_result = create_chatbot(chatbot_event, group_name)

input_body["tableItemId"] = context.aws_request_id
input_body["chatbotId"] = chatbot_id
input_body["embeddingModelType"] = chatbot_result["modelType"]
input_payload = json.dumps(input_body)
response = client.start_execution(stateMachineArn=sfn_arn, input=input_payload)

# Update execution table item
if "tableItemId" in input_body:
del input_body["tableItemId"]
execution_id = response["executionArn"].split(":")[-1]
input_body["sfnExecutionId"] = execution_id
input_body["executionStatus"] = ExecutionStatus.IN_PROGRESS.value
input_body["indexId"] = index_id
input_body["executionId"] = context.aws_request_id
input_body["uiStatus"] = UiStatus.ACTIVE.value
input_body["createTime"] = create_time

execution_table.put_item(Item=input_body)

return {
"statusCode": 200,
"headers": resp_header,
"body": json.dumps(
{
"execution_id": context.aws_request_id,
"step_function_arn": response["executionArn"],
"input_payload": input_payload,
}
),
}
Loading
Loading