From 6c8105c1dcfc0e56b2d0d0a8b0c62cb1b7b6ad3d Mon Sep 17 00:00:00 2001 From: Supriya Date: Sun, 13 Oct 2024 18:12:44 -0700 Subject: [PATCH] RAG using Llama3.1 Bedrock KB --- .../rag_llama31_qna.py | 225 ++++++++++++++++++ llama3-bedrock-rag-streamlit/rag_qna_app.py | 30 +++ 2 files changed, 255 insertions(+) create mode 100644 llama3-bedrock-rag-streamlit/rag_llama31_qna.py create mode 100644 llama3-bedrock-rag-streamlit/rag_qna_app.py diff --git a/llama3-bedrock-rag-streamlit/rag_llama31_qna.py b/llama3-bedrock-rag-streamlit/rag_llama31_qna.py new file mode 100644 index 0000000..02c444c --- /dev/null +++ b/llama3-bedrock-rag-streamlit/rag_llama31_qna.py @@ -0,0 +1,225 @@ +import json +import boto3 +from botocore.client import Config +from langchain.prompts import PromptTemplate + +MAX_MESSAGES = 20 +MODEL_ID = 'meta.llama3-1-70b-instruct-v1:0' +KNOWLEDGE_BASE_ID = "DYTL71ODQZ" + +bedrock_client = boto3.client(service_name='bedrock-runtime') + +class ChatMessage(): + def __init__(self, role, text): + self.role = role + self.text = text + +def get_tools(): + tools = [ + { + "toolSpec": { + "name": "amazon_shareholder_information", + "description": "Retrieve information about Amazon shareholder 2023 documents.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The retrieval augmented generation query used to search information in the knowledgebase about Amazon shareholder info." + } + }, + "required": [ + "query" + ] + } + } + } + } + ] + + return tools + +def transform_messages_for_api(chat_messages): + return [{"role": msg.role, "content": [{"text": msg.text}]} for msg in chat_messages] + +def convert_chat_messages_to_converse_api(chat_messages): + messages = [] + + for chat_msg in chat_messages: + messages.append({ + "role": chat_msg.role, + "content": [ + { + "text": chat_msg.text + } + ] + }) + return messages + +def process_tool(response_message, messages, bedrock, tool_list): + messages.append(response_message) + + response_content_blocks = response_message['content'] + follow_up_content_blocks = [] + + for content_block in response_content_blocks: + if 'toolUse' in content_block: + tool_use_block = content_block['toolUse'] + + if tool_use_block['name'] == 'amazon_shareholder_information': + + query = tool_use_block['input']['query'] + rag_content = get_shareholder_info(query) + + follow_up_content_blocks.append({ + "toolResult": { + "toolUseId": tool_use_block['toolUseId'], + "content": [ + { "text": rag_content } + ] + } + }) + + + if len(follow_up_content_blocks) > 0: + + follow_up_message = { + "role": "user", + "content": follow_up_content_blocks, + } + + messages.append(follow_up_message) + + response = bedrock.converse( + modelId=MODEL_ID, + messages=messages, + inferenceConfig={ + "maxTokens": 2000, + "temperature": 0, + "topP": 0.9, + "stopSequences": [] + }, + toolConfig={ + "tools": tool_list + } + ) + + return True, response['output']['message']['content'][0]['text'] + + else: + return False, None + +def get_contexts(retrievalResults): + contexts = [] + for retrievedResult in retrievalResults: + text = retrievedResult['content']['text'] + if text.startswith("Document 1: "): + text = text[len("Document 1: "):] + contexts.append(text) + contexts_string = ', '.join(contexts) + return contexts_string + +def get_shareholder_info(question): + response_retrieve = retrieve(question, KNOWLEDGE_BASE_ID)["retrievalResults"] + contexts = get_contexts(response_retrieve) + + PROMPT_TEMPLATE = """DOCUMENT: + {context} + QUESTION: + {message} + INSTRUCTIONS: + Answer the user's QUESTION using only the DOCUMENT text above. Greet friendly if the QUESTION contains "hi" or "hello" + Keep your answer strictly grounded in the facts provided. Do not refer to the "DOCUMENT," "documents," "provided text," ,"based on.." or any similar phrases in your answer. + If the provided text contains the facts to answer the QUESTION, include all relevant details in your answer. + If the provided text doesn’t contain the facts to answer the QUESTION, respond only with "I don't know" and do not add any further information. + """ + + prompt = PromptTemplate(template=PROMPT_TEMPLATE, + input_variables=["context","message"]) + + prompt_final = prompt.format(context=contexts, + message=question) + + native_request = { + "prompt": prompt_final, + "max_gen_len": 2048, + "temperature": 0.5, + } + + # Convert the native request to JSON. + request = json.dumps(native_request) + model_id = MODEL_ID + accept = 'application/json' + content_type = 'application/json' + response = bedrock_client.invoke_model(body=request, modelId=model_id, accept=accept, contentType=content_type) + response_body = json.loads(response.get('body').read()) + + if response_body.get('content') and response_body['content'][0].get('text'): + response_text = response_body['content'][0]['text'] + elif response_body.get('generation'): + response_text = response_body['generation'] + else: + response_text = "Sorry, I didn't get it" + + return response_text + +def retrieve(query, kbId, numberOfResults=3): + bedrock_config = Config(connect_timeout=120, read_timeout=120, retries={'max_attempts': 0}) + + bedrock_agent_client = boto3.client("bedrock-agent-runtime",config=bedrock_config) + return bedrock_agent_client.retrieve( + retrievalQuery= { + 'text': query + }, + knowledgeBaseId=kbId, + retrievalConfiguration= { + 'vectorSearchConfiguration': { + 'numberOfResults': numberOfResults, + 'overrideSearchType': "HYBRID" + } + } + ) + +def converse_with_model(message_history, new_text=None): + session = boto3.Session() + bedrock = session.client(service_name='bedrock-runtime') + + tool_list = get_tools() + + new_text_message = ChatMessage('user', text=new_text) + message_history.append(new_text_message) + + number_of_messages = len(message_history) + + if number_of_messages > MAX_MESSAGES: + del message_history[0 : (number_of_messages - MAX_MESSAGES) * 2] + + messages = transform_messages_for_api(message_history) + + response = bedrock.converse( + modelId=MODEL_ID, + messages=messages, + inferenceConfig={ + "maxTokens": 2000, + "temperature": 0, + "topP": 0.9, + "stopSequences": [] + }, + toolConfig={ + "tools": tool_list + } + ) + + response_message = response['output']['message'] + + tool_used, output = process_tool(response_message, messages, bedrock, tool_list) + + if not tool_used: + output = response['output']['message']['content'][0]['text'] + + + response_chat_message = ChatMessage('assistant', output) + message_history.append(response_chat_message) + + return diff --git a/llama3-bedrock-rag-streamlit/rag_qna_app.py b/llama3-bedrock-rag-streamlit/rag_qna_app.py new file mode 100644 index 0000000..4825709 --- /dev/null +++ b/llama3-bedrock-rag-streamlit/rag_qna_app.py @@ -0,0 +1,30 @@ +import streamlit as st #all streamlit commands will be available through the "st" alias +import rag_llama31_qna as glib #reference to local lib script + + +st.set_page_config(page_title="Amazon Shareholder RAG QnA Chatbot - powered by Amazon Bedrock Llama 3.1") #HTML title +st.title("RAG based QnA Chatbot") #page title +st.subheader("(Powered by Amazon Bedrock Knowledgebases, Llama 3.1)") #page title + +message = st.chat_message("assistant") +message.write("Hello 👋 I am a friendly chat bot who can help you answering questions related to Amazon Shareholder information.") + +if 'chat_history' not in st.session_state: #see if the chat history hasn't been created yet + st.session_state.chat_history = [] #initialize the chat history + + + +chat_container = st.container() + +input_text = st.chat_input("Type your question here...") #display a chat input box + +if input_text: + glib.converse_with_model(message_history=st.session_state.chat_history, new_text=input_text) + + + +#Re-render the chat history (Streamlit re-runs this script, so need this to preserve previous chat messages) +for message in st.session_state.chat_history: #loop through the chat history + with chat_container.chat_message(message.role): #renders a chat line for the given role, containing everything in the with block + st.markdown(message.text) #display the chat content +