-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRagger.py
229 lines (194 loc) · 9.01 KB
/
Ragger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_chroma import Chroma
from langchain_core.documents import Document
import bs4
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_community.chat_models import ChatOllama
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import OllamaEmbeddings
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_community.document_loaders import TextLoader
import ollama
from langchain.embeddings import HuggingFaceBgeEmbeddings
import json
import faiss
from langchain_community.vectorstores import FAISS
from langchain.storage import InMemoryByteStore
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
import uuid
from langchain_experimental.text_splitter import SemanticChunker
from langchain_core.prompts import ChatPromptTemplate
import logging
logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
import ollama
class RAGRetriever():
def __init__(self,text_path,device='cuda',embed_device='cpu',vecStoreDir = './',GeneratorModel='llama3:70b', temp=0):
self.text_txt = open(text_path, "r").read()
self.text = TextLoader(text_path).load()
'''
self.embedding_function = OllamaEmbeddings(
model="llama3:8b", temperature=0)
'''
model_name = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device': embed_device}
encode_kwargs = {"normalize_embeddings": True}
self.embedding_function = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs)
self.vecStoreDir = vecStoreDir
self.chatModel = ChatOllama(model=GeneratorModel, temperature=temp,
device=device)
def update_chat_model(self, GeneratorModel='llama3:70b', temp=0, device='cuda'):
self.chatModel = ChatOllama(model=GeneratorModel, temperature=temp, device=device)
def get_chunks(self,chunk_size =1000,chunk_overlap=500):
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
#separators=["\n\n"],
length_function=len,
is_separator_regex=False,
)
#texts = text_splitter.create_documents([self.text])
texts = text_splitter.split_documents(self.text)
'''
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size*.2,
chunk_overlap=chunk_overlap*.2,
length_function=len,
is_separator_regex=False,
)
texts2 = text_splitter.split_documents(self.text)
texts = texts + texts2
'''
self.chunks = texts
return texts
def get_chunks_2(self, chunk_size, chunk_overlap):
chunks = []
start = 0
while start < len(self.text):
end = min(start + chunk_size, len(self.text))
chunks.append(self.text[start:end])
# Move the start to the next chunk, ensuring overlap
start += chunk_size - chunk_overlap
return chunks
def semantic_split(self,docs: list[Document]) -> list[Document]:
"""
Semantic chunking
Args:
docs (List[Document]): List of documents to chunk
Returns:
List[Document]: List of chunked documents
"""
splitter = SemanticChunker(
self.embedding_function, breakpoint_threshold_type="gradient"
)
return splitter.split_documents(docs)
def createVecStore_multiVec(self,chunk_size =1000,chunk_overlap=500):
store = InMemoryByteStore()
id_key = "doc_id"
# The retriever (empty to start)
docs = self.get_chunks(chunk_size =chunk_size,chunk_overlap=chunk_overlap)
retriever = MultiVectorRetriever(
vectorstore=Chroma(collection_name="summaries", embedding_function=self.embedding_function),
byte_store=store,
id_key=id_key,
search_kwargs={"k": 4},
search_type='mmr')
doc_ids = [str(uuid.uuid4()) for _ in docs]
child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=200,
chunk_overlap=50,
length_function=len,
is_separator_regex=False)
sub_docs = []
for i, doc in enumerate(docs):
_id = doc_ids[i]
_sub_docs = child_text_splitter.split_documents([doc])
for _doc in _sub_docs:
_doc.metadata[id_key] = _id
#_doc.metadata["original_chunk"] = doc.page_content # Store original chunk content
sub_docs.extend(_sub_docs)
retriever.vectorstore.add_documents(sub_docs)
retriever.docstore.mset(list(zip(doc_ids, docs)))
#retriever1 = MultiQueryRetriever.from_llm( retriever=retriever,
# llm=self.chatModel)
self.MultiRetriever = retriever
#return retriever, doc_ids
def get_chunksLLM(self):
prompt= """
For processing text it is essential to chunk text for LLMs.
Chunks must be coherrent and related.
Instructions:
1- Related text must be in one chunk, just add the seperator "chunk_here"
before each chunk.
2- You must not to modify or delete and text and must reproduce the whole text.
3- Read each line and understand the relation to the previous text.
Adhere strictly to the instructions.
Here is the text:
"""
prompt = prompt + self.text_txt
response = ollama.chat(model='llama3:8b', messages=[
{
'role': 'user',
'content': prompt,
},
])
res = response['message']['content']
with open("DocLLM.txt", "w") as text_file:
text_file.write(res)
print(res)
return res
def createVecStore(self,chunk_size =1000,chunk_overlap=500):
self.db = Chroma.from_documents(self.get_chunks(chunk_size=chunk_size,chunk_overlap=chunk_overlap)
, self.embedding_function)#, persist_directory=self.vecStoreDir)
return self.db
def retrieve(self,Q='query'):
try:
retriever = self.db.as_retriever(search_kwargs={"k": 2})
#retriever1 = self.multiVec()
#retriever = MultiQueryRetriever.from_llm( retriever=retriever,
# llm=self.chatModel)
#query_result = retriever.retriever.vectorstore.search(Q) # Assuming a search method that returns results
#document_ids_used = [doc.metadata['doc_id'] for doc in query_result]
self.retrieverTest = retriever
except:
print('Create or load the vector store before,\
use the functions: createVecStore() or\
load_vectorStore()')
self.prompt = hub.pull("rlm/rag-prompt")
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| self.prompt
| self.chatModel
| StrOutputParser()
)
return rag_chain.invoke(Q)
def mRetriever(self,Q='query',hint=''):
#retriever_context, ids = self.multiVec(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
RETRIEVER = self.MultiRetriever
#self.used_documents = retriever_context.get_used_documents()
#CH.MultiRetriever.vectorstore.get()
# Prompt template
template = """Answer the question based only on the following context, which can include text and tables.
{context}
Question: {question}
You are a only a retriever and you should only output the text as requested (copy paste). If the asnwer not found only write 'not found'.
"""
if len(hint)>0:
template = template + 'Hint: ' + hint
prompt = ChatPromptTemplate.from_template(template)
self.prompt = prompt
self.template = template
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
chain = {"context": RETRIEVER | format_docs,
"question": RunnablePassthrough()} | prompt | self.chatModel | StrOutputParser()
return chain.invoke(Q)