Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: merge dev to main #467

Merged
merged 25 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
43a3411
feat: add tour guide
aws-cloudfront-extension-bot Nov 20, 2024
197d7e9
chore: update tour guide
NingLu Nov 26, 2024
d865ffb
chore: fomatting api stack
IcyKallen Dec 2, 2024
7ef736b
Merge branch 'dev' into lvn
NingLu Dec 2, 2024
a5b08c5
chore: i18n
NingLu Dec 2, 2024
bb722bd
fix: update text
NingLu Dec 2, 2024
ab385e4
chore: update chat history api
IcyKallen Dec 2, 2024
3dc929b
chore: format front end code
IcyKallen Dec 2, 2024
9e6bf36
feat: naive version of clicking feedback
IcyKallen Dec 2, 2024
49aeef2
feat: remove rating for first sentence and move rating to right side …
IcyKallen Dec 2, 2024
2130ebe
feat: support withdraw feedback on frontend
IcyKallen Dec 2, 2024
d7bd07f
Merge pull request #460 from aws-samples/lvn
530051970 Dec 3, 2024
71ee4ee
Merge pull request #461 from aws-samples/xuhan-dev
NingLu Dec 3, 2024
a2c6565
feat: support showing markdown image
NingLu Dec 3, 2024
d831239
Merge pull request #462 from aws-samples/lvn
NingLu Dec 3, 2024
e1ad131
bug fix: miss index when upload intention and update test IO
530051970 Dec 3, 2024
6a364de
reslove conflict
530051970 Dec 3, 2024
facea95
Merge pull request #463 from aws-samples/hubin-daily
NingLu Dec 3, 2024
b113abe
delete unused code
530051970 Dec 3, 2024
f61d8f7
Merge branch 'dev' of https://github.com/aws-samples/Intelli-Agent in…
530051970 Dec 3, 2024
222e2d4
Merge pull request #464 from aws-samples/hubin-daily
NingLu Dec 3, 2024
2348745
fix: miss index when upload intention
530051970 Dec 4, 2024
bb5444f
Merge branch 'dev' of https://github.com/aws-samples/Intelli-Agent in…
530051970 Dec 4, 2024
96ffbb0
fix: miss index when upload intention
530051970 Dec 4, 2024
0086d9d
Merge pull request #465 from aws-samples/hubin-daily
NingLu Dec 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
424 changes: 227 additions & 197 deletions source/infrastructure/lib/api/api-stack.ts

Large diffs are not rendered by default.

300 changes: 300 additions & 0 deletions source/lambda/chat_history/chat_history_management.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
"""
Lambda function for managing chat history operations.
Provides REST API endpoints for listing sessions, messages,
and managing message ratings.
"""

import json
import logging
import os
from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from typing import Any, Dict, Optional

import boto3
from boto3.dynamodb.conditions import Key
from botocore.paginate import TokenEncoder

# Configure logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)


@dataclass
class AwsResources:
"""Centralized AWS resource management"""

dynamodb = boto3.resource("dynamodb")
dynamodb_client = boto3.client("dynamodb")

def __post_init__(self):
# Initialize DynamoDB tables
self.sessions_table = self.dynamodb.Table(Config.SESSIONS_TABLE_NAME)
self.messages_table = self.dynamodb.Table(Config.MESSAGES_TABLE_NAME)


class Config:
"""Configuration constants"""

SESSIONS_TABLE_NAME = os.environ["SESSIONS_TABLE_NAME"]
MESSAGES_TABLE_NAME = os.environ["MESSAGES_TABLE_NAME"]
SESSIONS_BY_TIMESTAMP_INDEX = os.environ["SESSIONS_BY_TIMESTAMP_INDEX_NAME"]
MESSAGES_BY_SESSION_ID_INDEX = os.environ["MESSAGES_BY_SESSION_ID_INDEX_NAME"]
DEFAULT_PAGE_SIZE = 50
DEFAULT_MAX_ITEMS = 50

CORS_HEADERS = {
"Content-Type": "application/json",
"Access-Control-Allow-Headers": "Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "*",
}


# Initialize AWS resources
aws_resources = AwsResources()
token_encoder = TokenEncoder()


class DecimalEncoder(json.JSONEncoder):
"""Custom JSON encoder for Decimal types"""

