Skip to content

Commit

Permalink
feat(huixiangdou): add code retrieval (#379)
Browse files Browse the repository at this point in the history
* feat(huixiangdou): add bm25 retriever

* feat(huixiangdou): support code search
  • Loading branch information
tpoisonooo authored Sep 2, 2024
1 parent d301d37 commit 96c1f5d
Show file tree
Hide file tree
Showing 18 changed files with 381 additions and 43 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ evaluation/rejection/gt_good.txt
workdir832/
workdir.bak/
workdir-20240729-kg-included/
bm25.pkl
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ If this helps you, please give it a star ⭐

Our Web version has been released to [OpenXLab](https://openxlab.org.cn/apps/detail/tpoisonooo/huixiangdou-web), where you can create knowledge base, update positive and negative examples, turn on web search, test chat, and integrate into Feishu/WeChat groups. See [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn) and [YouTube](https://www.youtube.com/watch?v=ylXrT-Tei-Y) !

- \[2024/09\] [code retrieval](./huixiangdou/service/parallel_pipeline.py)
- \[2024/08\] [chat_with_readthedocs](https://huixiangdou.readthedocs.io/en/latest/), see [how to integrate](./docs/zh/doc_add_readthedocs.md) 👍
- \[2024/07\] Image and text retrieval & Removal of `langchain` 👍
- \[2024/07\] [Hybrid Knowledge Graph and Dense Retrieval](./docs/en/doc_knowledge_graph.md) improve 1.7% F1 score 🎯
Expand Down Expand Up @@ -117,10 +118,12 @@ Our Web version has been released to [OpenXLab](https://openxlab.org.cn/apps/det

<td>

- Dense for Document
- Sparse for Code
- [Knowledge Graph](./docs/en/doc_knowledge_graph.md)
- [Internet Search](./huixiangdou/service/web_search.py)
- [SourceGraph](https://sourcegraph.com)
- Image and text (only markdown)
- Image and Text

</td>

Expand Down
4 changes: 3 additions & 1 deletion README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

Web 版视频教程见 [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn)[YouTube](https://www.youtube.com/watch?v=ylXrT-Tei-Y)

- \[2024/09\] 稀疏方法实现[代码检索](./huixiangdou/service/parallel_pipeline.py)
- \[2024/08\] ["chat_with readthedocs"](https://huixiangdou.readthedocs.io/zh-cn/latest/) ,见[集成说明](./docs/zh/doc_add_readthedocs.md)
- \[2024/07\] 图文检索 & 移除 `langchain` 👍
- \[2024/07\] [混合知识图谱和稠密检索,F1 提升 1.7%](./docs/zh/doc_knowledge_graph.md) 🎯
Expand Down Expand Up @@ -119,10 +120,11 @@ Web 版视频教程见 [BiliBili](https://www.bilibili.com/video/BV1S2421N7mn)

<td>

- 文档用稠密,代码用稀疏
- [知识图谱](./docs/zh/doc_knowledge_graph.md)
- [联网搜索](./huixiangdou/service/web_search.py)
- [SourceGraph](https://sourcegraph.com)
- 图文混合(仅 markdown)
- 图文混合

</td>

Expand Down
2 changes: 1 addition & 1 deletion evaluation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ For bce-embedding-base_v1
For bge-large-zh-v1.5

- The chunksize range should be (423, 1240)
- The compression rate of embedding.tokenzier is slightly lower
- The compression rate of embedding.tokenizer is slightly lower
- The best F1@throttle obtained on the right value is [email protected]

The basis for choosing splitter is:
Expand Down
2 changes: 1 addition & 1 deletion evaluation/README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ print(result)
对 bge-large-zh-v1.5

- chunksize 范围应在 (423, 1240)
- embedding.tokenzier 的压缩率略低
- embedding.tokenizer 的压缩率略低
- 右值取到的最佳 F1@throttle 为 [email protected]

splitter 选择依据
Expand Down
2 changes: 1 addition & 1 deletion huixiangdou/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
from .service import SerialPipeline, ParallelPipeline # no E401
from .service import build_reply_text # noqa E401
from .service import llm_serve # noqa E401
from .version import __version__
from .version import __version__
20 changes: 17 additions & 3 deletions huixiangdou/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def parse_args():

language='en'
enable_web_search=False
enable_code_search=True
pipeline='chat_with_repo'
main_args = None
paralle_assistant = None
Expand All @@ -76,6 +77,13 @@ def on_web_search_changed(value: str):
else:
enable_web_search = True

def on_code_search_changed(value: str):
global enable_code_search
print(value)
if 'no' in value:
enable_code_search = False
else:
enable_code_search = True

def format_refs(refs: List[str]):
refs_filter = list(set(refs))
Expand All @@ -92,7 +100,6 @@ def format_refs(refs: List[str]):
text += '\r\n'
return text


async def predict(text:str, image:str):
global language
global enable_web_search
Expand Down Expand Up @@ -142,6 +149,7 @@ async def predict(text:str, image:str):
paralle_assistant = ParallelPipeline(work_dir=main_args.work_dir, config_path=main_args.config_path)
args = {'query':query, 'history':[], 'language':language}
args['enable_web_search'] = enable_web_search
args['enable_code_search'] = enable_code_search

sentence = ''
async for sess in paralle_assistant.generate(**args):
Expand Down Expand Up @@ -212,6 +220,7 @@ def build_feature_store(main_args):
gr.Markdown("""
#### [HuixiangDou](https://github.com/internlm/huixiangdou) AI assistant
""", label='Reply', header_links=True, line_breaks=True,)

with gr.Row():
if len(radio_options) > 1:
with gr.Column():
Expand All @@ -222,17 +231,22 @@ def build_feature_store(main_args):
ui_language.change(fn=on_language_changed, inputs=ui_language, outputs=[])
with gr.Column():
ui_web_search = gr.Radio(["no", "yes"], label="Enable web search", info="Disable by default ")
ui_web_search.change(on_web_search_changed, inputs=ui_web_search, outputs=[])
ui_web_search.change(fn=on_web_search_changed, inputs=ui_web_search, outputs=[])
with gr.Column():
ui_code_search = gr.Radio(["yes", "no"], label="Enable code search", info="Enable by default ")
ui_code_search.change(fn=on_code_search_changed, inputs=ui_code_search, outputs=[])

with gr.Row():
input_question = gr.TextArea(label='Input your question', placeholder=main_args.placeholder, show_copy_button=True, lines=9)
input_image = gr.Image(label='[Optional] Image-text retrieval needs `config-multimodal.ini`', render=show_image)

with gr.Row():
run_button = gr.Button()

with gr.Row():
result = gr.Markdown('>Text reply or inner status callback here, depends on `pipeline type`', label='Reply', show_label=True, header_links=True, line_breaks=True, show_copy_button=True)
# result = gr.TextArea(label='Reply', show_copy_button=True, placeholder='Text Reply or inner status callback, depends on `pipeline type`')

run_button.click(predict, [input_question, input_image], [result])
demo.queue()
demo.launch(share=False, server_name='0.0.0.0', debug=True)
3 changes: 2 additions & 1 deletion huixiangdou/primitive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
MarkdownHeaderTextSplitter,
MarkdownTextRefSplitter,
RecursiveCharacterTextSplitter,
nested_split_markdown)
nested_split_markdown, split_python_code)
from .rpm import RPM
from .bm250kapi import BM25Okapi
189 changes: 189 additions & 0 deletions huixiangdou/primitive/bm250kapi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
#!/usr/bin/env python
# heavily modified from https://github.com/dorianbrown/rank_bm25/blob/master/rank_bm25.py
import math
import numpy as np
import pickle as pkl
import os
import jieba.analyse

from loguru import logger
from typing import List, Union
from .chunk import Chunk

"""
All of these algorithms have been taken from the paper:
Trotmam et al, Improvements to BM25 and Language Models Examined
Here we implement all the BM25 variations mentioned.
"""

class BM25Okapi:
def __init__(self, k1=1.5, b=0.75, epsilon=0.25):
# BM25Okapi parameters
self.k1 = k1
self.b = b
self.epsilon = epsilon

# dump to pickle
self.corpus_size = 0
self.avgdl = 0
self.doc_freqs = []
self.idf = {}
self.doc_len = []
self.average_idf = 0.0
self.chunks = []

# option
self.tokenizer = jieba.analyse.extract_tags

def _initialize(self, corpus):
nd = {} # word -> number of documents with word
num_doc = 0
for document in corpus:
self.doc_len.append(len(document))
num_doc += len(document)

frequencies = {}
for word in document:
if word not in frequencies:
frequencies[word] = 0
frequencies[word] += 1
self.doc_freqs.append(frequencies)

for word, freq in frequencies.items():
try:
nd[word]+=1
except KeyError:
nd[word] = 1

self.corpus_size += 1

self.avgdl = num_doc / self.corpus_size
return nd

def _tokenize_corpus(self, corpus):
tokenized_corpus = self.tokenizer(corpus)
return tokenized_corpus

def save(self, chunks:List[Chunk], filedir:str):
# generate idf with corpus
self.chunks = chunks

filtered_corpus = []
for c in chunks:
content = c.content_or_path
if self.tokenizer is not None:
# input str, output list of str
corpus = self.tokenizer(content)
if content not in corpus:
corpus.append(content)
else:
logger.warning('No tokenizer, use naive split')
corpus = content.split(' ')
filtered_corpus.append(corpus)

nd = self._initialize(filtered_corpus)
self._calc_idf(nd)

# dump to `filepath`
data = {
'corpus_size': self.corpus_size,
'avgdl': self.avgdl,
'doc_freqs': self.doc_freqs,
'idf': self.idf,
'doc_len': self.doc_len,
'average_idf': self.average_idf,
'chunks': chunks
}
logger.info('bm250kpi dump..')
# logger.info(data)

if not os.path.exists(filedir):
os.makedirs(filedir)

filepath = os.path.join(filedir, 'bm25.pkl')
with open(filepath, 'wb') as f:
pkl.dump(data, f)

def load(self, filedir: str, tokenizer=None):
self.tokenizer = tokenizer
filepath = os.path.join(filedir, 'bm25.pkl')
with open(filepath, 'rb') as f:
data = pkl.load(f)
self.corpus_size = data['corpus_size']
self.avgdl = data['avgdl']
self.doc_freqs = data['doc_freqs']
self.idf = data['idf']
self.doc_len = data['doc_len']
self.average_idf = data['average_idf']
self.chunks = data['chunks']

def _calc_idf(self, nd):
"""
Calculates frequencies of terms in documents and in corpus.
This algorithm sets a floor on the idf values to eps * average_idf
"""
# collect idf sum to calculate an average idf for epsilon value
idf_sum = 0
# collect words with negative idf to set them a special epsilon value.
# idf can be negative if word is contained in more than half of documents
negative_idfs = []
for word, freq in nd.items():
idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5)
self.idf[word] = idf
idf_sum += idf
if idf < 0:
negative_idfs.append(word)
self.average_idf = idf_sum / len(self.idf)

eps = self.epsilon * self.average_idf
for word in negative_idfs:
self.idf[word] = eps

def get_scores(self, query: List):
"""
The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores,
this algorithm also adds a floor to the idf value of epsilon.
See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info
:param query:
:return:
"""
if type(query) is not list:
raise ValueError('query must be list, tokenize it byself.')
score = np.zeros(self.corpus_size)
doc_len = np.array(self.doc_len)
for q in query:
q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs])
score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) /
(q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)))
return score

def get_batch_scores(self, query, doc_ids):
"""
Calculate bm25 scores between query and subset of all docs
"""
assert all(di < len(self.doc_freqs) for di in doc_ids)
score = np.zeros(len(doc_ids))
doc_len = np.array(self.doc_len)[doc_ids]
for q in query:
q_freq = np.array([(self.doc_freqs[di].get(q) or 0) for di in doc_ids])
score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) /
(q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)))
return score.tolist()

