-
Notifications
You must be signed in to change notification settings - Fork 153
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(huixiangdou): add code retrieval (#379)
* feat(huixiangdou): add bm25 retriever * feat(huixiangdou): support code search
- Loading branch information
1 parent
d301d37
commit 96c1f5d
Showing
18 changed files
with
381 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,3 +57,4 @@ evaluation/rejection/gt_good.txt | |
workdir832/ | ||
workdir.bak/ | ||
workdir-20240729-kg-included/ | ||
bm25.pkl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ For bce-embedding-base_v1 | |
For bge-large-zh-v1.5 | ||
|
||
- The chunksize range should be (423, 1240) | ||
- The compression rate of embedding.tokenzier is slightly lower | ||
- The compression rate of embedding.tokenizer is slightly lower | ||
- The best F1@throttle obtained on the right value is [email protected] | ||
|
||
The basis for choosing splitter is: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ print(result) | |
对 bge-large-zh-v1.5 | ||
|
||
- chunksize 范围应在 (423, 1240) | ||
- embedding.tokenzier 的压缩率略低 | ||
- embedding.tokenizer 的压缩率略低 | ||
- 右值取到的最佳 F1@throttle 为 [email protected] | ||
|
||
splitter 选择依据 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
#!/usr/bin/env python | ||
# heavily modified from https://github.com/dorianbrown/rank_bm25/blob/master/rank_bm25.py | ||
import math | ||
import numpy as np | ||
import pickle as pkl | ||
import os | ||
import jieba.analyse | ||
|
||
from loguru import logger | ||
from typing import List, Union | ||
from .chunk import Chunk | ||
|
||
""" | ||
All of these algorithms have been taken from the paper: | ||
Trotmam et al, Improvements to BM25 and Language Models Examined | ||
Here we implement all the BM25 variations mentioned. | ||
""" | ||
|
||
class BM25Okapi: | ||
def __init__(self, k1=1.5, b=0.75, epsilon=0.25): | ||
# BM25Okapi parameters | ||
self.k1 = k1 | ||
self.b = b | ||
self.epsilon = epsilon | ||
|
||
# dump to pickle | ||
self.corpus_size = 0 | ||
self.avgdl = 0 | ||
self.doc_freqs = [] | ||
self.idf = {} | ||
self.doc_len = [] | ||
self.average_idf = 0.0 | ||
self.chunks = [] | ||
|
||
# option | ||
self.tokenizer = jieba.analyse.extract_tags | ||
|
||
def _initialize(self, corpus): | ||
nd = {} # word -> number of documents with word | ||
num_doc = 0 | ||
for document in corpus: | ||
self.doc_len.append(len(document)) | ||
num_doc += len(document) | ||
|
||
frequencies = {} | ||
for word in document: | ||
if word not in frequencies: | ||
frequencies[word] = 0 | ||
frequencies[word] += 1 | ||
self.doc_freqs.append(frequencies) | ||
|
||
for word, freq in frequencies.items(): | ||
try: | ||
nd[word]+=1 | ||
except KeyError: | ||
nd[word] = 1 | ||
|
||
self.corpus_size += 1 | ||
|
||
self.avgdl = num_doc / self.corpus_size | ||
return nd | ||
|
||
def _tokenize_corpus(self, corpus): | ||
tokenized_corpus = self.tokenizer(corpus) | ||
return tokenized_corpus | ||
|
||
def save(self, chunks:List[Chunk], filedir:str): | ||
# generate idf with corpus | ||
self.chunks = chunks | ||
|
||
filtered_corpus = [] | ||
for c in chunks: | ||
content = c.content_or_path | ||
if self.tokenizer is not None: | ||
# input str, output list of str | ||
corpus = self.tokenizer(content) | ||
if content not in corpus: | ||
corpus.append(content) | ||
else: | ||
logger.warning('No tokenizer, use naive split') | ||
corpus = content.split(' ') | ||
filtered_corpus.append(corpus) | ||
|
||
nd = self._initialize(filtered_corpus) | ||
self._calc_idf(nd) | ||
|
||
# dump to `filepath` | ||
data = { | ||
'corpus_size': self.corpus_size, | ||
'avgdl': self.avgdl, | ||
'doc_freqs': self.doc_freqs, | ||
'idf': self.idf, | ||
'doc_len': self.doc_len, | ||
'average_idf': self.average_idf, | ||
'chunks': chunks | ||
} | ||
logger.info('bm250kpi dump..') | ||
# logger.info(data) | ||
|
||
if not os.path.exists(filedir): | ||
os.makedirs(filedir) | ||
|
||
filepath = os.path.join(filedir, 'bm25.pkl') | ||
with open(filepath, 'wb') as f: | ||
pkl.dump(data, f) | ||
|
||
def load(self, filedir: str, tokenizer=None): | ||
self.tokenizer = tokenizer | ||
filepath = os.path.join(filedir, 'bm25.pkl') | ||
with open(filepath, 'rb') as f: | ||
data = pkl.load(f) | ||
self.corpus_size = data['corpus_size'] | ||
self.avgdl = data['avgdl'] | ||
self.doc_freqs = data['doc_freqs'] | ||
self.idf = data['idf'] | ||
self.doc_len = data['doc_len'] | ||
self.average_idf = data['average_idf'] | ||
self.chunks = data['chunks'] | ||
|
||
def _calc_idf(self, nd): | ||
""" | ||
Calculates frequencies of terms in documents and in corpus. | ||
This algorithm sets a floor on the idf values to eps * average_idf | ||
""" | ||
# collect idf sum to calculate an average idf for epsilon value | ||
idf_sum = 0 | ||
# collect words with negative idf to set them a special epsilon value. | ||
# idf can be negative if word is contained in more than half of documents | ||
negative_idfs = [] | ||
for word, freq in nd.items(): | ||
idf = math.log(self.corpus_size - freq + 0.5) - math.log(freq + 0.5) | ||
self.idf[word] = idf | ||
idf_sum += idf | ||
if idf < 0: | ||
negative_idfs.append(word) | ||
self.average_idf = idf_sum / len(self.idf) | ||
|
||
eps = self.epsilon * self.average_idf | ||
for word in negative_idfs: | ||
self.idf[word] = eps | ||
|
||
def get_scores(self, query: List): | ||
""" | ||
The ATIRE BM25 variant uses an idf function which uses a log(idf) score. To prevent negative idf scores, | ||
this algorithm also adds a floor to the idf value of epsilon. | ||
See [Trotman, A., X. Jia, M. Crane, Towards an Efficient and Effective Search Engine] for more info | ||
:param query: | ||
:return: | ||
""" | ||
if type(query) is not list: | ||
raise ValueError('query must be list, tokenize it byself.') | ||
score = np.zeros(self.corpus_size) | ||
doc_len = np.array(self.doc_len) | ||
for q in query: | ||
q_freq = np.array([(doc.get(q) or 0) for doc in self.doc_freqs]) | ||
score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) / | ||
(q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl))) | ||
return score | ||
|
||
def get_batch_scores(self, query, doc_ids): | ||
""" | ||
Calculate bm25 scores between query and subset of all docs | ||
""" | ||
assert all(di < len(self.doc_freqs) for di in doc_ids) | ||
score = np.zeros(len(doc_ids)) | ||
doc_len = np.array(self.doc_len)[doc_ids] | ||
for q in query: | ||
q_freq = np.array([(self.doc_freqs[di].get(q) or 0) for di in doc_ids]) | ||
score += (self.idf.get(q) or 0) * (q_freq * (self.k1 + 1) / | ||
(q_freq + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl))) | ||
return score.tolist() | ||
|
||
def get_top_n(self, query: Union[List,str], n=5): | ||
if type(query) is str: | ||
if self.tokenizer is not None: | ||
queries = self.tokenizer(query) | ||
else: | ||
queries = query.split(' ') | ||
else: | ||
queries = query | ||
|
||
scores = self.get_scores(queries) | ||
top_n = np.argsort(scores)[::-1][:n] | ||
logger.info('{} {}'.format(scores, top_n)) | ||
if abs(scores[top_n[0]]) < 1e-5: | ||
# not match, quit | ||
return [] | ||
return [self.chunks[i] for i in top_n] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.