-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrepository.py
146 lines (128 loc) · 5.89 KB
/
repository.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import sqlite3
import time
import faiss
import numpy as np
from memory import Memory
class Repository:
def __init__(self, channel_id: int) -> None:
self.db_path = self.__get_db_path(channel_id)
self.__create_db_if_not_exists()
def clear_messages(self) -> None:
conn = sqlite3.connect(self.db_path)
conn.execute("DELETE FROM messages")
conn.commit()
conn.close()
def save_message(self, sender, content):
conn = sqlite3.connect(self.db_path)
conn.execute("INSERT INTO messages (sender, content) VALUES (?, ?)", (sender, content))
conn.commit()
conn.close()
def save_conversation_context(self, conversation_context):
conn = sqlite3.connect(self.db_path)
conn.execute("DELETE FROM conversation_context")
conn.execute("INSERT INTO conversation_context (context) VALUES (?)", (conversation_context,))
conn.commit()
conn.close()
def save_long_term_memory(self, long_term_memory, unix_timestamp, serialized_embedding):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("INSERT INTO long_term_memory_text (timestamp, memory_text, embedding_serialized_csv_text) VALUES (?, ?, ?)", (unix_timestamp, long_term_memory, serialized_embedding))
memory_id = cursor.lastrowid
conn.commit()
conn.close()
return memory_id
def save_long_term_memory_index(self, faiss_index):
conn = sqlite3.connect(self.db_path)
conn.execute("DELETE FROM long_term_memory_index")
conn.execute("INSERT INTO long_term_memory_index (serialized_faiss_index) VALUES (?)", (faiss.serialize_index(faiss_index),))
conn.commit()
conn.close()
def load_long_term_memory_index(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("SELECT serialized_faiss_index FROM long_term_memory_index")
serialized_index = cursor.fetchone()
conn.close()
if serialized_index:
serialized_index_np = np.frombuffer(serialized_index[0], dtype=np.uint8)
return faiss.deserialize_index(serialized_index_np)
else:
return None
def load_memory(self, id):
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("SELECT id, memory_text, timestamp, embedding_serialized_csv_text FROM long_term_memory_text WHERE id=?", (id,))
memory = cursor.fetchone()
conn.close()
return Memory(memory[0], memory[1], memory[2], memory[3]) if memory else None
# Load ordered embeddings ascending by id
def load_embeddings(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("SELECT embedding_serialized_csv_text FROM long_term_memory_text ORDER BY id ASC")
embeddings = cursor.fetchall()
conn.close()
return [list(map(float, embedding[0].split(','))) for embedding in embeddings]
def load_messages(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("SELECT sender, content FROM messages")
messages = cursor.fetchall()
conn.close()
return messages
def load_conversation_context(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.execute("SELECT context FROM conversation_context")
context = cursor.fetchone()
conn.close()
return context[0] if context else ''
def sync_conversation_context(self, conversation):
self.clear_messages()
for message in conversation.messages:
self.save_message(message.sender, message.content)
self.save_conversation_context(conversation.conversation_context)
self.save_long_term_memory(conversation.long_term_memory)
async def summarize_conversation(self, conversation, trigger_token_limit=400, conversation_window_tokens=100):
needed_summary=False
await conversation.lock.acquire()
try:
needed_summary=True
await conversation.run_summarizer()
new_messages = []
total_tokens = 0
for message in reversed(conversation.conversation_history):
potential_total_tokens = total_tokens + message.get_number_of_tokens()
if potential_total_tokens > trigger_token_limit:
break
new_messages.insert(0, message)
total_tokens = potential_total_tokens
if total_tokens > conversation_window_tokens:
break
conversation.conversation_history = new_messages
self.save_conversation_context(conversation.active_memory)
self.clear_messages()
for message in conversation.conversation_history:
self.save_message(message.sender, message.content)
finally:
conversation.lock.release()
return needed_summary
def __get_db_path(self, channel_id: int) -> str:
return os.path.join("conversations", f"{channel_id}.db")
def __create_db_if_not_exists(self) -> None:
conn = sqlite3.connect(self.db_path)
conn.execute('''CREATE TABLE IF NOT EXISTS messages
(id INTEGER PRIMARY KEY AUTOINCREMENT,
sender TEXT NOT NULL,
content TEXT NOT NULL);''')
conn.execute('''CREATE TABLE IF NOT EXISTS conversation_context
(id INTEGER PRIMARY KEY AUTOINCREMENT,
context TEXT NOT NULL);''')
conn.execute('''CREATE TABLE IF NOT EXISTS long_term_memory_text (
id INTEGER PRIMARY KEY,
timestamp INTEGER,
memory_text TEXT,
embedding_serialized_csv_text TEXT
)''')
conn.execute('''CREATE TABLE IF NOT EXISTS long_term_memory_index (
id INTEGER PRIMARY KEY,
serialized_faiss_index BLOB
)''')
conn.commit()
conn.close()