def default(self, o):
if isinstance(o, Decimal):
return str(o)
return super(DecimalEncoder, self).default(o)


class PaginationConfig:
@staticmethod
def get_query_parameter(event: Dict[str, Any], parameter_name: str, default_value: Any = None) -> Any:
"""Extract query parameter from event with default value"""
if event.get("queryStringParameters") and parameter_name in event["queryStringParameters"]:
return event["queryStringParameters"][parameter_name]
return default_value

@classmethod
def get_pagination_config(cls, event: Dict[str, Any]) -> Dict[str, Any]:
"""Build pagination configuration from event parameters"""
return {
"MaxItems": int(cls.get_query_parameter(event, "max_items", Config.DEFAULT_MAX_ITEMS)),
"PageSize": int(cls.get_query_parameter(event, "page_size", Config.DEFAULT_PAGE_SIZE)),
"StartingToken": cls.get_query_parameter(event, "starting_token"),
}


class ChatHistoryManager:
"""Handles chat history related database operations"""

@staticmethod
def get_session(session_id: str, user_id: str) -> Optional[Dict]:
"""Retrieve session details from DynamoDB"""
response = aws_resources.sessions_table.get_item(Key={"sessionId": session_id, "userId": user_id})
return response.get("Item")

@staticmethod
def get_message(message_id: str, session_id: str) -> Optional[Dict]:
"""Retrieve message details from DynamoDB"""
response = aws_resources.messages_table.get_item(Key={"messageId": message_id, "sessionId": session_id})
return response.get("Item")

@staticmethod
def list_sessions(user_id: str, pagination_config: Dict[str, Any]) -> Dict[str, Any]:
"""List sessions for a user with pagination"""
paginator = aws_resources.dynamodb_client.get_paginator("query")

response_iterator = paginator.paginate(
TableName=Config.SESSIONS_TABLE_NAME,
IndexName=Config.SESSIONS_BY_TIMESTAMP_INDEX,
KeyConditionExpression="userId = :user_id",
ExpressionAttributeValues={":user_id": {"S": user_id}},
ScanIndexForward=False,
PaginationConfig=pagination_config,
)

return ChatHistoryManager._process_paginated_response(
response_iterator,
["sessionId", "userId", "createTimestamp", "latestQuestion"],
pagination_config=pagination_config,
)

@staticmethod
def list_messages(session_id: str, pagination_config: Dict[str, Any]) -> Dict[str, Any]:
"""List messages for a session with pagination"""
paginator = aws_resources.dynamodb_client.get_paginator("query")

response_iterator = paginator.paginate(
TableName=Config.MESSAGES_TABLE_NAME,
IndexName=Config.MESSAGES_BY_SESSION_ID_INDEX,
KeyConditionExpression="sessionId = :session_id",
ExpressionAttributeValues={":session_id": {"S": session_id}},
ScanIndexForward=False,
PaginationConfig=pagination_config,
)

return ChatHistoryManager._process_paginated_response(
response_iterator,
["messageId", "role", "content", "createTimestamp"],
pagination_config=pagination_config,
is_messages_list=True,
)

@staticmethod
def _process_paginated_response(
response_iterator, keys: list, pagination_config: Dict[str, Any] = None, is_messages_list: bool = False
) -> Dict[str, Any]:
"""Process paginated responses from DynamoDB"""
output = {}
processed_items = []

for page in response_iterator:
items = page["Items"]

for item in items:
processed_item = {key: item.get(key, {"S": ""})["S"] for key in keys}
# special handling for AI messages while listing messages
if is_messages_list and item.get("role", {}).get("S") == "ai":
processed_item["additional_kwargs"] = json.loads(item["additional_kwargs"]["S"])
processed_items.append(processed_item)

if "LastEvaluatedKey" in page:
output["LastEvaluatedKey"] = token_encoder.encode({"ExclusiveStartKey": page["LastEvaluatedKey"]})
break

# Sort based on createTimestamp
# For sessions list: descending order (newest first)
# For messages list: ascending order (oldest first)
if "createTimestamp" in keys:
processed_items.sort(
key=lambda x: x["createTimestamp"],
reverse=not is_messages_list, # False for messages (ascending), True for sessions (descending)
)

output["Items"] = processed_items
output["Config"] = pagination_config
output["Count"] = len(processed_items)
return output

