-
Notifications
You must be signed in to change notification settings - Fork 139
/
Copy pathparallel_pipeline.py
389 lines (317 loc) · 14.5 KB
/
parallel_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
# Copyright (c) OpenMMLab. All rights reserved.
"""Pipeline."""
import argparse
import asyncio
import json
import copy
from typing import List, Tuple, Union, Generator, AsyncGenerator
import pytoml
from loguru import logger
from huixiangdou.primitive import Query, Chunk
from .helper import ErrorCode
from .llm_client import ChatClient
from .retriever import CacheRetriever, Retriever
from .session import Session
from .web_search import WebSearch
from .prompt import (INTENTION_TEMPLATE_CN, CR_CN, SCORING_RELAVANCE_TEMPLATE_CN, KEYWORDS_TEMPLATE_CN)
from .prompt import (INTENTION_TEMPLATE_EN, CR_EN, SCORING_RELAVANCE_TEMPLATE_EN, KEYWORDS_TEMPLATE_EN)
from .prompt import CitationGeneratePrompt
class PreprocNode:
"""PreprocNode is for coreference resolution and scoring based on group
chats.
See https://arxiv.org/abs/2405.02817
"""
def __init__(self, config: dict, llm: ChatClient, language: str):
self.llm = llm
self.enable_cr = config['worker']['enable_cr']
if language == 'zh':
self.INTENTION_TEMPLATE = INTENTION_TEMPLATE_CN
self.CR = CR_CN
else:
self.INTENTION_TEMPLATE = INTENTION_TEMPLATE_EN
self.CR = CR_EN
def process(self, sess: Session) -> Generator[Session, None, None]:
# check input
if sess.query.text is None or len(sess.query.text) < 2:
sess.code = ErrorCode.QUESTION_TOO_SHORT
yield sess
return
prompt = self.INTENTION_TEMPLATE.format(sess.query.text)
json_str = self.llm.generate_response(prompt=prompt, backend='remote')
sess.debug['PreprocNode_intention_response'] = json_str
logger.info('intention response {}'.format(json_str))
try:
if json_str.startswith('```json'):
json_str = json_str[len('```json'):]
if json_str.endswith('```'):
json_str = json_str[0:-3]
json_obj = json.loads(json_str)
intention = json_obj['intention']
if intention is not None:
intention = intention.lower()
else:
intention = 'undefine'
topic = json_obj['topic']
if topic is not None:
topic = topic.lower()
else:
topic = 'undefine'
for block_intention in ['问候', 'greeting', 'undefine']:
if block_intention in intention:
sess.code = ErrorCode.NOT_A_QUESTION
yield sess
return
for block_topic in ['身份', 'identity', 'undefine']:
if block_topic in topic:
sess.code = ErrorCode.NOT_A_QUESTION
yield sess
return
except Exception as e:
logger.error(str(e))
if not self.enable_cr:
yield sess
return
if len(sess.groupchats) < 1:
logger.debug('history conversation empty, skip CR')
yield sess
return
talks = []
# rewrite user_id to ABCD..
name_map = dict()
name_int = ord('A')
for msg in sess.groupchats:
sender = msg.sender
if sender not in name_map:
name_map[sender] = chr(name_int)
name_int += 1
talks.append({'sender': name_map[sender], 'content': msg.query})
talk_str = json.dumps(talks, ensure_ascii=False)
prompt = self.CR.format(talk_str, sess.query.text)
self.cr = self.llm.generate_response(prompt=prompt, backend='remote')
if self.cr.startswith('“') and self.cr.endswith('”'):
self.cr = self.cr[1:len(self.cr) - 1]
if self.cr.startswith('"') and self.cr.endswith('"'):
self.cr = self.cr[1:len(self.cr) - 1]
sess.debug['cr'] = self.cr
# rewrite query
queries = [sess.query.text, self.cr]
self.query = '\n'.join(queries)
logger.debug('merge query and cr, query: {} cr: {}'.format(
self.query, self.cr))
class Text2vecRetrieval:
"""Text2vecNode is for retrieve from knowledge base."""
def __init__(self, retriever: Retriever):
self.retriever = retriever
async def process_coroutine(self, sess: Session) -> Session:
"""Try get reply with text2vec & rerank model."""
# retrieve from knowledge base
sess.parallel_chunks = await asyncio.to_thread(self.retriever.text2vec_retrieve, sess.query)
return sess
class InvertedIndexRetrieval:
"""Text2vecNode is for retrieve from knowledge base."""
def __init__(self, retriever: Retriever):
self.retriever = retriever
async def process_coroutine(self, sess: Session) -> Session:
"""Try get reply with text2vec & rerank model."""
# retrieve from knowledge base
sess.parallel_chunks = await asyncio.to_thread(self.retriever.inverted_index_retrieve, sess.query)
return sess
class CodeRetrieval:
"""CodeNode is for retrieve from codebase."""
def __init__(self, retriever: Retriever):
self.retriever = retriever
async def process_coroutine(self, sess: Session) -> Session:
"""Try get reply with text2vec & rerank model."""
# retrieve from knowledge base
if self.retriever.bm25 is None:
sess.parallel_chunks = []
return sess
sess.parallel_chunks = self.retriever.bm25.get_top_n(query=sess.query.text)
return sess
class WebSearchRetrieval:
"""WebSearchNode is for web search, use `ddgs` or `serper`"""
def __init__(self, config: dict, config_path: str, llm: ChatClient,
language: str):
self.llm = llm
self.config_path = config_path
self.enable = config['worker']['enable_web_search']
llm_config = config['llm']
self.context_max_length = llm_config['server'][
'local_llm_max_text_length']
self.language = language
if llm_config['enable_remote']:
self.context_max_length = llm_config['server'][
'remote_llm_max_text_length']
if language == 'zh':
self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_CN
self.KEYWORDS_TEMPLATE = KEYWORDS_TEMPLATE_CN
else:
self.SCORING_RELAVANCE_TEMPLATE = SCORING_RELAVANCE_TEMPLATE_EN
self.KEYWORDS_TEMPLATE = KEYWORDS_TEMPLATE_EN
async def process(self, sess: Session) -> AsyncGenerator[Session, None]:
"""Try web search."""
if not self.enable:
logger.debug('disable web_search')
yield sess
return
engine = WebSearch(config_path=self.config_path, language=self.language)
prompt = self.KEYWORDS_TEMPLATE.format(sess.groupname, sess.query.text)
search_keywords = self.llm.generate_response(prompt)
search_keywords = search_keywords.replace('"', '')
sess.debug['WebSearchNode_keywords'] = prompt
articles, error = await asyncio.to_thread(engine.get, search_keywords, 4)
if error is not None:
sess.code = ErrorCode.WEB_SEARCH_FAIL
sess.parallel_chunks = []
yield sess
return
if len(articles) < 1:
sess.code = ErrorCode.NO_SEARCH_RESULT
sess.parallel_chunks = []
yield sess
return
for _, article in enumerate(articles):
article.cut(0, self.context_max_length)
c = Chunk(content_or_path=article.content, metadata={'source': article.source})
sess.parallel_chunks.append(c)
yield sess
async def process_coroutine(self, sess: Session) -> Session:
results = []
async for value in self.process(sess=sess):
results.append(value)
return results[-1]
class ReduceGenerate:
def __init__(self, config: dict, llm: ChatClient, retriever: CacheRetriever, language: str):
self.llm = llm
self.retriever = retriever
llm_config = config['llm']
self.context_max_length = llm_config['server']['local_llm_max_text_length']
if llm_config['enable_remote']:
self.context_max_length = llm_config['server']['remote_llm_max_text_length']
self.language = language
async def process(self, sess: Session) -> AsyncGenerator[Session, None]:
question = sess.query.text
history = sess.history
if len(sess.parallel_chunks) < 1:
# direct chat
async for part in self.llm.chat_stream(prompt=question, history=history):
sess.delta = part
yield sess
else:
_, _, references, context_texts = self.retriever.rerank_fuse(query=sess.query, chunks=sess.parallel_chunks, context_max_length=self.context_max_length)
sess.references = references
citation = CitationGeneratePrompt(self.language)
prompt = citation.build(texts=context_texts, question=question)
async for part in self.llm.chat_stream(prompt=prompt, history=history):
sess.delta = part
yield sess
class ParallelPipeline:
"""The ParallelPipeline class orchestrates the logic of handling user queries,
generating responses and managing several aspects of a chat assistant. It
enables feature storage, language model client setup, time scheduling and
much more.
Attributes:
llm: A ChatClient instance that communicates with the language model.
fs: An instance of FeatureStore for loading and querying features.
config_path: A string indicating the path of the configuration file.
config: A dictionary holding the configuration settings.
context_max_length: An integer representing the maximum length of the context used by the language model. # noqa E501
Several template strings for various prompts are also defined.
"""
def __init__(self, work_dir: str, config_path: str):
"""Constructs all the necessary attributes for the worker object.
Args:
work_dir (str): The working directory where feature files are located.
config_path (str): The location of the configuration file.
"""
self.llm = ChatClient(config_path=config_path)
self.retriever = CacheRetriever(config_path=config_path).get(work_dir=work_dir)
self.config_path = config_path
self.config = None
with open(config_path, encoding='utf8') as f:
self.config = pytoml.load(f)
if self.config is None:
raise Exception('worker config can not be None')
async def generate(self,
query: Union[Query, str],
history: List[Tuple[str]]=[],
language: str='zh',
enable_web_search: bool=True,
enable_code_search: bool=True):
"""Processes user queries and generates appropriate responses. It
involves several steps including checking for valid questions,
extracting topics, querying the feature store, searching the web, and
generating responses from the language model.
Args:
query (Union[Query,str]): User's multimodal query.
history (str): Chat history.
language (str): zh or en.
enable_web_search (bool): enable web search or not, default value is True.
enable_code_search (bool): enable code search or not, default value is True.
Returns:
Session: Sync generator, this function would yield session which contains:
ErrorCode: An error code indicating the status of response generation. # noqa E501
str: Generated response to the user query.
references: List for referenced filename or web url
"""
# format input
if type(query) is str:
query = Query(text=query)
# build input session
sess = Session(query=query,
history=history,
log_path=self.config['worker']['save_path'])
# build pipeline
preproc = PreprocNode(self.config, self.llm, language)
text2vec = Text2vecRetrieval(self.retriever)
inverted_index = InvertedIndexRetrieval(self.retriever)
coderetrieval = CodeRetrieval(self.retriever)
websearch = WebSearchRetrieval(self.config, self.config_path, self.llm, language)
reduce = ReduceGenerate(self.config, self.llm, self.retriever, language)
direct_chat_states = [
ErrorCode.QUESTION_TOO_SHORT, ErrorCode.NOT_A_QUESTION,
ErrorCode.NO_TOPIC, ErrorCode.UNRELATED
]
# if not a good question, return
for sess in preproc.process(sess):
if sess.code in direct_chat_states:
async for resp in reduce.process(sess):
yield resp
return
# parallel run text2vec, websearch and codesearch
tasks = [text2vec.process_coroutine(copy.deepcopy(sess)), inverted_index.process_coroutine(copy.deepcopy(sess))]
if enable_web_search:
tasks.append(websearch.process_coroutine(copy.deepcopy(sess)))
if enable_code_search:
tasks.append(coderetrieval.process_coroutine(copy.deepcopy(sess)))
task_results = await asyncio.gather(*tasks, return_exceptions=True)
for result in task_results:
if type(result) is Session:
sess.parallel_chunks += result.parallel_chunks
continue
logger.error(result)
async for sess in reduce.process(sess):
yield sess
return
def parse_args():
"""Parses command-line arguments."""
parser = argparse.ArgumentParser(description='SerialPipeline.')
parser.add_argument('work_dir', type=str, help='Working directory.')
parser.add_argument(
'--config_path',
default='config.ini',
help='SerialPipeline configuration path. Default value is config.ini')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
bot = ParallelPipeline(work_dir=args.work_dir, config_path=args.config_path)
loop = asyncio.get_event_loop()
queries = ['茴香豆是什么?', 'HuixiangDou 是什么?']
for q in queries:
async def wrap_async_as_coroutine():
async for sess in bot.generate(query=q, history=[], enable_web_search=False):
print(sess.delta, end='', flush=True)
pass
print('\n')
print(sess.references)
loop.run_until_complete(wrap_async_as_coroutine())