Skip to content

Commit

Permalink
Fix query filtering and vocabulary dict (#96)
Browse files Browse the repository at this point in the history
* update readme

* fix: token ID in the query higher than the number of tokens in the index (#92)

* Fix query filtering by using a set of vocab dict

* add edge case when all tokens are integers

* fix allow true

* update tests to match new changes

* Fix changes to test

* Fix error during yield

---------

Co-authored-by: Nguyễn Hoàng Nhật <[email protected]>
  • Loading branch information
xhluca and mossbee authored Dec 29, 2024
1 parent ce8f886 commit 6dfb6ce
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 40 deletions.
63 changes: 51 additions & 12 deletions bm25s/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ def build_index_from_tokens(
def index(
self,
corpus: Union[Iterable, Tuple, tokenization.Tokenized],
create_empty_token=True,
show_progress=True,
leave_progress=False,
):
Expand All @@ -427,6 +428,25 @@ def index(
The `vocab_dict` dictionary is a mapping from tokens to their index in the vocabulary. This is used to
create the sparse matrix representation of the BM25 scores, as well as during query time to convert the
tokens to their indices.
Parameters
----------
corpus : Iterable or Tuple or tokenization.Tokenized
The corpus of documents. This can be either:
- An iterable of documents, where each document is a list of tokens (strings).
- A tuple of two elements: the first is the list of unique token IDs (int), and the second is the vocabulary dictionary.
- An object with the `ids` and `vocab` attributes, which are the unique token IDs and the token IDs for each document, respectively.
create_empty_token : bool
If True, it will create an empty token, "", in the vocabulary if it is not already present.
This is added at the end of the vocabulary and is used to score documents that do not contain any tokens.
If False, it will not create an empty token, which may lead to an error if a query does not contain any tokens.
show_progress : bool
If True, a progress bar will be shown. If False, no progress bar will be shown.
leave_progress : bool
If True, the progress bars will remain after the function completes.
"""
inferred_corpus_obj = self._infer_corpus_object(corpus)

Expand Down Expand Up @@ -456,9 +476,20 @@ def index(
show_progress=show_progress,
)

if create_empty_token:
if all(isinstance(token, int) for token in vocab_dict):
# if all tokens are integers, we don't need to add an empty token
pass

if "" not in vocab_dict:
vocab_dict[""] = max(vocab_dict.values()) + 1

self.scores = scores
self.vocab_dict = vocab_dict

# we create unique token IDs from the vocab_dict for faster lookup
self.unique_token_ids_set = set(self.vocab_dict.values())

def get_tokens_ids(self, query_tokens: List[str]) -> List[int]:
"""
For a given list of tokens, return the list of token IDs, leaving out tokens
Expand All @@ -481,7 +512,7 @@ def get_scores_from_ids(
query_tokens_ids: np.ndarray = np.asarray(query_tokens_ids, dtype=int_dtype)

max_token_id = int(query_tokens_ids.max(initial=0))

if max_token_id >= len(indptr) - 1:
raise ValueError(
f"The maximum token ID in the query ({max_token_id}) is higher than the number of tokens in the index."
Expand Down Expand Up @@ -539,7 +570,7 @@ def _get_top_k_results(
This function is used to retrieve the top-k results for a single query.
Since it's a hidden function, the user should not call it directly and
may change in the future. Please use the `retrieve` function instead.
"""
"""
if len(query_tokens_single) == 0:
logger.info(
msg="The query is empty. This will result in a zero score for all documents."
Expand Down Expand Up @@ -629,18 +660,18 @@ def retrieve(
weight_mask : np.ndarray
A weight mask to filter the documents. If provided, the scores for the masked
documents will be set to 0 to avoid returning them in the results.
Returns
-------
Results or np.ndarray
If `return_as="tuple"`, a named tuple with two fields will be returned: `documents` and `scores`.
If `return_as="documents"`, only the retrieved documents (or indices if `corpus` is not provided) will be returned.
Raises
------
ValueError
If the `query_tokens` is not a list of list of tokens (str) or a tuple of two lists: the first list is the list of unique token IDs, and the second list is the list of token IDs for each document.
ImportError
If the numba backend is selected but numba is not installed.
"""
Expand All @@ -659,11 +690,19 @@ def retrieve(
query_tokens_filtered = []
for query in query_tokens:
query_filtered = [
token_id for token_id in query if token_id in self.vocab_dict
token_id
for token_id in query
if token_id in self.unique_token_ids_set
]
if len(query_filtered) == 0:
if "" not in self.vocab_dict:
self.vocab_dict[""] = max(self.vocab_dict.values()) + 1
raise ValueError(
"The query does not contain any tokens that are in the vocabulary. "
"Please provide a query that contains at least one token that is in the vocabulary. "
"Alternatively, you can set `create_empty_token=True` when calling `index` to add an empty token to the vocabulary. "
"You can also manually add an empty token to the vocabulary by setting `retriever.vocab_dict[''] = max(retriever.vocab_dict.values()) + 1`. "
"Then, run `retriever.unique_token_ids_set = set(retriever.vocab_dict.values())` to update the unique token IDs."
)
query_filtered = [self.vocab_dict[""]]

query_tokens_filtered.append(query_filtered)
Expand Down Expand Up @@ -876,9 +915,9 @@ def save(
# Save the vocab dictionary
vocab_path = save_dir / vocab_name

with open(vocab_path, "wt", encoding='utf-8') as f:
with open(vocab_path, "wt", encoding="utf-8") as f:
f.write(json_functions.dumps(self.vocab_dict, ensure_ascii=False))

# Save the parameters
params_path = save_dir / params_name
params = dict(
Expand All @@ -899,7 +938,7 @@ def save(
corpus = corpus if corpus is not None else self.corpus

if corpus is not None:
with open(save_dir / corpus_name, "wt", encoding='utf-8') as f:
with open(save_dir / corpus_name, "wt", encoding="utf-8") as f:
# if it's not an iterable, we skip
if not isinstance(corpus, Iterable):
logging.warning(
Expand Down Expand Up @@ -1060,7 +1099,7 @@ def load(
# Load the vocab dictionary
if load_vocab:
vocab_path = save_dir / vocab_name
with open(vocab_path, "r", encoding='utf-8') as f:
with open(vocab_path, "r", encoding="utf-8") as f:
vocab_dict: dict = json_functions.loads(f.read())
else:
vocab_dict = None
Expand Down Expand Up @@ -1091,7 +1130,7 @@ def load(
corpus = utils.corpus.JsonlCorpus(corpus_file)
else:
corpus = []
with open(corpus_file, "r", encoding='utf-8') as f:
with open(corpus_file, "r", encoding="utf-8") as f:
for line in f:
doc = json_functions.loads(line)
corpus.append(doc)
Expand Down
30 changes: 14 additions & 16 deletions bm25s/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,17 @@ def streaming_tokenize(
stopwords_set = set(self.stopwords) if self.stopwords is not None else None
using_stopwords = stopwords_set is not None
using_stemmer = self.stemmer is not None


if allow_empty is True and update_vocab is True and "" not in self.word_to_id:
idx = max(self.word_to_id.values(), default=-1) + 1
self.word_to_id[""] = idx

if using_stemmer:
if "" not in self.word_to_stem:
self.word_to_stem[""] = ""
if "" not in self.stem_to_sid:
self.stem_to_sid[""] = idx

for text in texts:
if self.lower:
text = text.lower()
Expand Down Expand Up @@ -271,21 +281,9 @@ def streaming_tokenize(
self.word_to_id[word] = wid
doc_ids.append(wid)

if len(doc_ids) == 0 and allow_empty is True:
if update_vocab is True and "" not in self.word_to_id:
idx = max(self.word_to_id.values(), default=-1) + 1
self.word_to_id[""] = idx

if using_stemmer:
if "" not in self.word_to_stem:
self.word_to_stem[""] = ""
if "" not in self.stem_to_sid:
self.stem_to_sid[""] = idx

# get the ID for the empty string
if "" in self.word_to_id:
doc_ids = [self.word_to_id[""]]

if len(doc_ids) == 0 and allow_empty is True and "" in self.word_to_id:
doc_ids = [self.word_to_id[""]]

yield doc_ids

def tokenize(
Expand Down
64 changes: 58 additions & 6 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,47 @@

Welcome to the test suite for BM25S! This test suite is designed to test the BM25S implementation in the `bm25s` package.

## Quick tests
## Core tests

To run the quick tests, simply run the following command:
To run the core tests (of library), simply run the following command:

```bash
python -m unittest tests/quick/*.py
python -m unittest tests/core/*.py
python -m unittest tests/stopwords/*.py
```

## Full tests
For numba, you have to run:

To run the full tests, simply run the following command:
```bash
python -m unittest tests/numba/*.py
```


## Basic Comparisons

To run the basic comparison tests (with other BM25 implementations), simply run the following command:

```bash
python -m unittest tests/comparison/*.py
```

## Multiple tests

To run the core tests (of library), simply run the following command:

```bash
python -m unittest tests/full/*.py
python -m unittest tests/core/*.py
python -m unittest tests/stopwords/*.py
python -m unittest tests/numba/*.py
python -m unittest tests/comparison/*.py
```

## Full comparison tests

To run the full comparison tests, simply run the following command:

```bash
python -m unittest tests/comparison_full/*.py
```

## Artifacts
Expand All @@ -25,3 +52,28 @@ By default, the artifacts are stored in the `./artifacts` directory. This direct
```bash
export BM25_ARTIFACTS_DIR=/path/to/artifacts
```


## Adding new tests

First, create a new file in tests/core, tests/comparison, tests/numba, tests/stopwords, or tests/comparison_full. Then, add the following code to the file:

```python
import os
import shutil
from pathlib import Path
import unittest
import tempfile
import Stemmer # optional: for stemming
import unittest.mock
import json

import bm25s

class TestYourName(unittest.TestCase):
def test_your_name(self):
# Your test code here
pass
```

Modify the `test_your_name` function to test your code. You can use the `bm25s` package to test your code. You can also use the `unittest.mock` package to mock objects.
16 changes: 16 additions & 0 deletions tests/core/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import Stemmer
import re

import numpy as np

from bm25s.tokenization import Tokenizer

class TestTokenizer(unittest.TestCase):
Expand Down Expand Up @@ -210,6 +212,20 @@ def test_save_load_stopwords(self):
# Check if the stopwords are loaded correctly
self.assertEqual(stopwords, tuple(tokenizer2.stopwords))

def test_empty_sentence_and_unknown_word(self):
corpus = [
"a cat is a feline and likes to purr",
"a dog is the human's best friend and loves to play",
"a bird is a beautiful animal that can fly",
"a fish is a creature that lives in water and swims",
]
new_docs = ["cat", "", "potato"]
tokenizer = Tokenizer(stopwords="en")
corpus_tokens = tokenizer.tokenize(corpus)
new_docs_tokens = tokenizer.tokenize(new_docs)

self.assertTrue(np.all(new_docs_tokens == np.array([[1], [0], [0]])))

@classmethod
def tearDownClass(cls):
"""Cleans up resources after all tests have run (not required in this test case)."""
Expand Down
12 changes: 6 additions & 6 deletions tests/core/test_tokenizer_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def test_new_ids(self):
tokenizer = Tokenizer(
stemmer=None, stopwords=None, splitter=lambda x: x.split()
)
corpus_tokens = tokenizer.tokenize(corpus)
corpus_tokens = tokenizer.tokenize(corpus, allow_empty=False)

bm25 = bm25s.BM25()
bm25.index(corpus_tokens)
bm25.index(corpus_tokens, create_empty_token=False)

query = "What is a fly?"
query_tokens = tokenizer.tokenize([query], update_vocab=True)
query_tokens = tokenizer.tokenize([query], update_vocab=True, allow_empty=False)
self.assertListEqual([[27, 2, 0, 28]], query_tokens)

results, scores = bm25.retrieve(query_tokens, k=3)
Expand All @@ -84,13 +84,13 @@ def test_failing_after_adding_new_tokens_query(self):
tokenizer = Tokenizer(
stemmer=None, stopwords=None, splitter=lambda x: x.split()
)
corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple")
corpus_tokens = tokenizer.tokenize(corpus, return_as="tuple", allow_empty=False)

bm25 = bm25s.BM25()
bm25.index(corpus_tokens)
bm25.index(corpus_tokens, create_empty_token=False)

query = "unknownword"
query_tokens = tokenizer.tokenize([query], update_vocab=True)
query_tokens = tokenizer.tokenize([query], update_vocab=True, allow_empty=False)

# assert a valueError is raised
with self.assertRaises(ValueError):
Expand Down
Loading

0 comments on commit 6dfb6ce

Please sign in to comment.