-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinterview.py
180 lines (145 loc) · 7.12 KB
/
interview.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import streamlit as st
from mistralai import Mistral, UserMessage, SystemMessage, AssistantMessage
import streamlit as st
import os
from PIL import Image
import mistral_files
import prompts
def interview(client: Mistral):
st.title("Your Personal JournalAIst")
st.subheader("Generate a personal story by uploading pictures and answering a few questions.")
st.write("""
Hi! What did you get up to today?
Upload some pictures and give me a short description of your day.
Click on 'End Conversation' to move onto story generation.
""")
if st.session_state.photo_upload is None:
# Add end conversation button outside of the chat input condition
st.write("Would you like to upload photos?")
st.session_state.n_pictures = 0
col1, col2 = st.columns(2)
with col1:
upload_photos_yes = st.button("Yes")
with col2:
upload_photos_no = st.button("No")
if upload_photos_yes is True:
st.session_state.photo_upload = True
elif upload_photos_no is True:
st.session_state.photo_upload = False
if st.session_state.photo_upload:
uploaded_files = st.file_uploader(
"Choose images...", type=["jpg", "png"], accept_multiple_files=True
)
if uploaded_files:
st.session_state.uploaded_files = uploaded_files
if not os.path.exists("./stories"):
os.makedirs("./stories")
# st.session_state.photo_upload = False
if len(st.session_state.uploaded_files) > 0:
cols = st.columns(len(st.session_state.uploaded_files))
for col, uploaded_file in zip(cols, st.session_state.uploaded_files):
image = Image.open(uploaded_file)
col.image(image, use_column_width=True)
if "processed_uploaded_files" not in st.session_state:
# Save uploaded files
for uploaded_file in st.session_state.uploaded_files:
st.session_state.n_pictures += 1
image = Image.open(uploaded_file)
image_location = f"./stories/{st.session_state.session_id}/picture_{st.session_state.n_pictures}.jpg"
image.convert("RGB").save(image_location)
file_info = mistral_files.handle_files(
st.session_state.uploaded_files,
client,
model=st.session_state["pixtral_model"],
)
picture_response = ""
#message_placeholder = st.empty()
if file_info:
picture_response += file_info.choices[0].message.content
st.session_state.picture_information = picture_response
#message_placeholder.markdown(picture_response + "▌")
#message_placeholder.markdown(picture_response)
else:
# Handle the case where response_generator is None
st.error("Failed to generate response")
#message_placeholder.markdown(picture_response)
st.session_state.picture_messages.append(
AssistantMessage(content=picture_response)
)
st.session_state.processed_uploaded_files = True
# Only show the system prompt and chat once the pictures are uploaded
if (
"processed_uploaded_files" in st.session_state
or st.session_state.photo_upload is False
):
# Add system prompt input
if "system_prompt" not in st.session_state:
picture_info = ""
if st.session_state.photo_upload:
picture_info = "The information about each picture is provided below: \n"
for picture_message in st.session_state.picture_messages:
picture_info = picture_info + picture_message.content
print(picture_info)
# Load prompt from file
st.session_state["system_prompt"] = prompts.render_template_from_file(
"prompts/interview.md", picture_info=picture_info
)
# print(st.session_state["system_prompt"])
# Add system prompt as a UserMessage if it doesn't exist
if st.session_state["system_prompt"] and not any(
message.role == "system" for message in st.session_state.messages
):
st.session_state.messages.insert(
0, SystemMessage(content=st.session_state["system_prompt"])
)
if st.session_state.n_pictures > 0:
intro_message = AssistantMessage(
content=f"Hi! I saw you uploaded {st.session_state.n_pictures} pictures. \
Give me a short summary of the images and what you did today? \
Just give me the bullet points, at least five if you can."
)
else:
# Add system message to the conversation log
intro_message = AssistantMessage(
content="Hi! Tell me about a meaningful event from in your life!"
)
st.session_state.messages.append(intro_message)
#TODO: Comment out when we don't want to render the picture description
#for message in st.session_state.picture_messages:
# if message.role != "system": # Skip system messages for UI
# with st.chat_message(message.role): # Use dot notation here
# st.markdown(message.content) # And here
for message in st.session_state.messages:
if message.role != "system": # Skip system messages for UI
with st.chat_message(message.role): # Use dot notation here
st.markdown(message.content) # And here
if prompt := st.chat_input(
"What event would you like me to write a story about?"
):
new_message = UserMessage(role="user", content=prompt)
st.session_state.messages.append(new_message)
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
response_generator = client.chat.stream(
model=st.session_state["mistral_model"],
messages=st.session_state.messages, # Pass the entire messages list
)
if response_generator is not None:
for response in response_generator:
full_response += response.data.choices[0].delta.content or ""
message_placeholder.markdown(full_response + " ")
message_placeholder.markdown(full_response)
else:
# Handle the case where response_generator is None
st.error("Failed to generate response")
message_placeholder.markdown(full_response)
st.session_state.messages.append(AssistantMessage(content=full_response))
# Add end conversation button outside of the chat input condition
end_conversation = st.button("End Conversation")
if end_conversation:
st.session_state.page = "story"
st.rerun()