@staticmethod
def add_feedback(
session_id: str, user_id: str, message_id: str, feedback_type: str, feedback_reason: str, suggest_message: Dict
) -> Dict[str, Any]:
"""Add feedback to a message"""
# First verify the session belongs to the user
session = ChatHistoryManager.get_session(session_id, user_id)
if not session:
return {"added": False, "error": "Session not found or unauthorized"}

message = ChatHistoryManager.get_message(message_id, session_id)
if not message:
return {"added": False, "error": "Message not found"}

try:
current_timestamp = datetime.utcnow().isoformat() + "Z"

# Update message with feedback
aws_resources.messages_table.update_item(
Key={"messageId": message_id, "sessionId": session_id},
UpdateExpression="SET feedbackType = :ft, feedbackReason = :fr, suggestMessage = :sm, lastModifiedTimestamp = :t",
ExpressionAttributeValues={
":ft": feedback_type,
":fr": feedback_reason,
":sm": suggest_message,
":t": current_timestamp,
},
)

# Update session last modified time
aws_resources.sessions_table.update_item(
Key={"sessionId": session_id, "userId": user_id},
UpdateExpression="SET lastModifiedTimestamp = :t",
ExpressionAttributeValues={":t": current_timestamp},
)

return {"added": True}
except Exception as e:
logger.error("Error adding feedback: %s", str(e))
return {"added": False, "error": str(e)}


class ApiResponse:
"""Standardized API response handler"""

@staticmethod
def success(data: Any, status_code: int = 200) -> Dict:
return {"statusCode": status_code, "headers": Config.CORS_HEADERS, "body": json.dumps(data, cls=DecimalEncoder)}

@staticmethod
def error(message: str, status_code: int = 500) -> Dict:
logger.error("Error: %s", message)
return {"statusCode": status_code, "headers": Config.CORS_HEADERS, "body": json.dumps({"error": str(message)})}


class ApiHandler:
"""API endpoint handlers"""

@staticmethod
def list_sessions(event: Dict) -> Dict:
"""Handle GET /chat-history/sessions endpoint"""
try:
claims = json.loads(event["requestContext"]["authorizer"]["claims"])
user_id = claims["cognito:username"]
pagination_config = PaginationConfig.get_pagination_config(event)
result = ChatHistoryManager.list_sessions(user_id, pagination_config)
return ApiResponse.success(result)
except Exception as e:
return ApiResponse.error(str(e))

@staticmethod
def list_messages(event: Dict) -> Dict:
"""Handle GET /chat-history/sessions/{sessionId}/messages endpoint"""
try:
session_id = event["pathParameters"]["sessionId"]
pagination_config = PaginationConfig.get_pagination_config(event)
result = ChatHistoryManager.list_messages(session_id, pagination_config)
return ApiResponse.success(result)
except Exception as e:
return ApiResponse.error(str(e))

@staticmethod
def add_feedback(event: Dict) -> Dict:
"""Handle POST /sessions/{sessionId}/messages/{messageId}/feedback endpoint"""
try:
# Extract path parameters
session_id = event["pathParameters"]["sessionId"]
message_id = event["pathParameters"]["messageId"]
claims = json.loads(event["requestContext"]["authorizer"]["claims"])
user_id = claims["cognito:username"]

# Parse request body
body = json.loads(event["body"])
result = ChatHistoryManager.add_feedback(
session_id=session_id,
user_id=user_id,
message_id=message_id,
feedback_type=body["feedback_type"],
feedback_reason=body["feedback_reason"],
suggest_message=body["suggest_message"],
)
return ApiResponse.success(result)
except Exception as e:
return ApiResponse.error(str(e))


def lambda_handler(event: Dict, context: Any) -> Dict:
"""Routes API requests to appropriate handlers based on HTTP method and path"""
logger.info("Received event: %s", json.dumps(event))

routes = {
# More RESTful paths
("GET", "/sessions"): ApiHandler.list_sessions,
("GET", "/sessions/{sessionId}/messages"): ApiHandler.list_messages,
("POST", "/sessions/{sessionId}/messages/{messageId}/feedback"): ApiHandler.add_feedback,
}

handler = routes.get((event["httpMethod"], event["resource"]))
if not handler:
return ApiResponse.error("Route not found", 404)

return handler(event)
Loading
Loading