-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconversation-retrieval.py
91 lines (75 loc) · 2.7 KB
/
conversation-retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from dotenv import load_dotenv
load_dotenv()
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.document_loaders import WebBaseLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores.faiss import FAISS
from langchain.chains import create_retrieval_chain
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.prompts import MessagesPlaceholder
from langchain.chains.history_aware_retriever import create_history_aware_retriever
def get_documents_from_web(url):
loader = WebBaseLoader(url)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=400,
chunk_overlap=20
)
splitDocs = splitter.split_documents(docs)
return splitDocs
def create_db(docs):
embedding = OpenAIEmbeddings()
vectorStore = FAISS.from_documents(docs, embedding)
return vectorStore
def create_chain(vectorStore):
model = ChatOpenAI(
model="gpt-4-turbo",
temperature=0.4
)
prompt = ChatPromptTemplate.from_messages([
("system", "Answer the user's question based on the context: {context}"),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}")
])
# chain = prompt | model
chain = create_stuff_documents_chain(
llm=model,
prompt=prompt
)
retriever = vectorStore.as_retriever(search_kwargs={"k": 3})
retriever_prompt = ChatPromptTemplate.from_messages([
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
("human", "Given the above conversation, generate a search query to look up in order to get information relevant to the conversation.")
])
history_aware_retriever = create_history_aware_retriever(
llm=model,
retriever=retriever,
prompt=retriever_prompt
)
retrieval_chain = create_retrieval_chain(
history_aware_retriever, chain
)
return retrieval_chain
def process_chat(chain, question, chat_history):
response = chain.invoke({
"input": question,
"chat_history": chat_history
})
return response['answer']
if __name__ == '__main__':
docs = get_documents_from_web('https://python.langchain.com/v0.1/docs/expression_language/')
vectorStore = create_db(docs)
chain = create_chain(vectorStore)
chat_history = []
while True:
user_input = input("You: ")
if user_input.lower() == 'exit':
break
response = process_chat(chain, user_input, chat_history)
chat_history.append(HumanMessage(content=user_input))
chat_history.append(AIMessage(content=response))
print("Asssitant:", response)