Skip to content

Commit

Permalink
Merge pull request #1340 from hanhainebula/master
Browse files Browse the repository at this point in the history
Fix bugs
  • Loading branch information
hanhainebula authored Jan 18, 2025
2 parents db83452 + 63324c3 commit 2163bea
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
7 changes: 0 additions & 7 deletions FlagEmbedding/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +0,0 @@
from .air_bench import AIRBenchEvalModelArgs, AIRBenchEvalArgs, AIRBenchEvalRunner
from .beir import *
# from miracle import *
# from mkqa import *
# from mldr import *
# from msmarco import *
from mteb import *
4 changes: 2 additions & 2 deletions FlagEmbedding/evaluation/beir/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _load_local_qrels(self, save_dir: str, dataset_name: Optional[str] = None, s
Returns:
datasets.DatasetDict: A dict of relevance of query and document.
"""
checked_split = self.check_splits(split)
checked_split = self.check_splits(split, dataset_name=dataset_name)
if len(checked_split) == 0:
raise ValueError(f"Split {split} not found in the dataset.")
split = checked_split[0]
Expand Down Expand Up @@ -450,7 +450,7 @@ def _load_local_queries(self, save_dir: str, dataset_name: Optional[str] = None,
Returns:
datasets.DatasetDict: A dict of queries with id as key, query text as value.
"""
checked_split = self.check_splits(split)
checked_split = self.check_splits(split, dataset_name=dataset_name)
if len(checked_split) == 0:
raise ValueError(f"Split {split} not found in the dataset.")
split = checked_split[0]
Expand Down
6 changes: 6 additions & 0 deletions scripts/hn_mine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def find_knn_neg(
p_vecs = model.encode(corpus)
print(f'inferencing embedding for queries (number={len(queries)})--------------')
q_vecs = model.encode_queries(queries)

# check if the embeddings are in dictionary format: M3Embedder
if isinstance(p_vecs, dict):
p_vecs = p_vecs["dense_vecs"]
if isinstance(q_vecs, dict):
q_vecs = q_vecs["dense_vecs"]

print('create index and search------------------')
index = create_index(p_vecs, use_gpu=use_gpu)
Expand Down

0 comments on commit 2163bea

Please sign in to comment.