Skip to content

Commit

Permalink
n-gram sliding window support to take into account local token orders…
Browse files Browse the repository at this point in the history
… and combinations
  • Loading branch information
Guest400123064 committed Apr 18, 2024
1 parent 0bc177d commit d24da13
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/bbm25_haystack/__about__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
#
# SPDX-License-Identifier: Apache-2.0

__version__ = "0.1.3"
__version__ = "0.2.0"
4 changes: 2 additions & 2 deletions src/bbm25_haystack/bbm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ def run(
sim = self.document_store._retrieval(query, filters=filters, top_k=top_k)

ret = []
for doc, score in sim:
for doc, scr in sim:
data = doc.to_dict()
if self.set_score:
data["score"] = score
data["score"] = scr
ret.append(Document.from_dict(data))

return {"documents": ret}
Expand Down
143 changes: 108 additions & 35 deletions src/bbm25_haystack/bbm25_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import heapq
import math
import os
from collections import Counter
from collections import Counter, deque
from collections.abc import Iterable
from itertools import chain
from typing import Any, Final, Optional, Union

import pandas as pd
from haystack import Document, default_from_dict, default_to_dict, logging
from haystack.document_stores.errors import (
DuplicateDocumentError,
Expand All @@ -21,6 +24,27 @@
logger = logging.getLogger(__name__)


def _n_grams(seq: Iterable[str], n: int):
"""
Returns a sliding window (of width n) over data from the iterable.
:param seq: the input token sequence.
:type seq: Iterable[str]
:param n: the window size.
:type n: int
:return: the n-gram window generator.
:rtype: Generator[tuple[str], None, None]
"""
it = iter(seq)
wd = deque((next(it, None) for _ in range(n)), maxlen=n)

yield tuple(wd)
for el in it:
wd.append(el)
yield tuple(wd)


class BetterBM25DocumentStore:
"""
An in-memory document store intended to improve the default BM25 document
Expand All @@ -38,6 +62,7 @@ def __init__(
b: float = 0.75,
delta: float = 1.0,
sp_file: Optional[str] = None,
n_grams: Union[int, tuple[int, int]] = 2,
haystack_filter_logic: bool = True,
) -> None:
"""
Expand All @@ -63,6 +88,8 @@ def __init__(
:param sp_file: the SentencePiece model file to use for
tokenization.
:type sp_file: Optional[str], optional
:param n_grams: the n-gram window size.
:type n_grams: Optional[Union[int, tuple[int, int]]], optional
:param haystack_filter_logic: Whether to use the Haystack
filter logic or the one implemented in this store,
which is more conservative.
Expand All @@ -76,10 +103,8 @@ def __init__(
# delete it; this will not affect the ranking
self.delta = delta / (self.k + 1.0)

self._sp_file = sp_file
self._sp_inst = SentencePieceProcessor(
model_file=(self._sp_file or self.default_sp_file)
)
self._parse_sp_file(sp_file=sp_file)
self._parse_n_grams(n_grams=n_grams)

self._haystack_filter_logic = haystack_filter_logic
self._filter_func = (
Expand All @@ -90,58 +115,103 @@ def __init__(

self._avg_doc_len: float = 0.0
self._freq_doc: Counter = Counter()
self._index: dict[str, tuple[Document, dict[str, int], int]] = {}
self._index: dict[str, tuple[Document, dict[tuple[str], int], int]] = {}

def _parse_sp_file(self, sp_file: Optional[str]) -> None:
self._sp_file = sp_file

if sp_file is None:
self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
return

if not os.path.exists(sp_file) or not os.path.isfile(sp_file):
msg = (
f"Tokenizer model file '{sp_file}' not accessible; "
f"fallback to default {self.default_sp_file}."
)
logger.warn(msg)
self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
return

try:
self._sp_inst = SentencePieceProcessor(model_file=sp_file)
except Exception as exc:
msg = (
f"Failed to load tokenizer model file '{sp_file}': {exc}; "
f"fallback to default {self.default_sp_file}."
)
logger.error(msg)
self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)

def _parse_n_grams(self, n_grams: Optional[Union[int, tuple[int, int]]]) -> None:
self._n_grams = n_grams

if isinstance(n_grams, int):
self._n_grams_min = 1
self._n_grams_max = n_grams
return

if isinstance(n_grams, tuple):
self._n_grams_min, self._n_grams_max = n_grams
if not all(isinstance(n, int) for n in n_grams):
msg = f"Invalid n-gram window size: {n_grams}."
raise ValueError(msg)
return

msg = f"Invalid n-gram window size: {n_grams}; expected int or tuple."
raise ValueError(msg)

def _tokenize(self, texts: Union[str, list[str]]) -> list[list[str]]:
def _tokenize(self, texts: Union[str, list[str]]) -> list[list[tuple[str]]]:
"""
Tokenize input text using SentencePiece model.
The input text can either be a single string or a list of strings,
such as a single user query or a group of raw document.
such as a single user query or a group of raw document. The tokenized
text will be augmented into set of n-grams based.
:param texts: the input text to tokenize.
:type texts: Union[str, list[str]]
:return: the tokenized text.
:rtype: list[list[str]]
"""
if isinstance(texts, str):
texts = [texts]
return self._sp_inst.encode(texts, out_type=str)

def _compute_idf(self, tokens: list[str]) -> dict[str, float]:
:return: the tokenized text, with n-grams augmented.
:rtype: list[list[tuple[str]]]
"""
Calculate the inverse document frequency for each token.

:param tokens: the tokens to calculate the IDF for.
:type tokens: list[str]
def _augment_to_n_grams(tokens: list[str]) -> list[tuple[str]]:
it = (
_n_grams(tokens, n)
for n in range(self._n_grams_min, self._n_grams_max + 1)
)
return list(chain(*it))

:return: the IDF for each token.
:rtype: dict[str, float]
"""
cnt = lambda token: self._freq_doc.get(token, 0)
idf = {
t: math.log(1 + (len(self._index) - cnt(t) + 0.5) / (cnt(t) + 0.5))
for t in tokens
}
return idf
if isinstance(texts, str):
texts = [texts]
return [
_augment_to_n_grams(tokens)
for tokens in self._sp_inst.encode(texts, out_type=str)
]

def _compute_bm25plus(
self,
idf: dict[str, float],
query: str,
documents: list[Document],
) -> list[tuple[Document, float]]:
"""
Calculate the BM25+ score for all documents in this index.
:param idf: the IDF for each token.
:type idf: dict[str, float]
:param query: the query to calculate the BM25+ score for.
:type query: str
:param documents: the pool of documents to calculate the BM25+ score for.
:type documents: list[Document]
:return: the BM25+ scores for all documents.
:rtype: list[tuple[Document, float]]
"""
cnt = lambda ng: self._freq_doc.get(ng, 0)
idf = {
ng: math.log(1 + (len(self._index) - cnt(ng) + 0.5) / (cnt(ng) + 0.5))
for ng in self._tokenize(query)[0]
}

sim = []
for doc in documents:
_, freq, doc_len = self._index[doc.id]
Expand Down Expand Up @@ -183,9 +253,7 @@ def _retrieval(
if not documents:
return []

idf = self._compute_idf(self._tokenize(query)[0])
sim = self._compute_bm25plus(idf, documents)

sim = self._compute_bm25plus(query, documents)
if top_k is None:
return sorted(sim, key=lambda x: x[1], reverse=True)
return heapq.nlargest(top_k, sim, key=lambda x: x[1])
Expand Down Expand Up @@ -261,7 +329,11 @@ def write_documents(
)
self.delete_documents([doc.id])

tokens = self._tokenize(doc.content or "")[0]
content = doc.content or ""
if content == "" and isinstance(doc.dataframe, pd.DataFrame):
content = doc.dataframe.astype(str).to_csv(index=False)

tokens = self._tokenize(content)[0]

self._index[doc.id] = (doc, Counter(tokens), len(tokens))
self._freq_doc.update(set(tokens))
Expand Down Expand Up @@ -317,6 +389,7 @@ def to_dict(self) -> dict[str, Any]:
b=self.b,
delta=self.delta * (self.k + 1.0), # Because we scaled it on init
sp_file=self._sp_file,
n_grams=self._n_grams,
haystack_filter_logic=self._haystack_filter_logic,
)

Expand Down

0 comments on commit d24da13

Please sign in to comment.