Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Add other bedrock models #432

Merged
merged 33 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f786e76
add langchain_integration folder;support lazy load model and chains
Oct 17, 2024
0e9cc92
modify
Oct 18, 2024
7d34b82
modify
Oct 18, 2024
b68322d
Merge branch 'dev' of https://github.com/aws-samples/Intelli-Agent in…
Oct 18, 2024
52cdcc8
refactor: unified tool to adapt to langchian's tool
Oct 24, 2024
d62f3a9
refactor: modify module import
Oct 24, 2024
e3f063e
refactor: add llm chain tool_calling_api and it's prompt template
Oct 24, 2024
a4c99df
complete the tool refactor, next to test
Oct 29, 2024
541c3e7
fix: one rag tool should respect to one index; refactor: add register…
Oct 31, 2024
ffb1f43
move langchain_integration into common_logic
Nov 2, 2024
4401f38
modify agent prompt; add python_repl tool; adapt to pydantic v2
Nov 2, 2024
e6e3066
remove lambda invoke in intention stage
Nov 3, 2024
6577e26
move retrievers to common_logic
Nov 3, 2024
d9b6851
move functions to __functions
Nov 5, 2024
cdb3e15
add lambda tool test
Nov 6, 2024
e526644
remove functions layer
Nov 7, 2024
bcc3742
fix bug in streaming
Nov 7, 2024
de41e80
add new model in ui
Nov 7, 2024
c94f7b1
modify logger, fix bug about inaccurate filename output
Nov 7, 2024
406a72f
remove llm_generate_utils
Nov 7, 2024
2442434
add CLAUDE_3_5_SONNET_V2 and CLAUDE_3_5_HAIKU models
Nov 7, 2024
882e17a
modify online requirements
Nov 7, 2024
57f5fdc
add enable_prefill parameter;optimize prompt
Nov 7, 2024
ec997ba
modify PythonREPL, fix bug running on lambda
Nov 8, 2024
6a45433
add llama-3.2
Nov 12, 2024
5171b31
merge from dev
Nov 12, 2024
5b622d7
modify agent prompt; add new intent logic
Nov 12, 2024
0e89072
debug intention logic
Nov 12, 2024
0686dc3
add sso example
Nov 12, 2024
75f194c
modify .viperlightignore
Nov 13, 2024
8f8df04
remove __functions __llm_generate_utils
Nov 13, 2024
7aa5eec
modify according to the pr comments
Nov 13, 2024
6976c5a
modify glue job requirements
Nov 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions source/infrastructure/lib/chat/chat-stack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
private lambdaOnlineAgent: Function;
private lambdaOnlineLLMGenerate: Function;
private chatbotTableName: string;
private lambdaOnlineFunctions: Function;
// private lambdaOnlineFunctions: Function;

constructor(scope: Construct, id: string, props: ChatStackProps) {
super(scope, id);
Expand Down Expand Up @@ -282,23 +282,23 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
this.lambdaOnlineLLMGenerate.addToRolePolicy(this.iamHelper.dynamodbStatement);


const lambdaOnlineFunctions = new LambdaFunction(this, "lambdaOnlineFunctions", {
runtime: Runtime.PYTHON_3_12,
handler: "lambda_tools.lambda_handler",
code: Code.fromAsset(
join(__dirname, "../../../lambda/online/functions/functions_utils"),
),
memorySize: 4096,
vpc: vpc,
securityGroups: securityGroups,
layers: [apiLambdaOnlineSourceLayer, apiLambdaJobSourceLayer],
environment: {
CHATBOT_TABLE: props.sharedConstructOutputs.chatbotTable.tableName,
INDEX_TABLE: this.indexTableName,
MODEL_TABLE: this.modelTableName,
},
});
this.lambdaOnlineFunctions = lambdaOnlineFunctions.function;
// const lambdaOnlineFunctions = new LambdaFunction(this, "lambdaOnlineFunctions", {
// runtime: Runtime.PYTHON_3_12,
// handler: "lambda_tools.lambda_handler",
// code: Code.fromAsset(
// join(__dirname, "../../../lambda/online/functions/functions_utils"),
// ),
// memorySize: 4096,
// vpc: vpc,
// securityGroups: securityGroups,
// layers: [apiLambdaOnlineSourceLayer, apiLambdaJobSourceLayer],
// environment: {
// CHATBOT_TABLE: props.sharedConstructOutputs.chatbotTable.tableName,
// INDEX_TABLE: this.indexTableName,
// MODEL_TABLE: this.modelTableName,
// },
// });
// this.lambdaOnlineFunctions = lambdaOnlineFunctions.function;

this.lambdaOnlineQueryPreprocess.grantInvoke(this.lambdaOnlineMain);

Expand All @@ -310,8 +310,8 @@ export class ChatStack extends NestedStack implements ChatStackOutputs {
this.lambdaOnlineLLMGenerate.grantInvoke(this.lambdaOnlineQueryPreprocess);
this.lambdaOnlineLLMGenerate.grantInvoke(this.lambdaOnlineAgent);

this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineMain);
this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineIntentionDetection);
// this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineMain);
// this.lambdaOnlineFunctions.grantInvoke(this.lambdaOnlineIntentionDetection);