def get_top_n(self, query: Union[List,str], n=5):
if type(query) is str:
if self.tokenizer is not None:
queries = self.tokenizer(query)
else:
queries = query.split(' ')
else:
queries = query

scores = self.get_scores(queries)
top_n = np.argsort(scores)[::-1][:n]
logger.info('{} {}'.format(scores, top_n))
if abs(scores[top_n[0]]) < 1e-5:
# not match, quit
return []
return [self.chunks[i] for i in top_n]
26 changes: 16 additions & 10 deletions huixiangdou/primitive/file_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self):
self.ppt_suffix = '.pptx'
self.html_suffix = ['.html', '.htm', '.shtml', '.xhtml']
self.word_suffix = ['.docx', '.doc']
# self.code_suffix = ['.py', '.cpp', '.h']
self.code_suffix = ['.py']
self.normal_suffix = [self.md_suffix
] + self.text_suffix + self.excel_suffix + [
self.pdf_suffix
Expand Down Expand Up @@ -111,9 +111,9 @@ def get_type(self, filepath: str):
if filepath.endswith(suffix):
return 'html'

# for suffix in self.code_suffix:
# if filepath.endswith(suffix):
# return 'code'
for suffix in self.code_suffix:
if filepath.endswith(suffix):
return 'code'
return None

def md5(self, filepath: str):
Expand Down Expand Up @@ -216,15 +216,21 @@ def read(self, filepath: str):
soup = BeautifulSoup(f.read(), 'html.parser')
text += soup.text

elif file_type == 'code':
with open(filepath, errors="ignore") as f:
text += f.read()

except Exception as e:
logger.error((filepath, str(e)))
return '', e
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')

if file_type != 'code':
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace('\n\n', '\n')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')
text = text.replace(' ', ' ')
return text, None


Expand Down
Loading

0 comments on commit 96c1f5d

Please sign in to comment.