-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpygamgee.py
151 lines (125 loc) · 4.92 KB
/
pygamgee.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# -*- coding: utf-8 -*-
# @Author: boyac
# @Date: 2025-02-20 08:18:18
# @Last Modified by: boyac
# @Last Modified time: 2025-02-20 08:18:18
import os
import fitz # PyMuPDF
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_ollama import OllamaEmbeddings
from langchain_ollama import OllamaLLM
from langchain.chains import RetrievalQA
from langchain.schema import Document
from langchain.memory import ConversationBufferMemory # ADD THIS IMPORT
# Folder path (use absolute path, replace with your actual path)
data_dir = r"data"
faiss_index_dir = "faiss_index"
mymodel = "deepseek-r1:1.5b"
# Check the initial size
initial_size = 0
if os.path.exists(os.path.join(faiss_index_dir, "index.faiss")):
initial_size = os.path.getsize(os.path.join(faiss_index_dir, "index.faiss"))
print(f"Initial FAISS index size: {initial_size} bytes")
# Load documents
documents = []
print(f"Loading documents, folder: {data_dir}")
for filename in os.listdir(data_dir):
if filename.endswith(".pdf"):
try:
filepath = os.path.join(data_dir, filename)
print(f"Attempting to load file: {filepath}")
# Use PyMuPDF to read PDF content
with fitz.open(filepath) as doc:
text = ""
for page in doc:
text += page.get_text()
# Create Langchain Document object
documents.append(Document(page_content=text, metadata={"source": filename}))
print(f"Successfully loaded {filename}")
except Exception as e:
print(f"❌ Failed to load file {filename}: {e}")
print(f"Total number of documents loaded: {len(documents)}")
if not documents:
print("❌ No documents loaded, please check folder and file format")
exit()
# Split text
text_splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=20) # Adjust chunk_size and chunk_overlap
documents = text_splitter.split_documents(documents)
print(f"Number of split documents: {len(documents)}")
if not documents:
print("❌ No documents after splitting, please check document content and splitting settings")
exit()
# Create vector database (Load if exists, otherwise create and save)
print("Creating vector database")
embeddings = OllamaEmbeddings(model=mymodel)
if os.path.exists(faiss_index_dir):
print("Loading FAISS index from disk...")
try:
db = FAISS.load_local(faiss_index_dir, embeddings, allow_dangerous_deserialization=True)
print("FAISS index loaded successfully.")
except Exception as e:
print(f"❌ Failed to load FAISS index: {e}")
exit()
else:
try:
db = FAISS.from_documents(documents, embeddings)
print("Successfully created vector database")
# Save the FAISS index to disk
os.makedirs(faiss_index_dir, exist_ok=True) # Ensure the directory exists
db.save_local(faiss_index_dir)
print(f"FAISS index saved to: {faiss_index_dir}")
except Exception as e:
print(f"❌ Failed to create vector database: {e}")
exit()
# Create QA Chain
print("Creating QA Chain")
llm = OllamaLLM(model=mymodel)
# ADD THIS SECTION:
use_memory = True # Set to True to use memory, False to disable it
memory = None # Initialize memory to None
if use_memory:
memory = ConversationBufferMemory(
llm=llm,
memory_key="chat_history",
return_messages=True,
output_key='result' # Specify the output key
)
qa_chain = RetrievalQA.from_chain_type(
llm,
chain_type="stuff",
retriever=db.as_retriever(),
memory=memory if use_memory else None, # Pass memory conditionally
return_source_documents=True
)
def pretty_print_docs(documents):
print(f"\n{'-' * 100}\n".join(f"Document {i+1}:\n\n" + d.page_content for i, d in enumerate(documents)))
if __name__ == "__main__":
# Question
prompt = """
What is the best why to understand business consolidation?
"""
query = prompt # Simplified question
print(f"Question: {query}")
try:
result = qa_chain.invoke({"query": query})
print(f"Answer: {result['result']}")
if 'source_documents' in result:
print("\nSource Documents:")
pretty_print_docs(result['source_documents'])
else:
print("\nNo source documents found.")
if use_memory:
# **ADD THIS SECTION TO PRINT THE MEMORY**
print("\n--- Conversation History ---")
# Print the memory ONLY if it exists (use_memory is True)
print(memory.load_memory_variables({})) # Check the memory here
except Exception as e:
print(f"❌ QA execution failed: {e}")
print("Program execution finished")
# Check the size
final_size = 0
if os.path.exists(os.path.join(faiss_index_dir, "index.faiss")):
final_size = os.path.getsize(os.path.join(faiss_index_dir, "index.faiss"))
print(f"Final FAISS index size: {final_size} bytes")