-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtools.py
346 lines (261 loc) · 10.3 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Filename: tools.py
Description: Tool methods that the langchain agent can use to retrieve informatoin the LLM does not know about.
Author: @alexdjulin
Date: 2024-07-25
"""
import os
from pathlib import Path
import requests
# langchain
from langchain_core.tools import tool
from langchain.schema import Document
from langchain_openai import OpenAIEmbeddings
# logger: import from ai_chatbot subrepo or create a default one if not available
try:
from logger import get_logger
LOG = get_logger(Path(__file__).stem)
except ImportError:
import logging
logging.basicConfig(
level=10, # debug
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
filename='movie_advisor.log',
force=True
)
LOG = logging.getLogger(Path(__file__).stem)
# log tool calls to file
TOOL_CALLS_LOG = "tool_calls.log"
with open(TOOL_CALLS_LOG, 'w') as log:
pass
# xata
from xata.client import XataClient
from langchain_community.vectorstores.xata import XataVectorStore
xata = XataClient()
TABLE_NAME = "movie_history"
vector_store = XataVectorStore(
embedding=OpenAIEmbeddings(),
api_key=os.getenv("XATA_API_KEY"),
db_url=os.getenv("XATA_DATABASE_URL"),
table_name=TABLE_NAME,
)
def log_tool_calls(tool_call: str = None) -> None:
"""Log the steps of the tool calls.
Useful to keep an eye on the tool calls agent verbose is set to false.
Args:
tool_call (str): tool call to log
"""
with open(TOOL_CALLS_LOG, 'a') as log:
log.write(tool_call + '\n')
def init_table() -> None:
"""Initialise xata table and vectorstore for movie history."""
# try accessing the table
try:
assert xata.data().query(TABLE_NAME).is_success()
except AssertionError:
print(f"Table '{TABLE_NAME}' does not exist. Creating table.")
create_table()
def create_table() -> None:
"""Create xata table to store our movie history."""
table_schema = {
"columns": [
{"name": "title", "type": "text"},
{"name": "status", "type": "text"},
{"name": "comment", "type": "text"},
{"name": "content", "type": "text"},
{"name": "embedding", "type": "vector", "vector": {"dimension": 1536}}
]
}
try:
# create table
assert xata.table().create(TABLE_NAME).is_success()
# set schema
resp = xata.table().set_schema(TABLE_NAME, table_schema)
assert resp.is_success(), resp
except AssertionError as e:
LOG.error(f"Error creating or setting schema for table '{TABLE_NAME}': {e}")
return
LOG.info(f"Table '{TABLE_NAME}' created successfully.")
def query_table(columns: list[str] = None, filter: dict = None) -> list[dict]:
"""Query all table records.
Args:
columns (list): list of columns to retrieve
filter (dict): filter to apply to the query
Returns:
list: list of record dicts
"""
payload = {}
payload["page"] = {"size": 1000} # limiting it to 1000 results
if columns:
payload["columns"] = columns
if filter:
payload["filter"] = filter
records = xata.data().query(
table_name=TABLE_NAME,
payload=payload,
)
LOG.debug(f"Retrieved {len(records)} records from table.")
return records["records"]
def add_update_movie(movie_record: dict) -> None:
"""Add a new record to the table. If it exists, it will be deleted first
and replaced by the new version.
Args:
movie_record (dict): new or updated record dict
"""
# query table titles
table_records = query_table(columns=["title"])
# delete record if it exists
for rec in table_records:
if rec["title"] == movie_record["title"]:
xata.records().delete(TABLE_NAME, rec["id"])
break
# add new or updated record
doc = Document(page_content=movie_record["content"], metadata={k: v for k, v in movie_record.items() if k != "content"})
vector_store.add_documents([doc])
LOG.debug(f"Added record '{movie_record['title']}' to table.")
def get_watch_lists() -> str:
"""Create a dictionary with watchlists content from the table records.
Returns:
str: string listing the contents of the four lists
"""
watched_liked = [rec['title'] for rec in query_table(columns=["title"], filter={"status": "watched_liked"})]
watched_disliked = [rec['title'] for rec in query_table(columns=["title"], filter={"status": "watched_disliked"})]
must_see = [rec['title'] for rec in query_table(columns=["title"], filter={"status": "must_see"})]
not_interested = [rec['title'] for rec in query_table(columns=["title"], filter={"status": "not_interested"})]
watch_lists = 'Here is my movie history. Do not recommend any of these:\n'
watch_lists += f"1. Watched and liked ({len(watched_liked)}): " + ", ".join(watched_liked) + "\n"
watch_lists += f"2. Watched and disliked ({len(watched_disliked)}): " + ", ".join(watched_disliked) + "\n"
watch_lists += f"3. Must see ({len(must_see)}): " + ", ".join(must_see) + "\n"
watch_lists += f"4. Not interested ({len(not_interested)}): " + ", ".join(not_interested) + "\n"
LOG.debug(f"Retrieved watch lists: {watch_lists}")
return watch_lists
@tool
def add_title_to_movies_I_watched_and_liked(title: str, comment: str) -> None:
"""Add a movie title to the list of movies I have already watched and liked.
Args:
title (str): original movie title
comment (str): my personal comment about the movie. Translate it in English if necessary.
"""
log_tool_calls("[Tool Call] Add movie to watched and liked")
LOG.debug("Tool call: add_title_to_movies_I_have_already_watched")
record = {
"title": title,
"status": "watched_liked",
"comment": comment,
"content": f"{title} (watched_liked) {comment}"
}
add_update_movie(record)
@tool
def add_title_to_movies_I_watched_and_disliked(title: str, comment: str) -> None:
"""Add a movie title to the list of movies I have already watched but did not like.
Args:
title (str): original movie title
comment (str): my personal comment about the movie. Translate it in English if necessary.
"""
log_tool_calls("[Tool Call] Add movie to watched and disliked")
LOG.debug("Tool call: add_title_to_movies_I_have_already_watched")
record = {
"title": title,
"status": "watched_disliked",
"comment": comment,
"content": f"{title} (watched_disliked) {comment}"
}
add_update_movie(record)
@tool
def add_title_to_movies_I_have_never_watched_but_want_to(title: str, comment: str) -> None:
"""Add a movie title to the list of movies I have never watched but I want to watch later.
Args:
title (str): movie title
comment (str): my personal comment about the movie. Translate it in English if necessary.
"""
log_tool_calls("[Tool Call] Add movie to must_see list")
LOG.debug("Tool call: add_title_to_movies_I_have_never_watched_but_want_to")
record = {
"title": title,
"status": "must_see",
"comment": comment,
"content": f"{title} (must_see) {comment}"
}
add_update_movie(record)
@tool
def add_title_to_movies_I_have_never_watched_and_dont_want_to(title: str, comment: str) -> None:
"""Add a movie title to the list of movies I have never watched and I don't want to watch then ever.
Args:
title (str): movie title
comment (str): my personal comment about the movie. Translate it in English if necessary.
"""
log_tool_calls("[Tool Call] Add movie to not_interested list")
LOG.debug("Tool call: add_title_to_movies_I_have_never_watched_and_dont_want_to")
record = {
"title": title,
"status": "not_interested",
"comment": comment,
"content": f"{title} (not_interested) {comment}"
}
add_update_movie(record)
@tool
def search_movie_history_for_info_and_preferences(query: str) -> list | None:
"""Search the database for any kind of information regarding my movie history, preferences, likes and dislikes.
Args:
title (str): movie title to delete
"""
log_tool_calls("[Tool Call] Vector search in movie history")
LOG.debug("Tool call: search_for_personal_information_in_movie_history")
found_docs = vector_store.similarity_search(query, k=3)
if not found_docs:
LOG.debug("No docs found in table.")
return
context = []
for doc in found_docs:
context.append(doc.page_content)
return context
@tool
def query_tmdb_database_for_information_about_a_movie(query: str) -> list:
"""Query the TMDB database for a movie title and return some information about it.
Args:
query (str): movie title to search for
Return:
list: list of matching movie info dicts
"""
log_tool_calls("[Tool Call] Query TMDB")
LOG.debug("Tool call: query_tmdb_database_for_information_about_a_movie")
# Base URL for the search endpoint
url = "https://api.themoviedb.org/3/search/movie"
# Parameters for the API request
params = {
'api_key': os.getenv("TMDB_BEARER_TOKEN"),
'query': query,
'include_adult': False,
}
# Making the GET request to the API
response = requests.get(url, params=params)
# Checking if the request was successful
if response.status_code == 200:
# Parsing the JSON response
data = response.json()
return data['results']
else:
# If the request failed, print the status code
print(f"Error: {response.status_code}")
return None
@tool
def get_all_movies_from_my_watch_lists() -> dict:
"""Get all movie titles listed in my 3 watch litsts.
Returns:
dict: dictionary with the three lists of titles.
"""
log_tool_calls("[Tool Call] Get watch lists")
LOG.debug("Tool call: get_all_movies_from_my_watch_lists")
return get_watch_lists()
# List of tools
agent_tools = [
add_title_to_movies_I_watched_and_liked,
add_title_to_movies_I_watched_and_disliked,
add_title_to_movies_I_have_never_watched_but_want_to,
add_title_to_movies_I_have_never_watched_and_dont_want_to,
search_movie_history_for_info_and_preferences,
query_tmdb_database_for_information_about_a_movie,
get_all_movies_from_my_watch_lists,
]