-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
66 lines (52 loc) · 1.89 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from flask import Flask, request, jsonify
from operator import itemgetter
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import ChatOpenAI
import requests
app = Flask(__name__)
# (default: mistral:instruct)
model = ChatOpenAI(
temperature=0,
model_name="mistral:instruct",
openai_api_base="http://localhost:11434/v1",
openai_api_key="insert your api key here",
max_tokens=80,
)
# define the prompt and system message
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You're Siri-GPT, an open source AI smarter than Siri that runs on user's devices. You're helping a user with tasks, for any question answer very briefly (answer is about 30 words) and informatively. else, ask for more information.",
),
MessagesPlaceholder(variable_name="history"),
("human", "{input}"),
]
)
# define memory type
memory = ConversationBufferWindowMemory(k=5, return_messages=True)
# define the chain
chain = (
RunnablePassthrough.assign(
history=RunnableLambda(memory.load_memory_variables) | itemgetter("history")
)
| prompt
| model
)
def generate(user_input="Test"):
print(f'current memory:\n{memory.load_memory_variables({""})}')
if user_input == "":
return "End of conversation"
inputs = {"input": f"{user_input}"}
response = chain.invoke(inputs)
memory.save_context(inputs, {"output": response.content})
return response.content
@app.route("/", methods=["POST"])
def generate_route():
prompt = request.json.get("prompt", "")
response = generate(prompt)
return response
if __name__ == "__main__":
app.run(host="0.0.0.0")