if (props.config.chat.amazonConnect.enabled) {
new ConnectConstruct(this, "connect-construct", {
Expand Down
15 changes: 8 additions & 7 deletions source/lambda/job/dep/llm_bot_dep/sm_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
import io
from typing import Any, Dict, Iterator, List, Mapping, Optional
from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint
from langchain.embeddings import SagemakerEndpointEmbeddings, BedrockEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain_community.llms import SagemakerEndpoint
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does glue job's requirements.txt need to be update as well?
langchain -> langchain_community

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, need to add "langchain_community" to it's requirements.txt

from langchain_community.llms.sagemaker_endpoint import LLMContentHandler
from langchain_community.embeddings import SagemakerEndpointEmbeddings,BedrockEmbeddings
from langchain_community.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.utils import enforce_stop_tokens
from langchain_community.llms.utils import enforce_stop_tokens
from typing import Dict, List, Optional, Any,Iterator
from langchain_core.outputs import GenerationChunk
import boto3
Expand Down Expand Up @@ -234,12 +235,12 @@ def transform_output(self, output: bytes) -> str:
function. See `boto3`_. docs for more info.
.. _boto3: <https://boto3.amazonaws.com/v1/documentation/api/latest/index.html>
"""
content_type = "application/json"
accepts = "application/json"
content_type: str = "application/json"
accepts: str = "application/json"
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
extra = Extra.forbid.value

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
Expand Down
11 changes: 11 additions & 0 deletions source/lambda/online/__functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# tool
# from ._tool_base import get_tool_by_name,Tool,tool_manager

# def init_common_tools():
# from . import lambda_common_tools

# def init_aws_qa_tools():
# from . import lambda_aws_qa_tools

# def init_retail_tools():
# from . import lambda_retail_tools
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import Enum
from common_logic.common_utils.constant import SceneType,ToolRuningMode


class ToolDefType(Enum):
openai = "openai"

Expand All @@ -19,6 +20,7 @@ class Tool(BaseModel):
scene: str = Field(description="tool use scene",default=SceneType.COMMON)
# should_ask_parameter: bool = Field(description="tool use scene")


class ToolManager:
def __init__(self) -> None:
self.tools = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os

os.environ["PYTHONUNBUFFERED"] = "1"
import logging
import sys
Expand All @@ -24,9 +23,9 @@
GoogleRetriever,
)
from langchain.retrievers import (
AmazonKnowledgeBasesRetriever,
ContextualCompressionRetriever,
)
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain_community.retrievers import AmazonKnowledgeBasesRetriever
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
logger = logging.getLogger()
logger.setLevel(logging.INFO)

from langchain.utilities import GoogleSearchAPIWrapper
from langchain_community.utilities import GoogleSearchAPIWrapper
from langchain.callbacks.manager import CallbackManagerForRetrieverRun
from langchain.docstore.document import Document
from langchain.schema.retriever import BaseRetriever
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def lambda_handler(event_body, context=None):
retriever_params = state["chatbot_config"]["private_knowledge_config"]
retriever_params["query"] = state[retriever_params.get(
"retriever_config", {}).get("query_key", "query")]

output: str = invoke_lambda(
event_body=retriever_params,
lambda_name="Online_Functions",
Expand Down
16 changes: 13 additions & 3 deletions source/lambda/online/common_logic/common_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,19 @@ class LLMTaskType(ConstantBase):
HYDE_TYPE = "hyde"
CONVERSATION_SUMMARY_TYPE = "conversation_summary"
RETAIL_CONVERSATION_SUMMARY_TYPE = "retail_conversation_summary"

MKT_CONVERSATION_SUMMARY_TYPE = "mkt_conversation_summary"
MKT_QUERY_REWRITE_TYPE = "mkt_query_rewrite"
STEPBACK_PROMPTING_TYPE = "stepback_prompting"
TOOL_CALLING = "tool_calling"
TOOL_CALLING_XML = "tool_calling_xml"
TOOL_CALLING_API = "tool_calling_api"
RETAIL_TOOL_CALLING = "retail_tool_calling"
RAG = "rag"
MTK_RAG = "mkt_rag"
CHAT = 'chat'
AUTO_EVALUATION = "auto_evaluation"



class MessageType(ConstantBase):
HUMAN_MESSAGE_TYPE = 'human'
AI_MESSAGE_TYPE = 'ai'
Expand Down Expand Up @@ -126,19 +128,26 @@ class LLMModelType(ConstantBase):
CLAUDE_2 = "anthropic.claude-v2"
CLAUDE_21 = "anthropic.claude-v2:1"
CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"
CLAUDE_3_5_HAIKU = "anthropic.claude-3-5-haiku-20241022-v1:0"
CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"
CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"
CLAUDE_3_5_SONNET_V2 = "anthropic.claude-3-5-sonnet-20241022-v2:0"
MIXTRAL_8X7B_INSTRUCT = "mistral.mixtral-8x7b-instruct-v0:1"
BAICHUAN2_13B_CHAT = "Baichuan2-13B-Chat-4bits"
INTERNLM2_CHAT_7B = "internlm2-chat-7b"
INTERNLM2_CHAT_20B = "internlm2-chat-20b"
GLM_4_9B_CHAT = "glm-4-9b-chat"
CHATGPT_35_TURBO = "gpt-3.5-turbo-0125"
CHATGPT_35_TURBO_0125 = "gpt-3.5-turbo-0125"
CHATGPT_4_TURBO = "gpt-4-turbo"
CHATGPT_4O = "gpt-4o"
QWEN2INSTRUCT7B = "qwen2-7B-instruct"
QWEN2INSTRUCT72B = "qwen2-72B-instruct"
QWEN15INSTRUCT32B = "qwen1_5-32B-instruct"
LLAMA3_1_70B_INSTRUCT = "meta.llama3-1-70b-instruct-v1:0"
LLAMA3_2_90B_INSTRUCT = "us.meta.llama3-2-90b-instruct-v1:0"
MISTRAL_LARGE_2407 = "mistral.mistral-large-2407-v1:0"
COHERE_COMMAND_R_PLUS = "cohere.command-r-plus-v1:0"



class EmbeddingModelType(ConstantBase):
Expand Down Expand Up @@ -179,4 +188,5 @@ class KBType(Enum):

class Threshold(ConstantBase):
QQ_IN_RAG_CONTEXT = 0.5
INTENTION_ALL_KNOWLEDGAE_RETRIEVE = 0.4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo issue,
KNOWLEDGE_RETRIEVAL

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified


Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,23 @@
import importlib
import json
import time
import os
from typing import Any, Dict, Optional, Callable, Union
import threading

import requests
from common_logic.common_utils.constant import StreamMessageType
from common_logic.common_utils.logger_utils import get_logger
from common_logic.common_utils.websocket_utils import is_websocket_request, send_to_ws_client
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from pydantic import BaseModel, Field, model_validator


from .exceptions import LambdaInvokeError

logger = get_logger("lambda_invoke_utils")

# thread_local = threading.local()
thread_local = threading.local()
CURRENT_STATE = None

__FUNC_NAME_MAP = {
"query_preprocess": "Preprocess for Multi-round Conversation",
Expand All @@ -26,6 +31,38 @@
"llm_direct_results_generation": "LLM Response"
}


class StateContext:

def __init__(self,state):
self.state=state

@classmethod
def get_current_state(cls):
# print("thread id",threading.get_ident(),'parent id',threading.)
# state = getattr(thread_local,'state',None)
state = CURRENT_STATE
assert state is not None,"There is not a valid state in current context"
return state

@classmethod
def set_current_state(cls, state):
global CURRENT_STATE
assert CURRENT_STATE is None, "Parallel node executions are not alowed"
CURRENT_STATE = state

@classmethod
def clear_state(cls):
global CURRENT_STATE
CURRENT_STATE = None

def __enter__(self):
self.set_current_state(self.state)

def __exit__(self, exc_type, exc_val, exc_tb):
self.clear_state()


class LAMBDA_INVOKE_MODE(enum.Enum):
LAMBDA = "lambda"
LOCAL = "local"
Expand Down Expand Up @@ -55,26 +92,24 @@ class LambdaInvoker(BaseModel):
region_name: str = None
credentials_profile_name: Optional[str] = Field(default=None, exclude=True)

@root_validator()
@model_validator(mode="before")
def validate_environment(cls, values: Dict):
if values.get("client") is not None:
return values
try:
import boto3

try:
if values["credentials_profile_name"] is not None:
if values.get("credentials_profile_name") is not None:
session = boto3.Session(
profile_name=values["credentials_profile_name"]
)
else:
# use default credentials
session = boto3.Session()

values["client"] = session.client(
"lambda", region_name=values["region_name"]
"lambda",
region_name=values.get("region_name",os.environ['AWS_REGION'])
)

except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
Expand All @@ -97,8 +132,9 @@ def invoke_with_lambda(self, lambda_name: str, event_body: dict):
)
response_body = invoke_response["Payload"]
response_str = response_body.read().decode()

response_body = json.loads(response_str)
if "body" in response_body:
response_body = json.loads(response_body['body'])

if "errorType" in response_body:
error = (
Expand All @@ -108,7 +144,6 @@ def invoke_with_lambda(self, lambda_name: str, event_body: dict):
+ f"{response_body['errorType']}: {response_body['errorMessage']}"
)
raise LambdaInvokeError(error)

return response_body

def invoke_with_local(
Expand Down Expand Up @@ -285,7 +320,10 @@ def wrapper(state: Dict[str, Any]) -> Dict[str, Any]:
current_stream_use, ws_connection_id, enable_trace)
state['trace_infos'].append(
f"Enter: {func.__name__}, time: {time.time()}")
output = func(state)

with StateContext(state):
output = func(state)

current_monitor_infos = output.get(monitor_key, None)
if current_monitor_infos is not None:
send_trace(f"\n\n {current_monitor_infos}",
Expand Down
30 changes: 17 additions & 13 deletions source/lambda/online/common_logic/common_utils/logger_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import logging
import threading
import os
Expand All @@ -10,15 +9,13 @@
logger_lock = threading.Lock()


def cloud_print_wrapper(fn):
@wraps(fn)
def _inner(msg, *args, **kwargs):
class CloudStreamHandler(logging.StreamHandler):
def emit(self, record):
from common_logic.common_utils.lambda_invoke_utils import is_running_local
if not is_running_local:
# enable multiline as one message in cloudwatch
msg = msg.replace("\n", "\r")
return fn(msg, *args, **kwargs)
return _inner
record.msg = record.msg.replace("\n", "\r")
return super().emit(record)


class Logger:
Expand All @@ -37,16 +34,11 @@ def _get_logger(
logger = logging.getLogger(name)
logger.propagate = 0
# Create a handler
c_handler = logging.StreamHandler()
c_handler = CloudStreamHandler()
formatter = logging.Formatter(format, datefmt=datefmt)
c_handler.setFormatter(formatter)
logger.addHandler(c_handler)
logger.setLevel(level)
logger.info = cloud_print_wrapper(logger.info)
logger.error = cloud_print_wrapper(logger.error)
logger.warning = cloud_print_wrapper(logger.warning)
logger.critical = cloud_print_wrapper(logger.critical)
logger.debug = cloud_print_wrapper(logger.debug)
cls.logger_map[name] = logger
return logger

Expand All @@ -72,3 +64,15 @@ def print_llm_messages(msg, logger=logger):
"ENABLE_PRINT_MESSAGES", 'True').lower() in ('true', '1', 't')
if enable_print_messages:
logger.info(msg)


def llm_messages_print_decorator(fn):
@wraps(fn)
def _inner(*args, **kwargs):
if args:
print_llm_messages(args)
if kwargs:
print_llm_messages(kwargs)
return fn(*args, **kwargs)
return _inner

Loading
Loading