diff --git a/capreolus/__init__.py b/capreolus/__init__.py
index 9056e08db..ac644fe35 100644
--- a/capreolus/__init__.py
+++ b/capreolus/__init__.py
@@ -2,9 +2,9 @@
import os
from pathlib import Path
-from profane import ConfigOption, Dependency, constants, config_list_to_dict
+from profane import ConfigOption, Dependency, ModuleBase, constants, config_list_to_dict, module_registry
-__version__ = "0.2.1"
+__version__ = "0.2.2"
# specify a base package that we should look for modules under (e.g., .task)
# constants must be specified before importing Task (or any other modules!)
@@ -21,16 +21,18 @@
jnius_config.set_classpath(Anserini.get_fat_jar())
-# import capreolus.evaluator as evaluator
-# from capreolus.benchmark import Benchmark
-# from capreolus.collection import Collection
-# from capreolus.extractor import Extractor
-# from capreolus.index import Index
-# from capreolus.reranker.base import Reranker
-# from capreolus.searcher import Searcher
-# from capreolus.task.base import Task
-# from capreolus.tokenizer import Tokenizer
-# from capreolus.trainer import Trainer
+# note: order is important to avoid circular imports
+from capreolus.utils.loginit import get_logger
+import capreolus.evaluator as evaluator
+from capreolus.benchmark import Benchmark
+from capreolus.collection import Collection
+from capreolus.index import Index
+from capreolus.searcher import Searcher
+from capreolus.extractor import Extractor
+from capreolus.reranker import Reranker
+from capreolus.tokenizer import Tokenizer
+from capreolus.trainer import Trainer
+from capreolus.task import Task
def parse_config_string(s):
diff --git a/capreolus/benchmark/__init__.py b/capreolus/benchmark/__init__.py
index b7db5a9ef..20f2dac29 100644
--- a/capreolus/benchmark/__init__.py
+++ b/capreolus/benchmark/__init__.py
@@ -1,37 +1,27 @@
-from profane import import_all_modules
-
import json
-import re
-import os
-import gzip
-import pickle
-
-from tqdm import tqdm
-from zipfile import ZipFile
-from pathlib import Path
-from collections import defaultdict
-from bs4 import BeautifulSoup
-from profane import ModuleBase, Dependency, ConfigOption, constants
-
-from capreolus.utils.loginit import get_logger
-from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt
-from capreolus.utils.common import download_file, remove_newline, get_udel_query_expander
-logger = get_logger(__name__)
-PACKAGE_PATH = constants["PACKAGE_PATH"]
+from capreolus import ModuleBase
+from capreolus.utils.trec import load_qrels, load_trec_topics
class Benchmark(ModuleBase):
- """the module base class"""
+ """Base class for Benchmark modules. The purpose of a Benchmark is to provide the data needed to run an experiment, such as queries, folds, and relevance judgments.
+
+ Modules should provide:
+ - a ``topics`` dict mapping query ids (*qids*) to *queries*
+ - a ``qrels`` dict mapping *qids* to *docids* and *relevance labels*
+ - a ``folds`` dict mapping a fold name to *training*, *dev* (validation), and *testing* qids
+ - if these can be loaded from files in standard formats, they can be specified by setting the ``topic_file``, ``qrel_file``, and ``fold_file``, respectively, rather than by setting the above attributes directly
+ """
module_type = "benchmark"
qrel_file = None
topic_file = None
fold_file = None
query_type = None
- # documents with a relevance label >= relevance_level will be considered relevant
- # corresponds to trec_eval's --level_for_rel (and passed to pytrec_eval as relevance_level)
relevance_level = 1
+ """ Documents with a relevance label >= relevance_level will be considered relevant.
+ This corresponds to trec_eval's --level_for_rel (and is passed to pytrec_eval as relevance_level). """
@property
def qrels(self):
@@ -52,546 +42,8 @@ def folds(self):
return self._folds
-@Benchmark.register
-class DummyBenchmark(Benchmark):
- module_name = "dummy"
- dependencies = [Dependency(key="collection", module="collection", name="dummy")]
- qrel_file = PACKAGE_PATH / "data" / "qrels.dummy.txt"
- topic_file = PACKAGE_PATH / "data" / "topics.dummy.txt"
- fold_file = PACKAGE_PATH / "data" / "dummy_folds.json"
- query_type = "title"
-
-
-@Benchmark.register
-class WSDM20Demo(Benchmark):
- """ Robust04 benchmark equivalent to robust04.yang19 """
-
- module_name = "wsdm20demo"
- dependencies = [Dependency(key="collection", module="collection", name="robust04")]
- qrel_file = PACKAGE_PATH / "data" / "qrels.robust2004.txt"
- topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt"
- fold_file = PACKAGE_PATH / "data" / "rob04_yang19_folds.json"
- query_type = "title"
-
-
-@Benchmark.register
-class Robust04Yang19(Benchmark):
- """Robust04 benchmark using the folds from Yang et al. [1]
-
- [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. 2019. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019.
- """
-
- module_name = "robust04.yang19"
- dependencies = [Dependency(key="collection", module="collection", name="robust04")]
- qrel_file = PACKAGE_PATH / "data" / "qrels.robust2004.txt"
- topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt"
- fold_file = PACKAGE_PATH / "data" / "rob04_yang19_folds.json"
- query_type = "title"
-
-
-@Benchmark.register
-class NF(Benchmark):
- """ A Full-Text Learning to Rank Dataset for Medical Information Retrieval [1]
-
- [1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval Proceedings of the 38th European Conference on Information Retrieval (ECIR), Padova, Italy, 2016
- """
-
- module_name = "nf"
- dependencies = [Dependency(key="collection", module="collection", name="nf")]
- config_spec = [
- ConfigOption(key="labelrange", default_value="0-2", description="range of dataset qrels, options: 0-2, 1-3"),
- ConfigOption(
- key="fields",
- default_value="all_fields",
- description="query fields included in topic file, "
- "options: 'all_fields', 'all_titles', 'nontopics', 'vid_title', 'vid_desc'",
- ),
- ]
-
- fold_file = PACKAGE_PATH / "data" / "nf.json"
-
- query_type = "title"
-
- def __init__(self, config, provide, share_dependency_objects):
- super().__init__(config, provide, share_dependency_objects)
- fields, label_range = self.config["fields"], self.config["labelrange"]
- self.field2kws = {
- "all_fields": ["all"],
- "nontopics": ["nontopic-titles"],
- "vid_title": ["vid-titles"],
- "vid_desc": ["vid-desc"],
- "all_titles": ["nontopic-titles", "vid-titles", "nontopic-titles"],
- }
- self.labelrange2kw = {"0-2": "2-1-0", "1-3": "3-2-1"}
-
- if fields not in self.field2kws:
- raise ValueError(f"Unexpected fields value: {fields}, expect: {', '.join(self.field2kws.keys())}")
- if label_range not in self.labelrange2kw:
- raise ValueError(f"Unexpected label range: {label_range}, expect: {', '.join(self.field2kws.keys())}")
-
- self.qrel_file = PACKAGE_PATH / "data" / f"qrels.nf.{label_range}.txt"
- self.test_qrel_file = PACKAGE_PATH / "data" / f"test.qrels.nf.{label_range}.txt"
- self.topic_file = PACKAGE_PATH / "data" / f"topics.nf.{fields}.txt"
- self.download_if_missing()
-
- def _transform_qid(self, raw):
- """ NFCorpus dataset specific, remove prefix in query id since anserini convert all qid to integer """
- return raw.replace("PLAIN-", "")
-
- def download_if_missing(self):
- if all([f.exists() for f in [self.topic_file, self.fold_file, self.qrel_file]]):
- return
-
- tmp_corpus_dir = self.collection.download_raw()
- topic_f = open(self.topic_file, "w", encoding="utf-8")
- qrel_f = open(self.qrel_file, "w", encoding="utf-8")
- test_qrel_f = open(self.test_qrel_file, "w", encoding="utf-8")
-
- set_names = ["train", "dev", "test"]
- folds = {s: set() for s in set_names}
- qrel_kw = self.labelrange2kw[self.config["labelrange"]]
- for set_name in set_names:
- with open(tmp_corpus_dir / f"{set_name}.{qrel_kw}.qrel") as f:
- for line in f:
- line = self._transform_qid(line)
- qid = line.strip().split()[0]
- folds[set_name].add(qid)
- if set_name == "test":
- test_qrel_f.write(line)
- qrel_f.write(line)
-
- files = [tmp_corpus_dir / f"{set_name}.{keyword}.queries" for keyword in self.field2kws[self.config["fields"]]]
- qids2topics = self._align_queries(files, "title")
-
- for qid, txts in qids2topics.items():
- topic_f.write(topic_to_trectxt(qid, txts["title"]))
-
- json.dump(
- {"s1": {"train_qids": list(folds["train"]), "predict": {"dev": list(folds["dev"]), "test": list(folds["test"])}}},
- open(self.fold_file, "w"),
- )
-
- topic_f.close()
- qrel_f.close()
- test_qrel_f.close()
- logger.info(f"nf benchmark prepared")
-
- def _align_queries(self, files, field, qid2queries=None):
- if not qid2queries:
- qid2queries = {}
- for fn in files:
- with open(fn, "r", encoding="utf-8") as f:
- for line in f:
- qid, txt = line.strip().split("\t")
- qid = self._transform_qid(qid)
- txt = " ".join(re.sub("[^A-Za-z]", " ", txt).split()[:1020])
- if qid not in qid2queries:
- qid2queries[qid] = {field: txt}
- else:
- if field in qid2queries[qid]:
- logger.warning(f"Overwriting title for query {qid}")
- qid2queries[qid][field] = txt
- return qid2queries
-
-
-@Benchmark.register
-class ANTIQUE(Benchmark):
- """A Non-factoid Question Answering Benchmark from Hashemi et al. [1]
-
- [1] Helia Hashemi, Mohammad Aliannejadi, Hamed Zamani, and W. Bruce Croft. 2020. ANTIQUE: A non-factoid question answering benchmark. ECIR 2020.
- """
-
- module_name = "antique"
- dependencies = [Dependency(key="collection", module="collection", name="antique")]
- qrel_file = PACKAGE_PATH / "data" / "qrels.antique.txt"
- topic_file = PACKAGE_PATH / "data" / "topics.antique.txt"
- fold_file = PACKAGE_PATH / "data" / "antique.json"
- query_type = "title"
- relevance_level = 2
-
-
-@Benchmark.register
-class MSMarcoPassage(Benchmark):
- module_name = "msmarcopassage"
- dependencies = [Dependency(key="collection", module="collection", name="msmarco")]
- qrel_file = PACKAGE_PATH / "data" / "qrels.msmarcopassage.txt"
- topic_file = PACKAGE_PATH / "data" / "topics.msmarcopassage.txt"
- fold_file = PACKAGE_PATH / "data" / "msmarcopassage.folds.json"
- query_type = "title"
-
-
-@Benchmark.register
-class CodeSearchNetCorpus(Benchmark):
- module_name = "codesearchnet_corpus"
- dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")]
- url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2"
- query_type = "title"
-
- file_fn = PACKAGE_PATH / "data" / "csn_corpus"
-
- qrel_dir = file_fn / "qrels"
- topic_dir = file_fn / "topics"
- fold_dir = file_fn / "folds"
-
- qidmap_dir = file_fn / "qidmap"
- docidmap_dir = file_fn / "docidmap"
-
- config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")]
-
- def build(self):
- lang = self.config["lang"]
-
- self.qid_map_file = self.qidmap_dir / f"{lang}.json"
- self.docid_map_file = self.docidmap_dir / f"{lang}.json"
-
- self.qrel_file = self.qrel_dir / f"{lang}.txt"
- self.topic_file = self.topic_dir / f"{lang}.txt"
- self.fold_file = self.fold_dir / f"{lang}.json"
-
- for file in [var for var in vars(self) if var.endswith("file")]:
- getattr(self, file).parent.mkdir(exist_ok=True, parents=True)
-
- self.download_if_missing()
-
- @property
- def qid_map(self):
- if not hasattr(self, "_qid_map"):
- if not self.qid_map_file.exists():
- self.download_if_missing()
-
- self._qid_map = json.load(open(self.qid_map_file, "r"))
- return self._qid_map
-
- @property
- def docid_map(self):
- if not hasattr(self, "_docid_map"):
- if not self.docid_map_file.exists():
- self.download_if_missing()
-
- self._docid_map = json.load(open(self.docid_map_file, "r"))
- return self._docid_map
-
- def download_if_missing(self):
- files = [self.qid_map_file, self.docid_map_file, self.qrel_file, self.topic_file, self.fold_file]
- if all([f.exists() for f in files]):
- return
-
- lang = self.config["lang"]
-
- tmp_dir = Path("/tmp")
- zip_fn = tmp_dir / f"{lang}.zip"
- if not zip_fn.exists():
- download_file(f"{self.url}/{lang}.zip", zip_fn)
-
- with ZipFile(zip_fn, "r") as zipobj:
- zipobj.extractall(tmp_dir)
-
- # prepare docid-url mapping from dedup.pkl
- pkl_fn = tmp_dir / f"{lang}_dedupe_definitions_v2.pkl"
- doc_objs = pickle.load(open(pkl_fn, "rb"))
- self._docid_map = self._prep_docid_map(doc_objs)
- assert self._get_n_docid() == len(doc_objs)
-
- # prepare folds, qrels, topics, docstring2qid # TODO: shall we add negative samples?
- qrels, self._qid_map = defaultdict(dict), {}
- qids = {s: [] for s in ["train", "valid", "test"]}
-
- topic_file = open(self.topic_file, "w", encoding="utf-8")
- qrel_file = open(self.qrel_file, "w", encoding="utf-8")
-
- def gen_doc_from_gzdir(dir):
- """ generate parsed dict-format doc from all jsonl.gz files under given directory """
- for fn in sorted(dir.glob("*.jsonl.gz")):
- f = gzip.open(fn, "rb")
- for doc in f:
- yield json.loads(doc)
-
- for set_name in qids:
- set_path = tmp_dir / lang / "final" / "jsonl" / set_name
- for doc in gen_doc_from_gzdir(set_path):
- code = remove_newline(" ".join(doc["code_tokens"]))
- docstring = remove_newline(" ".join(doc["docstring_tokens"]))
- n_words_in_docstring = len(docstring.split())
- if n_words_in_docstring >= 1024:
- logger.warning(
- f"chunk query to first 1000 words otherwise TooManyClause would be triggered "
- f"at lucene at search stage, "
- )
- docstring = " ".join(docstring.split()[:1020]) # for TooManyClause
-
- docid = self.get_docid(doc["url"], code)
- qid = self._qid_map.get(docstring, str(len(self._qid_map)))
- qrel_file.write(f"{qid} Q0 {docid} 1\n")
-
- if docstring not in self._qid_map:
- self._qid_map[docstring] = qid
- qids[set_name].append(qid)
- topic_file.write(topic_to_trectxt(qid, docstring))
-
- topic_file.close()
- qrel_file.close()
-
- # write to qid_map.json, docid_map, fold.json
- json.dump(self._qid_map, open(self.qid_map_file, "w"))
- json.dump(self._docid_map, open(self.docid_map_file, "w"))
- json.dump(
- {"s1": {"train_qids": qids["train"], "predict": {"dev": qids["valid"], "test": qids["test"]}}},
- open(self.fold_file, "w"),
- )
-
- def _prep_docid_map(self, doc_objs):
- """
- construct a nested dict to map each doc into a unique docid
- which follows the structure: {url: {" ".join(code_tokens): docid, ...}}
-
- For all the lanugage datasets the url uniquely maps to a code_tokens yet it's not the case for but js and php
- which requires a second-level mapping from raw_doc to docid
-
- :param doc_objs: a list of dict having keys ["nwo", "url", "sha", "identifier", "arguments"
- "function", "function_tokens", "docstring", "doctring_tokens",],
- :return:
- """
- # TODO: any way to avoid the twice traversal of all url and make the return dict structure consistent
- lang = self.config["lang"]
- url2docid = defaultdict(dict)
- for i, doc in tqdm(enumerate(doc_objs), desc=f"Preparing the {lang} docid_map"):
- url, code_tokens = doc["url"], remove_newline(" ".join(doc["function_tokens"]))
- url2docid[url][code_tokens] = f"{lang}-FUNCTION-{i}"
-
- # remove the code_tokens for the unique url-docid mapping
- for url, docids in tqdm(url2docid.items(), desc=f"Compressing the {lang} docid_map"):
- url2docid[url] = list(docids.values()) if len(docids) == 1 else docids # {code_tokens: docid} -> [docid]
- return url2docid
-
- def _get_n_docid(self):
- """ calculate the number of document ids contained in the nested docid map """
- lens = [len(docs) for url, docs in self._docid_map.items()]
- return sum(lens)
-
- def get_docid(self, url, code_tokens):
- """ retrieve the doc id according to the doc dict """
- docids = self.docid_map[url]
- return docids[0] if len(docids) == 1 else docids[code_tokens]
-
-
-@Benchmark.register
-class CodeSearchNetChallenge(Benchmark):
- """
- CodeSearchNetChallenge can only be used for training but not for evaluation since qrels is not provided
- """
-
- module_name = "codesearchnet_challenge"
- dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")]
- config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")]
-
- url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/queries.csv"
- query_type = "title"
-
- file_fn = PACKAGE_PATH / "data" / "csn_challenge"
- topic_file = file_fn / "topics.txt"
- qid_map_file = file_fn / "qidmap.json"
-
- def download_if_missing(self):
- """ download query.csv and prepare queryid - query mapping file """
- if self.topic_file.exists() and self.qid_map_file.exists():
- return
-
- tmp_dir = Path("/tmp")
- tmp_dir.mkdir(exist_ok=True, parents=True)
- self.file_fn.mkdir(exist_ok=True, parents=True)
-
- query_fn = tmp_dir / f"query.csv"
- if not query_fn.exists():
- download_file(self.url, query_fn)
-
- # prepare qid - query
- qid_map = {}
- topic_file = open(self.topic_file, "w", encoding="utf-8")
- query_file = open(query_fn)
- for qid, line in enumerate(query_file):
- if qid != 0: # ignore the first line "query"
- topic_file.write(topic_to_trectxt(qid, line.strip()))
- qid_map[qid] = line
- topic_file.close()
- json.dump(qid_map, open(self.qid_map_file, "w"))
-
-
-@Benchmark.register
-class COVID(Benchmark):
- """ Ongoing TREC-COVID bechmark from https://ir.nist.gov/covidSubmit """
-
- module_name = "covid"
- dependencies = [Dependency(key="collection", module="collection", name="covid")]
- data_dir = PACKAGE_PATH / "data" / "covid"
- topic_url = "https://ir.nist.gov/covidSubmit/data/topics-rnd%d.xml"
- qrel_url = "https://ir.nist.gov/covidSubmit/data/qrels-rnd%d.txt"
- lastest_round = 3
-
- config_spec = [
- ConfigOption("round", 3, "TREC-COVID round to use"),
- ConfigOption("udelqexpand", False),
- ConfigOption("excludeknown", True),
- ]
-
- def build(self):
- if self.config["round"] == self.lastest_round and not self.config["excludeknown"]:
- logger.warning(f"No evaluation can be done for the lastest round in exclude-known mode")
-
- data_dir = self.get_cache_path() / "documents"
- data_dir.mkdir(exist_ok=True, parents=True)
-
- self.qrel_ignore = f"{data_dir}/ignore.qrel.txt"
- self.qrel_file = f"{data_dir}/qrel.txt"
- self.topic_file = f"{data_dir}/topic.txt"
- self.fold_file = f"{data_dir}/fold.json"
-
- self.download_if_missing()
-
- def download_if_missing(self):
- if all([os.path.exists(fn) for fn in [self.qrel_file, self.qrel_ignore, self.topic_file, self.fold_file]]):
- return
-
- rnd_i, excludeknown = self.config["round"], self.config["excludeknown"]
- if rnd_i > self.lastest_round:
- raise ValueError(f"round {rnd_i} is unavailable")
-
- logger.info(f"Preparing files for covid round-{rnd_i}")
-
- topic_url = self.topic_url % rnd_i
- qrel_ignore_urls = [self.qrel_url % i for i in range(1, rnd_i)] # download all the qrels before current run
-
- # topic file
- tmp_dir = Path("/tmp")
- topic_tmp = tmp_dir / f"topic.round.{rnd_i}.xml"
- if not os.path.exists(topic_tmp):
- download_file(topic_url, topic_tmp)
- all_qids = self.xml2trectopic(topic_tmp) # will update self.topic_file
-
- if excludeknown:
- qrel_fn = open(self.qrel_file, "w")
- for i, qrel_url in enumerate(qrel_ignore_urls):
- qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1)
- if not os.path.exists(qrel_tmp):
- download_file(qrel_url, qrel_tmp)
- with open(qrel_tmp) as f:
- for line in f:
- qrel_fn.write(line)
- qrel_fn.close()
-
- f = open(self.qrel_ignore, "w") # empty ignore file
- f.close()
- else:
- qrel_fn = open(self.qrel_ignore, "w")
- for i, qrel_url in enumerate(qrel_ignore_urls):
- qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1)
- if not os.path.exists(qrel_tmp):
- download_file(qrel_url, qrel_tmp)
- with open(qrel_tmp) as f:
- for line in f:
- qrel_fn.write(line)
- qrel_fn.close()
-
- if rnd_i == self.lastest_round:
- f = open(self.qrel_file, "w")
- f.close()
- else:
- with open(tmp_dir / f"qrel-{rnd_i}") as fin, open(self.qrel_file, "w") as fout:
- for line in fin:
- fout.write(line)
-
- # folds: use all labeled query for train, valid, and use all of them for test set
- labeled_qids = list(load_qrels(self.qrel_ignore).keys())
- folds = {"s1": {"train_qids": labeled_qids, "predict": {"dev": labeled_qids, "test": all_qids}}}
- json.dump(folds, open(self.fold_file, "w"))
-
- def xml2trectopic(self, xmlfile):
- with open(xmlfile, "r") as f:
- topic = f.read()
-
- all_qids = []
- soup = BeautifulSoup(topic, "lxml")
- topics = soup.find_all("topic")
- expand_query = get_udel_query_expander()
-
- with open(self.topic_file, "w") as fout:
- for topic in topics:
- qid = topic["number"]
- title = topic.find_all("query")[0].text.strip()
- desc = topic.find_all("question")[0].text.strip()
- narr = topic.find_all("narrative")[0].text.strip()
-
- if self.config["udelqexpand"]:
- title = expand_query(title, rm_sw=True)
- desc = expand_query(desc, rm_sw=False)
-
- title = title + " " + desc
- desc = " "
-
- topic_line = topic_to_trectxt(qid, title, desc=desc, narr=narr)
- fout.write(topic_line)
- all_qids.append(qid)
- return all_qids
-
-
-@Benchmark.register
-class CovidQA(Benchmark):
- module_name = "covidqa"
- dependencies = [Dependency(key="collection", module="collection", name="covidqa")]
- url = "https://raw.githubusercontent.com/castorini/pygaggle/master/data/kaggle-lit-review-%s.json"
- available_versions = ["0.1", "0.2"]
-
- datadir = PACKAGE_PATH / "data" / "covidqa"
-
- config_spec = [ConfigOption("version", "0.1+0.2")]
-
- def build(self):
- os.makedirs(self.datadir, exist_ok=True)
-
- version = self.config["version"]
- self.qrel_file = self.datadir / f"qrels.v{version}.txt"
- self.topic_file = self.datadir / f"topics.v{version}.txt"
- self.fold_file = self.datadir / f"v{version}.json" # HOW TO SPLIT THE FOLD HERE?
-
- self.download_if_missing()
-
- def download_if_missing(self):
- if all([os.path.exists(f) for f in [self.qrel_file, self.topic_file, self.fold_file]]):
- return
-
- tmp_dir = Path("/tmp")
- topic_f = open(self.topic_file, "w", encoding="utf-8")
- qrel_f = open(self.qrel_file, "w", encoding="utf-8")
-
- all_qids = []
- qid = 2001 # to distingsuish queries here from queries in TREC-covid
- versions = self.config["version"].split("+") if isinstance(self.config["version"], str) else str(self.config["version"])
- for v in versions:
- if v not in self.available_versions:
- vs = " ".join(self.available_versions)
- logger.warning(f"Invalid version {v}, should be one of {vs}")
- continue
-
- url = self.url % v
- target_fn = tmp_dir / f"covidqa-v{v}.json"
- if not os.path.exists(target_fn):
- download_file(url, target_fn)
- qa = json.load(open(target_fn))
- for subcate in qa["categories"]:
- name = subcate["name"]
-
- for qa in subcate["sub_categories"]:
- nq_name, kq_name = qa["nq_name"], qa["kq_name"]
- query_line = topic_to_trectxt(qid, kq_name, nq_name) # kq_name == "query", nq_name == "question"
- topic_f.write(query_line)
- for ans in qa["answers"]:
- docid = ans["id"]
- qrel_f.write(f"{qid} Q0 {docid} 1\n")
- all_qids.append(qid)
- qid += 1
-
- json.dump({"s1": {"train_qids": all_qids, "predict": {"dev": all_qids, "test": all_qids}}}, open(self.fold_file, "w"))
- topic_f.close()
- qrel_f.close()
+from profane import import_all_modules
+from .dummy import DummyBenchmark
import_all_modules(__file__, __package__)
diff --git a/capreolus/benchmark/antique.py b/capreolus/benchmark/antique.py
new file mode 100644
index 000000000..c6601b0bd
--- /dev/null
+++ b/capreolus/benchmark/antique.py
@@ -0,0 +1,22 @@
+from . import Benchmark
+from capreolus import constants, ConfigOption, Dependency
+from capreolus.utils.loginit import get_logger
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Benchmark.register
+class ANTIQUE(Benchmark):
+ """A Non-factoid Question Answering Benchmark from Hashemi et al. [1]
+
+ [1] Helia Hashemi, Mohammad Aliannejadi, Hamed Zamani, and W. Bruce Croft. 2020. ANTIQUE: A non-factoid question answering benchmark. ECIR 2020.
+ """
+
+ module_name = "antique"
+ dependencies = [Dependency(key="collection", module="collection", name="antique")]
+ qrel_file = PACKAGE_PATH / "data" / "qrels.antique.txt"
+ topic_file = PACKAGE_PATH / "data" / "topics.antique.txt"
+ fold_file = PACKAGE_PATH / "data" / "antique.json"
+ query_type = "title"
+ relevance_level = 2
diff --git a/capreolus/benchmark/codesearchnet.py b/capreolus/benchmark/codesearchnet.py
new file mode 100644
index 000000000..662a031ac
--- /dev/null
+++ b/capreolus/benchmark/codesearchnet.py
@@ -0,0 +1,220 @@
+import gzip
+import pickle
+import json
+from collections import defaultdict
+from pathlib import Path
+from zipfile import ZipFile
+
+from tqdm import tqdm
+
+from . import Benchmark
+from capreolus import constants, ConfigOption, Dependency
+from capreolus.utils.loginit import get_logger
+from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt
+from capreolus.utils.common import download_file, remove_newline
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Benchmark.register
+class CodeSearchNetCorpus(Benchmark):
+ """CodeSearchNet Corpus. [1]
+
+ [1] Hamel Husain, Ho-Hsiang Wu, Tiferet Gazit, Miltiadis Allamanis, and Marc Brockschmidt. 2019. CodeSearchNet Challenge: Evaluating the State of Semantic Code Search. arXiv 2019.
+ """
+
+ module_name = "codesearchnet_corpus"
+ dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")]
+ url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2"
+ query_type = "title"
+
+ file_fn = PACKAGE_PATH / "data" / "csn_corpus"
+
+ qrel_dir = file_fn / "qrels"
+ topic_dir = file_fn / "topics"
+ fold_dir = file_fn / "folds"
+
+ qidmap_dir = file_fn / "qidmap"
+ docidmap_dir = file_fn / "docidmap"
+
+ config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")]
+
+ def build(self):
+ lang = self.config["lang"]
+
+ self.qid_map_file = self.qidmap_dir / f"{lang}.json"
+ self.docid_map_file = self.docidmap_dir / f"{lang}.json"
+
+ self.qrel_file = self.qrel_dir / f"{lang}.txt"
+ self.topic_file = self.topic_dir / f"{lang}.txt"
+ self.fold_file = self.fold_dir / f"{lang}.json"
+
+ for file in [var for var in vars(self) if var.endswith("file")]:
+ getattr(self, file).parent.mkdir(exist_ok=True, parents=True)
+
+ self.download_if_missing()
+
+ @property
+ def qid_map(self):
+ if not hasattr(self, "_qid_map"):
+ if not self.qid_map_file.exists():
+ self.download_if_missing()
+
+ self._qid_map = json.load(open(self.qid_map_file, "r"))
+ return self._qid_map
+
+ @property
+ def docid_map(self):
+ if not hasattr(self, "_docid_map"):
+ if not self.docid_map_file.exists():
+ self.download_if_missing()
+
+ self._docid_map = json.load(open(self.docid_map_file, "r"))
+ return self._docid_map
+
+ def download_if_missing(self):
+ files = [self.qid_map_file, self.docid_map_file, self.qrel_file, self.topic_file, self.fold_file]
+ if all([f.exists() for f in files]):
+ return
+
+ lang = self.config["lang"]
+
+ tmp_dir = Path("/tmp")
+ zip_fn = tmp_dir / f"{lang}.zip"
+ if not zip_fn.exists():
+ download_file(f"{self.url}/{lang}.zip", zip_fn)
+
+ with ZipFile(zip_fn, "r") as zipobj:
+ zipobj.extractall(tmp_dir)
+
+ # prepare docid-url mapping from dedup.pkl
+ pkl_fn = tmp_dir / f"{lang}_dedupe_definitions_v2.pkl"
+ doc_objs = pickle.load(open(pkl_fn, "rb"))
+ self._docid_map = self._prep_docid_map(doc_objs)
+ assert self._get_n_docid() == len(doc_objs)
+
+ # prepare folds, qrels, topics, docstring2qid # TODO: shall we add negative samples?
+ qrels, self._qid_map = defaultdict(dict), {}
+ qids = {s: [] for s in ["train", "valid", "test"]}
+
+ topic_file = open(self.topic_file, "w", encoding="utf-8")
+ qrel_file = open(self.qrel_file, "w", encoding="utf-8")
+
+ def gen_doc_from_gzdir(dir):
+ """ generate parsed dict-format doc from all jsonl.gz files under given directory """
+ for fn in sorted(dir.glob("*.jsonl.gz")):
+ f = gzip.open(fn, "rb")
+ for doc in f:
+ yield json.loads(doc)
+
+ for set_name in qids:
+ set_path = tmp_dir / lang / "final" / "jsonl" / set_name
+ for doc in gen_doc_from_gzdir(set_path):
+ code = remove_newline(" ".join(doc["code_tokens"]))
+ docstring = remove_newline(" ".join(doc["docstring_tokens"]))
+ n_words_in_docstring = len(docstring.split())
+ if n_words_in_docstring >= 1024:
+ logger.warning(
+ f"chunk query to first 1000 words otherwise TooManyClause would be triggered "
+ f"at lucene at search stage, "
+ )
+ docstring = " ".join(docstring.split()[:1020]) # for TooManyClause
+
+ docid = self.get_docid(doc["url"], code)
+ qid = self._qid_map.get(docstring, str(len(self._qid_map)))
+ qrel_file.write(f"{qid} Q0 {docid} 1\n")
+
+ if docstring not in self._qid_map:
+ self._qid_map[docstring] = qid
+ qids[set_name].append(qid)
+ topic_file.write(topic_to_trectxt(qid, docstring))
+
+ topic_file.close()
+ qrel_file.close()
+
+ # write to qid_map.json, docid_map, fold.json
+ json.dump(self._qid_map, open(self.qid_map_file, "w"))
+ json.dump(self._docid_map, open(self.docid_map_file, "w"))
+ json.dump(
+ {"s1": {"train_qids": qids["train"], "predict": {"dev": qids["valid"], "test": qids["test"]}}},
+ open(self.fold_file, "w"),
+ )
+
+ def _prep_docid_map(self, doc_objs):
+ """
+ construct a nested dict to map each doc into a unique docid
+ which follows the structure: {url: {" ".join(code_tokens): docid, ...}}
+
+ For all the lanugage datasets the url uniquely maps to a code_tokens yet it's not the case for but js and php
+ which requires a second-level mapping from raw_doc to docid
+
+ :param doc_objs: a list of dict having keys ["nwo", "url", "sha", "identifier", "arguments"
+ "function", "function_tokens", "docstring", "doctring_tokens",],
+ :return:
+ """
+ # TODO: any way to avoid the twice traversal of all url and make the return dict structure consistent
+ lang = self.config["lang"]
+ url2docid = defaultdict(dict)
+ for i, doc in tqdm(enumerate(doc_objs), desc=f"Preparing the {lang} docid_map"):
+ url, code_tokens = doc["url"], remove_newline(" ".join(doc["function_tokens"]))
+ url2docid[url][code_tokens] = f"{lang}-FUNCTION-{i}"
+
+ # remove the code_tokens for the unique url-docid mapping
+ for url, docids in tqdm(url2docid.items(), desc=f"Compressing the {lang} docid_map"):
+ url2docid[url] = list(docids.values()) if len(docids) == 1 else docids # {code_tokens: docid} -> [docid]
+ return url2docid
+
+ def _get_n_docid(self):
+ """ calculate the number of document ids contained in the nested docid map """
+ lens = [len(docs) for url, docs in self._docid_map.items()]
+ return sum(lens)
+
+ def get_docid(self, url, code_tokens):
+ """ retrieve the doc id according to the doc dict """
+ docids = self.docid_map[url]
+ return docids[0] if len(docids) == 1 else docids[code_tokens]
+
+
+@Benchmark.register
+class CodeSearchNetChallenge(Benchmark):
+ """CodeSearchNet Challenge. [1]
+ This benchmark can only be used for training (and challenge submissions) because no qrels are provided.
+
+ [1] Hamel Husain, Ho-Hsiang Wu, Tiferet Gazit, Miltiadis Allamanis, and Marc Brockschmidt. 2019. CodeSearchNet Challenge: Evaluating the State of Semantic Code Search. arXiv 2019.
+ """
+
+ module_name = "codesearchnet_challenge"
+ dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")]
+ config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")]
+
+ url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/queries.csv"
+ query_type = "title"
+
+ file_fn = PACKAGE_PATH / "data" / "csn_challenge"
+ topic_file = file_fn / "topics.txt"
+ qid_map_file = file_fn / "qidmap.json"
+
+ def download_if_missing(self):
+ """ download query.csv and prepare queryid - query mapping file """
+ if self.topic_file.exists() and self.qid_map_file.exists():
+ return
+
+ tmp_dir = Path("/tmp")
+ tmp_dir.mkdir(exist_ok=True, parents=True)
+ self.file_fn.mkdir(exist_ok=True, parents=True)
+
+ query_fn = tmp_dir / f"query.csv"
+ if not query_fn.exists():
+ download_file(self.url, query_fn)
+
+ # prepare qid - query
+ qid_map = {}
+ topic_file = open(self.topic_file, "w", encoding="utf-8")
+ query_file = open(query_fn)
+ for qid, line in enumerate(query_file):
+ if qid != 0: # ignore the first line "query"
+ topic_file.write(topic_to_trectxt(qid, line.strip()))
+ qid_map[qid] = line
+ topic_file.close()
+ json.dump(qid_map, open(self.qid_map_file, "w"))
diff --git a/capreolus/benchmark/covid.py b/capreolus/benchmark/covid.py
new file mode 100644
index 000000000..97941166c
--- /dev/null
+++ b/capreolus/benchmark/covid.py
@@ -0,0 +1,192 @@
+import os
+import json
+from pathlib import Path
+
+from bs4 import BeautifulSoup
+
+from . import Benchmark
+from capreolus import constants, ConfigOption, Dependency
+from capreolus.utils.loginit import get_logger
+from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt
+from capreolus.utils.common import download_file, remove_newline, get_udel_query_expander
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Benchmark.register
+class COVID(Benchmark):
+ """ Ongoing TREC-COVID bechmark from https://ir.nist.gov/covidSubmit that uses documents from CORD, the COVID-19 Open Research Dataset (https://www.semanticscholar.org/cord19). """
+
+ module_name = "covid"
+ dependencies = [Dependency(key="collection", module="collection", name="covid")]
+ data_dir = PACKAGE_PATH / "data" / "covid"
+ topic_url = "https://ir.nist.gov/covidSubmit/data/topics-rnd%d.xml"
+ qrel_url = "https://ir.nist.gov/covidSubmit/data/qrels-rnd%d.txt"
+ lastest_round = 3
+
+ config_spec = [
+ ConfigOption("round", 3, "TREC-COVID round to use"),
+ ConfigOption("udelqexpand", False),
+ ConfigOption("excludeknown", True),
+ ]
+
+ def build(self):
+ if self.config["round"] == self.lastest_round and not self.config["excludeknown"]:
+ logger.warning(f"No evaluation can be done for the lastest round in exclude-known mode")
+
+ data_dir = self.get_cache_path() / "documents"
+ data_dir.mkdir(exist_ok=True, parents=True)
+
+ self.qrel_ignore = f"{data_dir}/ignore.qrel.txt"
+ self.qrel_file = f"{data_dir}/qrel.txt"
+ self.topic_file = f"{data_dir}/topic.txt"
+ self.fold_file = f"{data_dir}/fold.json"
+
+ self.download_if_missing()
+
+ def download_if_missing(self):
+ if all([os.path.exists(fn) for fn in [self.qrel_file, self.qrel_ignore, self.topic_file, self.fold_file]]):
+ return
+
+ rnd_i, excludeknown = self.config["round"], self.config["excludeknown"]
+ if rnd_i > self.lastest_round:
+ raise ValueError(f"round {rnd_i} is unavailable")
+
+ logger.info(f"Preparing files for covid round-{rnd_i}")
+
+ topic_url = self.topic_url % rnd_i
+ qrel_ignore_urls = [self.qrel_url % i for i in range(1, rnd_i)] # download all the qrels before current run
+
+ # topic file
+ tmp_dir = Path("/tmp")
+ topic_tmp = tmp_dir / f"topic.round.{rnd_i}.xml"
+ if not os.path.exists(topic_tmp):
+ download_file(topic_url, topic_tmp)
+ all_qids = self.xml2trectopic(topic_tmp) # will update self.topic_file
+
+ if excludeknown:
+ qrel_fn = open(self.qrel_file, "w")
+ for i, qrel_url in enumerate(qrel_ignore_urls):
+ qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1)
+ if not os.path.exists(qrel_tmp):
+ download_file(qrel_url, qrel_tmp)
+ with open(qrel_tmp) as f:
+ for line in f:
+ qrel_fn.write(line)
+ qrel_fn.close()
+
+ f = open(self.qrel_ignore, "w") # empty ignore file
+ f.close()
+ else:
+ qrel_fn = open(self.qrel_ignore, "w")
+ for i, qrel_url in enumerate(qrel_ignore_urls):
+ qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1)
+ if not os.path.exists(qrel_tmp):
+ download_file(qrel_url, qrel_tmp)
+ with open(qrel_tmp) as f:
+ for line in f:
+ qrel_fn.write(line)
+ qrel_fn.close()
+
+ if rnd_i == self.lastest_round:
+ f = open(self.qrel_file, "w")
+ f.close()
+ else:
+ with open(tmp_dir / f"qrel-{rnd_i}") as fin, open(self.qrel_file, "w") as fout:
+ for line in fin:
+ fout.write(line)
+
+ # folds: use all labeled query for train, valid, and use all of them for test set
+ labeled_qids = list(load_qrels(self.qrel_ignore).keys())
+ folds = {"s1": {"train_qids": labeled_qids, "predict": {"dev": labeled_qids, "test": all_qids}}}
+ json.dump(folds, open(self.fold_file, "w"))
+
+ def xml2trectopic(self, xmlfile):
+ with open(xmlfile, "r") as f:
+ topic = f.read()
+
+ all_qids = []
+ soup = BeautifulSoup(topic, "lxml")
+ topics = soup.find_all("topic")
+ expand_query = get_udel_query_expander()
+
+ with open(self.topic_file, "w") as fout:
+ for topic in topics:
+ qid = topic["number"]
+ title = topic.find_all("query")[0].text.strip()
+ desc = topic.find_all("question")[0].text.strip()
+ narr = topic.find_all("narrative")[0].text.strip()
+
+ if self.config["udelqexpand"]:
+ title = expand_query(title, rm_sw=True)
+ desc = expand_query(desc, rm_sw=False)
+
+ title = title + " " + desc
+ desc = " "
+
+ topic_line = topic_to_trectxt(qid, title, desc=desc, narr=narr)
+ fout.write(topic_line)
+ all_qids.append(qid)
+ return all_qids
+
+
+@Benchmark.register
+class CovidQA(Benchmark):
+ module_name = "covidqa"
+ dependencies = [Dependency(key="collection", module="collection", name="covid")]
+ url = "https://raw.githubusercontent.com/castorini/pygaggle/master/data/kaggle-lit-review-%s.json"
+ available_versions = ["0.1", "0.2"]
+
+ datadir = PACKAGE_PATH / "data" / "covidqa"
+
+ config_spec = [ConfigOption("version", "0.1+0.2")]
+
+ def build(self):
+ os.makedirs(self.datadir, exist_ok=True)
+
+ version = self.config["version"]
+ self.qrel_file = self.datadir / f"qrels.v{version}.txt"
+ self.topic_file = self.datadir / f"topics.v{version}.txt"
+ self.fold_file = self.datadir / f"v{version}.json" # HOW TO SPLIT THE FOLD HERE?
+
+ self.download_if_missing()
+
+ def download_if_missing(self):
+ if all([os.path.exists(f) for f in [self.qrel_file, self.topic_file, self.fold_file]]):
+ return
+
+ tmp_dir = Path("/tmp")
+ topic_f = open(self.topic_file, "w", encoding="utf-8")
+ qrel_f = open(self.qrel_file, "w", encoding="utf-8")
+
+ all_qids = []
+ qid = 2001 # to distingsuish queries here from queries in TREC-covid
+ versions = self.config["version"].split("+") if isinstance(self.config["version"], str) else str(self.config["version"])
+ for v in versions:
+ if v not in self.available_versions:
+ vs = " ".join(self.available_versions)
+ logger.warning(f"Invalid version {v}, should be one of {vs}")
+ continue
+
+ url = self.url % v
+ target_fn = tmp_dir / f"covidqa-v{v}.json"
+ if not os.path.exists(target_fn):
+ download_file(url, target_fn)
+ qa = json.load(open(target_fn))
+ for subcate in qa["categories"]:
+ name = subcate["name"]
+
+ for qa in subcate["sub_categories"]:
+ nq_name, kq_name = qa["nq_name"], qa["kq_name"]
+ query_line = topic_to_trectxt(qid, kq_name, nq_name) # kq_name == "query", nq_name == "question"
+ topic_f.write(query_line)
+ for ans in qa["answers"]:
+ docid = ans["id"]
+ qrel_f.write(f"{qid} Q0 {docid} 1\n")
+ all_qids.append(qid)
+ qid += 1
+
+ json.dump({"s1": {"train_qids": all_qids, "predict": {"dev": all_qids, "test": all_qids}}}, open(self.fold_file, "w"))
+ topic_f.close()
+ qrel_f.close()
diff --git a/capreolus/benchmark/dummy.py b/capreolus/benchmark/dummy.py
new file mode 100644
index 000000000..18d777283
--- /dev/null
+++ b/capreolus/benchmark/dummy.py
@@ -0,0 +1,16 @@
+from . import Benchmark
+from capreolus import constants, ConfigOption, Dependency
+
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Benchmark.register
+class DummyBenchmark(Benchmark):
+ """ Tiny benchmark for testing """
+
+ module_name = "dummy"
+ dependencies = [Dependency(key="collection", module="collection", name="dummy")]
+ qrel_file = PACKAGE_PATH / "data" / "qrels.dummy.txt"
+ topic_file = PACKAGE_PATH / "data" / "topics.dummy.txt"
+ fold_file = PACKAGE_PATH / "data" / "dummy_folds.json"
+ query_type = "title"
diff --git a/capreolus/benchmark/msmarco.py b/capreolus/benchmark/msmarco.py
new file mode 100644
index 000000000..39019abae
--- /dev/null
+++ b/capreolus/benchmark/msmarco.py
@@ -0,0 +1,15 @@
+from capreolus import constants, ConfigOption, Dependency
+from . import Benchmark
+
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+# TODO add download_if_missing and re-enable
+# @Benchmark.register
+# class MSMarcoPassage(Benchmark):
+# module_name = "msmarcopassage"
+# dependencies = [Dependency(key="collection", module="collection", name="msmarco")]
+# qrel_file = PACKAGE_PATH / "data" / "qrels.msmarcopassage.txt"
+# topic_file = PACKAGE_PATH / "data" / "topics.msmarcopassage.txt"
+# fold_file = PACKAGE_PATH / "data" / "msmarcopassage.folds.json"
+# query_type = "title"
diff --git a/capreolus/benchmark/nf.py b/capreolus/benchmark/nf.py
new file mode 100644
index 000000000..f8129b876
--- /dev/null
+++ b/capreolus/benchmark/nf.py
@@ -0,0 +1,114 @@
+import json
+import re
+
+from . import Benchmark
+from capreolus import constants, ConfigOption, Dependency
+from capreolus.utils.loginit import get_logger
+from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Benchmark.register
+class NF(Benchmark):
+ """ NFCorpus: A Full-Text Learning to Rank Dataset for Medical Information Retrieval [1]
+
+ [1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval Proceedings of the 38th European Conference on Information Retrieval (ECIR), Padova, Italy, 2016. https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/
+ """
+
+ module_name = "nf"
+ dependencies = [Dependency(key="collection", module="collection", name="nf")]
+ config_spec = [
+ ConfigOption(key="labelrange", default_value="0-2", description="range of dataset qrels, options: 0-2, 1-3"),
+ ConfigOption(
+ key="fields",
+ default_value="all_titles",
+ description="query fields included in topic file, "
+ "options: 'all_fields', 'all_titles', 'nontopics', 'vid_title', 'vid_desc'",
+ ),
+ ]
+
+ fold_file = PACKAGE_PATH / "data" / "nf.json"
+
+ query_type = "title"
+
+ def build(self):
+ fields, label_range = self.config["fields"], self.config["labelrange"]
+ self.field2kws = {
+ "all_fields": ["all"],
+ "nontopics": ["nontopic-titles"],
+ "vid_title": ["vid-titles"],
+ "vid_desc": ["vid-desc"],
+ "all_titles": ["nontopic-titles", "vid-titles", "nontopic-titles"],
+ }
+ self.labelrange2kw = {"0-2": "2-1-0", "1-3": "3-2-1"}
+
+ if fields not in self.field2kws:
+ raise ValueError(f"Unexpected fields value: {fields}, expect: {', '.join(self.field2kws.keys())}")
+ if label_range not in self.labelrange2kw:
+ raise ValueError(f"Unexpected label range: {label_range}, expect: {', '.join(self.field2kws.keys())}")
+
+ self.qrel_file = PACKAGE_PATH / "data" / f"qrels.nf.{label_range}.txt"
+ self.test_qrel_file = PACKAGE_PATH / "data" / f"test.qrels.nf.{label_range}.txt"
+ self.topic_file = PACKAGE_PATH / "data" / f"topics.nf.{fields}.txt"
+ self.download_if_missing()
+
+ def _transform_qid(self, raw):
+ """ NFCorpus dataset specific, remove prefix in query id since anserini convert all qid to integer """
+ return raw.replace("PLAIN-", "")
+
+ def download_if_missing(self):
+ if all([f.exists() for f in [self.topic_file, self.fold_file, self.qrel_file]]):
+ return
+
+ tmp_corpus_dir = self.collection.download_raw()
+ topic_f = open(self.topic_file, "w", encoding="utf-8")
+ qrel_f = open(self.qrel_file, "w", encoding="utf-8")
+ test_qrel_f = open(self.test_qrel_file, "w", encoding="utf-8")
+
+ set_names = ["train", "dev", "test"]
+ folds = {s: set() for s in set_names}
+ qrel_kw = self.labelrange2kw[self.config["labelrange"]]
+ for set_name in set_names:
+ with open(tmp_corpus_dir / f"{set_name}.{qrel_kw}.qrel") as f:
+ for line in f:
+ line = self._transform_qid(line)
+ qid = line.strip().split()[0]
+ folds[set_name].add(qid)
+ if set_name == "test":
+ test_qrel_f.write(line)
+ qrel_f.write(line)
+
+ files = [tmp_corpus_dir / f"{set_name}.{keyword}.queries" for keyword in self.field2kws[self.config["fields"]]]
+ qids2topics = self._align_queries(files, "title")
+
+ for qid, txts in qids2topics.items():
+ topic_f.write(topic_to_trectxt(qid, txts["title"]))
+
+ json.dump(
+ {"s1": {"train_qids": list(folds["train"]), "predict": {"dev": list(folds["dev"]), "test": list(folds["test"])}}},
+ open(self.fold_file, "w"),
+ )
+
+ topic_f.close()
+ qrel_f.close()
+ test_qrel_f.close()
+ logger.info(f"nf benchmark prepared")
+
+ def _align_queries(self, files, field, qid2queries=None):
+ if not qid2queries:
+ qid2queries = {}
+ for fn in files:
+ with open(fn, "r", encoding="utf-8") as f:
+ for line in f:
+ qid, txt = line.strip().split("\t")
+ qid = self._transform_qid(qid)
+ txt = " ".join(re.sub("[^A-Za-z]", " ", txt).split()[:1020])
+ if qid not in qid2queries:
+ qid2queries[qid] = {field: txt}
+ else:
+ if field in qid2queries[qid]:
+ logger.warning(f"Overwriting title for query {qid}")
+ qid2queries[qid][field] = txt
+ return qid2queries
diff --git a/capreolus/benchmark/robust04.py b/capreolus/benchmark/robust04.py
index d3545aaec..e13254782 100644
--- a/capreolus/benchmark/robust04.py
+++ b/capreolus/benchmark/robust04.py
@@ -1,4 +1,4 @@
-from profane import ModuleBase, Dependency, ConfigOption, constants
+from capreolus import constants, ConfigOption, Dependency
from . import Benchmark
PACKAGE_PATH = constants["PACKAGE_PATH"]
@@ -10,6 +10,7 @@ class Robust04(Benchmark):
Given the remaining four folds, we split them into the same train and dev sets used in recent work. [2]
[1] Samuel Huston and W. Bruce Croft. 2014. Parameters learned in the comparison of retrieval models using term dependencies. Technical Report.
+
[2] Sean MacAvaney, Andrew Yates, Arman Cohan, Nazli Goharian. 2019. CEDR: Contextualized Embeddings for Document Ranking. SIGIR 2019.
"""
@@ -19,3 +20,18 @@ class Robust04(Benchmark):
topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt"
fold_file = PACKAGE_PATH / "data" / "rob04_cedr_folds.json"
query_type = "title"
+
+
+@Benchmark.register
+class Robust04Yang19(Benchmark):
+ """Robust04 benchmark using the folds from Yang et al. [1]
+
+ [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. 2019. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019.
+ """
+
+ module_name = "robust04.yang19"
+ dependencies = [Dependency(key="collection", module="collection", name="robust04")]
+ qrel_file = PACKAGE_PATH / "data" / "qrels.robust2004.txt"
+ topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt"
+ fold_file = PACKAGE_PATH / "data" / "rob04_yang19_folds.json"
+ query_type = "title"
diff --git a/capreolus/collection/__init__.py b/capreolus/collection/__init__.py
index f559c9643..d5bc5e048 100644
--- a/capreolus/collection/__init__.py
+++ b/capreolus/collection/__init__.py
@@ -1,44 +1,38 @@
-from profane import import_all_modules
-
-# import_all_modules(__file__, __package__)
-
import os
-import math
-import shutil
-import pickle
-import tarfile
-import filecmp
-from tqdm import tqdm
-from zipfile import ZipFile
-from pathlib import Path
-import pandas as pd
-from profane import ModuleBase, Dependency, ConfigOption, constants
+from capreolus import ModuleBase, Dependency, ConfigOption, constants
-from capreolus.utils.common import download_file, hash_file, remove_newline
-from capreolus.utils.loginit import get_logger
-from capreolus.utils.trec import anserini_index_to_trec_docs, document_to_trectxt
-logger = get_logger(__name__)
-PACKAGE_PATH = constants["PACKAGE_PATH"]
+class Collection(ModuleBase):
+ """Base class for Collection modules. The purpose of a Collection is to describe a document collection's location and its format.
+ Determining the document collection's location on disk:
+ - The *path* config option will be used if it contains a valid loation.
+ - If not, the ``_path`` attribute is used if it is valid. This is primarily used with :class:`~.dummy.DummyCollection`.
+ - If not, the class' ``download_if_missing`` method will be called.
+
+ Modules should provide:
+ - the ``collection_type`` and ``generator_type`` class attributes, corresponding to Anserini types
+ - a ``download_if_missing`` method, if the collection is publicly available
+ - a ``_validate_document_path`` method. See :func:`~capreolus.collection.Collection.validate_document_path`.
+ """
-class Collection(ModuleBase):
module_type = "collection"
is_large_collection = False
_path = None
def get_path_and_types(self):
+ """ Returns a ``(path, collection_type, generator_type)`` tuple. """
if not self.validate_document_path(self._path):
self._path = self.find_document_path()
return self._path, self.collection_type, self.generator_type
def validate_document_path(self, path):
- """ Attempt to validate the document collection at `path`.
+ """ Attempt to validate the document collection at ``path``.
- By default, this will only check whether `path` exists. Subclasses should override
- `_validate_document_path(path)` with their own logic to perform more detailed checks.
+ By default, this will only check whether ``path`` exists. Subclasses should override
+ ``_validate_document_path(path)`` with their own logic to perform more detailed checks.
Returns:
True if the path is valid following the logic described above, or False if it is not
@@ -61,11 +55,11 @@ def _validate_document_path(self, path):
def find_document_path(self):
""" Find the location of this collection's documents (i.e., the raw document collection).
- We first check the collection's config for a path key. If found, `self.validate_document_path` checks
- whether the path is valid. Subclasses should override the private method `self._validate_document_path`
- with custom logic for performing checks further than existence of the directory. See `Robust04`.
+ We first check the collection's config for a path key. If found, ``self.validate_document_path`` checks
+ whether the path is valid. Subclasses should override the private method ``self._validate_document_path``
+ with custom logic for performing checks further than existence of the directory.
- If a valid path was not found, call `download_if_missing`.
+ If a valid path was not found, call ``download_if_missing``.
Subclasses should override this method if downloading the needed documents is possible.
If a valid document path cannot be found, an exception is thrown.
@@ -78,451 +72,25 @@ def find_document_path(self):
if "path" in self.config and self.validate_document_path(self.config["path"]):
return self.config["path"]
+ # see if the path is hardcoded (e.g., for the dummy collection")
+ if self._path and self.validate_document_path(self._path):
+ return self._path
+
# if not, see if the collection can be obtained through its download_if_missing method
return self.download_if_missing()
def download_if_missing(self):
+ """ Download the collection and return its path. Subclasses should override this. """
raise IOError(
f"a download URL is not configured for collection={self.module_name} and the collection path does not exist; you must manually place the document collection at this path in order to use this collection"
)
-@Collection.register
-class Robust04(Collection):
- module_name = "robust04"
- collection_type = "TrecCollection"
- generator_type = "DefaultLuceneDocumentGenerator"
- config_keys_not_in_path = ["path"]
- config_spec = [ConfigOption("path", "Aquaint-TREC-3-4", "path to corpus")]
-
- def download_if_missing(self):
- return self.download_index(
- url="https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-robust04-20191213.tar.gz",
- sha256="dddb81f16d70ea6b9b0f94d6d6b888ed2ef827109a14ca21fd82b2acd6cbd450",
- index_directory_inside="index-robust04-20191213/",
- # this string should match how the index was built (i.e., Anserini, stopwords removed, Porter stemming)
- index_cache_path_string="index-anserini_indexstops-False_stemmer-porter",
- index_expected_document_count=528_030,
- cachedir=self.get_cache_path(),
- )
-
- def _validate_document_path(self, path):
- """ Validate that the document path appears to contain robust04's documents (Aquaint-TREC-3-4).
-
- Validation is performed by looking for four directories (case-insensitive): `FBIS`, `FR94`, `FT`, and `LATIMES`.
- These directories may either be at the root of `path` or they may be in `path/NEWS_data` (case-insensitive).
-
- Returns:
- True if the Aquaint-TREC-3-4 document directories are found or False if not
- """
-
- if not os.path.isdir(path):
- return False
-
- contents = {fn.lower(): fn for fn in os.listdir(path)}
- if "news_data" in contents:
- contents = {fn.lower(): fn for fn in os.listdir(os.path.join(path, contents["news_data"]))}
-
- if "fbis" in contents and "fr94" in contents and "ft" in contents and "latimes" in contents:
- return True
-
- return False
-
- def download_index(
- self, cachedir, url, sha256, index_directory_inside, index_cache_path_string, index_expected_document_count
- ):
- # Download the collection from URL and extract into a path in the cache directory.
- # To avoid re-downloading every call, we create an empty '/done' file in this directory on success.
- done_file = os.path.join(cachedir, "done")
- document_dir = os.path.join(cachedir, "documents")
-
- # already downloaded?
- if os.path.exists(done_file):
- return document_dir
-
- # 1. Download and extract Anserini index to a temporary location
- tmp_dir = os.path.join(cachedir, "tmp_download")
- archive_file = os.path.join(tmp_dir, "archive_file")
- os.makedirs(document_dir, exist_ok=True)
- os.makedirs(tmp_dir, exist_ok=True)
- logger.info("downloading index for missing collection %s to temporary file %s", self.module_name, archive_file)
- download_file(url, archive_file, expected_hash=sha256)
-
- logger.info("extracting index to %s (before moving to correct cache path)", tmp_dir)
- with tarfile.open(archive_file) as tar:
- tar.extractall(path=tmp_dir)
-
- extracted_dir = os.path.join(tmp_dir, index_directory_inside)
- if not (os.path.exists(extracted_dir) and os.path.isdir(extracted_dir)):
- raise ValueError(f"could not find expected index directory {extracted_dir} in {tmp_dir}")
-
- # 2. Move index to its correct location in the cache
- index_dir = os.path.join(cachedir, index_cache_path_string, "index")
- if not os.path.exists(os.path.join("index_dir", "done")):
- if os.path.exists(index_dir):
- shutil.rmtree(index_dir)
- shutil.move(extracted_dir, index_dir)
-
- # 3. Extract raw documents from the Anserini index to document_dir
- anserini_index_to_trec_docs(index_dir, document_dir, index_expected_document_count)
-
- # remove temporary files and create a /done we can use to verify extraction was successful
- shutil.rmtree(tmp_dir)
- with open(done_file, "wt") as outf:
- print("", file=outf)
-
- return document_dir
-
-
-@Collection.register
-class DummyCollection(Collection):
- module_name = "dummy"
- _path = PACKAGE_PATH / "data" / "dummy" / "data"
- collection_type = "TrecCollection"
- generator_type = "DefaultLuceneDocumentGenerator"
-
- def _validate_document_path(self, path):
- """ Validate that the document path contains `dummy_trec_doc` """
- return "dummy_trec_doc" in os.listdir(path)
-
-
-@Collection.register
-class NF(Collection):
- module_name = "nf"
- _path = PACKAGE_PATH / "data" / "nf-collection"
- url = "http://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/nfcorpus.tar.gz"
-
- collection_type = "TrecCollection"
- generator_type = "DefaultLuceneDocumentGenerator"
-
- def download_raw(self):
- cachedir = self.get_cache_path()
- tmp_dir = cachedir / "tmp"
- tmp_tar_fn, tmp_corpus_dir = tmp_dir / "nfcorpus.tar.gz", tmp_dir / "nfcorpus"
-
- os.makedirs(tmp_dir, exist_ok=True)
-
- if not tmp_tar_fn.exists():
- download_file(self.url, tmp_tar_fn)
-
- with tarfile.open(tmp_tar_fn) as f:
- f.extractall(tmp_dir)
- return tmp_corpus_dir
-
- def download_if_missing(self):
- cachedir = self.get_cache_path()
- document_dir = os.path.join(cachedir, "documents")
- coll_filename = os.path.join(document_dir, "nf-collection.txt")
- if os.path.exists(coll_filename):
- return document_dir
-
- os.makedirs(document_dir, exist_ok=True)
- tmp_corpus_dir = self.download_raw()
-
- inp_fns = [tmp_corpus_dir / f"{set_name}.docs" for set_name in ["train", "dev", "test"]]
- print(inp_fns)
- with open(coll_filename, "w", encoding="utf-8") as outp_file:
- self._convert_to_trec(inp_fns, outp_file)
- logger.info(f"nf collection file prepared, stored at {coll_filename}")
-
- return document_dir
-
- def _convert_to_trec(self, inp_fns, outp_file):
- for inp_fn in inp_fns:
- assert os.path.exists(inp_fn)
-
- with open(inp_fn, "rt", encoding="utf-8") as f:
- for line in f:
- docid, doc = line.strip().split("\t")
- outp_file.write(f"\n{docid}\n\n{doc}\n\n\n")
-
-
-@Collection.register
-class ANTIQUE(Collection):
- module_name = "antique"
- _path = PACKAGE_PATH / "data" / "antique-collection"
-
- collection_type = "TrecCollection"
- generator_type = "DefaultLuceneDocumentGenerator"
-
- def download_if_missing(self):
- url = "http://ciir.cs.umass.edu/downloads/Antique/antique-collection.txt"
- cachedir = self.get_cache_path()
- document_dir = os.path.join(cachedir, "documents")
- coll_filename = os.path.join(document_dir, "antique-collection.txt")
-
- if os.path.exists(coll_filename):
- return document_dir
-
- tmp_dir = cachedir / "tmp"
- tmp_filename = os.path.join(tmp_dir, "tmp.anqique.file")
-
- os.makedirs(tmp_dir, exist_ok=True)
- os.makedirs(document_dir, exist_ok=True)
-
- download_file(url, tmp_filename, expected_hash="68b6688f5f2668c93f0e8e43384f66def768c4da46da4e9f7e2629c1c47a0c36")
- self._convert_to_trec(inp_path=tmp_filename, outp_path=coll_filename)
- logger.info(f"antique collection file prepared, stored at {coll_filename}")
-
- for file in os.listdir(tmp_dir): # in case there are legacy files
- os.remove(os.path.join(tmp_dir, file))
- shutil.rmtree(tmp_dir)
-
- return document_dir
-
- def _convert_to_trec(self, inp_path, outp_path):
- assert os.path.exists(inp_path)
-
- fout = open(outp_path, "wt", encoding="utf-8")
- with open(inp_path, "rt", encoding="utf-8") as f:
- for line in f:
- docid, doc = line.strip().split("\t")
- fout.write(f"\n{docid}\n\n{doc}\n\n\n")
- fout.close()
- logger.debug(f"Converted file {os.path.basename(inp_path)} to TREC format, output to: {outp_path}")
-
- def _validate_document_path(self, path):
- """ Checks that the sha256sum is correct """
- return (
- hash_file(os.path.join(path, "antique-collection.txt"))
- == "409e0960f918970977ceab9e5b1d372f45395af25d53b95644bdc9ccbbf973da"
- )
-
-
-@Collection.register
-class MSMarco(Collection):
- module_name = "msmarco"
- config_keys_not_in_path = ["path"]
- collection_type = "TrecCollection"
- generator_type = "DefaultLuceneDocumentGenerator"
- config_spec = [ConfigOption("path", "/GW/NeuralIR/nobackup/msmarco/trec_format", "path to corpus")]
-
-
-@Collection.register
-class CodeSearchNet(Collection):
- module_name = "codesearchnet"
- url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2"
- collection_type = "TrecCollection" # TODO: any other supported type?
- generator_type = "DefaultLuceneDocumentGenerator"
- config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")]
-
- def download_if_missing(self):
- cachedir = self.get_cache_path()
- document_dir = cachedir / "documents"
- coll_filename = document_dir / ("csn-" + self.config["lang"] + "-collection.txt")
-
- if coll_filename.exists():
- return document_dir
-
- zipfile = self.config["lang"] + ".zip"
- lang_url = f"{self.url}/{zipfile}"
- tmp_dir = cachedir / "tmp"
- zip_path = tmp_dir / zipfile
-
- if zip_path.exists():
- logger.info(f"{zipfile} already exist under directory {tmp_dir}, skip downloaded")
- else:
- tmp_dir.mkdir(exist_ok=True, parents=True)
- download_file(lang_url, zip_path)
-
- document_dir.mkdir(exist_ok=True, parents=True) # tmp
- with ZipFile(zip_path, "r") as zipobj:
- zipobj.extractall(tmp_dir)
-
- pkl_path = tmp_dir / (self.config["lang"] + "_dedupe_definitions_v2.pkl")
- self._pkl2trec(pkl_path, coll_filename)
- return document_dir
-
- def _pkl2trec(self, pkl_path, trec_path):
- lang = self.config["lang"]
- with open(pkl_path, "rb") as f:
- codes = pickle.load(f)
-
- fout = open(trec_path, "w", encoding="utf-8")
- for i, code in tqdm(enumerate(codes), desc=f"Preparing the {lang} collection file"):
- docno = f"{lang}-FUNCTION-{i}"
- doc = remove_newline(" ".join(code["function_tokens"]))
- fout.write(document_to_trectxt(docno, doc))
- fout.close()
-
-
-@Collection.register
-class COVID(Collection):
- module_name = "covid"
- url = "https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/historical_releases/cord-19_%s.tar.gz"
- generator_type = "Cord19Generator"
- config_spec = [ConfigOption("coll_type", "abstract", "one of: abstract, fulltext, paragraph"), ConfigOption("round", 3)]
-
- def build(self):
- coll_type, round = self.config["coll_type"], self.config["round"]
- type2coll = {
- "abstract": "Cord19AbstractCollection",
- "fulltext": "Cord19FullTextCollection",
- "paragraph": "Cord19ParagraphCollection",
- }
- dates = ["2020-04-10", "2020-05-01", "2020-05-19"]
-
- if coll_type not in type2coll:
- raise ValueError(f"Unexpected coll_type: {coll_type}; expeced one of: {' '.join(type2coll.keys())}")
- if round > len(dates):
- raise ValueError(f"Unexpected round number: {round}; only {len(dates)} number of rounds are provided")
-
- self.collection_type = type2coll[coll_type]
- self.date = dates[round - 1]
-
- def download_if_missing(self):
- cachedir = self.get_cache_path()
- tmp_dir, document_dir = Path("/tmp"), cachedir / "documents"
- expected_fns = [document_dir / "metadata.csv", document_dir / "document_parses"]
- if all([os.path.exists(f) for f in expected_fns]):
- return document_dir
-
- url = self.url % self.date
- tar_file = tmp_dir / f"covid-19-{self.date}.tar.gz"
- if not tar_file.exists():
- download_file(url, tar_file)
-
- with tarfile.open(tar_file) as f:
- f.extractall(path=cachedir) # emb.tar.gz, metadata.csv, doc.tar.gz, changelog
- os.rename(cachedir / self.date, document_dir)
-
- doc_fn = "document_parses"
- if f"{doc_fn}.tar.gz" in os.listdir(document_dir):
- with tarfile.open(document_dir / f"{doc_fn}.tar.gz") as f:
- f.extractall(path=document_dir)
- else:
- self.transform_metadata(document_dir)
-
- # only document_parses and metadata.csv are expected
- for fn in os.listdir(document_dir):
- if (document_dir / fn) not in expected_fns:
- os.remove(document_dir / fn)
- return document_dir
-
- def transform_metadata(self, root_path):
- """
- the transformation is necessary for dataset round 1 and 2 according to
- https://discourse.cord-19.semanticscholar.org/t/faqs-about-cord-19-dataset/94
-
- the assumed directory under root_path:
- ./root_path
- ./metadata.csv
- ./comm_use_subset
- ./noncomm_use_subset
- ./custom_license
- ./biorxiv_medrxiv
- ./archive
-
- In a nutshell:
- 1. renaming:
- Microsoft Academic Paper ID -> mag_id;
- WHO #Covidence -> who_covidence_id
- 2. update:
- has_pdf_parse -> pdf_json_files # e.g. document_parses/pmc_json/PMC125340.xml.json
- has_pmc_xml_parse -> pmc_json_files
- """
- metadata_csv = str(root_path / "metadata.csv")
- orifiles = ["arxiv", "custom_license", "biorxiv_medrxiv", "comm_use_subset", "noncomm_use_subset"]
- for fn in orifiles:
- if (root_path / fn).exists():
- continue
-
- tar_fn = root_path / f"{fn}.tar.gz"
- if not tar_fn.exists():
- continue
-
- with tarfile.open(str(tar_fn)) as f:
- f.extractall(path=root_path)
- os.remove(tar_fn)
-
- metadata = pd.read_csv(metadata_csv, header=0)
- columns = metadata.columns.values
- cols_before = [
- "cord_uid",
- "sha",
- "source_x",
- "title",
- "doi",
- "pmcid",
- "pubmed_id",
- "license",
- "abstract",
- "publish_time",
- "authors",
- "journal",
- "Microsoft Academic Paper ID",
- "WHO #Covidence",
- "arxiv_id",
- "has_pdf_parse",
- "has_pmc_xml_parse",
- "full_text_file",
- "url",
- ]
- assert all(columns == cols_before)
-
- # step 1: rename column
- cols_to_rename = {"Microsoft Academic Paper ID": "mag_id", "WHO #Covidence": "who_covidence_id"}
- metadata.columns = [cols_to_rename.get(c, c) for c in columns]
-
- # step 2: parse path & move json file
- doc_outp = root_path / "document_parses"
- pdf_dir, pmc_dir = doc_outp / "pdf_json", doc_outp / "pmc_json"
- pdf_dir.mkdir(exist_ok=True, parents=True)
- pmc_dir.mkdir(exist_ok=True, parents=True)
-
- new_cols = ["pdf_json_files", "pmc_json_files"]
- for col in new_cols:
- metadata[col] = ""
- metadata["s2_id"] = math.nan # tmp, what's this column??
-
- iterbar = tqdm(desc="transforming data", total=len(metadata))
- for i, row in metadata.iterrows():
- dir = row["full_text_file"]
-
- if row["has_pmc_xml_parse"]:
- name = row["pmcid"] + ".xml.json"
- ori_fn = root_path / dir / "pmc_json" / name
- pmc_fn = f"document_parses/pmc_json/{name}"
- metadata.at[i, "pmc_json_files"] = pmc_fn
- pmc_fn = root_path / pmc_fn
- if not pmc_fn.exists():
- os.rename(ori_fn, pmc_fn) # check
- else:
- metadata.at[i, "pmc_json_files"] = math.nan
-
- if row["has_pdf_parse"]:
- shas = str(row["sha"]).split(";")
- pdf_fn_final = ""
- for sha in shas:
- name = sha.strip() + ".json"
- ori_fn = root_path / dir / "pdf_json" / name
- pdf_fn = f"document_parses/pdf_json/{name}"
- pdf_fn_final = f"{pdf_fn_final};{pdf_fn}" if pdf_fn_final else pdf_fn
- pdf_fn = root_path / pdf_fn
- if not pdf_fn.exists():
- os.rename(ori_fn, pdf_fn) # check
- else:
- if ori_fn.exists():
- assert filecmp.cmp(ori_fn, pdf_fn)
- os.remove(ori_fn)
-
- metadata.at[i, "pdf_json_files"] = pdf_fn_final
- else:
- metadata.at[i, "pdf_json_files"] = math.nan
-
- iterbar.update()
-
- # step 3: remove deprecated columns, remove unwanted directories
- cols_to_remove = ["has_pdf_parse", "has_pmc_xml_parse", "full_text_file"]
- metadata.drop(columns=cols_to_remove)
+from profane import import_all_modules
- dir_to_remove = ["comm_use_subset", "noncomm_use_subset", "custom_license", "biorxiv_medrxiv", "arxiv"]
- for dir in dir_to_remove:
- dir = root_path / dir
- for subdir in os.listdir(dir):
- os.rmdir(dir / subdir) # since we are supposed to move away all the files
- os.rmdir(dir)
+from .dummy import DummyCollection
+from .antique import ANTIQUE
+from .nf import NF
+from .robust04 import Robust04
- # assert len(metadata.columns) == 19
- # step 4: save back
- metadata.to_csv(metadata_csv, index=False)
+import_all_modules(__file__, __package__)
diff --git a/capreolus/collection/antique.py b/capreolus/collection/antique.py
new file mode 100644
index 000000000..a8eca5be3
--- /dev/null
+++ b/capreolus/collection/antique.py
@@ -0,0 +1,67 @@
+import os
+import shutil
+
+from . import Collection
+from capreolus import ModuleBase, Dependency, ConfigOption, constants
+from capreolus.utils.common import download_file, hash_file
+from capreolus.utils.loginit import get_logger
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Collection.register
+class ANTIQUE(Collection):
+ """A Non-factoid Question Answering Benchmark from Hashemi et al. [1]
+
+ [1] Helia Hashemi, Mohammad Aliannejadi, Hamed Zamani, and W. Bruce Croft. 2020. ANTIQUE: A non-factoid question answering benchmark. ECIR 2020.
+ """
+
+ module_name = "antique"
+ _path = PACKAGE_PATH / "data" / "antique-collection"
+
+ collection_type = "TrecCollection"
+ generator_type = "DefaultLuceneDocumentGenerator"
+
+ def download_if_missing(self):
+ url = "http://ciir.cs.umass.edu/downloads/Antique/antique-collection.txt"
+ cachedir = self.get_cache_path()
+ document_dir = os.path.join(cachedir, "documents")
+ coll_filename = os.path.join(document_dir, "antique-collection.txt")
+
+ if os.path.exists(coll_filename):
+ return document_dir
+
+ tmp_dir = cachedir / "tmp"
+ tmp_filename = os.path.join(tmp_dir, "tmp.anqique.file")
+
+ os.makedirs(tmp_dir, exist_ok=True)
+ os.makedirs(document_dir, exist_ok=True)
+
+ download_file(url, tmp_filename, expected_hash="68b6688f5f2668c93f0e8e43384f66def768c4da46da4e9f7e2629c1c47a0c36")
+ self._convert_to_trec(inp_path=tmp_filename, outp_path=coll_filename)
+ logger.info(f"antique collection file prepared, stored at {coll_filename}")
+
+ for file in os.listdir(tmp_dir): # in case there are legacy files
+ os.remove(os.path.join(tmp_dir, file))
+ shutil.rmtree(tmp_dir)
+
+ return document_dir
+
+ def _convert_to_trec(self, inp_path, outp_path):
+ assert os.path.exists(inp_path)
+
+ fout = open(outp_path, "wt", encoding="utf-8")
+ with open(inp_path, "rt", encoding="utf-8") as f:
+ for line in f:
+ docid, doc = line.strip().split("\t")
+ fout.write(f"\n{docid}\n\n{doc}\n\n\n")
+ fout.close()
+ logger.debug(f"Converted file {os.path.basename(inp_path)} to TREC format, output to: {outp_path}")
+
+ def _validate_document_path(self, path):
+ """ Checks that the sha256sum is correct """
+ return (
+ hash_file(os.path.join(path, "antique-collection.txt"))
+ == "409e0960f918970977ceab9e5b1d372f45395af25d53b95644bdc9ccbbf973da"
+ )
diff --git a/capreolus/collection/codesearchnet.py b/capreolus/collection/codesearchnet.py
new file mode 100644
index 000000000..48c0f3675
--- /dev/null
+++ b/capreolus/collection/codesearchnet.py
@@ -0,0 +1,66 @@
+import pickle
+
+from tqdm import tqdm
+from zipfile import ZipFile
+
+from . import Collection
+from capreolus import ModuleBase, Dependency, ConfigOption, constants
+from capreolus.utils.common import download_file, hash_file, remove_newline
+from capreolus.utils.loginit import get_logger
+from capreolus.utils.trec import document_to_trectxt
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Collection.register
+class CodeSearchNet(Collection):
+ """CodeSearchNet Corpus. [1]
+
+ [1] Hamel Husain, Ho-Hsiang Wu, Tiferet Gazit, Miltiadis Allamanis, and Marc Brockschmidt. 2019. CodeSearchNet Challenge: Evaluating the State of Semantic Code Search. arXiv 2019.
+ """
+
+ module_name = "codesearchnet"
+ url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2"
+ collection_type = "TrecCollection" # TODO: any other supported type?
+ generator_type = "DefaultLuceneDocumentGenerator"
+ config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")]
+
+ def download_if_missing(self):
+ cachedir = self.get_cache_path()
+ document_dir = cachedir / "documents"
+ coll_filename = document_dir / ("csn-" + self.config["lang"] + "-collection.txt")
+
+ if coll_filename.exists():
+ return document_dir
+
+ zipfile = self.config["lang"] + ".zip"
+ lang_url = f"{self.url}/{zipfile}"
+ tmp_dir = cachedir / "tmp"
+ zip_path = tmp_dir / zipfile
+
+ if zip_path.exists():
+ logger.info(f"{zipfile} already exist under directory {tmp_dir}, skip downloaded")
+ else:
+ tmp_dir.mkdir(exist_ok=True, parents=True)
+ download_file(lang_url, zip_path)
+
+ document_dir.mkdir(exist_ok=True, parents=True) # tmp
+ with ZipFile(zip_path, "r") as zipobj:
+ zipobj.extractall(tmp_dir)
+
+ pkl_path = tmp_dir / (self.config["lang"] + "_dedupe_definitions_v2.pkl")
+ self._pkl2trec(pkl_path, coll_filename)
+ return document_dir
+
+ def _pkl2trec(self, pkl_path, trec_path):
+ lang = self.config["lang"]
+ with open(pkl_path, "rb") as f:
+ codes = pickle.load(f)
+
+ fout = open(trec_path, "w", encoding="utf-8")
+ for i, code in tqdm(enumerate(codes), desc=f"Preparing the {lang} collection file"):
+ docno = f"{lang}-FUNCTION-{i}"
+ doc = remove_newline(" ".join(code["function_tokens"]))
+ fout.write(document_to_trectxt(docno, doc))
+ fout.close()
diff --git a/capreolus/collection/covid.py b/capreolus/collection/covid.py
new file mode 100644
index 000000000..324d3c174
--- /dev/null
+++ b/capreolus/collection/covid.py
@@ -0,0 +1,200 @@
+import os
+import math
+import tarfile
+import filecmp
+
+from tqdm import tqdm
+from pathlib import Path
+import pandas as pd
+
+from . import Collection
+from capreolus import ModuleBase, Dependency, ConfigOption, constants
+from capreolus.utils.common import download_file, hash_file, remove_newline
+from capreolus.utils.loginit import get_logger
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Collection.register
+class COVID(Collection):
+ """ The COVID-19 Open Research Dataset (https://www.semanticscholar.org/cord19) """
+
+ module_name = "covid"
+ url = "https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/historical_releases/cord-19_%s.tar.gz"
+ generator_type = "Cord19Generator"
+ config_spec = [ConfigOption("coll_type", "abstract", "one of: abstract, fulltext, paragraph"), ConfigOption("round", 3)]
+
+ def build(self):
+ coll_type, round = self.config["coll_type"], self.config["round"]
+ type2coll = {
+ "abstract": "Cord19AbstractCollection",
+ "fulltext": "Cord19FullTextCollection",
+ "paragraph": "Cord19ParagraphCollection",
+ }
+ dates = ["2020-04-10", "2020-05-01", "2020-05-19"]
+
+ if coll_type not in type2coll:
+ raise ValueError(f"Unexpected coll_type: {coll_type}; expeced one of: {' '.join(type2coll.keys())}")
+ if round > len(dates):
+ raise ValueError(f"Unexpected round number: {round}; only {len(dates)} number of rounds are provided")
+
+ self.collection_type = type2coll[coll_type]
+ self.date = dates[round - 1]
+
+ def download_if_missing(self):
+ cachedir = self.get_cache_path()
+ tmp_dir, document_dir = Path("/tmp"), cachedir / "documents"
+ expected_fns = [document_dir / "metadata.csv", document_dir / "document_parses"]
+ if all([os.path.exists(f) for f in expected_fns]):
+ return document_dir
+
+ url = self.url % self.date
+ tar_file = tmp_dir / f"covid-19-{self.date}.tar.gz"
+ if not tar_file.exists():
+ download_file(url, tar_file)
+
+ with tarfile.open(tar_file) as f:
+ f.extractall(path=cachedir) # emb.tar.gz, metadata.csv, doc.tar.gz, changelog
+ os.rename(cachedir / self.date, document_dir)
+
+ doc_fn = "document_parses"
+ if f"{doc_fn}.tar.gz" in os.listdir(document_dir):
+ with tarfile.open(document_dir / f"{doc_fn}.tar.gz") as f:
+ f.extractall(path=document_dir)
+ else:
+ self.transform_metadata(document_dir)
+
+ # only document_parses and metadata.csv are expected
+ for fn in os.listdir(document_dir):
+ if (document_dir / fn) not in expected_fns:
+ os.remove(document_dir / fn)
+ return document_dir
+
+ def transform_metadata(self, root_path):
+ """
+ the transformation is necessary for dataset round 1 and 2 according to
+ https://discourse.cord-19.semanticscholar.org/t/faqs-about-cord-19-dataset/94
+
+ the assumed directory under root_path:
+ ./root_path
+ ./metadata.csv
+ ./comm_use_subset
+ ./noncomm_use_subset
+ ./custom_license
+ ./biorxiv_medrxiv
+ ./archive
+
+ In a nutshell:
+ 1. renaming:
+ Microsoft Academic Paper ID -> mag_id;
+ WHO #Covidence -> who_covidence_id
+ 2. update:
+ has_pdf_parse -> pdf_json_files # e.g. document_parses/pmc_json/PMC125340.xml.json
+ has_pmc_xml_parse -> pmc_json_files
+ """
+ metadata_csv = str(root_path / "metadata.csv")
+ orifiles = ["arxiv", "custom_license", "biorxiv_medrxiv", "comm_use_subset", "noncomm_use_subset"]
+ for fn in orifiles:
+ if (root_path / fn).exists():
+ continue
+
+ tar_fn = root_path / f"{fn}.tar.gz"
+ if not tar_fn.exists():
+ continue
+
+ with tarfile.open(str(tar_fn)) as f:
+ f.extractall(path=root_path)
+ os.remove(tar_fn)
+
+ metadata = pd.read_csv(metadata_csv, header=0)
+ columns = metadata.columns.values
+ cols_before = [
+ "cord_uid",
+ "sha",
+ "source_x",
+ "title",
+ "doi",
+ "pmcid",
+ "pubmed_id",
+ "license",
+ "abstract",
+ "publish_time",
+ "authors",
+ "journal",
+ "Microsoft Academic Paper ID",
+ "WHO #Covidence",
+ "arxiv_id",
+ "has_pdf_parse",
+ "has_pmc_xml_parse",
+ "full_text_file",
+ "url",
+ ]
+ assert all(columns == cols_before)
+
+ # step 1: rename column
+ cols_to_rename = {"Microsoft Academic Paper ID": "mag_id", "WHO #Covidence": "who_covidence_id"}
+ metadata.columns = [cols_to_rename.get(c, c) for c in columns]
+
+ # step 2: parse path & move json file
+ doc_outp = root_path / "document_parses"
+ pdf_dir, pmc_dir = doc_outp / "pdf_json", doc_outp / "pmc_json"
+ pdf_dir.mkdir(exist_ok=True, parents=True)
+ pmc_dir.mkdir(exist_ok=True, parents=True)
+
+ new_cols = ["pdf_json_files", "pmc_json_files"]
+ for col in new_cols:
+ metadata[col] = ""
+ metadata["s2_id"] = math.nan # tmp, what's this column??
+
+ iterbar = tqdm(desc="transforming data", total=len(metadata))
+ for i, row in metadata.iterrows():
+ dir = row["full_text_file"]
+
+ if row["has_pmc_xml_parse"]:
+ name = row["pmcid"] + ".xml.json"
+ ori_fn = root_path / dir / "pmc_json" / name
+ pmc_fn = f"document_parses/pmc_json/{name}"
+ metadata.at[i, "pmc_json_files"] = pmc_fn
+ pmc_fn = root_path / pmc_fn
+ if not pmc_fn.exists():
+ os.rename(ori_fn, pmc_fn) # check
+ else:
+ metadata.at[i, "pmc_json_files"] = math.nan
+
+ if row["has_pdf_parse"]:
+ shas = str(row["sha"]).split(";")
+ pdf_fn_final = ""
+ for sha in shas:
+ name = sha.strip() + ".json"
+ ori_fn = root_path / dir / "pdf_json" / name
+ pdf_fn = f"document_parses/pdf_json/{name}"
+ pdf_fn_final = f"{pdf_fn_final};{pdf_fn}" if pdf_fn_final else pdf_fn
+ pdf_fn = root_path / pdf_fn
+ if not pdf_fn.exists():
+ os.rename(ori_fn, pdf_fn) # check
+ else:
+ if ori_fn.exists():
+ assert filecmp.cmp(ori_fn, pdf_fn)
+ os.remove(ori_fn)
+
+ metadata.at[i, "pdf_json_files"] = pdf_fn_final
+ else:
+ metadata.at[i, "pdf_json_files"] = math.nan
+
+ iterbar.update()
+
+ # step 3: remove deprecated columns, remove unwanted directories
+ cols_to_remove = ["has_pdf_parse", "has_pmc_xml_parse", "full_text_file"]
+ metadata.drop(columns=cols_to_remove)
+
+ dir_to_remove = ["comm_use_subset", "noncomm_use_subset", "custom_license", "biorxiv_medrxiv", "arxiv"]
+ for dir in dir_to_remove:
+ dir = root_path / dir
+ for subdir in os.listdir(dir):
+ os.rmdir(dir / subdir) # since we are supposed to move away all the files
+ os.rmdir(dir)
+
+ # assert len(metadata.columns) == 19
+ # step 4: save back
+ metadata.to_csv(metadata_csv, index=False)
diff --git a/capreolus/collection/dummy.py b/capreolus/collection/dummy.py
new file mode 100644
index 000000000..baa9d41ca
--- /dev/null
+++ b/capreolus/collection/dummy.py
@@ -0,0 +1,21 @@
+import os
+
+from . import Collection
+from capreolus import constants, get_logger
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Collection.register
+class DummyCollection(Collection):
+ """ Tiny collection for testing """
+
+ module_name = "dummy"
+ _path = PACKAGE_PATH / "data" / "dummy" / "data"
+ collection_type = "TrecCollection"
+ generator_type = "DefaultLuceneDocumentGenerator"
+
+ def _validate_document_path(self, path):
+ """ Validate that the document path contains `dummy_trec_doc` """
+ return "dummy_trec_doc" in os.listdir(path)
diff --git a/capreolus/collection/msmarco.py b/capreolus/collection/msmarco.py
new file mode 100644
index 000000000..e1f8ecbe4
--- /dev/null
+++ b/capreolus/collection/msmarco.py
@@ -0,0 +1,10 @@
+from . import Collection
+from capreolus import ConfigOption
+
+# @Collection.register
+# class MSMarco(Collection):
+# module_name = "msmarco"
+# config_keys_not_in_path = ["path"]
+# collection_type = "TrecCollection"
+# generator_type = "DefaultLuceneDocumentGenerator"
+# config_spec = [ConfigOption("path", "/GW/NeuralIR/nobackup/msmarco/trec_format", "path to corpus")]
diff --git a/capreolus/collection/nf.py b/capreolus/collection/nf.py
new file mode 100644
index 000000000..a1afc40cd
--- /dev/null
+++ b/capreolus/collection/nf.py
@@ -0,0 +1,71 @@
+import os
+import tarfile
+
+from . import Collection
+from capreolus import constants, get_logger
+from capreolus.utils.common import download_file
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Collection.register
+class NF(Collection):
+ """ NFCorpus: A Full-Text Learning to Rank Dataset for Medical Information Retrieval [1]
+
+ [1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval Proceedings of the 38th European Conference on Information Retrieval (ECIR), Padova, Italy, 2016. https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/
+ """
+
+ module_name = "nf"
+ _path = PACKAGE_PATH / "data" / "nf-collection"
+ url = "http://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/nfcorpus.tar.gz"
+
+ collection_type = "TrecCollection"
+ generator_type = "DefaultLuceneDocumentGenerator"
+
+ def download_raw(self):
+ cachedir = self.get_cache_path()
+ tmp_dir = cachedir / "tmp"
+ tmp_tar_fn, tmp_corpus_dir = tmp_dir / "nfcorpus.tar.gz", tmp_dir / "nfcorpus"
+
+ os.makedirs(tmp_dir, exist_ok=True)
+
+ if not tmp_tar_fn.exists():
+ download_file(self.url, tmp_tar_fn, "ebc026d4a8bef3f866148b727e945a2073eb4045ede9b7de95dd50fd086b4256")
+
+ with tarfile.open(tmp_tar_fn) as f:
+ f.extractall(tmp_dir)
+ return tmp_corpus_dir
+
+ def download_if_missing(self):
+ cachedir = self.get_cache_path()
+ document_dir = os.path.join(cachedir, "documents")
+ coll_filename = os.path.join(document_dir, "nf-collection.txt")
+ if os.path.exists(coll_filename):
+ return document_dir
+
+ os.makedirs(document_dir, exist_ok=True)
+ tmp_corpus_dir = self.download_raw()
+
+ inp_fns = [tmp_corpus_dir / f"{set_name}.docs" for set_name in ["train", "dev", "test"]]
+ print(inp_fns)
+ with open(coll_filename, "w", encoding="utf-8") as outp_file:
+ self._convert_to_trec(inp_fns, outp_file)
+ logger.info(f"nf collection file prepared, stored at {coll_filename}")
+
+ return document_dir
+
+ def _convert_to_trec(self, inp_fns, outp_file):
+ # train.docs, dev.docs, and test.docs have some overlap, so we check for duplicate docids
+ seen_docids = set()
+
+ for inp_fn in inp_fns:
+ assert os.path.exists(inp_fn)
+
+ with open(inp_fn, "rt", encoding="utf-8") as f:
+ for line in f:
+ docid, doc = line.strip().split("\t")
+
+ if docid not in seen_docids:
+ outp_file.write(f"\n{docid}\n\n{doc}\n\n\n")
+ seen_docids.add(docid)
diff --git a/capreolus/collection/robust04.py b/capreolus/collection/robust04.py
new file mode 100644
index 000000000..b5ac211c0
--- /dev/null
+++ b/capreolus/collection/robust04.py
@@ -0,0 +1,101 @@
+import os
+import shutil
+import tarfile
+
+from . import Collection
+from capreolus import ModuleBase, Dependency, ConfigOption, constants
+from capreolus.utils.common import download_file, hash_file, remove_newline
+from capreolus.utils.loginit import get_logger
+from capreolus.utils.trec import anserini_index_to_trec_docs, document_to_trectxt
+
+logger = get_logger(__name__)
+PACKAGE_PATH = constants["PACKAGE_PATH"]
+
+
+@Collection.register
+class Robust04(Collection):
+ """ TREC Robust04 (TREC disks 4 and 5 without the Congressional Record documents) """
+
+ module_name = "robust04"
+ collection_type = "TrecCollection"
+ generator_type = "DefaultLuceneDocumentGenerator"
+ config_keys_not_in_path = ["path"]
+ config_spec = [ConfigOption("path", "Aquaint-TREC-3-4", "path to corpus")]
+
+ def download_if_missing(self):
+ return self.download_index(
+ url="https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-robust04-20191213.tar.gz",
+ sha256="dddb81f16d70ea6b9b0f94d6d6b888ed2ef827109a14ca21fd82b2acd6cbd450",
+ index_directory_inside="index-robust04-20191213/",
+ # this string should match how the index was built (i.e., Anserini, stopwords removed, Porter stemming)
+ index_cache_path_string="index-anserini_indexstops-False_stemmer-porter",
+ index_expected_document_count=528_030,
+ cachedir=self.get_cache_path(),
+ )
+
+ def _validate_document_path(self, path):
+ """ Validate that the document path appears to contain robust04's documents (Aquaint-TREC-3-4).
+
+ Validation is performed by looking for four directories (case-insensitive): `FBIS`, `FR94`, `FT`, and `LATIMES`.
+ These directories may either be at the root of `path` or they may be in `path/NEWS_data` (case-insensitive).
+
+ Returns:
+ True if the Aquaint-TREC-3-4 document directories are found or False if not
+ """
+
+ if not os.path.isdir(path):
+ return False
+
+ contents = {fn.lower(): fn for fn in os.listdir(path)}
+ if "news_data" in contents:
+ contents = {fn.lower(): fn for fn in os.listdir(os.path.join(path, contents["news_data"]))}
+
+ if "fbis" in contents and "fr94" in contents and "ft" in contents and "latimes" in contents:
+ return True
+
+ return False
+
+ def download_index(
+ self, cachedir, url, sha256, index_directory_inside, index_cache_path_string, index_expected_document_count
+ ):
+ # Download the collection from URL and extract into a path in the cache directory.
+ # To avoid re-downloading every call, we create an empty '/done' file in this directory on success.
+ done_file = os.path.join(cachedir, "done")
+ document_dir = os.path.join(cachedir, "documents")
+
+ # already downloaded?
+ if os.path.exists(done_file):
+ return document_dir
+
+ # 1. Download and extract Anserini index to a temporary location
+ tmp_dir = os.path.join(cachedir, "tmp_download")
+ archive_file = os.path.join(tmp_dir, "archive_file")
+ os.makedirs(document_dir, exist_ok=True)
+ os.makedirs(tmp_dir, exist_ok=True)
+ logger.info("downloading index for missing collection %s to temporary file %s", self.module_name, archive_file)
+ download_file(url, archive_file, expected_hash=sha256)
+
+ logger.info("extracting index to %s (before moving to correct cache path)", tmp_dir)
+ with tarfile.open(archive_file) as tar:
+ tar.extractall(path=tmp_dir)
+
+ extracted_dir = os.path.join(tmp_dir, index_directory_inside)
+ if not (os.path.exists(extracted_dir) and os.path.isdir(extracted_dir)):
+ raise ValueError(f"could not find expected index directory {extracted_dir} in {tmp_dir}")
+
+ # 2. Move index to its correct location in the cache
+ index_dir = os.path.join(cachedir, index_cache_path_string, "index")
+ if not os.path.exists(os.path.join("index_dir", "done")):
+ if os.path.exists(index_dir):
+ shutil.rmtree(index_dir)
+ shutil.move(extracted_dir, index_dir)
+
+ # 3. Extract raw documents from the Anserini index to document_dir
+ anserini_index_to_trec_docs(index_dir, document_dir, index_expected_document_count)
+
+ # remove temporary files and create a /done we can use to verify extraction was successful
+ shutil.rmtree(tmp_dir)
+ with open(done_file, "wt") as outf:
+ print("", file=outf)
+
+ return document_dir
diff --git a/capreolus/extractor/__init__.py b/capreolus/extractor/__init__.py
index 5590bf2ea..0f05df178 100644
--- a/capreolus/extractor/__init__.py
+++ b/capreolus/extractor/__init__.py
@@ -1,27 +1,18 @@
-from profane import import_all_modules
-
-# import_all_modules(__file__, __package__)
-
-import pickle
-from collections import defaultdict
-import tensorflow as tf
-
-import os
-import numpy as np
import hashlib
-from pymagnitude import Magnitude, MagnitudeUtils
-from tqdm import tqdm
-from profane import ModuleBase, Dependency, ConfigOption, constants
-
+import os
-from capreolus.utils.loginit import get_logger
-from capreolus.utils.common import padlist
-from capreolus.utils.exceptions import MissingDocError
+from capreolus import ModuleBase, get_logger
logger = get_logger(__name__)
class Extractor(ModuleBase):
+ """Base class for Extractor modules. The purpose of an Extractor is to convert queries and documents to a representation suitable for use with a :class:`~capreolus.reranker.Reranker` module.
+
+ Modules should provide:
+ - an ``id2vec(qid, posid, negid=None)`` method that converts the given query and document ids to an appropriate representation
+ """
+
module_type = "extractor"
def _extend_stoi(self, toks_list, calc_idf=False):
@@ -75,349 +66,7 @@ def build_from_benchmark(self, *args, **kwargs):
raise NotImplementedError
-@Extractor.register
-class EmbedText(Extractor):
- module_name = "embedtext"
- requires_random_seed = True
- dependencies = [
- Dependency(
- key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"}
- ),
- Dependency(key="tokenizer", module="tokenizer", name="anserini"),
- ]
- config_spec = [
- ConfigOption("embeddings", "glove6b"),
- ConfigOption("zerounk", False),
- ConfigOption("calcidf", True),
- ConfigOption("maxqlen", 4),
- ConfigOption("maxdoclen", 800),
- ConfigOption("usecache", False),
- ]
-
- pad = 0
- pad_tok = ""
- embed_paths = {
- "glove6b": "glove/light/glove.6B.300d",
- "glove6b.50d": "glove/light/glove.6B.50d",
- "w2vnews": "word2vec/light/GoogleNews-vectors-negative300",
- "fasttext": "fasttext/light/wiki-news-300d-1M-subword",
- }
-
- def _get_pretrained_emb(self):
- magnitude_cache = constants["CACHE_BASE_PATH"] / "magnitude/"
- return Magnitude(MagnitudeUtils.download_model(self.embed_paths[self.config["embeddings"]], download_dir=magnitude_cache))
-
- def load_state(self, qids, docids):
- with open(self.get_state_cache_file_path(qids, docids), "rb") as f:
- state_dict = pickle.load(f)
- self.qid2toks = state_dict["qid2toks"]
- self.docid2toks = state_dict["docid2toks"]
- self.stoi = state_dict["stoi"]
- self.itos = state_dict["itos"]
-
- def cache_state(self, qids, docids):
- os.makedirs(self.get_cache_path(), exist_ok=True)
- with open(self.get_state_cache_file_path(qids, docids), "wb") as f:
- state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "stoi": self.stoi, "itos": self.itos}
- pickle.dump(state_dict, f, protocol=-1)
-
- def get_tf_feature_description(self):
- feature_description = {
- "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64),
- "query_idf": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.float32),
- "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
- "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
- "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)),
- }
-
- return feature_description
-
- def create_tf_feature(self, sample):
- """
- sample - output from self.id2vec()
- return - a tensorflow feature
- """
- query, query_idf, posdoc, negdoc = (sample["query"], sample["query_idf"], sample["posdoc"], sample["negdoc"])
- feature = {
- "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)),
- "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)),
- "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)),
- "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)),
- }
-
- return feature
-
- def parse_tf_example(self, example_proto):
- feature_description = self.get_tf_feature_description()
- parsed_example = tf.io.parse_example(example_proto, feature_description)
- posdoc = parsed_example["posdoc"]
- negdoc = parsed_example["negdoc"]
- query = parsed_example["query"]
- query_idf = parsed_example["query_idf"]
- label = parsed_example["label"]
-
- return (posdoc, negdoc, query, query_idf), label
-
- def _build_vocab(self, qids, docids, topics):
- if self.is_state_cached(qids, docids) and self.config["usecache"]:
- self.load_state(qids, docids)
- logger.info("Vocabulary loaded from cache")
- else:
- tokenize = self.tokenizer.tokenize
- self.qid2toks = {qid: tokenize(topics[qid]) for qid in qids}
- self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in docids}
- self._extend_stoi(self.qid2toks.values(), calc_idf=self.config["calcidf"])
- self._extend_stoi(self.docid2toks.values(), calc_idf=self.config["calcidf"])
- self.itos = {i: s for s, i in self.stoi.items()}
- logger.info(f"vocabulary constructed, with {len(self.itos)} terms in total")
- if self.config["usecache"]:
- self.cache_state(qids, docids)
-
- def _get_idf(self, toks):
- return [self.idf.get(tok, 0) for tok in toks]
-
- def _build_embedding_matrix(self):
- assert len(self.stoi) > 1 # needs more vocab than self.pad_tok
-
- magnitude_emb = self._get_pretrained_emb()
- emb_dim = magnitude_emb.dim
- embed_vocab = set(term for term, _ in magnitude_emb)
- embed_matrix = np.zeros((len(self.stoi), emb_dim), dtype=np.float32)
-
- n_missed = 0
- for term, idx in tqdm(self.stoi.items()):
- if term in embed_vocab:
- embed_matrix[idx] = magnitude_emb.query(term)
- elif term == self.pad_tok:
- embed_matrix[idx] = np.zeros(emb_dim)
- else:
- n_missed += 1
- embed_matrix[idx] = np.zeros(emb_dim) if self.config["zerounk"] else np.random.normal(scale=0.5, size=emb_dim)
-
- logger.info(f"embedding matrix {self.config['embeddings']} constructed, with shape {embed_matrix.shape}")
- if n_missed > 0:
- logger.warning(f"{n_missed}/{len(self.stoi)} (%.3f) term missed" % (n_missed / len(self.stoi)))
-
- self.embeddings = embed_matrix
-
- def exist(self):
- return (
- hasattr(self, "embeddings")
- and self.embeddings is not None
- and isinstance(self.embeddings, np.ndarray)
- and 0 < len(self.stoi) == self.embeddings.shape[0]
- )
-
- def preprocess(self, qids, docids, topics):
- if self.exist():
- return
-
- self.index.create_index()
-
- self.itos = {self.pad: self.pad_tok}
- self.stoi = {self.pad_tok: self.pad}
- self.qid2toks = defaultdict(list)
- self.docid2toks = defaultdict(list)
- self.idf = defaultdict(lambda: 0)
- self.embeddings = None
- # self.cache = self.load_cache() # TODO
-
- self._build_vocab(qids, docids, topics)
- self._build_embedding_matrix()
-
- def _tok2vec(self, toks):
- # return [self.embeddings[self.stoi[tok]] for tok in toks]
- return [self.stoi[tok] for tok in toks]
-
- def id2vec(self, qid, posid, negid=None):
- query = self.qid2toks[qid]
-
- # TODO find a way to calculate qlen/doclen stats earlier, so we can log them and check sanity of our values
- qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"]
- posdoc = self.docid2toks.get(posid, None)
- if not posdoc:
- raise MissingDocError(qid, posid)
-
- idfs = padlist(self._get_idf(query), qlen, 0)
- query = self._tok2vec(padlist(query, qlen, self.pad_tok))
- posdoc = self._tok2vec(padlist(posdoc, doclen, self.pad_tok))
-
- # TODO determine whether pin_memory is happening. may not be because we don't place the strings in a np or torch object
- data = {
- "qid": qid,
- "posdocid": posid,
- "idfs": np.array(idfs, dtype=np.float32),
- "query": np.array(query, dtype=np.long),
- "posdoc": np.array(posdoc, dtype=np.long),
- "query_idf": np.array(idfs, dtype=np.float32),
- "negdocid": "",
- "negdoc": np.zeros(self.config["maxdoclen"], dtype=np.long),
- }
-
- if negid:
- negdoc = self.docid2toks.get(negid, None)
- if not negdoc:
- raise MissingDocError(qid, negid)
-
- negdoc = self._tok2vec(padlist(negdoc, doclen, self.pad_tok))
- data["negdocid"] = negid
- data["negdoc"] = np.array(negdoc, dtype=np.long)
-
- return data
-
-
-@Extractor.register
-class BertText(Extractor):
- module_name = "berttext"
- dependencies = [
- Dependency(
- key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"}
- ),
- Dependency(key="tokenizer", module="tokenizer", name="berttokenizer"),
- ]
- config_spec = [ConfigOption("maxqlen", 4), ConfigOption("maxdoclen", 800), ConfigOption("usecache", False)]
-
- pad = 0
- pad_tok = ""
-
- @staticmethod
- def config():
- maxqlen = 4
- maxdoclen = 800
- usecache = False
-
- def load_state(self, qids, docids):
- with open(self.get_state_cache_file_path(qids, docids), "rb") as f:
- state_dict = pickle.load(f)
- self.qid2toks = state_dict["qid2toks"]
- self.docid2toks = state_dict["docid2toks"]
- self.clsidx = state_dict["clsidx"]
- self.sepidx = state_dict["sepidx"]
-
- def cache_state(self, qids, docids):
- os.makedirs(self.get_cache_path(), exist_ok=True)
- with open(self.get_state_cache_file_path(qids, docids), "wb") as f:
- state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "clsidx": self.clsidx, "sepidx": self.sepidx}
- pickle.dump(state_dict, f, protocol=-1)
-
- def get_tf_feature_description(self):
- feature_description = {
- "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64),
- "query_mask": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64),
- "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
- "posdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
- "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
- "negdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
- "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)),
- }
-
- return feature_description
-
- def create_tf_feature(self, sample):
- """
- sample - output from self.id2vec()
- return - a tensorflow feature
- """
- query, posdoc, negdoc, negdoc_id = sample["query"], sample["posdoc"], sample["negdoc"], sample["negdocid"]
- query_mask, posdoc_mask, negdoc_mask = sample["query_mask"], sample["posdoc_mask"], sample["negdoc_mask"]
-
- feature = {
- "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)),
- "query_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=query_mask)),
- "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)),
- "posdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc_mask)),
- "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)),
- "negdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc_mask)),
- }
-
- return feature
-
- def parse_tf_example(self, example_proto):
- feature_description = self.get_tf_feature_description()
- parsed_example = tf.io.parse_example(example_proto, feature_description)
- posdoc = parsed_example["posdoc"]
- posdoc_mask = parsed_example["posdoc_mask"]
- negdoc = parsed_example["negdoc"]
- negdoc_mask = parsed_example["negdoc_mask"]
- query = parsed_example["query"]
- query_mask = parsed_example["query_mask"]
- label = parsed_example["label"]
-
- return (posdoc, posdoc_mask, negdoc, negdoc_mask, query, query_mask), label
-
- def _build_vocab(self, qids, docids, topics):
- if self.is_state_cached(qids, docids) and self.config["usecache"]:
- self.load_state(qids, docids)
- logger.info("Vocabulary loaded from cache")
- else:
- logger.info("Building bertext vocabulary")
- tokenize = self.tokenizer.tokenize
- self.qid2toks = {qid: tokenize(topics[qid]) for qid in tqdm(qids, desc="querytoks")}
- self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in tqdm(docids, desc="doctoks")}
- self.clsidx, self.sepidx = self.tokenizer.convert_tokens_to_ids(["CLS", "SEP"])
-
- self.cache_state(qids, docids)
-
- def exist(self):
- return hasattr(self, "docid2toks") and len(self.docid2toks)
-
- def preprocess(self, qids, docids, topics):
- if self.exist():
- return
-
- self.index.create_index()
- self.qid2toks = defaultdict(list)
- self.docid2toks = defaultdict(list)
- self.clsidx = None
- self.sepidx = None
-
- self._build_vocab(qids, docids, topics)
-
- def id2vec(self, qid, posid, negid=None):
- tokenizer = self.tokenizer
- qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"]
-
- query_toks = tokenizer.convert_tokens_to_ids(self.qid2toks[qid])
- query_mask = self.get_mask(query_toks, qlen)
- query = padlist(query_toks, qlen)
-
- posdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks[posid])
- posdoc_mask = self.get_mask(posdoc_toks, doclen)
- posdoc = padlist(posdoc_toks, doclen)
-
- data = {
- "qid": qid,
- "posdocid": posid,
- "idfs": np.zeros(qlen, dtype=np.float32),
- "query": np.array(query, dtype=np.long),
- "query_mask": np.array(query_mask, dtype=np.long),
- "posdoc": np.array(posdoc, dtype=np.long),
- "posdoc_mask": np.array(posdoc_mask, dtype=np.long),
- "query_idf": np.array(query, dtype=np.float32),
- "negdocid": "",
- "negdoc": np.zeros(doclen, dtype=np.long),
- "negdoc_mask": np.zeros(doclen, dtype=np.long),
- }
-
- if negid:
- negdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks.get(negid, None))
- negdoc_mask = self.get_mask(negdoc_toks, doclen)
- negdoc = padlist(negdoc_toks, doclen)
-
- if not negdoc:
- raise MissingDocError(qid, negid)
-
- data["negdocid"] = negid
- data["negdoc"] = np.array(negdoc, dtype=np.long)
- data["negdoc_mask"] = np.array(negdoc_mask, dtype=np.long)
+from profane import import_all_modules
- return data
- def get_mask(self, doc, to_len):
- """
- Returns a mask where it is 1 for actual toks and 0 for pad toks
- """
- s = doc[:to_len]
- padlen = to_len - len(s)
- mask = [1 for _ in s] + [0 for _ in range(padlen)]
- return mask
+import_all_modules(__file__, __package__)
diff --git a/capreolus/extractor/bagofwords.py b/capreolus/extractor/bagofwords.py
index 53ebceaf9..94fd8c7df 100644
--- a/capreolus/extractor/bagofwords.py
+++ b/capreolus/extractor/bagofwords.py
@@ -1,9 +1,9 @@
import pickle
import os
import time
-from profane import Dependency, ConfigOption
-from capreolus.extractor import Extractor
+from capreolus import Dependency, ConfigOption
+from . import Extractor
from capreolus.tokenizer import Tokenizer
from capreolus.utils.loginit import get_logger
from tqdm import tqdm
diff --git a/capreolus/extractor/berttext.py b/capreolus/extractor/berttext.py
new file mode 100644
index 000000000..6b464663d
--- /dev/null
+++ b/capreolus/extractor/berttext.py
@@ -0,0 +1,165 @@
+import os
+import pickle
+from collections import defaultdict
+
+import numpy as np
+import tensorflow as tf
+from tqdm import tqdm
+
+from . import Extractor
+from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger
+from capreolus.utils.common import padlist
+from capreolus.utils.exceptions import MissingDocError
+
+logger = get_logger(__name__)
+
+
+@Extractor.register
+class BertText(Extractor):
+ module_name = "berttext"
+ dependencies = [
+ Dependency(
+ key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"}
+ ),
+ Dependency(key="tokenizer", module="tokenizer", name="berttokenizer"),
+ ]
+ config_spec = [ConfigOption("maxqlen", 4), ConfigOption("maxdoclen", 800), ConfigOption("usecache", False)]
+
+ pad = 0
+ pad_tok = ""
+
+ def load_state(self, qids, docids):
+ with open(self.get_state_cache_file_path(qids, docids), "rb") as f:
+ state_dict = pickle.load(f)
+ self.qid2toks = state_dict["qid2toks"]
+ self.docid2toks = state_dict["docid2toks"]
+ self.clsidx = state_dict["clsidx"]
+ self.sepidx = state_dict["sepidx"]
+
+ def cache_state(self, qids, docids):
+ os.makedirs(self.get_cache_path(), exist_ok=True)
+ with open(self.get_state_cache_file_path(qids, docids), "wb") as f:
+ state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "clsidx": self.clsidx, "sepidx": self.sepidx}
+ pickle.dump(state_dict, f, protocol=-1)
+
+ def get_tf_feature_description(self):
+ feature_description = {
+ "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64),
+ "query_mask": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64),
+ "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
+ "posdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
+ "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
+ "negdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
+ "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)),
+ }
+
+ return feature_description
+
+ def create_tf_feature(self, sample):
+ """
+ sample - output from self.id2vec()
+ return - a tensorflow feature
+ """
+ query, posdoc, negdoc, negdoc_id = sample["query"], sample["posdoc"], sample["negdoc"], sample["negdocid"]
+ query_mask, posdoc_mask, negdoc_mask = sample["query_mask"], sample["posdoc_mask"], sample["negdoc_mask"]
+
+ feature = {
+ "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)),
+ "query_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=query_mask)),
+ "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)),
+ "posdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc_mask)),
+ "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)),
+ "negdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc_mask)),
+ }
+
+ return feature
+
+ def parse_tf_example(self, example_proto):
+ feature_description = self.get_tf_feature_description()
+ parsed_example = tf.io.parse_example(example_proto, feature_description)
+ posdoc = parsed_example["posdoc"]
+ posdoc_mask = parsed_example["posdoc_mask"]
+ negdoc = parsed_example["negdoc"]
+ negdoc_mask = parsed_example["negdoc_mask"]
+ query = parsed_example["query"]
+ query_mask = parsed_example["query_mask"]
+ label = parsed_example["label"]
+
+ return (posdoc, posdoc_mask, negdoc, negdoc_mask, query, query_mask), label
+
+ def _build_vocab(self, qids, docids, topics):
+ if self.is_state_cached(qids, docids) and self.config["usecache"]:
+ self.load_state(qids, docids)
+ logger.info("Vocabulary loaded from cache")
+ else:
+ logger.info("Building bertext vocabulary")
+ tokenize = self.tokenizer.tokenize
+ self.qid2toks = {qid: tokenize(topics[qid]) for qid in tqdm(qids, desc="querytoks")}
+ self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in tqdm(docids, desc="doctoks")}
+ self.clsidx, self.sepidx = self.tokenizer.convert_tokens_to_ids(["CLS", "SEP"])
+
+ self.cache_state(qids, docids)
+
+ def exist(self):
+ return hasattr(self, "docid2toks") and len(self.docid2toks)
+
+ def preprocess(self, qids, docids, topics):
+ if self.exist():
+ return
+
+ self.index.create_index()
+ self.qid2toks = defaultdict(list)
+ self.docid2toks = defaultdict(list)
+ self.clsidx = None
+ self.sepidx = None
+
+ self._build_vocab(qids, docids, topics)
+
+ def id2vec(self, qid, posid, negid=None):
+ tokenizer = self.tokenizer
+ qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"]
+
+ query_toks = tokenizer.convert_tokens_to_ids(self.qid2toks[qid])
+ query_mask = self.get_mask(query_toks, qlen)
+ query = padlist(query_toks, qlen)
+
+ posdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks[posid])
+ posdoc_mask = self.get_mask(posdoc_toks, doclen)
+ posdoc = padlist(posdoc_toks, doclen)
+
+ data = {
+ "qid": qid,
+ "posdocid": posid,
+ "idfs": np.zeros(qlen, dtype=np.float32),
+ "query": np.array(query, dtype=np.long),
+ "query_mask": np.array(query_mask, dtype=np.long),
+ "posdoc": np.array(posdoc, dtype=np.long),
+ "posdoc_mask": np.array(posdoc_mask, dtype=np.long),
+ "query_idf": np.array(query, dtype=np.float32),
+ "negdocid": "",
+ "negdoc": np.zeros(doclen, dtype=np.long),
+ "negdoc_mask": np.zeros(doclen, dtype=np.long),
+ }
+
+ if negid:
+ negdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks.get(negid, None))
+ negdoc_mask = self.get_mask(negdoc_toks, doclen)
+ negdoc = padlist(negdoc_toks, doclen)
+
+ if not negdoc:
+ raise MissingDocError(qid, negid)
+
+ data["negdocid"] = negid
+ data["negdoc"] = np.array(negdoc, dtype=np.long)
+ data["negdoc_mask"] = np.array(negdoc_mask, dtype=np.long)
+
+ return data
+
+ def get_mask(self, doc, to_len):
+ """
+ Returns a mask where it is 1 for actual toks and 0 for pad toks
+ """
+ s = doc[:to_len]
+ padlen = to_len - len(s)
+ mask = [1 for _ in s] + [0 for _ in range(padlen)]
+ return mask
diff --git a/capreolus/extractor/deeptileextractor.py b/capreolus/extractor/deeptileextractor.py
index 0eb0ce469..a947eb576 100644
--- a/capreolus/extractor/deeptileextractor.py
+++ b/capreolus/extractor/deeptileextractor.py
@@ -13,7 +13,7 @@
from tqdm import tqdm
from profane import ConfigOption, Dependency, constants
-from capreolus.extractor import Extractor
+from . import Extractor
from capreolus.utils.common import padlist
from capreolus.utils.loginit import get_logger
diff --git a/capreolus/extractor/embedtext.py b/capreolus/extractor/embedtext.py
new file mode 100644
index 000000000..0e90198c5
--- /dev/null
+++ b/capreolus/extractor/embedtext.py
@@ -0,0 +1,206 @@
+import os
+import pickle
+from collections import defaultdict
+
+import numpy as np
+import tensorflow as tf
+from pymagnitude import Magnitude, MagnitudeUtils
+from tqdm import tqdm
+
+from . import Extractor
+from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger
+from capreolus.utils.common import padlist
+from capreolus.utils.exceptions import MissingDocError
+
+logger = get_logger(__name__)
+
+
+@Extractor.register
+class EmbedText(Extractor):
+ module_name = "embedtext"
+ requires_random_seed = True
+ dependencies = [
+ Dependency(
+ key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"}
+ ),
+ Dependency(key="tokenizer", module="tokenizer", name="anserini"),
+ ]
+ config_spec = [
+ ConfigOption("embeddings", "glove6b"),
+ ConfigOption("zerounk", False),
+ ConfigOption("calcidf", True),
+ ConfigOption("maxqlen", 4),
+ ConfigOption("maxdoclen", 800),
+ ConfigOption("usecache", False),
+ ]
+
+ pad = 0
+ pad_tok = ""
+ embed_paths = {
+ "glove6b": "glove/light/glove.6B.300d",
+ "glove6b.50d": "glove/light/glove.6B.50d",
+ "w2vnews": "word2vec/light/GoogleNews-vectors-negative300",
+ "fasttext": "fasttext/light/wiki-news-300d-1M-subword",
+ }
+
+ def _get_pretrained_emb(self):
+ magnitude_cache = constants["CACHE_BASE_PATH"] / "magnitude/"
+ return Magnitude(MagnitudeUtils.download_model(self.embed_paths[self.config["embeddings"]], download_dir=magnitude_cache))
+
+ def load_state(self, qids, docids):
+ with open(self.get_state_cache_file_path(qids, docids), "rb") as f:
+ state_dict = pickle.load(f)
+ self.qid2toks = state_dict["qid2toks"]
+ self.docid2toks = state_dict["docid2toks"]
+ self.stoi = state_dict["stoi"]
+ self.itos = state_dict["itos"]
+
+ def cache_state(self, qids, docids):
+ os.makedirs(self.get_cache_path(), exist_ok=True)
+ with open(self.get_state_cache_file_path(qids, docids), "wb") as f:
+ state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "stoi": self.stoi, "itos": self.itos}
+ pickle.dump(state_dict, f, protocol=-1)
+
+ def get_tf_feature_description(self):
+ feature_description = {
+ "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64),
+ "query_idf": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.float32),
+ "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
+ "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64),
+ "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)),
+ }
+
+ return feature_description
+
+ def create_tf_feature(self, sample):
+ """
+ sample - output from self.id2vec()
+ return - a tensorflow feature
+ """
+ query, query_idf, posdoc, negdoc = (sample["query"], sample["query_idf"], sample["posdoc"], sample["negdoc"])
+ feature = {
+ "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)),
+ "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)),
+ "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)),
+ "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)),
+ }
+
+ return feature
+
+ def parse_tf_example(self, example_proto):
+ feature_description = self.get_tf_feature_description()
+ parsed_example = tf.io.parse_example(example_proto, feature_description)
+ posdoc = parsed_example["posdoc"]
+ negdoc = parsed_example["negdoc"]
+ query = parsed_example["query"]
+ query_idf = parsed_example["query_idf"]
+ label = parsed_example["label"]
+
+ return (posdoc, negdoc, query, query_idf), label
+
+ def _build_vocab(self, qids, docids, topics):
+ if self.is_state_cached(qids, docids) and self.config["usecache"]:
+ self.load_state(qids, docids)
+ logger.info("Vocabulary loaded from cache")
+ else:
+ tokenize = self.tokenizer.tokenize
+ self.qid2toks = {qid: tokenize(topics[qid]) for qid in qids}
+ self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in docids}
+ self._extend_stoi(self.qid2toks.values(), calc_idf=self.config["calcidf"])
+ self._extend_stoi(self.docid2toks.values(), calc_idf=self.config["calcidf"])
+ self.itos = {i: s for s, i in self.stoi.items()}
+ logger.info(f"vocabulary constructed, with {len(self.itos)} terms in total")
+ if self.config["usecache"]:
+ self.cache_state(qids, docids)
+
+ def _get_idf(self, toks):
+ return [self.idf.get(tok, 0) for tok in toks]
+
+ def _build_embedding_matrix(self):
+ assert len(self.stoi) > 1 # needs more vocab than self.pad_tok
+
+ magnitude_emb = self._get_pretrained_emb()
+ emb_dim = magnitude_emb.dim
+ embed_vocab = set(term for term, _ in magnitude_emb)
+ embed_matrix = np.zeros((len(self.stoi), emb_dim), dtype=np.float32)
+
+ n_missed = 0
+ for term, idx in tqdm(self.stoi.items()):
+ if term in embed_vocab:
+ embed_matrix[idx] = magnitude_emb.query(term)
+ elif term == self.pad_tok:
+ embed_matrix[idx] = np.zeros(emb_dim)
+ else:
+ n_missed += 1
+ embed_matrix[idx] = np.zeros(emb_dim) if self.config["zerounk"] else np.random.normal(scale=0.5, size=emb_dim)
+
+ logger.info(f"embedding matrix {self.config['embeddings']} constructed, with shape {embed_matrix.shape}")
+ if n_missed > 0:
+ logger.warning(f"{n_missed}/{len(self.stoi)} (%.3f) term missed" % (n_missed / len(self.stoi)))
+
+ self.embeddings = embed_matrix
+
+ def exist(self):
+ return (
+ hasattr(self, "embeddings")
+ and self.embeddings is not None
+ and isinstance(self.embeddings, np.ndarray)
+ and 0 < len(self.stoi) == self.embeddings.shape[0]
+ )
+
+ def preprocess(self, qids, docids, topics):
+ if self.exist():
+ return
+
+ self.index.create_index()
+
+ self.itos = {self.pad: self.pad_tok}
+ self.stoi = {self.pad_tok: self.pad}
+ self.qid2toks = defaultdict(list)
+ self.docid2toks = defaultdict(list)
+ self.idf = defaultdict(lambda: 0)
+ self.embeddings = None
+ # self.cache = self.load_cache() # TODO
+
+ self._build_vocab(qids, docids, topics)
+ self._build_embedding_matrix()
+
+ def _tok2vec(self, toks):
+ # return [self.embeddings[self.stoi[tok]] for tok in toks]
+ return [self.stoi[tok] for tok in toks]
+
+ def id2vec(self, qid, posid, negid=None):
+ query = self.qid2toks[qid]
+
+ # TODO find a way to calculate qlen/doclen stats earlier, so we can log them and check sanity of our values
+ qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"]
+ posdoc = self.docid2toks.get(posid, None)
+ if not posdoc:
+ raise MissingDocError(qid, posid)
+
+ idfs = padlist(self._get_idf(query), qlen, 0)
+ query = self._tok2vec(padlist(query, qlen, self.pad_tok))
+ posdoc = self._tok2vec(padlist(posdoc, doclen, self.pad_tok))
+
+ # TODO determine whether pin_memory is happening. may not be because we don't place the strings in a np or torch object
+ data = {
+ "qid": qid,
+ "posdocid": posid,
+ "idfs": np.array(idfs, dtype=np.float32),
+ "query": np.array(query, dtype=np.long),
+ "posdoc": np.array(posdoc, dtype=np.long),
+ "query_idf": np.array(idfs, dtype=np.float32),
+ "negdocid": "",
+ "negdoc": np.zeros(self.config["maxdoclen"], dtype=np.long),
+ }
+
+ if negid:
+ negdoc = self.docid2toks.get(negid, None)
+ if not negdoc:
+ raise MissingDocError(qid, negid)
+
+ negdoc = self._tok2vec(padlist(negdoc, doclen, self.pad_tok))
+ data["negdocid"] = negid
+ data["negdoc"] = np.array(negdoc, dtype=np.long)
+
+ return data
diff --git a/capreolus/index/__init__.py b/capreolus/index/__init__.py
index 8ed070c6b..be0ca3df5 100644
--- a/capreolus/index/__init__.py
+++ b/capreolus/index/__init__.py
@@ -1,22 +1,18 @@
-from profane import import_all_modules
-
-# import_all_modules(__file__, __package__)
-
-import logging
-import math
-import os
-import subprocess
-
-from profane import ModuleBase, Dependency, ConfigOption, constants
+from capreolus import ModuleBase, Dependency, ConfigOption, get_logger
-from capreolus.utils.common import Anserini
-from capreolus.utils.loginit import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
-MAX_THREADS = constants["MAX_THREADS"]
class Index(ModuleBase):
+ """Base class for Index modules. The purpose of an Index module is to represent an inverted index that can be queried with a :class:`~capreolus.searcher.Searcher` module and used to obtain documents and collection statistics.
+
+ Modules should provide:
+ - a ``_create_index`` method that creates an index on the ``Collection`` dependency
+ - a ``get_doc(docid)`` and a ``get_docs(docid)`` method
+ - a ``get_df(term)`` method
+ """
+
module_type = "index"
dependencies = [Dependency(key="collection", module="collection")]
@@ -46,82 +42,8 @@ def get_docs(self, doc_ids):
raise NotImplementedError()
-@Index.register
-class AnseriniIndex(Index):
- module_name = "anserini"
- config_spec = [
- ConfigOption("indexstops", False, "should stopwords be indexed? (if False, stopwords are removed)"),
- ConfigOption("stemmer", "porter", "stemmer: porter, krovetz, or none"),
- ]
-
- def _create_index(self):
- outdir = self.get_index_path()
- stops = "-keepStopwords" if self.config["indexstops"] else ""
- stemmer = "none" if self.config["stemmer"] is None else self.config["stemmer"]
-
- collection_path, document_type, generator_type = self.collection.get_path_and_types()
-
- anserini_fat_jar = Anserini.get_fat_jar()
- if self.collection.is_large_collection:
- cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -stemmer {stemmer} {stops}"
- else:
- cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -storePositions -storeDocvectors -storeContents -stemmer {stemmer} {stops}"
-
- logger.info("building index %s", outdir)
- logger.debug(cmd)
- os.makedirs(os.path.basename(outdir), exist_ok=True)
-
- app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True)
-
- # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger
- for line in app.stdout:
- Anserini.filter_and_log_anserini_output(line, logger)
+from profane import import_all_modules
- app.wait()
- if app.returncode != 0:
- raise RuntimeError("command failed")
+from .anserini import AnseriniIndex
- def get_docs(self, doc_ids):
- # if self.collection.is_large_collection:
- # return self.get_documents_from_disk(doc_ids)
- return [self.get_doc(doc_id) for doc_id in doc_ids]
-
- def get_doc(self, docid):
- try:
- if not hasattr(self, "index_utils") or self.index_utils is None:
- self.open()
- return self.index_reader_utils.documentContents(self.reader, self.JString(docid))
- except Exception as e:
- raise
-
- def get_df(self, term):
- # returns 0 for missing terms
- if not hasattr(self, "reader") or self.reader is None:
- self.open()
- jterm = self.JTerm("contents", term)
- return self.reader.docFreq(jterm)
-
- def get_idf(self, term):
- """ BM25's IDF with a floor of 0 """
- df = self.get_df(term)
- idf = (self.numdocs - df + 0.5) / (df + 0.5)
- idf = math.log(1 + idf)
- return max(idf, 0)
-
- def open(self):
- from jnius import autoclass
-
- index_path = self.get_index_path().as_posix()
-
- JIndexUtils = autoclass("io.anserini.index.IndexUtils")
- JIndexReaderUtils = autoclass("io.anserini.index.IndexReaderUtils")
- self.index_utils = JIndexUtils(index_path)
- self.index_reader_utils = JIndexReaderUtils()
-
- JFile = autoclass("java.io.File")
- JFSDirectory = autoclass("org.apache.lucene.store.FSDirectory")
- fsdir = JFSDirectory.open(JFile(index_path).toPath())
- self.reader = autoclass("org.apache.lucene.index.DirectoryReader").open(fsdir)
- self.numdocs = self.reader.numDocs()
- self.JTerm = autoclass("org.apache.lucene.index.Term")
- self.JString = autoclass("java.lang.String")
+import_all_modules(__file__, __package__)
diff --git a/capreolus/index/anserini.py b/capreolus/index/anserini.py
new file mode 100644
index 000000000..ca3e3db39
--- /dev/null
+++ b/capreolus/index/anserini.py
@@ -0,0 +1,91 @@
+import math
+import os
+import subprocess
+
+from . import Index
+from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger
+from capreolus.utils.common import Anserini
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+MAX_THREADS = constants["MAX_THREADS"]
+
+
+@Index.register
+class AnseriniIndex(Index):
+ module_name = "anserini"
+ config_spec = [
+ ConfigOption("indexstops", False, "should stopwords be indexed? (if False, stopwords are removed)"),
+ ConfigOption("stemmer", "porter", "stemmer: porter, krovetz, or none"),
+ ]
+
+ def _create_index(self):
+ outdir = self.get_index_path()
+ stops = "-keepStopwords" if self.config["indexstops"] else ""
+ stemmer = "none" if self.config["stemmer"] is None else self.config["stemmer"]
+
+ collection_path, document_type, generator_type = self.collection.get_path_and_types()
+
+ anserini_fat_jar = Anserini.get_fat_jar()
+ if self.collection.is_large_collection:
+ cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -stemmer {stemmer} {stops}"
+ else:
+ cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -storePositions -storeDocvectors -storeContents -stemmer {stemmer} {stops}"
+
+ logger.info("building index %s", outdir)
+ logger.debug(cmd)
+ os.makedirs(os.path.basename(outdir), exist_ok=True)
+
+ app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True)
+
+ # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger
+ for line in app.stdout:
+ Anserini.filter_and_log_anserini_output(line, logger)
+
+ app.wait()
+ if app.returncode != 0:
+ raise RuntimeError("command failed")
+
+ def get_docs(self, doc_ids):
+ # if self.collection.is_large_collection:
+ # return self.get_documents_from_disk(doc_ids)
+ return [self.get_doc(doc_id) for doc_id in doc_ids]
+
+ def get_doc(self, docid):
+ try:
+ if not hasattr(self, "index_utils") or self.index_utils is None:
+ self.open()
+ return self.index_reader_utils.documentContents(self.reader, self.JString(docid))
+ except Exception as e:
+ raise
+
+ def get_df(self, term):
+ # returns 0 for missing terms
+ if not hasattr(self, "reader") or self.reader is None:
+ self.open()
+ jterm = self.JTerm("contents", term)
+ return self.reader.docFreq(jterm)
+
+ def get_idf(self, term):
+ """ BM25's IDF with a floor of 0 """
+ df = self.get_df(term)
+ idf = (self.numdocs - df + 0.5) / (df + 0.5)
+ idf = math.log(1 + idf)
+ return max(idf, 0)
+
+ def open(self):
+ from jnius import autoclass
+
+ index_path = self.get_index_path().as_posix()
+
+ JIndexUtils = autoclass("io.anserini.index.IndexUtils")
+ JIndexReaderUtils = autoclass("io.anserini.index.IndexReaderUtils")
+ self.index_utils = JIndexUtils(index_path)
+ self.index_reader_utils = JIndexReaderUtils()
+
+ JFile = autoclass("java.io.File")
+ JFSDirectory = autoclass("org.apache.lucene.store.FSDirectory")
+ fsdir = JFSDirectory.open(JFile(index_path).toPath())
+ self.reader = autoclass("org.apache.lucene.index.DirectoryReader").open(fsdir)
+ self.numdocs = self.reader.numDocs()
+ self.JTerm = autoclass("org.apache.lucene.index.Term")
+ self.JString = autoclass("java.lang.String")
diff --git a/capreolus/index/tests/test_index.py b/capreolus/index/tests/test_index.py
index 7228d3c12..d42dd66ef 100644
--- a/capreolus/index/tests/test_index.py
+++ b/capreolus/index/tests/test_index.py
@@ -1,13 +1,18 @@
import pytest
+from capreolus import module_registry
from capreolus.collection import Collection, DummyCollection
-from capreolus.index import Index
-from capreolus.index import AnseriniIndex
+from capreolus.index import Index, AnseriniIndex
from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index
-def test_anserini_create_index(tmpdir_as_cache):
- index = AnseriniIndex({"name": "anserini", "indexstops": False, "stemmer": "porter", "collection": {"name": "dummy"}})
+indexs = set(module_registry.get_module_names("index"))
+
+
+@pytest.mark.parametrize("index_name", indexs)
+def test_create_index(tmpdir_as_cache, index_name):
+ provide = {"collection": DummyCollection()}
+ index = Index.create(index_name, provide=provide)
assert not index.exists()
index.create_index()
assert index.exists()
diff --git a/capreolus/reranker/CDSSM.py b/capreolus/reranker/CDSSM.py
index 33e4cbe75..4e4dcfa5d 100644
--- a/capreolus/reranker/CDSSM.py
+++ b/capreolus/reranker/CDSSM.py
@@ -1,7 +1,7 @@
import torch
-from profane import ConfigOption, Dependency
from torch import nn
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer
@@ -74,7 +74,8 @@ def forward(self, sentence, query):
@Reranker.register
class CDSSM(Reranker):
- description = """Yelong Shen, Xiaodong He, Jianfeng Gao, Li Deng, and Grégoire Mesnil. 2014. A Latent Semantic Model with Convolutional-Pooling Structure for Information Retrieval. In CIKM'14."""
+ """Yelong Shen, Xiaodong He, Jianfeng Gao, Li Deng, and Grégoire Mesnil. 2014. A Latent Semantic Model with Convolutional-Pooling Structure for Information Retrieval. In CIKM'14."""
+
module_name = "CDSSM"
config_spec = [
diff --git a/capreolus/reranker/ConvKNRM.py b/capreolus/reranker/ConvKNRM.py
index b7a7c4968..936022131 100644
--- a/capreolus/reranker/ConvKNRM.py
+++ b/capreolus/reranker/ConvKNRM.py
@@ -1,7 +1,7 @@
import torch
-from profane import ConfigOption, Dependency
from torch import nn
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import RbfKernelBank, SimilarityMatrix, create_emb_layer
from capreolus.utils.loginit import get_logger
@@ -79,8 +79,9 @@ def forward(self, sentence, query_sentence, query_idf):
@Reranker.register
class ConvKNRM(Reranker):
+ """Zhuyun Dai, Chenyan Xiong, Jamie Callan, and Zhiyuan Liu. 2018. Convolutional Neural Networks for Soft-Matching N-Grams in Ad-hoc Search. In WSDM'18."""
+
module_name = "ConvKNRM"
- description = """Zhuyun Dai, Chenyan Xiong, Jamie Callan, and Zhiyuan Liu. 2018. Convolutional Neural Networks for Soft-Matching N-Grams in Ad-hoc Search. In WSDM'18."""
config_spec = [
ConfigOption("gradkernels", True, "backprop through mus and sigmas"),
diff --git a/capreolus/reranker/DRMM.py b/capreolus/reranker/DRMM.py
index a079b54f9..e93f79da3 100644
--- a/capreolus/reranker/DRMM.py
+++ b/capreolus/reranker/DRMM.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from profane import ConfigOption, Dependency
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer
from capreolus.utils.loginit import get_logger
@@ -122,8 +122,9 @@ def forward(self, sentence, query_sentence, query_idf):
@Reranker.register
class DRMM(Reranker):
+ """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16."""
+
module_name = "DRMM"
- description = """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16."""
config_spec = [
ConfigOption("nbins", 29, "number of bins in matching histogram"),
diff --git a/capreolus/reranker/DRMMTKS.py b/capreolus/reranker/DRMMTKS.py
index e9bf70517..eddbe7e74 100644
--- a/capreolus/reranker/DRMMTKS.py
+++ b/capreolus/reranker/DRMMTKS.py
@@ -1,8 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from profane import ConfigOption, Dependency
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer
from capreolus.utils.loginit import get_logger
@@ -86,9 +86,10 @@ def forward(self, doc, query, query_idf):
@Reranker.register
class DRMMTKS(Reranker):
- # refernce: https://github.com/NTMC-Community/MatchZoo-py/blob/master/matchzoo/models/drmmtks.py
+ """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16."""
+
+ # reference: https://github.com/NTMC-Community/MatchZoo-py/blob/master/matchzoo/models/drmmtks.py
module_name = "DRMMTKS"
- description = """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16."""
config_spec = [
ConfigOption("topk", 10, "number of bins in matching histogram"),
diff --git a/capreolus/reranker/DSSM.py b/capreolus/reranker/DSSM.py
index b53aa4f85..60a5be00c 100644
--- a/capreolus/reranker/DSSM.py
+++ b/capreolus/reranker/DSSM.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
-from profane import ConfigOption, Dependency
+from capreolus import ConfigOption, Dependency
from capreolus.extractor.bagofwords import BagOfWords
from capreolus.reranker import Reranker
from capreolus.utils.loginit import get_logger
@@ -46,7 +46,8 @@ def forward(self, sentence, query, query_idf):
@Reranker.register
class DSSM(Reranker):
- description = """Po-Sen Huang, Xiaodong He, Jianfeng Gao, Li Deng, Alex Acero, and Larry Heck. 2013. Learning deep structured semantic models for web search using clickthrough data. In CIKM'13."""
+ """Po-Sen Huang, Xiaodong He, Jianfeng Gao, Li Deng, Alex Acero, and Larry Heck. 2013. Learning deep structured semantic models for web search using clickthrough data. In CIKM'13."""
+
module_name = "DSSM"
dependencies = [
Dependency(key="extractor", module="extractor", name="bagofwords"),
diff --git a/capreolus/reranker/DUET.py b/capreolus/reranker/DUET.py
index 0a5f4ef4e..95103f221 100644
--- a/capreolus/reranker/DUET.py
+++ b/capreolus/reranker/DUET.py
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
-from profane import ConfigOption, Dependency
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer
from capreolus.utils.loginit import get_logger
@@ -131,8 +131,9 @@ def forward(self, documents, queries, query_idf):
@Reranker.register
class DUET(Reranker):
+ """Bhaskar Mitra, Fernando Diaz, and Nick Craswell. 2017. Learning to Match using Local and Distributed Representations of Text for Web Search. In WWW'17."""
+
module_name = "DUET"
- description = """Bhaskar Mitra, Fernando Diaz, and Nick Craswell. 2017. Learning to Match using Local and Distributed Representations of Text for Web Search. In WWW'17."""
config_spec = [
ConfigOption("nfilter", 10, "number of filters for both local and distrbuted model"),
diff --git a/capreolus/reranker/DeepTileBar.py b/capreolus/reranker/DeepTileBar.py
index 3fd141b57..af49aed54 100644
--- a/capreolus/reranker/DeepTileBar.py
+++ b/capreolus/reranker/DeepTileBar.py
@@ -2,10 +2,10 @@
import torch
import torch.nn.functional as F
-from profane import ConfigOption, Dependency
from torch import nn
from torch.autograd import Variable
+from capreolus import ConfigOption, Dependency
from capreolus.extractor.deeptileextractor import DeepTileExtractor
from capreolus.reranker import Reranker
from capreolus.utils.loginit import get_logger
@@ -172,7 +172,8 @@ def test_forward(self, pos_tile_matrix):
@Reranker.register
class DeepTileBar(Reranker):
- description = """Zhiwen Tang and Grace Hui Yang. 2019. DeepTileBars: Visualizing Term Distribution for Neural Information Retrieval. In AAAI'19."""
+ """Zhiwen Tang and Grace Hui Yang. 2019. DeepTileBars: Visualizing Term Distribution for Neural Information Retrieval. In AAAI'19."""
+
module_name = "DeepTileBar"
dependencies = [
diff --git a/capreolus/reranker/HINT.py b/capreolus/reranker/HINT.py
index 888f36a3d..9e568fa28 100644
--- a/capreolus/reranker/HINT.py
+++ b/capreolus/reranker/HINT.py
@@ -2,10 +2,10 @@
import torch
import torch.nn.functional as F
-from profane import ConfigOption, Dependency
from torch import nn
from torch.autograd import Variable
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer
@@ -323,8 +323,9 @@ def test_forward(self, query_sentence, query_idf, pos_sentence):
@Reranker.register
class HINT(Reranker):
+ """Yixing Fan, Jiafeng Guo, Yanyan Lan, Jun Xu, Chengxiang Zhai, and Xueqi Cheng. 2018. Modeling Diverse Relevance Patterns in Ad-hoc Retrieval. In SIGIR'18."""
+
module_name = "HINT"
- description = """Yixing Fan, Jiafeng Guo, Yanyan Lan, Jun Xu, Chengxiang Zhai, and Xueqi Cheng. 2018. Modeling Diverse Relevance Patterns in Ad-hoc Retrieval. In SIGIR'18."""
config_spec = [ConfigOption("spatialGRU", 2), ConfigOption("LSTMdim", 6), ConfigOption("kmax", 10)]
diff --git a/capreolus/reranker/KNRM.py b/capreolus/reranker/KNRM.py
index 29924284d..153ac0d8a 100644
--- a/capreolus/reranker/KNRM.py
+++ b/capreolus/reranker/KNRM.py
@@ -1,8 +1,8 @@
import matplotlib.pyplot as plt
import torch
-from profane import ConfigOption, Dependency
from torch import nn
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import RbfKernelBank, SimilarityMatrix, create_emb_layer
from capreolus.utils.loginit import get_logger
@@ -60,9 +60,9 @@ def forward(self, doctoks, querytoks, query_idf):
@Reranker.register
class KNRM(Reranker):
+ """Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017. End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17."""
+
module_name = "KNRM"
- description = """Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017.
- End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17."""
config_spec = [
ConfigOption("gradkernels", True, "backprop through mus and sigmas"),
diff --git a/capreolus/reranker/PACRR.py b/capreolus/reranker/PACRR.py
index ad8d97d74..3ee47f601 100644
--- a/capreolus/reranker/PACRR.py
+++ b/capreolus/reranker/PACRR.py
@@ -1,8 +1,8 @@
import torch
-from profane import ConfigOption, Dependency
from torch import nn
from torch.nn import functional as F
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
# TODO add shuffle, cascade, disambig?
@@ -85,9 +85,9 @@ def forward(self, simmat):
@Reranker.register
class PACRR(Reranker):
+ """Kai Hui, Andrew Yates, Klaus Berberich, and Gerard de Melo. 2017. PACRR: A Position-Aware Neural IR Model for Relevance Matching. EMNLP 2017. """
+
module_name = "PACRR"
- description = """Kai Hui, Andrew Yates, Klaus Berberich, and Gerard de Melo. EMNLP 2017.
- PACRR: A Position-Aware Neural IR Model for Relevance Matching. """
config_spec = [
ConfigOption("mingram", 1, "minimum length of ngram used"),
diff --git a/capreolus/reranker/POSITDRMM.py b/capreolus/reranker/POSITDRMM.py
index 1359eeba0..6839fe550 100644
--- a/capreolus/reranker/POSITDRMM.py
+++ b/capreolus/reranker/POSITDRMM.py
@@ -1,9 +1,9 @@
import torch
import torch.nn.functional as F
-from profane import ConfigOption, Dependency
from torch import nn
from torch.autograd import Variable
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer
from capreolus.utils.loginit import get_logger
@@ -127,8 +127,8 @@ def test_forward(self, query_sentence, query_idf, pos_sentence, extras):
@Reranker.register
class POSITDRMM(Reranker):
- description = """Ryan McDonald, George Brokos, and Ion Androutsopoulos. 2018.
- Deep Relevance Ranking Using Enhanced Document-Query Interactions. In EMNLP'18."""
+ """Ryan McDonald, George Brokos, and Ion Androutsopoulos. 2018. Deep Relevance Ranking Using Enhanced Document-Query Interactions. In EMNLP'18."""
+
module_name = "POSITDRMM"
def build_model(self):
diff --git a/capreolus/reranker/TFKNRM.py b/capreolus/reranker/TFKNRM.py
index 71e392f01..8fe08050f 100644
--- a/capreolus/reranker/TFKNRM.py
+++ b/capreolus/reranker/TFKNRM.py
@@ -1,6 +1,6 @@
import tensorflow as tf
-from profane import ConfigOption, Dependency
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import RbfKernelBankTF, similarity_matrix_tf
@@ -57,7 +57,13 @@ def call(self, x, **kwargs):
@Reranker.register
class TFKNRM(Reranker):
+ """TensorFlow implementation of KNRM.
+
+ Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017. End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17.
+ """
+
module_name = "TFKNRM"
+
dependencies = [
Dependency(key="extractor", module="extractor", name="embedtext"),
Dependency(key="trainer", module="trainer", name="tensorflow"),
diff --git a/capreolus/reranker/TFVanillaBert.py b/capreolus/reranker/TFVanillaBert.py
index 2eea4b194..c67412bd6 100644
--- a/capreolus/reranker/TFVanillaBert.py
+++ b/capreolus/reranker/TFVanillaBert.py
@@ -1,7 +1,7 @@
import tensorflow as tf
-from profane import ConfigOption, Dependency
from transformers import TFBertForSequenceClassification
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.utils.loginit import get_logger
@@ -53,7 +53,10 @@ def call(self, x, **kwargs):
@Reranker.register
class TFVanillaBERT(Reranker):
+ """TensorFlow implementation of Vanilla BERT."""
+
module_name = "TFVanillaBERT"
+
dependencies = [
Dependency(key="extractor", module="extractor", name="berttext"),
Dependency(key="trainer", module="trainer", name="tensorflow"),
diff --git a/capreolus/reranker/TK.py b/capreolus/reranker/TK.py
index 16f5c1da0..f0476022e 100644
--- a/capreolus/reranker/TK.py
+++ b/capreolus/reranker/TK.py
@@ -1,11 +1,11 @@
import math
import torch
-from profane import ConfigOption, Dependency
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
-
from allennlp.modules.matrix_attention import CosineMatrixAttention
+
+from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import SimilarityMatrix, create_emb_layer
from capreolus.utils.loginit import get_logger
@@ -146,9 +146,9 @@ def forward(self, doctoks, querytoks, query_idf):
@Reranker.register
class TK(Reranker):
+ """Sebastian Hofstätter, Markus Zlabinger, and Allan Hanbury. 2019. TU Wien @ TREC Deep Learning '19 -- Simple Contextualization for Re-ranking. In TREC '19."""
+
module_name = "TK"
- description = """Sebastian Hofstätter, Markus Zlabinger, and Allan Hanbury. 2019.
- TU Wien @ TREC Deep Learning '19 -- Simple Contextualization for Re-ranking. In TREC '19."""
config_spec = [
ConfigOption("gradkernels", True, "backprop through mus and sigmas"),
diff --git a/capreolus/reranker/__init__.py b/capreolus/reranker/__init__.py
index 77645a0dc..9248dcc3f 100644
--- a/capreolus/reranker/__init__.py
+++ b/capreolus/reranker/__init__.py
@@ -1,5 +1,61 @@
+import os
+import pickle
+
+from capreolus import ConfigOption, Dependency, ModuleBase
+
+
+class Reranker(ModuleBase):
+ """Base class for Reranker modules. The purpose of a Reranker is to predict relevance scores for input documents. Rerankers are generally supervised methods implemented in PyTorch or TensorFlow.
+
+ Modules should provide:
+ - a ``build_model`` method that initializes the model used
+ - a ``score`` and a ``test`` method that take a representation created by an :class:`~capreolus.extractor.Extractor` module as input and return document scores
+ - a ``load_weights`` and a ``save_weights`` method, if the base class' PyTorch methods cannot be used
+ """
+
+ module_type = "reranker"
+ dependencies = [
+ Dependency(key="extractor", module="extractor", name="embedtext"),
+ Dependency(key="trainer", module="trainer", name="pytorch"),
+ ]
+
+ def add_summary(self, summary_writer, niter):
+ """
+ Write to the summay_writer custom visualizations/data specific to this reranker
+ """
+ for name, weight in self.model.named_parameters():
+ summary_writer.add_histogram(name, weight.data.cpu(), niter)
+ # summary_writer.add_histogram(f'{name}.grad', weight.grad, niter)
+
+ def save_weights(self, weights_fn, optimizer):
+ if not os.path.exists(os.path.dirname(weights_fn)):
+ os.makedirs(os.path.dirname(weights_fn))
+
+ d = {k: v for k, v in self.model.state_dict().items() if ("embedding.weight" not in k and "_nosave_" not in k)}
+ with open(weights_fn, "wb") as outf:
+ pickle.dump(d, outf, protocol=-1)
+
+ optimizer_fn = weights_fn.as_posix() + ".optimizer"
+ with open(optimizer_fn, "wb") as outf:
+ pickle.dump(optimizer.state_dict(), outf, protocol=-1)
+
+ def load_weights(self, weights_fn, optimizer):
+ with open(weights_fn, "rb") as f:
+ d = pickle.load(f)
+
+ cur_keys = set(k for k in self.model.state_dict().keys() if not ("embedding.weight" in k or "_nosave_" in k))
+ missing = cur_keys - set(d.keys())
+ if len(missing) > 0:
+ raise RuntimeError("loading state_dict with keys that do not match current model: %s" % missing)
+
+ self.model.load_state_dict(d, strict=False)
+
+ optimizer_fn = weights_fn.as_posix() + ".optimizer"
+ with open(optimizer_fn, "rb") as f:
+ optimizer.load_state_dict(pickle.load(f))
+
+
from profane import import_all_modules
-from .base import Reranker
import_all_modules(__file__, __package__)
diff --git a/capreolus/reranker/base.py b/capreolus/reranker/base.py
deleted file mode 100644
index 832714e1e..000000000
--- a/capreolus/reranker/base.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import os
-import pickle
-
-from profane import ConfigOption, Dependency, ModuleBase
-
-
-class Reranker(ModuleBase):
- module_type = "reranker"
- dependencies = [
- Dependency(key="extractor", module="extractor", name="embedtext"),
- Dependency(key="trainer", module="trainer", name="pytorch"),
- ]
-
- def add_summary(self, summary_writer, niter):
- """
- Write to the summay_writer custom visualizations/data specific to this reranker
- """
- for name, weight in self.model.named_parameters():
- summary_writer.add_histogram(name, weight.data.cpu(), niter)
- # summary_writer.add_histogram(f'{name}.grad', weight.grad, niter)
-
- def save_weights(self, weights_fn, optimizer):
- if not os.path.exists(os.path.dirname(weights_fn)):
- os.makedirs(os.path.dirname(weights_fn))
-
- d = {k: v for k, v in self.model.state_dict().items() if ("embedding.weight" not in k and "_nosave_" not in k)}
- with open(weights_fn, "wb") as outf:
- pickle.dump(d, outf, protocol=-1)
-
- optimizer_fn = weights_fn.as_posix() + ".optimizer"
- with open(optimizer_fn, "wb") as outf:
- pickle.dump(optimizer.state_dict(), outf, protocol=-1)
-
- def load_weights(self, weights_fn, optimizer):
- with open(weights_fn, "rb") as f:
- d = pickle.load(f)
-
- cur_keys = set(k for k in self.model.state_dict().keys() if not ("embedding.weight" in k or "_nosave_" in k))
- missing = cur_keys - set(d.keys())
- if len(missing) > 0:
- raise RuntimeError("loading state_dict with keys that do not match current model: %s" % missing)
-
- self.model.load_state_dict(d, strict=False)
-
- optimizer_fn = weights_fn.as_posix() + ".optimizer"
- with open(optimizer_fn, "rb") as f:
- optimizer.load_state_dict(pickle.load(f))
diff --git a/capreolus/reranker/tests/test_rerankers.py b/capreolus/reranker/tests/test_rerankers.py
index b8b0ce41b..682503492 100644
--- a/capreolus/reranker/tests/test_rerankers.py
+++ b/capreolus/reranker/tests/test_rerankers.py
@@ -5,8 +5,9 @@
import torch
from pymagnitude import Magnitude
+from capreolus import Reranker, module_registry
from capreolus.benchmark import DummyBenchmark
-from capreolus.extractor import EmbedText
+from capreolus.extractor.embedtext import EmbedText
from capreolus.extractor.bagofwords import BagOfWords
from capreolus.extractor.deeptileextractor import DeepTileExtractor
from capreolus.reranker.CDSSM import CDSSM
@@ -24,6 +25,15 @@
from capreolus.trainer import PytorchTrainer, TensorFlowTrainer
+rerankers = set(module_registry.get_module_names("reranker"))
+
+
+@pytest.mark.parametrize("reranker_name", rerankers)
+def test_reranker_creatable(tmpdir_as_cache, dummy_index, reranker_name):
+ provide = {"collection": dummy_index.collection, "index": dummy_index}
+ reranker = Reranker.create(reranker_name, provide=provide)
+
+
def test_knrm_pytorch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch):
def fake_magnitude_embedding(*args, **kwargs):
return Magnitude(None)
diff --git a/capreolus/run.py b/capreolus/run.py
index f716cd610..39ef2048c 100644
--- a/capreolus/run.py
+++ b/capreolus/run.py
@@ -39,28 +39,42 @@ def prepare_task(fullcommand, config):
help = """
- Usage:
- run.py COMMAND [(with CONFIG...)] [options]
- run.py help [COMMAND]
- run.py (-h | --help)
+Usage:
+ capreolus COMMAND [(with CONFIG...)] [options]
+ capreolus help [COMMAND]
+ capreolus (-h | --help)
- Options:
- -h --help Print this help message and exit.
- -l VALUE --loglevel=VALUE Set the log level: DEBUG, INFO, WARNING, ERROR, or CRITICAL.
- -p VALUE --priority=VALUE Sets the priority for a queued up experiment. No effect without -q flag.
- -q --queue Only queue this run, do not start it.
+ Options:
+ -h --help Print this help message and exit.
+ -l VALUE --loglevel=VALUE Set the log level: DEBUG, INFO, WARNING, ERROR, or CRITICAL.
+ -p VALUE --priority=VALUE Sets the priority for a queued up experiment. No effect without -q flag.
+ -q --queue Queue this run, and do not start it.
- Arguments:
- COMMAND Name of command to run (see below for list of commands)
- CONFIG Configuration assignments of the form foo.bar=17
+ Arguments:
+ PIPELINE Name of pipeline to run, which consists of a Task and a command (see below for a list)
+ CONFIG Configuration assignments of the form foo.bar=17
- Commands: (TODO expand/generate)
- rank.run ...description here...
- rank.describe ...description here...
- """
+ Tasks and their commands:
+ rank.search search a collection using queries from a benchmark
+ rank.evaluate evaluate the result of rank.search
+ rank.searcheval run rank.search followed by rank.evaluate
+
+ rerank.train run rank.search and train a model to rerank the results
+ rerank.evaluate evaluate the result of rerank.train
+ rerank.traineval run rerank.train followed by rerank.evaluate
+
+ rererank.train run rerank.train and train a (second) model to rerank the results
+ rererank.evaluate evaluate the result of rererank.train
+ rererank.traineval run rererank.train followed by rererank.evaluate
+
+ tutorial.run task from the "Getting Started" tutorial
+
+ All tasks additionally support the following help commands: describe, print_config, print_pipeline
+ e.g., capreolus rank.print_config with searcher=BM25
+"""
if __name__ == "__main__":
# hack to make docopt print full help message if no arguments are give
diff --git a/capreolus/sampler/tests/test_sampler.py b/capreolus/sampler/tests/test_sampler.py
index 93f6b9179..59a159b6c 100644
--- a/capreolus/sampler/tests/test_sampler.py
+++ b/capreolus/sampler/tests/test_sampler.py
@@ -4,7 +4,7 @@
import numpy as np
from capreolus.benchmark import DummyBenchmark
-from capreolus.extractor import EmbedText
+from capreolus.extractor.embedtext import EmbedText
from capreolus.sampler import TrainDataset, PredDataset
from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index
diff --git a/capreolus/searcher/__init__.py b/capreolus/searcher/__init__.py
index c4b1735a2..34da19ef0 100644
--- a/capreolus/searcher/__init__.py
+++ b/capreolus/searcher/__init__.py
@@ -1,17 +1,7 @@
-from profane import import_all_modules
-
-
-# import_all_modules(__file__, __package__)
-
import os
-import math
-import subprocess
from collections import defaultdict, OrderedDict
-import numpy as np
-from profane import ModuleBase, Dependency, ConfigOption, constants
-
-from capreolus.utils.common import Anserini
+from capreolus import ModuleBase, Dependency, ConfigOption, constants
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import topic_to_trectxt
@@ -24,6 +14,15 @@ def list2str(l, delimiter="-"):
class Searcher(ModuleBase):
+ """Base class for Searcher modules. The purpose of a Searcher is to query a collection via an :class:`~capreolus.index.Index` module.
+
+ Similar to Rerankers, Searchers return a list of documents and their relevance scores for a given query.
+ Searchers are unsupervised and efficient, whereas Rerankers are supervised and do not use an inverted index directly.
+
+ Modules should provide:
+ - a ``query(string)`` and a ``query_from_file(path)`` method that return document scores
+ """
+
module_type = "searcher"
@staticmethod
@@ -83,495 +82,8 @@ def query(self, query, **kwargs):
return config2runs["searcher"] if len(config2runs) == 1 else config2runs
-class AnseriniSearcherMixIn:
- """ MixIn for searchers that use Anserini's SearchCollection script """
-
- def _anserini_query_from_file(self, topicsfn, anserini_param_str, output_base_path, topicfield):
- if not os.path.exists(topicsfn):
- raise IOError(f"could not find topics file: {topicsfn}")
-
- # for covid:
- field2querytype = {"query": "title", "question": "description", "narrative": "narrative"}
- for k, v in field2querytype.items():
- topicfield = topicfield.replace(k, v)
-
- donefn = os.path.join(output_base_path, "done")
- if os.path.exists(donefn):
- logger.debug(f"skipping Anserini SearchCollection call because path already exists: {donefn}")
- return
-
- # create index if it does not exist. the call returns immediately if the index does exist.
- self.index.create_index()
-
- os.makedirs(output_base_path, exist_ok=True)
- output_path = os.path.join(output_base_path, "searcher")
-
- # add stemmer and stop options to match underlying index
- indexopts = "-stemmer "
- indexopts += "none" if self.index.config["stemmer"] is None else self.index.config["stemmer"]
- if self.index.config["indexstops"]:
- indexopts += " -keepstopwords"
-
- index_path = self.index.get_index_path()
- anserini_fat_jar = Anserini.get_fat_jar()
- cmd = (
- f"java -classpath {anserini_fat_jar} "
- f"-Xms512M -Xmx31G -Dapp.name=SearchCollection io.anserini.search.SearchCollection "
- f"-topicreader Trec -index {index_path} {indexopts} -topics {topicsfn} -output {output_path} "
- f"-topicfield {topicfield} -inmem -threads {MAX_THREADS} {anserini_param_str}"
- )
- logger.info("Anserini writing runs to %s", output_path)
- logger.debug(cmd)
-
- app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True)
-
- # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger
- for line in app.stdout:
- Anserini.filter_and_log_anserini_output(line, logger)
-
- app.wait()
- if app.returncode != 0:
- raise RuntimeError("command failed")
-
- with open(donefn, "wt") as donef:
- print("done", file=donef)
-
-
-class PostprocessMixin:
- def _keep_topn(self, runs, topn):
- queries = sorted(list(runs.keys()), key=lambda k: int(k))
- for q in queries:
- docs = runs[q]
- if len(docs) <= topn:
- continue
- docs = sorted(docs.items(), key=lambda kv: kv[1], reverse=True)[:topn]
- runs[q] = {k: v for k, v in docs}
- return runs
-
- def filter(self, run_dir, docs_to_remove=None, docs_to_keep=None, topn=None):
- if (not docs_to_keep) and (not docs_to_remove):
- raise
-
- for fn in os.listdir(run_dir):
- if fn == "done":
- continue
-
- run_fn = os.path.join(run_dir, fn)
- self._filter(run_fn, docs_to_remove, docs_to_keep, topn)
- return run_dir
-
- def _filter(self, runfile, docs_to_remove, docs_to_keep, topn):
- runs = Searcher.load_trec_run(runfile)
-
- # filtering
- if docs_to_remove: # prioritize docs_to_remove
- if isinstance(docs_to_remove, list):
- docs_to_remove = {q: docs_to_remove for q in runs}
- runs = {q: {d: v for d, v in docs.items() if d not in docs_to_remove.get(q, [])} for q, docs in runs.items()}
- elif docs_to_keep:
- if isinstance(docs_to_keep, list):
- docs_to_keep = {q: docs_to_keep for q in runs}
- runs = {q: {d: v for d, v in docs.items() if d in docs_to_keep[q]} for q, docs in runs.items()}
-
- if topn:
- runs = self._keep_topn(runs, topn)
- Searcher.write_trec_run(runs, runfile) # overwrite runfile
-
- def dedup(self, run_dir, topn=None):
- for fn in os.listdir(run_dir):
- if fn == "done":
- continue
- run_fn = os.path.join(run_dir, fn)
- self._dedup(run_fn, topn)
- return run_dir
-
- def _dedup(self, runfile, topn):
- runs = Searcher.load_trec_run(runfile)
- new_runs = {q: {} for q in runs}
-
- # use the sum of each passage score as the document score, no sorting is done here
- for q, psg in runs.items():
- for pid, score in psg.items():
- docid = pid.split(".")[0]
- new_runs[q][docid] = max(new_runs[q].get(docid, -math.inf), score)
- runs = new_runs
-
- if topn:
- runs = self._keep_topn(runs, topn)
- Searcher.write_trec_run(runs, runfile)
-
-
-@Searcher.register
-class BM25(Searcher, AnseriniSearcherMixIn):
- """ BM25 with fixed k1 and b. """
-
- module_name = "BM25"
-
- dependencies = [Dependency(key="index", module="index", name="anserini")]
- config_spec = [
- ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
- ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- """
- Runs BM25 search. Takes a query from the topic files, and fires it against the index
- Args:
- topicsfn: Path to a topics file
- output_path: Path where the results of the search (i.e the run file) should be stored
-
- Returns: Path to the run file where the results of the search are stored
-
- """
- bstr, k1str = list2str(config["b"], delimiter=" "), list2str(config["k1"], delimiter=" ")
- hits = config["hits"]
- anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}"
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class BM25Grid(Searcher, AnseriniSearcherMixIn):
- """ BM25 with a grid search for k1 and b. Search is from 0.1 to bmax/k1max in 0.1 increments """
-
- module_name = "BM25Grid"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
- config_spec = [
- ConfigOption("k1max", 1.0, "maximum k1 value to include in grid search (starting at 0.1)"),
- ConfigOption("bmax", 1.0, "maximum b value to include in grid search (starting at 0.1)"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- bs = np.around(np.arange(0.1, config["bmax"] + 0.1, 0.1), 1)
- k1s = np.around(np.arange(0.1, config["k1max"] + 0.1, 0.1), 1)
- bstr = " ".join(str(x) for x in bs)
- k1str = " ".join(str(x) for x in k1s)
- hits = config["hits"]
- anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}"
-
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class BM25RM3(Searcher, AnseriniSearcherMixIn):
-
- module_name = "BM25RM3"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
- config_spec = [
- ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"),
- ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"),
- ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"),
- ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"),
- ConfigOption("originalQueryWeight", [0.5], "the weight of unexpended query", value_type="floatlist"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- hits = str(config["hits"])
-
- anserini_param_str = (
- "-rm3 "
- + " ".join(f"-rm3.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "originalQueryWeight"])
- + " -bm25 "
- + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"])
- + f" -hits {hits}"
- )
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class BM25PostProcess(BM25, PostprocessMixin):
- module_name = "BM25Postprocess"
-
- config_spec = [
- ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
- ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("topn", 1000),
- ConfigOption("fields", "title"),
- ConfigOption("dedep", False),
- ]
-
- def query_from_file(self, topicsfn, output_path, docs_to_remove=None):
- output_path = super().query_from_file(topicsfn, output_path) # will call _query_from_file() from BM25
-
- if docs_to_remove:
- output_path = self.filter(output_path, docs_to_remove=docs_to_remove, topn=self.config["topn"])
- if self.config["dedup"]:
- output_path = self.dedup(output_path, topn=self.config["topn"])
-
- return output_path
-
-
-@Searcher.register
-class StaticBM25RM3Rob04Yang19(Searcher):
- """ Tuned BM25+RM3 run used by Yang et al. in [1]. This should be used only with a benchmark using the same folds and queries.
-
- [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019.
- """
-
- module_name = "bm25staticrob04yang19"
-
- def _query_from_file(self, topicsfn, output_path, config):
- import shutil
-
- outfn = os.path.join(output_path, "static.run")
- os.makedirs(output_path, exist_ok=True)
- shutil.copy2(constants["PACKAGE_PATH"] / "data" / "rob04_yang19_rm3.run", outfn)
-
- return output_path
-
- def query(self, *args, **kwargs):
- raise NotImplementedError("this searcher uses a static run file, so it cannot handle new queries")
-
-
-@Searcher.register
-class BM25PRF(Searcher, AnseriniSearcherMixIn):
- """
- BM25 with PRF
- """
-
- module_name = "BM25PRF"
-
- dependencies = [Dependency(key="index", module="index", name="anserini")]
- config_spec = [
- ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"),
- ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"),
- ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"),
- ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"),
- ConfigOption("newTermWeight", [0.2, 0.25], value_type="floatlist"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- hits = str(config["hits"])
-
- anserini_param_str = (
- "-bm25prf "
- + " ".join(f"-bm25prf.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "newTermWeight", "k1", "b"])
- + " -bm25 "
- + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"])
- + f" -hits {hits}"
- )
- print(output_path)
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class AxiomaticSemanticMatching(Searcher, AnseriniSearcherMixIn):
- """
- TODO: Add more info on retrieval method
- Also, BM25 is hard-coded to be the scoring model
- """
-
- module_name = "axiomatic"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
- config_spec = [
- ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
- ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
- ConfigOption("r", 20, value_type="intlist"),
- ConfigOption("n", 30, value_type="intlist"),
- ConfigOption("beta", 0.4, value_type="floatlist"),
- ConfigOption("top", 20, value_type="intlist"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- hits = str(config["hits"])
- conditionals = ""
-
- anserini_param_str = "-axiom -axiom.deterministic -axiom.r {0} -axiom.n {1} -axiom.beta {2} -axiom.top {3}".format(
- *[list2str(config[k], " ") for k in ["r", "n", "beta", "top"]]
- )
- anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1} ".format(*[list2str(config[k], " ") for k in ["k1", "b"]])
- anserini_param_str += f" -hits {hits}"
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class DirichletQL(Searcher, AnseriniSearcherMixIn):
- """ Dirichlet QL with a fixed mu """
-
- module_name = "DirichletQL"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
-
- config_spec = [
- ConfigOption("mu", 1000, "smoothing parameter", value_type="intlist"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- """
- Runs Dirichlet QL search. Takes a query from the topic files, and fires it against the index
- Args:
- topicsfn: Path to a topics file
- output_path: Path where the results of the search (i.e the run file) should be stored
-
- Returns: Path to the run file where the results of the search are stored
-
- """
- mustr = list2str(config["mu"], delimiter=" ")
- hits = config["hits"]
- anserini_param_str = f"-qld -qld.mu {mustr} -hits {hits}"
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class QLJM(Searcher, AnseriniSearcherMixIn):
- """
- QL with Jelinek-Mercer smoothing
- """
-
- module_name = "QLJM"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
- config_spec = [
- ConfigOption("lam", 0.1, value_type="floatlist"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- anserini_param_str = "-qljm -qljm.lambda {0} -hits {1}".format(list2str(config["lam"], delimiter=" "), config["hits"])
-
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class INL2(Searcher, AnseriniSearcherMixIn):
- """
- I(n)L2 scoring model
- """
-
- module_name = "INL2"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
- config_spec = [
- ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- anserini_param_str = "-inl2 -inl2.c {0} -hits {1}".format(config["c"], config["hits"])
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
- return output_path
-
-
-@Searcher.register
-class SPL(Searcher, AnseriniSearcherMixIn):
- """
- SPL scoring model
- """
-
- module_name = "SPL"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
-
- config_spec = [
- ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- anserini_param_str = "-spl -spl.c {0} -hits {1}".format(config["c"], config["hits"])
-
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class F2Exp(Searcher, AnseriniSearcherMixIn):
- """
- F2Exp scoring model
- """
-
- module_name = "F2Exp"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
-
- config_spec = [
- ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- anserini_param_str = "-f2exp -f2exp.s {0} -hits {1}".format(config["s"], config["hits"])
-
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class F2Log(Searcher, AnseriniSearcherMixIn):
- """
- F2Log scoring model
- """
-
- module_name = "F2Log"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
-
- config_spec = [
- ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
-
- def _query_from_file(self, topicsfn, output_path, config):
- anserini_param_str = "-f2log -f2log.s {0} -hits {1}".format(config["s"], config["hits"])
-
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
-
- return output_path
-
-
-@Searcher.register
-class SDM(Searcher, AnseriniSearcherMixIn):
- """
- Sequential Dependency Model
- The scoring model is hardcoded to be BM25 (TODO: Make it configurable?)
- """
-
- module_name = "SDM"
- dependencies = [Dependency(key="index", module="index", name="anserini")]
-
- # array input of (tw, ow, uw) is not support by anserini.SearchCollection
- config_spec = [
- ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
- ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
- ConfigOption("tw", 0.85, "term weight"),
- ConfigOption("ow", 0.15, "ordered window weight"),
- ConfigOption("uw", 0.05, "unordered window weight"),
- ConfigOption("hits", 1000, "number of results to return"),
- ConfigOption("fields", "title"),
- ]
+from profane import import_all_modules
- def _query_from_file(self, topicsfn, output_path, config):
- hits = config["hits"]
- anserini_param_str = "-sdm -sdm.tw {0} -sdm.ow {1} -sdm.uw {2}".format(*[config[k] for k in ["tw", "ow", "uw"]])
- anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1}".format(*[list2str(config[k], " ") for k in ["k1", "b"]])
- anserini_param_str += f" -hits {hits}"
- self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+from .anserini import BM25, BM25RM3, SDM
- return output_path
+import_all_modules(__file__, __package__)
diff --git a/capreolus/searcher/anserini.py b/capreolus/searcher/anserini.py
new file mode 100644
index 000000000..25ff6c780
--- /dev/null
+++ b/capreolus/searcher/anserini.py
@@ -0,0 +1,504 @@
+import os
+import math
+import subprocess
+from collections import defaultdict, OrderedDict
+
+import numpy as np
+
+from . import Searcher
+from capreolus import ModuleBase, Dependency, ConfigOption, constants
+from capreolus.utils.common import Anserini
+from capreolus.utils.loginit import get_logger
+from capreolus.utils.trec import topic_to_trectxt
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+MAX_THREADS = constants["MAX_THREADS"]
+
+
+def list2str(l, delimiter="-"):
+ return delimiter.join(str(x) for x in l)
+
+
+class AnseriniSearcherMixIn:
+ """ MixIn for searchers that use Anserini's SearchCollection script """
+
+ def _anserini_query_from_file(self, topicsfn, anserini_param_str, output_base_path, topicfield):
+ if not os.path.exists(topicsfn):
+ raise IOError(f"could not find topics file: {topicsfn}")
+
+ # for covid:
+ field2querytype = {"query": "title", "question": "description", "narrative": "narrative"}
+ for k, v in field2querytype.items():
+ topicfield = topicfield.replace(k, v)
+
+ donefn = os.path.join(output_base_path, "done")
+ if os.path.exists(donefn):
+ logger.debug(f"skipping Anserini SearchCollection call because path already exists: {donefn}")
+ return
+
+ # create index if it does not exist. the call returns immediately if the index does exist.
+ self.index.create_index()
+
+ os.makedirs(output_base_path, exist_ok=True)
+ output_path = os.path.join(output_base_path, "searcher")
+
+ # add stemmer and stop options to match underlying index
+ indexopts = "-stemmer "
+ indexopts += "none" if self.index.config["stemmer"] is None else self.index.config["stemmer"]
+ if self.index.config["indexstops"]:
+ indexopts += " -keepstopwords"
+
+ index_path = self.index.get_index_path()
+ anserini_fat_jar = Anserini.get_fat_jar()
+ cmd = (
+ f"java -classpath {anserini_fat_jar} "
+ f"-Xms512M -Xmx31G -Dapp.name=SearchCollection io.anserini.search.SearchCollection "
+ f"-topicreader Trec -index {index_path} {indexopts} -topics {topicsfn} -output {output_path} "
+ f"-topicfield {topicfield} -inmem -threads {MAX_THREADS} {anserini_param_str}"
+ )
+ logger.info("Anserini writing runs to %s", output_path)
+ logger.debug(cmd)
+
+ app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True)
+
+ # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger
+ for line in app.stdout:
+ Anserini.filter_and_log_anserini_output(line, logger)
+
+ app.wait()
+ if app.returncode != 0:
+ raise RuntimeError("command failed")
+
+ with open(donefn, "wt") as donef:
+ print("done", file=donef)
+
+
+class PostprocessMixin:
+ def _keep_topn(self, runs, topn):
+ queries = sorted(list(runs.keys()), key=lambda k: int(k))
+ for q in queries:
+ docs = runs[q]
+ if len(docs) <= topn:
+ continue
+ docs = sorted(docs.items(), key=lambda kv: kv[1], reverse=True)[:topn]
+ runs[q] = {k: v for k, v in docs}
+ return runs
+
+ def filter(self, run_dir, docs_to_remove=None, docs_to_keep=None, topn=None):
+ if (not docs_to_keep) and (not docs_to_remove):
+ raise
+
+ for fn in os.listdir(run_dir):
+ if fn == "done":
+ continue
+
+ run_fn = os.path.join(run_dir, fn)
+ self._filter(run_fn, docs_to_remove, docs_to_keep, topn)
+ return run_dir
+
+ def _filter(self, runfile, docs_to_remove, docs_to_keep, topn):
+ runs = Searcher.load_trec_run(runfile)
+
+ # filtering
+ if docs_to_remove: # prioritize docs_to_remove
+ if isinstance(docs_to_remove, list):
+ docs_to_remove = {q: docs_to_remove for q in runs}
+ runs = {q: {d: v for d, v in docs.items() if d not in docs_to_remove.get(q, [])} for q, docs in runs.items()}
+ elif docs_to_keep:
+ if isinstance(docs_to_keep, list):
+ docs_to_keep = {q: docs_to_keep for q in runs}
+ runs = {q: {d: v for d, v in docs.items() if d in docs_to_keep[q]} for q, docs in runs.items()}
+
+ if topn:
+ runs = self._keep_topn(runs, topn)
+ Searcher.write_trec_run(runs, runfile) # overwrite runfile
+
+ def dedup(self, run_dir, topn=None):
+ for fn in os.listdir(run_dir):
+ if fn == "done":
+ continue
+ run_fn = os.path.join(run_dir, fn)
+ self._dedup(run_fn, topn)
+ return run_dir
+
+ def _dedup(self, runfile, topn):
+ runs = Searcher.load_trec_run(runfile)
+ new_runs = {q: {} for q in runs}
+
+ # use the sum of each passage score as the document score, no sorting is done here
+ for q, psg in runs.items():
+ for pid, score in psg.items():
+ docid = pid.split(".")[0]
+ new_runs[q][docid] = max(new_runs[q].get(docid, -math.inf), score)
+ runs = new_runs
+
+ if topn:
+ runs = self._keep_topn(runs, topn)
+ Searcher.write_trec_run(runs, runfile)
+
+
+@Searcher.register
+class BM25(Searcher, AnseriniSearcherMixIn):
+ """ Anserini BM25. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """
+
+ module_name = "BM25"
+
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+ config_spec = [
+ ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
+ ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ """
+ Runs BM25 search. Takes a query from the topic files, and fires it against the index
+ Args:
+ topicsfn: Path to a topics file
+ output_path: Path where the results of the search (i.e the run file) should be stored
+
+ Returns: Path to the run file where the results of the search are stored
+
+ """
+ bstr, k1str = list2str(config["b"], delimiter=" "), list2str(config["k1"], delimiter=" ")
+ hits = config["hits"]
+ anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}"
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class BM25Grid(Searcher, AnseriniSearcherMixIn):
+ """ Deprecated. BM25 with a grid search for k1 and b. Search is from 0.1 to bmax/k1max in 0.1 increments """
+
+ module_name = "BM25Grid"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+ config_spec = [
+ ConfigOption("k1max", 1.0, "maximum k1 value to include in grid search (starting at 0.1)"),
+ ConfigOption("bmax", 1.0, "maximum b value to include in grid search (starting at 0.1)"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ bs = np.around(np.arange(0.1, config["bmax"] + 0.1, 0.1), 1)
+ k1s = np.around(np.arange(0.1, config["k1max"] + 0.1, 0.1), 1)
+ bstr = " ".join(str(x) for x in bs)
+ k1str = " ".join(str(x) for x in k1s)
+ hits = config["hits"]
+ anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}"
+
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class BM25RM3(Searcher, AnseriniSearcherMixIn):
+ """ Anserini BM25 with RM3 expansion. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """
+
+ module_name = "BM25RM3"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+ config_spec = [
+ ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"),
+ ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"),
+ ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"),
+ ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"),
+ ConfigOption("originalQueryWeight", [0.5], "the weight of unexpended query", value_type="floatlist"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ hits = str(config["hits"])
+
+ anserini_param_str = (
+ "-rm3 "
+ + " ".join(f"-rm3.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "originalQueryWeight"])
+ + " -bm25 "
+ + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"])
+ + f" -hits {hits}"
+ )
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class BM25PostProcess(BM25, PostprocessMixin):
+ module_name = "BM25Postprocess"
+
+ config_spec = [
+ ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
+ ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("topn", 1000),
+ ConfigOption("fields", "title"),
+ ConfigOption("dedep", False),
+ ]
+
+ def query_from_file(self, topicsfn, output_path, docs_to_remove=None):
+ output_path = super().query_from_file(topicsfn, output_path) # will call _query_from_file() from BM25
+
+ if docs_to_remove:
+ output_path = self.filter(output_path, docs_to_remove=docs_to_remove, topn=self.config["topn"])
+ if self.config["dedup"]:
+ output_path = self.dedup(output_path, topn=self.config["topn"])
+
+ return output_path
+
+
+@Searcher.register
+class StaticBM25RM3Rob04Yang19(Searcher):
+ """ Tuned BM25+RM3 run used by Yang et al. in [1]. This should be used only with a benchmark using the same folds and queries.
+
+ [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019.
+ """
+
+ module_name = "bm25staticrob04yang19"
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ import shutil
+
+ outfn = os.path.join(output_path, "static.run")
+ os.makedirs(output_path, exist_ok=True)
+ shutil.copy2(constants["PACKAGE_PATH"] / "data" / "rob04_yang19_rm3.run", outfn)
+
+ return output_path
+
+ def query(self, *args, **kwargs):
+ raise NotImplementedError("this searcher uses a static run file, so it cannot handle new queries")
+
+
+@Searcher.register
+class BM25PRF(Searcher, AnseriniSearcherMixIn):
+ """ Anserini BM25 PRF. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """
+
+ module_name = "BM25PRF"
+
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+ config_spec = [
+ ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"),
+ ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"),
+ ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"),
+ ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"),
+ ConfigOption("newTermWeight", [0.2, 0.25], value_type="floatlist"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ hits = str(config["hits"])
+
+ anserini_param_str = (
+ "-bm25prf "
+ + " ".join(f"-bm25prf.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "newTermWeight", "k1", "b"])
+ + " -bm25 "
+ + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"])
+ + f" -hits {hits}"
+ )
+ print(output_path)
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class AxiomaticSemanticMatching(Searcher, AnseriniSearcherMixIn):
+ """ Anserini BM25 with Axiomatic query expansion. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """
+
+ module_name = "axiomatic"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+ config_spec = [
+ ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
+ ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
+ ConfigOption("r", 20, value_type="intlist"),
+ ConfigOption("n", 30, value_type="intlist"),
+ ConfigOption("beta", 0.4, value_type="floatlist"),
+ ConfigOption("top", 20, value_type="intlist"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ hits = str(config["hits"])
+ conditionals = ""
+
+ anserini_param_str = "-axiom -axiom.deterministic -axiom.r {0} -axiom.n {1} -axiom.beta {2} -axiom.top {3}".format(
+ *[list2str(config[k], " ") for k in ["r", "n", "beta", "top"]]
+ )
+ anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1} ".format(*[list2str(config[k], " ") for k in ["k1", "b"]])
+ anserini_param_str += f" -hits {hits}"
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class DirichletQL(Searcher, AnseriniSearcherMixIn):
+ """ Anserini QL with Dirichlet smoothing. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """
+
+ module_name = "DirichletQL"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+
+ config_spec = [
+ ConfigOption("mu", 1000, "smoothing parameter", value_type="intlist"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ """
+ Runs Dirichlet QL search. Takes a query from the topic files, and fires it against the index
+ Args:
+ topicsfn: Path to a topics file
+ output_path: Path where the results of the search (i.e the run file) should be stored
+
+ Returns: Path to the run file where the results of the search are stored
+
+ """
+ mustr = list2str(config["mu"], delimiter=" ")
+ hits = config["hits"]
+ anserini_param_str = f"-qld -qld.mu {mustr} -hits {hits}"
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class QLJM(Searcher, AnseriniSearcherMixIn):
+ """ Anserini QL with Jelinek-Mercer smoothing. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """
+
+ module_name = "QLJM"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+ config_spec = [
+ ConfigOption("lam", 0.1, value_type="floatlist"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ anserini_param_str = "-qljm -qljm.lambda {0} -hits {1}".format(list2str(config["lam"], delimiter=" "), config["hits"])
+
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class INL2(Searcher, AnseriniSearcherMixIn):
+ """ Anserini I(n)L2 scoring model. This searcher does not support list parameters. """
+
+ module_name = "INL2"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+ config_spec = [
+ ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ anserini_param_str = "-inl2 -inl2.c {0} -hits {1}".format(config["c"], config["hits"])
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+ return output_path
+
+
+@Searcher.register
+class SPL(Searcher, AnseriniSearcherMixIn):
+ """
+ Anserini SPL scoring model. This searcher does not support list parameters.
+ """
+
+ module_name = "SPL"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+
+ config_spec = [
+ ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ anserini_param_str = "-spl -spl.c {0} -hits {1}".format(config["c"], config["hits"])
+
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class F2Exp(Searcher, AnseriniSearcherMixIn):
+ """
+ F2Exp scoring model. This searcher does not support list parameters.
+ """
+
+ module_name = "F2Exp"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+
+ config_spec = [
+ ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ anserini_param_str = "-f2exp -f2exp.s {0} -hits {1}".format(config["s"], config["hits"])
+
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class F2Log(Searcher, AnseriniSearcherMixIn):
+ """
+ F2Log scoring model. This searcher does not support list parameters.
+ """
+
+ module_name = "F2Log"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+
+ config_spec = [
+ ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ anserini_param_str = "-f2log -f2log.s {0} -hits {1}".format(config["s"], config["hits"])
+
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
+
+
+@Searcher.register
+class SDM(Searcher, AnseriniSearcherMixIn):
+ """
+ Anserini BM25 with the Sequential Dependency Model. This searcher supports list parameters for only k1 and b.
+ """
+
+ module_name = "SDM"
+ dependencies = [Dependency(key="index", module="index", name="anserini")]
+
+ # array input of (tw, ow, uw) is not support by anserini.SearchCollection
+ config_spec = [
+ ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
+ ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
+ ConfigOption("tw", 0.85, "term weight"),
+ ConfigOption("ow", 0.15, "ordered window weight"),
+ ConfigOption("uw", 0.05, "unordered window weight"),
+ ConfigOption("hits", 1000, "number of results to return"),
+ ConfigOption("fields", "title"),
+ ]
+
+ def _query_from_file(self, topicsfn, output_path, config):
+ hits = config["hits"]
+ anserini_param_str = "-sdm -sdm.tw {0} -sdm.ow {1} -sdm.uw {2}".format(*[config[k] for k in ["tw", "ow", "uw"]])
+ anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1}".format(*[list2str(config[k], " ") for k in ["k1", "b"]])
+ anserini_param_str += f" -hits {hits}"
+ self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
+
+ return output_path
diff --git a/capreolus/searcher/tests/test_searcher.py b/capreolus/searcher/tests/test_searcher.py
index 469f5dc62..0a87d61f8 100644
--- a/capreolus/searcher/tests/test_searcher.py
+++ b/capreolus/searcher/tests/test_searcher.py
@@ -1,11 +1,11 @@
import os
import numpy as np
import pytest
-from profane import module_registry
+from capreolus import module_registry
from capreolus.utils.trec import load_trec_topics
from capreolus.benchmark import DummyBenchmark
-from capreolus.searcher import Searcher, BM25, BM25Grid
+from capreolus.searcher.anserini import Searcher, BM25, BM25Grid
from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index
skip_searchers = {"bm25staticrob04yang19", "BM25Grid", "BM25Postprocess", "axiomatic"}
diff --git a/capreolus/task/__init__.py b/capreolus/task/__init__.py
index 6412f8b0e..b64a361d2 100644
--- a/capreolus/task/__init__.py
+++ b/capreolus/task/__init__.py
@@ -1,5 +1,62 @@
+from capreolus import ModuleBase, Dependency, ConfigOption, constants, module_registry
+
+
+class Task(ModuleBase):
+ """Base class for Task modules. The purpose of a Task is to describe a Capreolus pipeline and serve as the pipeline's entry point. Tasks provide one or more commands that provide entry points while sharing the Task's configuration options and dependencies.
+
+ Modules should provide:
+ - a ``commands`` attribute containing the names of methods that can serve as pipeline entry points (*Task commands*). Each command will be accessible via the CLI using the syntax ``capreolus . ...``
+ - a ``default_command`` attribute containing the name of a command to run if none is given
+ - methods (taking only the *self* argument) that correspond to each command defined
+ """
+
+ module_type = "task"
+ commands = []
+ help_commands = ["describe", "print_config", "print_paths", "print_pipeline"]
+ default_command = "describe"
+ requires_random_seed = True
+
+ def print_config(self):
+ print("Configuration:")
+ self.print_module_config(prefix=" ")
+
+ def print_paths(self): # TODO
+ pass
+
+ def print_pipeline(self):
+ print(f"Module graph:")
+ self.print_module_graph(prefix=" ")
+
+ def describe(self):
+ self.print_pipeline()
+ print("\n")
+ self.print_config()
+
+ def get_results_path(self):
+ """ Return an absolute path that can be used for storing results.
+ The path is a function of the module's config and the configs of its dependencies.
+ """
+
+ return constants["RESULTS_BASE_PATH"] / self.get_module_path()
+
+
+@Task.register
+class ModulesTask(Task):
+ module_name = "modules"
+ commands = ["list_modules"]
+ default_command = "list_modules"
+
+ def list_modules(self):
+ for module_type in module_registry.get_module_types():
+ print(f"module type={module_type}")
+
+ for module_name in module_registry.get_module_names(module_type):
+ print(f" name={module_name}")
+
+
from profane import import_all_modules
-from .base import Task
+from .rank import RankTask
+from .rerank import RerankTask
import_all_modules(__file__, __package__)
diff --git a/capreolus/task/base.py b/capreolus/task/base.py
deleted file mode 100644
index 694713063..000000000
--- a/capreolus/task/base.py
+++ /dev/null
@@ -1,46 +0,0 @@
-from profane import ModuleBase, Dependency, ConfigOption, constants, module_registry
-
-
-class Task(ModuleBase):
- module_type = "task"
- commands = []
- help_commands = ["describe", "print_config", "print_paths", "print_pipeline"]
- default_command = "describe"
- requires_random_seed = True
-
- def print_config(self):
- print("Configuration:")
- self.print_module_config(prefix=" ")
-
- def print_paths(self): # TODO
- pass
-
- def print_pipeline(self):
- print(f"Module graph:")
- self.print_module_graph(prefix=" ")
-
- def describe(self):
- self.print_pipeline()
- print("\n")
- self.print_config()
-
- def get_results_path(self):
- """ Return an absolute path that can be used for storing results.
- The path is a function of the module's config and the configs of its dependencies.
- """
-
- return constants["RESULTS_BASE_PATH"] / self.get_module_path()
-
-
-@Task.register
-class ModulesTask(Task):
- module_name = "modules"
- commands = ["list_modules"]
- default_command = "list_modules"
-
- def list_modules(self):
- for module_type in module_registry.get_module_types():
- print(f"module type={module_type}")
-
- for module_name in module_registry.get_module_names(module_type):
- print(f" name={module_name}")
diff --git a/capreolus/task/rank.py b/capreolus/task/rank.py
index ff664fe4d..526aa0899 100644
--- a/capreolus/task/rank.py
+++ b/capreolus/task/rank.py
@@ -20,7 +20,9 @@ class RankTask(Task):
config_keys_not_in_path = ["optimize", "metrics"] # affect only evaluation but not search()
dependencies = [
- Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]),
+ Dependency(
+ key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"]
+ ),
Dependency(key="searcher", module="searcher", name="BM25"),
]
diff --git a/capreolus/task/rerank.py b/capreolus/task/rerank.py
index 980eee019..963464dcc 100644
--- a/capreolus/task/rerank.py
+++ b/capreolus/task/rerank.py
@@ -23,7 +23,9 @@ class RerankTask(Task):
ConfigOption("optimize", "map", "metric to maximize on the dev set"), # affects train() because we check to save weights
]
dependencies = [
- Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]),
+ Dependency(
+ key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"]
+ ),
Dependency(key="rank", module="task", name="rank"),
Dependency(key="reranker", module="reranker", name="KNRM"),
]
diff --git a/capreolus/task/rererank.py b/capreolus/task/rererank.py
index 5928eb075..e8acd04a0 100644
--- a/capreolus/task/rererank.py
+++ b/capreolus/task/rererank.py
@@ -24,7 +24,9 @@ class ReRerankTask(Task):
ConfigOption("topn", 100, "number of stage two results to rerank"),
]
dependencies = [
- Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]),
+ Dependency(
+ key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"]
+ ),
Dependency(key="rank", module="task", name="rank", provide_this=True),
Dependency(key="rerank1", module="task", name="rerank"),
Dependency(key="rerank2", module="task", name="rerank"),
diff --git a/capreolus/task/tests/test_task.py b/capreolus/task/tests/test_task.py
new file mode 100644
index 000000000..b9e7e0df2
--- /dev/null
+++ b/capreolus/task/tests/test_task.py
@@ -0,0 +1,13 @@
+import pytest
+
+from capreolus import Benchmark, Task, module_registry
+from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache
+
+
+tasks = set(module_registry.get_module_names("task"))
+
+
+@pytest.mark.parametrize("task_name", tasks)
+def test_task_creatable(tmpdir_as_cache, dummy_index, task_name):
+ provide = {"index": dummy_index, "benchmark": Benchmark.create("dummy"), "collection": dummy_index.collection}
+ task = Task.create(task_name, provide=provide)
diff --git a/capreolus/task/tutorial.py b/capreolus/task/tutorial.py
index 1df31c4e6..0b3488af8 100644
--- a/capreolus/task/tutorial.py
+++ b/capreolus/task/tutorial.py
@@ -11,9 +11,7 @@ class TutorialTask(Task):
module_name = "tutorial"
config_spec = [ConfigOption("optimize", "map", "metric to maximize on the validation set")]
dependencies = [
- Dependency(
- key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"]
- ),
+ Dependency(key="benchmark", module="benchmark", name="nf", provide_this=True, provide_children=["collection"]),
Dependency(key="searcher1", module="searcher", name="BM25RM3"),
Dependency(key="searcher2", module="searcher", name="SDM"),
]
@@ -37,7 +35,7 @@ def run(self):
)
for fold, path in best_results["path"].items():
- shortpath = "..." + path[:-20]
+ shortpath = "..." + path[-40:]
logger.info("fold=%s best run: %s", fold, shortpath)
logger.info("cross-validated results when optimizing for '%s':", self.config["optimize"])
diff --git a/capreolus/tests/test_benchmark.py b/capreolus/tests/test_benchmark.py
index 0ade5c188..6a177be25 100644
--- a/capreolus/tests/test_benchmark.py
+++ b/capreolus/tests/test_benchmark.py
@@ -3,14 +3,26 @@
import pytest
from tqdm import tqdm
+from capreolus import Benchmark, module_registry
from capreolus.utils.loginit import get_logger
from capreolus.utils.common import remove_newline
-from capreolus.benchmark import CodeSearchNetCorpus as CodeSearchNetCodeSearchNetCorpusBenchmark
-from capreolus.benchmark import CodeSearchNetChallenge as CodeSearchNetCodeSearchNetChallengeBenchmark
-from capreolus.collection import CodeSearchNet as CodeSearchNetCollection
+from capreolus.benchmark.codesearchnet import CodeSearchNetCorpus as CodeSearchNetCodeSearchNetCorpusBenchmark
+from capreolus.benchmark.codesearchnet import CodeSearchNetChallenge as CodeSearchNetCodeSearchNetChallengeBenchmark
+from capreolus.collection.codesearchnet import CodeSearchNet as CodeSearchNetCollection
+from capreolus.tests.common_fixtures import tmpdir_as_cache
logger = get_logger(__name__)
+benchmarks = set(module_registry.get_module_names("benchmark"))
+
+
+@pytest.mark.parametrize("benchmark_name", benchmarks)
+@pytest.mark.download
+def test_benchmark_creatable(tmpdir_as_cache, benchmark_name):
+ benchmark = Benchmark.create(benchmark_name)
+ if hasattr(benchmark, "download_if_missing"):
+ benchmark.download_if_missing()
+
@pytest.mark.download
def test_csn_corpus_benchmark_downloadifmissing():
diff --git a/capreolus/tests/test_collection.py b/capreolus/tests/test_collection.py
index 34e455ecf..54a348107 100644
--- a/capreolus/tests/test_collection.py
+++ b/capreolus/tests/test_collection.py
@@ -3,8 +3,25 @@
import pytest
+from capreolus import Collection, module_registry
from capreolus.index import AnseriniIndex
-from capreolus.collection import ANTIQUE, CodeSearchNet
+from capreolus.collection.antique import ANTIQUE
+from capreolus.collection.codesearchnet import CodeSearchNet
+from capreolus.tests.common_fixtures import tmpdir_as_cache
+
+collections = set(module_registry.get_module_names("collection"))
+
+
+@pytest.mark.parametrize("collection_name", collections)
+def test_collection_creatable(tmpdir_as_cache, collection_name):
+ collection = Collection.create(collection_name)
+
+
+@pytest.mark.parametrize("collection_name", collections)
+@pytest.mark.download
+def test_collection_downloadable(tmpdir_as_cache, collection_name):
+ collection = Collection.create(collection_name)
+ collection.find_document_path()
@pytest.mark.download
diff --git a/capreolus/tests/test_extractor.py b/capreolus/tests/test_extractor.py
index 87f0d8333..e0ef28f26 100644
--- a/capreolus/tests/test_extractor.py
+++ b/capreolus/tests/test_extractor.py
@@ -4,12 +4,14 @@
from nltk import TextTilingTokenizer
from pymagnitude import Magnitude, MagnitudeUtils
import numpy as np
+import pytest
+from capreolus import Extractor, module_registry
from capreolus.collection import DummyCollection
from capreolus.index import AnseriniIndex
from capreolus.tokenizer import AnseriniTokenizer
from capreolus.benchmark import DummyBenchmark
-from capreolus.extractor import EmbedText
+from capreolus.extractor.embedtext import EmbedText
from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index
from capreolus.utils.exceptions import MissingDocError
@@ -19,6 +21,14 @@
MAXQLEN = 8
MAXDOCLEN = 7
+extractors = set(module_registry.get_module_names("extractor"))
+
+
+@pytest.mark.parametrize("extractor_name", extractors)
+def test_extractor_creatable(tmpdir_as_cache, dummy_index, extractor_name):
+ provide = {"index": dummy_index, "collection": dummy_index.collection}
+ extractor = Extractor.create(extractor_name, provide=provide)
+
def test_embedtext_creation(monkeypatch):
def fake_magnitude_embedding(*args, **kwargs):
diff --git a/capreolus/tokenizer/__init__.py b/capreolus/tokenizer/__init__.py
index ccd471b54..0ce810e08 100644
--- a/capreolus/tokenizer/__init__.py
+++ b/capreolus/tokenizer/__init__.py
@@ -1,70 +1,19 @@
-from profane import import_all_modules
-
-# import_all_modules(__file__, __package__)
-
-from profane import ModuleBase, Dependency, ConfigOption
-
-from transformers import BertTokenizer as HFBertTokenizer
+from capreolus import ModuleBase, Dependency, ConfigOption
class Tokenizer(ModuleBase):
- module_type = "tokenizer"
-
-
-@Tokenizer.register
-class AnseriniTokenizer(Tokenizer):
- module_name = "anserini"
- config_spec = [
- ConfigOption("keepstops", True, "keep stopwords if True"),
- ConfigOption("stemmer", "none", "stemmer: porter, krovetz, or none"),
- ]
-
- def build(self):
- self._tokenize = self._get_tokenize_fn()
-
- def _get_tokenize_fn(self):
- from jnius import autoclass
-
- stemmer, keepstops = self.config["stemmer"], self.config["keepstops"]
- if stemmer is None:
- stemmer = "none"
+ """Base class for Tokenizer modules. The purpose of a Tokenizer is to tokenize strings of text (e.g., as required by an :class:`~capreolus.extractor.Extractor`).
- emptyjchar = autoclass("org.apache.lucene.analysis.CharArraySet").EMPTY_SET
- Analyzer = autoclass("io.anserini.analysis.DefaultEnglishAnalyzer")
- analyzer = Analyzer.newStemmingInstance(stemmer, emptyjchar) if keepstops else Analyzer.newStemmingInstance(stemmer)
- tokenizefn = autoclass("io.anserini.analysis.AnalyzerUtils").analyze
+ Modules should provide:
+ - a ``tokenize(strings)`` method that takes a list of strings and returns tokenized versions
+ """
- def _tokenize(sentence):
- return tokenizefn(analyzer, sentence).toArray()
-
- return _tokenize
-
- def tokenize(self, sentences):
- if not sentences or len(sentences) == 0: # either "" or []
- return []
-
- if isinstance(sentences, str):
- return self._tokenize(sentences)
-
- return [self._tokenize(s) for s in sentences]
-
-
-@Tokenizer.register
-class BertTokenizer(Tokenizer):
- module_name = "berttokenizer"
- config_spec = [ConfigOption("pretrained", "bert-base-uncased", "pretrained model to load vocab from")]
-
- def build(self):
- self.bert_tokenizer = HFBertTokenizer.from_pretrained(self.config["pretrained"])
+ module_type = "tokenizer"
- def convert_tokens_to_ids(self, tokens):
- return self.bert_tokenizer.convert_tokens_to_ids(tokens)
- def tokenize(self, sentences):
- if not sentences or len(sentences) == 0: # either "" or []
- return []
+from profane import import_all_modules
- if isinstance(sentences, str):
- return self.bert_tokenizer.tokenize(sentences)
+from .anserini import AnseriniTokenizer
+from .bert import BertTokenizer
- return [self.bert_tokenizer.tokenize(s) for s in sentences]
+import_all_modules(__file__, __package__)
diff --git a/capreolus/tokenizer/anserini.py b/capreolus/tokenizer/anserini.py
new file mode 100644
index 000000000..13eb83a73
--- /dev/null
+++ b/capreolus/tokenizer/anserini.py
@@ -0,0 +1,40 @@
+from . import Tokenizer
+from capreolus import ModuleBase, Dependency, ConfigOption
+
+
+@Tokenizer.register
+class AnseriniTokenizer(Tokenizer):
+ module_name = "anserini"
+ config_spec = [
+ ConfigOption("keepstops", True, "keep stopwords if True"),
+ ConfigOption("stemmer", "none", "stemmer: porter, krovetz, or none"),
+ ]
+
+ def build(self):
+ self._tokenize = self._get_tokenize_fn()
+
+ def _get_tokenize_fn(self):
+ from jnius import autoclass
+
+ stemmer, keepstops = self.config["stemmer"], self.config["keepstops"]
+ if stemmer is None:
+ stemmer = "none"
+
+ emptyjchar = autoclass("org.apache.lucene.analysis.CharArraySet").EMPTY_SET
+ Analyzer = autoclass("io.anserini.analysis.DefaultEnglishAnalyzer")
+ analyzer = Analyzer.newStemmingInstance(stemmer, emptyjchar) if keepstops else Analyzer.newStemmingInstance(stemmer)
+ tokenizefn = autoclass("io.anserini.analysis.AnalyzerUtils").analyze
+
+ def _tokenize(sentence):
+ return tokenizefn(analyzer, sentence).toArray()
+
+ return _tokenize
+
+ def tokenize(self, sentences):
+ if not sentences or len(sentences) == 0: # either "" or []
+ return []
+
+ if isinstance(sentences, str):
+ return self._tokenize(sentences)
+
+ return [self._tokenize(s) for s in sentences]
diff --git a/capreolus/tokenizer/bert.py b/capreolus/tokenizer/bert.py
new file mode 100644
index 000000000..29f698ec7
--- /dev/null
+++ b/capreolus/tokenizer/bert.py
@@ -0,0 +1,25 @@
+from . import Tokenizer
+from capreolus import ModuleBase, Dependency, ConfigOption
+
+from transformers import BertTokenizer as HFBertTokenizer
+
+
+@Tokenizer.register
+class BertTokenizer(Tokenizer):
+ module_name = "berttokenizer"
+ config_spec = [ConfigOption("pretrained", "bert-base-uncased", "pretrained model to load vocab from")]
+
+ def build(self):
+ self.bert_tokenizer = HFBertTokenizer.from_pretrained(self.config["pretrained"])
+
+ def convert_tokens_to_ids(self, tokens):
+ return self.bert_tokenizer.convert_tokens_to_ids(tokens)
+
+ def tokenize(self, sentences):
+ if not sentences or len(sentences) == 0: # either "" or []
+ return []
+
+ if isinstance(sentences, str):
+ return self.bert_tokenizer.tokenize(sentences)
+
+ return [self.bert_tokenizer.tokenize(s) for s in sentences]
diff --git a/capreolus/trainer/__init__.py b/capreolus/trainer/__init__.py
index 2ca2011bf..a0bc0fde2 100644
--- a/capreolus/trainer/__init__.py
+++ b/capreolus/trainer/__init__.py
@@ -1,37 +1,18 @@
-from profane import import_all_modules
-
-# import_all_modules(__file__, __package__)
-
-import hashlib
-import math
import os
-import sys
-import time
-import uuid
-from collections import defaultdict
-from copy import copy
-from profane import ModuleBase, Dependency, ConfigOption, constants
-import tensorflow as tf
-import tensorflow_ranking as tfr
-import numpy as np
-import torch
-from keras import Sequential, layers
-from keras.layers import Dense
-from torch.utils.tensorboard import SummaryWriter
-from tqdm import tqdm
-
-from capreolus.reranker.common import pair_hinge_loss, pair_softmax_loss
-from capreolus.searcher import Searcher
-from capreolus.utils.loginit import get_logger
-from capreolus.utils.common import plot_metrics, plot_loss
-from capreolus import evaluator
+from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
-RESULTS_BASE_PATH = constants["RESULTS_BASE_PATH"]
class Trainer(ModuleBase):
+ """Base class for Trainer modules. The purpose of a Trainer is to train a :class:`~capreolus.reranker.Reranker` module and use it to make predictions. Capreolus provides two trainers: :class:`~capreolus.trainer.pytorch.PytorchTrainer` and :class:`~capreolus.trainer.tensorflow.TensorFlowTrainer`
+
+ Modules should provide:
+ - a ``train`` method that trains a reranker on training and dev (validation) data
+ - a ``predict`` method that uses a reranker to make predictions on data
+ """
+
module_type = "trainer"
requires_random_seed = True
@@ -49,649 +30,9 @@ def get_paths_for_early_stopping(self, train_output_path, dev_output_path):
return dev_best_weight_fn, weights_output_path, info_output_path, loss_fn
-@Trainer.register
-class PytorchTrainer(Trainer):
- module_name = "pytorch"
- config_spec = [
- ConfigOption("batch", 32, "batch size"),
- ConfigOption("niters", 20, "number of iterations to train for"),
- ConfigOption("itersize", 512, "number of training instances in one iteration"),
- ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"),
- ConfigOption("lr", 0.001, "learning rate"),
- ConfigOption("softmaxloss", False, "True to use softmax loss (over pairs) or False to use hinge loss"),
- ConfigOption("fastforward", False),
- ConfigOption("validatefreq", 1),
- ConfigOption("boardname", "default"),
- ]
- config_keys_not_in_path = ["fastforward", "boardname"]
-
- def build(self):
- # sanity checks
- if self.config["batch"] < 1:
- raise ValueError("batch must be >= 1")
-
- if self.config["niters"] <= 0:
- raise ValueError("niters must be > 0")
-
- if self.config["itersize"] < self.config["batch"]:
- raise ValueError("itersize must be >= batch")
-
- if self.config["gradacc"] < 1 or not float(self.config["gradacc"]).is_integer():
- raise ValueError("gradacc must be an integer >= 1")
-
- if self.config["lr"] <= 0:
- raise ValueError("lr must be > 0")
-
- torch.manual_seed(self.config["seed"])
- torch.cuda.manual_seed_all(self.config["seed"])
-
- def single_train_iteration(self, reranker, train_dataloader):
- """Train model for one iteration using instances from train_dataloader.
-
- Args:
- model (Reranker): a PyTorch Reranker
- train_dataloader (DataLoader): a PyTorch DataLoader that iterates over training instances
-
- Returns:
- float: average loss over the iteration
-
- """
-
- iter_loss = []
- batches_since_update = 0
- batches_per_epoch = (self.config["itersize"] // self.config["batch"]) or 1
- batches_per_step = self.config["gradacc"]
-
- for bi, batch in tqdm(enumerate(train_dataloader), desc="Iter progression"):
- # TODO make sure _prepare_batch_with_strings equivalent is happening inside the sampler
- batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()}
- doc_scores = reranker.score(batch)
- loss = self.loss(doc_scores)
- iter_loss.append(loss)
- loss.backward()
-
- batches_since_update += 1
- if batches_since_update == batches_per_step:
- batches_since_update = 0
- self.optimizer.step()
- self.optimizer.zero_grad()
-
- if (bi + 1) % batches_per_epoch == 0:
- break
-
- return torch.stack(iter_loss).mean()
-
- def load_loss_file(self, fn):
- """Loads loss history from fn
-
- Args:
- fn (Path): path to a loss.txt file
-
- Returns:
- a list of losses ordered by iterations
-
- """
-
- loss = []
- with fn.open(mode="rt") as f:
- for lineidx, line in enumerate(f):
- line = line.strip()
- if not line:
- continue
-
- iteridx, iterloss = line.rstrip().split()
-
- if int(iteridx) != lineidx:
- raise IOError(f"malformed loss file {fn} ... did two processes write to it?")
-
- loss.append(float(iterloss))
-
- return loss
-
- def fastforward_training(self, reranker, weights_path, loss_fn):
- """Skip to the last training iteration whose weights were saved.
-
- If saved model and optimizer weights are available, this method will load those weights into model
- and optimizer, and then return the next iteration to be run. For example, if weights are available for
- iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and
- this method will return 11.
-
- If an error or inconsistency is encountered when checking for weights, this method returns 0.
-
- This method checks several files to determine if weights "are available". First, loss_fn is read to
- determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.)
- Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer.
- If this is successful, the method returns `1 + last recorded iteration`. If not, it returns 0.
- (We consider loss_fn because it is written at the end of every training iteration.)
-
- Args:
- model (Reranker): a PyTorch Reranker whose state should be loaded
- weights_path (Path): directory containing model and optimizer weights
- loss_fn (Path): file containing loss history
-
- Returns:
- int: the next training iteration after fastforwarding. If successful, this is > 0.
- If no weights are available or they cannot be loaded, 0 is returned.
-
- """
-
- if not (weights_path.exists() and loss_fn.exists()):
- return 0
-
- try:
- loss = self.load_loss_file(loss_fn)
- except IOError:
- return 0
-
- last_loss_iteration = len(loss) - 1
- weights_fn = weights_path / f"{last_loss_iteration}.p"
-
- try:
- reranker.load_weights(weights_fn, self.optimizer)
- return last_loss_iteration + 1
- except:
- logger.info("attempted to load weights from %s but failed, starting at iteration 0", weights_fn)
- return 0
-
- def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1):
- """Train a model following the trainer's config (specifying batch size, number of iterations, etc).
-
- Args:
- train_dataset (IterableDataset): training dataset
- train_output_path (Path): directory under which train_dataset runs and training loss will be saved
- dev_data (IterableDataset): dev dataset
- dev_output_path (Path): directory where dev_data runs and metrics will be saved
-
- """
- # Set up logging
- # TODO why not put this under train_output_path?
- summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" / self.config["boardname"], comment=train_output_path)
-
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- model = reranker.model.to(self.device)
- self.optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=self.config["lr"])
-
- if self.config["softmaxloss"]:
- self.loss = pair_softmax_loss
- else:
- self.loss = pair_hinge_loss
-
- dev_best_weight_fn, weights_output_path, info_output_path, loss_fn = self.get_paths_for_early_stopping(
- train_output_path, dev_output_path
- )
-
- initial_iter = self.fastforward_training(reranker, weights_output_path, loss_fn) if self.config["fastforward"] else 0
- logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"])
-
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset, batch_size=self.config["batch"], pin_memory=True, num_workers=0
- )
- # dataiter = iter(train_dataloader)
- # sample_input = dataiter.next()
- # summary_writer.add_graph(
- # reranker.model,
- # [
- # sample_input["query"].to(self.device),
- # sample_input["posdoc"].to(self.device),
- # sample_input["negdoc"].to(self.device),
- # ],
- # )
-
- train_loss = []
- # are we resuming training?
- if initial_iter > 0:
- train_loss = self.load_loss_file(loss_fn)
-
- # are we done training?
- if initial_iter < self.config["niters"]:
- logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter)
- batches_per_epoch = self.config["itersize"] // self.config["batch"]
- for niter in range(initial_iter):
- for bi, batch in enumerate(train_dataloader):
- if (bi + 1) % batches_per_epoch == 0:
- break
-
- dev_best_metric = -np.inf
- validation_frequency = self.config["validatefreq"]
- train_start_time = time.time()
- for niter in range(initial_iter, self.config["niters"]):
- model.train()
-
- iter_start_time = time.time()
- iter_loss_tensor = self.single_train_iteration(reranker, train_dataloader)
- logger.info("A single iteration takes {}".format(time.time() - iter_start_time))
- train_loss.append(iter_loss_tensor.item())
- logger.info("iter = %d loss = %f", niter, train_loss[-1])
-
- # write model weights to file
- weights_fn = weights_output_path / f"{niter}.p"
- reranker.save_weights(weights_fn, self.optimizer)
- # predict performance on dev set
-
- if niter % validation_frequency == 0:
- pred_fn = dev_output_path / f"{niter}.run"
- preds = self.predict(reranker, dev_data, pred_fn)
-
- # log dev metrics
- metrics = evaluator.eval_runs(preds, qrels, evaluator.DEFAULT_METRICS, relevance_level)
- logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())]))
- summary_writer.add_scalar("ndcg_cut_20", metrics["ndcg_cut_20"], niter)
- summary_writer.add_scalar("map", metrics["map"], niter)
- summary_writer.add_scalar("P_20", metrics["P_20"], niter)
- # write best dev weights to file
- if metrics[metric] > dev_best_metric:
- reranker.save_weights(dev_best_weight_fn, self.optimizer)
-
- # write train_loss to file
- loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss)))
-
- summary_writer.add_scalar("training_loss", iter_loss_tensor.item(), niter)
- reranker.add_summary(summary_writer, niter)
- summary_writer.flush()
- logger.info("training loss: %s", train_loss)
- logger.info("Training took {}".format(time.time() - train_start_time))
- summary_writer.close()
-
- # TODO should we write a /done so that training can be skipped if possible when fastforward=False? or in Task?
-
- def load_best_model(self, reranker, train_output_path):
- self.optimizer = torch.optim.Adam(
- filter(lambda param: param.requires_grad, reranker.model.parameters()), lr=self.config["lr"]
- )
-
- dev_best_weight_fn = train_output_path / "dev.best"
- reranker.load_weights(dev_best_weight_fn, self.optimizer)
-
- def predict(self, reranker, pred_data, pred_fn):
- """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn`
-
- Args:
- model (Reranker): a PyTorch Reranker
- pred_data (IterableDataset): data to predict on
- pred_fn (Path): path to write the prediction run file to
-
- Returns:
- TREC Run
-
- """
-
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- # save to pred_fn
- model = reranker.model.to(self.device)
- model.eval()
-
- preds = {}
- pred_dataloader = torch.utils.data.DataLoader(pred_data, batch_size=self.config["batch"], pin_memory=True, num_workers=0)
- with torch.autograd.no_grad():
- for batch in tqdm(pred_dataloader, desc="Predicting on dev"):
- if len(batch["qid"]) != self.config["batch"]:
- batch = self.fill_incomplete_batch(batch)
-
- batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()}
- scores = reranker.test(batch)
- scores = scores.view(-1).cpu().numpy()
- for qid, docid, score in zip(batch["qid"], batch["posdocid"], scores):
- # Need to use float16 because pytrec_eval's c function call crashes with higher precision floats
- preds.setdefault(qid, {})[docid] = score.astype(np.float16).item()
-
- os.makedirs(os.path.dirname(pred_fn), exist_ok=True)
- Searcher.write_trec_run(preds, pred_fn)
-
- return preds
-
- def fill_incomplete_batch(self, batch):
- """
- If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data.
- How the data is chosen:
- If the values are just a simple list, use the first element of the list to pad the batch
- If the values are tensors/numpy arrays, use repeat() along the batch dimension
- """
- # logger.debug("filling in an incomplete batch")
- repeat_times = math.ceil(self.config["batch"] / len(batch["qid"]))
- diff = self.config["batch"] - len(batch["qid"])
-
- def pad(v):
- if isinstance(v, np.ndarray) or torch.is_tensor(v):
- _v = v.repeat((repeat_times,) + tuple([1 for x in range(len(v.shape) - 1)]))
- else:
- _v = v + [v[0]] * diff
-
- return _v[: self.config["batch"]]
-
- batch = {k: pad(v) for k, v in batch.items()}
- return batch
-
-
-class TrecCheckpointCallback(tf.keras.callbacks.Callback):
- """
- A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset.
- See TensorflowTrainer.train() for the invocation
- Also saves the best model to disk
- """
-
- def __init__(self, qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs):
- super(TrecCheckpointCallback, self).__init__(*args, **kwargs)
- """
- qrels - a qrels dict
- dev_data - a torch.utils.IterableDataset
- dev_records - a BatchedDataset instance
- """
- self.best_metric = -np.inf
- self.qrels = qrels
- self.dev_data = dev_data
- self.dev_records = dev_records
- self.output_path = output_path
- self.iter_start_time = time.time()
- self.metric = metric
- self.validate_freq = validate_freq
- self.relevance_level = relevance_level
-
- def save_model(self):
- self.model.save_weights("{0}/dev.best".format(self.output_path))
-
- def on_epoch_begin(self, epoch, logs=None):
- self.iter_start_time = time.time()
-
- def on_epoch_end(self, epoch, logs=None):
- logger.debug("Epoch {} took {}".format(epoch, time.time() - self.iter_start_time))
- if (epoch + 1) % self.validate_freq == 0:
- predictions = self.model.predict(self.dev_records, verbose=1, workers=8, use_multiprocessing=True)
- trec_preds = self.get_preds_in_trec_format(predictions, self.dev_data)
- metrics = evaluator.eval_runs(trec_preds, dict(self.qrels), evaluator.DEFAULT_METRICS, self.relevance_level)
- logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())]))
-
- if metrics[self.metric] > self.best_metric:
- self.best_metric = metrics[self.metric]
- # TODO: Prevent the embedding layer weights from being saved
- self.save_model()
-
- @staticmethod
- def get_preds_in_trec_format(predictions, dev_data):
- """
- Takes in a list of predictions and returns a dict that can be fed into pytrec_eval
- As a side effect, also writes the predictions into a file in the trec format
- """
- pred_dict = defaultdict(lambda: dict())
-
- for i, (qid, docid) in enumerate(dev_data.get_qid_docid_pairs()):
- # Pytrec_eval has problems with high precision floats
- pred_dict[qid][docid] = predictions[i][0].astype(np.float16).item()
-
- return dict(pred_dict)
-
-
-@Trainer.register
-class TensorFlowTrainer(Trainer):
- module_name = "tensorflow"
-
- config_spec = [
- ConfigOption("batch", 32, "batch size"),
- ConfigOption("niters", 20, "number of iterations to train for"),
- ConfigOption("itersize", 512, "number of training instances in one iteration"),
- # ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"),
- ConfigOption("lr", 0.001, "learning rate"),
- ConfigOption("loss", "pairwise_hinge_loss", "must be one of tfr.losses.RankingLossKey"),
- # ConfigOption("fastforward", False),
- ConfigOption("validatefreq", 1),
- ConfigOption("boardname", "default"),
- ConfigOption("usecache", False),
- ConfigOption("tpuname", None),
- ConfigOption("tpuzone", None),
- ConfigOption("storage", None),
- ]
- config_keys_not_in_path = ["fastforward", "boardname", "usecache", "tpuname", "tpuzone", "storage"]
-
- def build(self):
- tf.random.set_seed(self.config["seed"])
-
- # Use TPU if available, otherwise resort to GPU/CPU
- try:
- self.tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=self.config["tpuname"], zone=self.config["tpuzone"])
- except ValueError:
- self.tpu = None
- logger.info("Could not find the tpu")
-
- # TPUStrategy for distributed training
- if self.tpu:
- logger.info("Utilizing TPUs")
- tf.config.experimental_connect_to_cluster(self.tpu)
- tf.tpu.experimental.initialize_tpu_system(self.tpu)
- self.strategy = tf.distribute.experimental.TPUStrategy(self.tpu)
- else: # default strategy that works on CPU and single GPU
- self.strategy = tf.distribute.get_strategy()
-
- # Defining some props that we will later initialize
- self.optimizer = None
- self.loss = None
- self.validate()
-
- def validate(self):
- if self.tpu and any([self.config["storage"] is None, self.config["tpuname"] is None, self.config["tpuzone"] is None]):
- raise ValueError("storage, tpuname and tpuzone configs must be provided when training on TPU")
- if self.tpu and self.config["storage"] and not self.config["storage"].startswith("gs://"):
- raise ValueError("For TPU utilization, the storage config should start with 'gs://'")
-
- def get_optimizer(self):
- return tf.keras.optimizers.Adam(learning_rate=self.config["lr"])
-
- def fastforward_training(self, reranker, weights_path, loss_fn):
- # TODO: Fix fast forwarding
- return 0
-
- def load_best_model(self, reranker, train_output_path):
- # TODO: Do the train_output_path modification at one place?
- if self.tpu:
- train_output_path = "{0}/{1}/{2}".format(
- self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest()
- )
-
- reranker.model.load_weights("{0}/dev.best".format(train_output_path))
-
- def apply_gradients(self, weights, grads):
- self.optimizer.apply_gradients(zip(grads, weights))
-
- def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1):
- # summary_writer = tf.summary.create_file_writer("{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"]))
-
- # Because TPUs can't work with local files
- if self.tpu:
- train_output_path = "{0}/{1}/{2}".format(
- self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest()
- )
-
- os.makedirs(dev_output_path, exist_ok=True)
- initial_iter = self.fastforward_training(reranker, dev_output_path, None)
- logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"])
-
- strategy_scope = self.strategy.scope()
- with strategy_scope:
- train_records = self.get_tf_train_records(reranker, train_dataset)
- dev_records = self.get_tf_dev_records(reranker, dev_data)
- trec_callback = TrecCheckpointCallback(
- qrels,
- dev_data,
- dev_records,
- train_output_path,
- metric,
- self.config["validatefreq"],
- relevance_level=relevance_level,
- )
- tensorboard_callback = tf.keras.callbacks.TensorBoard(
- log_dir="{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"])
- )
- reranker.build_model() # TODO needed here?
-
- self.optimizer = self.get_optimizer()
- loss = tfr.keras.losses.get(self.config["loss"])
- reranker.model.compile(optimizer=self.optimizer, loss=loss)
-
- train_start_time = time.time()
- reranker.model.fit(
- train_records.prefetch(tf.data.experimental.AUTOTUNE),
- epochs=self.config["niters"],
- steps_per_epoch=self.config["itersize"],
- callbacks=[tensorboard_callback, trec_callback],
- workers=8,
- use_multiprocessing=True,
- )
- logger.info("Training took {}".format(time.time() - train_start_time))
-
- # Skipping dumping metrics and plotting loss since that should be done through tensorboard
-
- def create_tf_feature(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc):
- """
- Creates a single tf.train.Feature instance (i.e, a single sample)
- """
- feature = {
- "qid": tf.train.Feature(bytes_list=tf.train.BytesList(value=[qid.encode("utf-8")])),
- "query": tf.train.Feature(float_list=tf.train.FloatList(value=query)),
- "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)),
- "posdoc_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[posdoc_id.encode("utf-8")])),
- "posdoc": tf.train.Feature(float_list=tf.train.FloatList(value=posdoc)),
- }
-
- if negdoc_id:
- feature["negdoc_id"] = (tf.train.Feature(bytes_list=tf.train.BytesList(value=[negdoc_id.encode("utf-8")])),)
- feature["negdoc"] = tf.train.Feature(float_list=tf.train.FloatList(value=negdoc))
-
- return feature
-
- def write_tf_record_to_file(self, dir_name, tf_features):
- """
- Actually write the tf record to file. The destination can also be a gcs bucket.
- TODO: Use generators to optimize memory usage
- """
- filename = "{0}/{1}.tfrecord".format(dir_name, str(uuid.uuid4()))
- examples = [tf.train.Example(features=tf.train.Features(feature=feature)) for feature in tf_features]
-
- if not os.path.isdir(dir_name):
- os.makedirs(dir_name, exist_ok=True)
-
- examples = [example.SerializeToString() for example in examples]
- with tf.io.TFRecordWriter(filename) as writer:
- for example in examples:
- writer.write(example)
-
- logger.info("Wrote tf record file: {}".format(filename))
-
- return str(filename)
-
- def convert_to_tf_dev_record(self, reranker, dataset):
- """
- Similar to self.convert_to_tf_train_record(), but won't result in multiple files
- """
- dir_name = self.get_tf_record_cache_path(dataset)
-
- tf_features = [reranker.extractor.create_tf_feature(sample) for sample in dataset]
-
- return [self.write_tf_record_to_file(dir_name, tf_features)]
-
- def convert_to_tf_train_record(self, reranker, dataset):
- """
- Tensorflow works better if the input data is fed in as tfrecords
- Takes in a dataset, iterates through it, and creates multiple tf records from it.
- The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature()
- """
- dir_name = self.get_tf_record_cache_path(dataset)
-
- total_samples = dataset.get_total_samples()
- tf_features = []
- tf_record_filenames = []
-
- for niter in tqdm(range(0, self.config["niters"]), desc="Converting data to tf records"):
- for sample_idx, sample in enumerate(dataset):
- tf_features.append(reranker.extractor.create_tf_feature(sample))
-
- if len(tf_features) > 20000:
- tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features))
- tf_features = []
-
- if sample_idx + 1 >= self.config["itersize"] * self.config["batch"]:
- break
-
- if len(tf_features):
- tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features))
-
- return tf_record_filenames
-
- def get_tf_record_cache_path(self, dataset):
- """
- Get the path to the directory where tf records are written to.
- If using TPUs, this will be a gcs path.
- """
- if self.tpu:
- return "{0}/capreolus_tfrecords/{1}".format(self.config["storage"], dataset.get_hash())
- else:
- base_path = self.get_cache_path()
- return "{0}/{1}".format(base_path, dataset.get_hash())
-
- def cache_exists(self, dataset):
- # TODO: Add checks to make sure that the number of files in the directory is correct
- cache_dir = self.get_tf_record_cache_path(dataset)
- logger.info("The cache path is {0} and does it exist? : {1}".format(cache_dir, tf.io.gfile.exists(cache_dir)))
-
- return tf.io.gfile.isdir(cache_dir)
-
- def load_tf_records_from_file(self, reranker, filenames, batch_size):
- raw_dataset = tf.data.TFRecordDataset(filenames)
- tf_records_dataset = raw_dataset.batch(batch_size, drop_remainder=True).map(
- reranker.extractor.parse_tf_example, num_parallel_calls=tf.data.experimental.AUTOTUNE
- )
-
- return tf_records_dataset
-
- def load_cached_tf_records(self, reranker, dataset, batch_size):
- logger.info("Loading TF records from cache")
- cache_dir = self.get_tf_record_cache_path(dataset)
- filenames = tf.io.gfile.listdir(cache_dir)
- filenames = ["{0}/{1}".format(cache_dir, name) for name in filenames]
-
- return self.load_tf_records_from_file(reranker, filenames, batch_size)
-
- def get_tf_dev_records(self, reranker, dataset):
- """
- 1. Returns tf records from cache (disk) if applicable
- 2. Else, converts the dataset into tf records, writes them to disk, and returns them
- """
- if self.config["usecache"] and self.cache_exists(dataset):
- return self.load_cached_tf_records(reranker, dataset, 1)
- else:
- tf_record_filenames = self.convert_to_tf_dev_record(reranker, dataset)
- # TODO use actual batch size here. see issue #52
- return self.load_tf_records_from_file(reranker, tf_record_filenames, 1) # self.config["batch"])
-
- def get_tf_train_records(self, reranker, dataset):
- """
- 1. Returns tf records from cache (disk) if applicable
- 2. Else, converts the dataset into tf records, writes them to disk, and returns them
- """
-
- if self.config["usecache"] and self.cache_exists(dataset):
- return self.load_cached_tf_records(reranker, dataset, self.config["batch"])
- else:
- tf_record_filenames = self.convert_to_tf_train_record(reranker, dataset)
- return self.load_tf_records_from_file(reranker, tf_record_filenames, self.config["batch"])
-
- def predict(self, reranker, pred_data, pred_fn):
- """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn`
-
- Args:
- model (Reranker): a PyTorch Reranker
- pred_data (IterableDataset): data to predict on
- pred_fn (Path): path to write the prediction run file to
-
- Returns:
- TREC Run
-
- """
-
- strategy_scope = self.strategy.scope()
- with strategy_scope:
- pred_records = self.get_tf_dev_records(reranker, pred_data)
- predictions = reranker.model.predict(pred_records)
- trec_preds = TrecCheckpointCallback.get_preds_in_trec_format(predictions, pred_data)
+from profane import import_all_modules
- os.makedirs(os.path.dirname(pred_fn), exist_ok=True)
- Searcher.write_trec_run(trec_preds, pred_fn)
+from .pytorch import PytorchTrainer
+from .tensorflow import TensorFlowTrainer
- return trec_preds
+import_all_modules(__file__, __package__)
diff --git a/capreolus/trainer/pytorch.py b/capreolus/trainer/pytorch.py
new file mode 100644
index 000000000..78938bace
--- /dev/null
+++ b/capreolus/trainer/pytorch.py
@@ -0,0 +1,329 @@
+import math
+import os
+import time
+
+import numpy as np
+import torch
+from torch.utils.tensorboard import SummaryWriter
+from tqdm import tqdm
+
+from . import Trainer
+from capreolus import ModuleBase, Dependency, ConfigOption, Searcher, constants, evaluator, get_logger
+from capreolus.reranker.common import pair_hinge_loss, pair_softmax_loss
+from capreolus.utils.common import plot_metrics, plot_loss
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+RESULTS_BASE_PATH = constants["RESULTS_BASE_PATH"]
+
+
+@Trainer.register
+class PytorchTrainer(Trainer):
+ module_name = "pytorch"
+ config_spec = [
+ ConfigOption("batch", 32, "batch size"),
+ ConfigOption("niters", 20, "number of iterations to train for"),
+ ConfigOption("itersize", 512, "number of training instances in one iteration"),
+ ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"),
+ ConfigOption("lr", 0.001, "learning rate"),
+ ConfigOption("softmaxloss", False, "True to use softmax loss (over pairs) or False to use hinge loss"),
+ ConfigOption("fastforward", False),
+ ConfigOption("validatefreq", 1),
+ ConfigOption("boardname", "default"),
+ ]
+ config_keys_not_in_path = ["fastforward", "boardname"]
+
+ def build(self):
+ # sanity checks
+ if self.config["batch"] < 1:
+ raise ValueError("batch must be >= 1")
+
+ if self.config["niters"] <= 0:
+ raise ValueError("niters must be > 0")
+
+ if self.config["itersize"] < self.config["batch"]:
+ raise ValueError("itersize must be >= batch")
+
+ if self.config["gradacc"] < 1 or not float(self.config["gradacc"]).is_integer():
+ raise ValueError("gradacc must be an integer >= 1")
+
+ if self.config["lr"] <= 0:
+ raise ValueError("lr must be > 0")
+
+ torch.manual_seed(self.config["seed"])
+ torch.cuda.manual_seed_all(self.config["seed"])
+
+ def single_train_iteration(self, reranker, train_dataloader):
+ """Train model for one iteration using instances from train_dataloader.
+
+ Args:
+ model (Reranker): a PyTorch Reranker
+ train_dataloader (DataLoader): a PyTorch DataLoader that iterates over training instances
+
+ Returns:
+ float: average loss over the iteration
+
+ """
+
+ iter_loss = []
+ batches_since_update = 0
+ batches_per_epoch = (self.config["itersize"] // self.config["batch"]) or 1
+ batches_per_step = self.config["gradacc"]
+
+ for bi, batch in tqdm(enumerate(train_dataloader), desc="Iter progression"):
+ # TODO make sure _prepare_batch_with_strings equivalent is happening inside the sampler
+ batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()}
+ doc_scores = reranker.score(batch)
+ loss = self.loss(doc_scores)
+ iter_loss.append(loss)
+ loss.backward()
+
+ batches_since_update += 1
+ if batches_since_update == batches_per_step:
+ batches_since_update = 0
+ self.optimizer.step()
+ self.optimizer.zero_grad()
+
+ if (bi + 1) % batches_per_epoch == 0:
+ break
+
+ return torch.stack(iter_loss).mean()
+
+ def load_loss_file(self, fn):
+ """Loads loss history from fn
+
+ Args:
+ fn (Path): path to a loss.txt file
+
+ Returns:
+ a list of losses ordered by iterations
+
+ """
+
+ loss = []
+ with fn.open(mode="rt") as f:
+ for lineidx, line in enumerate(f):
+ line = line.strip()
+ if not line:
+ continue
+
+ iteridx, iterloss = line.rstrip().split()
+
+ if int(iteridx) != lineidx:
+ raise IOError(f"malformed loss file {fn} ... did two processes write to it?")
+
+ loss.append(float(iterloss))
+
+ return loss
+
+ def fastforward_training(self, reranker, weights_path, loss_fn):
+ """Skip to the last training iteration whose weights were saved.
+
+ If saved model and optimizer weights are available, this method will load those weights into model
+ and optimizer, and then return the next iteration to be run. For example, if weights are available for
+ iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and
+ this method will return 11.
+
+ If an error or inconsistency is encountered when checking for weights, this method returns 0.
+
+ This method checks several files to determine if weights "are available". First, loss_fn is read to
+ determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.)
+ Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer.
+ If this is successful, the method returns `1 + last recorded iteration`. If not, it returns 0.
+ (We consider loss_fn because it is written at the end of every training iteration.)
+
+ Args:
+ model (Reranker): a PyTorch Reranker whose state should be loaded
+ weights_path (Path): directory containing model and optimizer weights
+ loss_fn (Path): file containing loss history
+
+ Returns:
+ int: the next training iteration after fastforwarding. If successful, this is > 0.
+ If no weights are available or they cannot be loaded, 0 is returned.
+
+ """
+
+ if not (weights_path.exists() and loss_fn.exists()):
+ return 0
+
+ try:
+ loss = self.load_loss_file(loss_fn)
+ except IOError:
+ return 0
+
+ last_loss_iteration = len(loss) - 1
+ weights_fn = weights_path / f"{last_loss_iteration}.p"
+
+ try:
+ reranker.load_weights(weights_fn, self.optimizer)
+ return last_loss_iteration + 1
+ except:
+ logger.info("attempted to load weights from %s but failed, starting at iteration 0", weights_fn)
+ return 0
+
+ def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1):
+ """Train a model following the trainer's config (specifying batch size, number of iterations, etc).
+
+ Args:
+ train_dataset (IterableDataset): training dataset
+ train_output_path (Path): directory under which train_dataset runs and training loss will be saved
+ dev_data (IterableDataset): dev dataset
+ dev_output_path (Path): directory where dev_data runs and metrics will be saved
+
+ """
+ # Set up logging
+ # TODO why not put this under train_output_path?
+ summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" / self.config["boardname"], comment=train_output_path)
+
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ model = reranker.model.to(self.device)
+ self.optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=self.config["lr"])
+
+ if self.config["softmaxloss"]:
+ self.loss = pair_softmax_loss
+ else:
+ self.loss = pair_hinge_loss
+
+ dev_best_weight_fn, weights_output_path, info_output_path, loss_fn = self.get_paths_for_early_stopping(
+ train_output_path, dev_output_path
+ )
+
+ initial_iter = self.fastforward_training(reranker, weights_output_path, loss_fn) if self.config["fastforward"] else 0
+ logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"])
+
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=self.config["batch"], pin_memory=True, num_workers=0
+ )
+ # dataiter = iter(train_dataloader)
+ # sample_input = dataiter.next()
+ # summary_writer.add_graph(
+ # reranker.model,
+ # [
+ # sample_input["query"].to(self.device),
+ # sample_input["posdoc"].to(self.device),
+ # sample_input["negdoc"].to(self.device),
+ # ],
+ # )
+
+ train_loss = []
+ # are we resuming training?
+ if initial_iter > 0:
+ train_loss = self.load_loss_file(loss_fn)
+
+ # are we done training?
+ if initial_iter < self.config["niters"]:
+ logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter)
+ batches_per_epoch = self.config["itersize"] // self.config["batch"]
+ for niter in range(initial_iter):
+ for bi, batch in enumerate(train_dataloader):
+ if (bi + 1) % batches_per_epoch == 0:
+ break
+
+ dev_best_metric = -np.inf
+ validation_frequency = self.config["validatefreq"]
+ train_start_time = time.time()
+ for niter in range(initial_iter, self.config["niters"]):
+ model.train()
+
+ iter_start_time = time.time()
+ iter_loss_tensor = self.single_train_iteration(reranker, train_dataloader)
+ logger.info("A single iteration takes {}".format(time.time() - iter_start_time))
+ train_loss.append(iter_loss_tensor.item())
+ logger.info("iter = %d loss = %f", niter, train_loss[-1])
+
+ # write model weights to file
+ weights_fn = weights_output_path / f"{niter}.p"
+ reranker.save_weights(weights_fn, self.optimizer)
+ # predict performance on dev set
+
+ if niter % validation_frequency == 0:
+ pred_fn = dev_output_path / f"{niter}.run"
+ preds = self.predict(reranker, dev_data, pred_fn)
+
+ # log dev metrics
+ metrics = evaluator.eval_runs(preds, qrels, evaluator.DEFAULT_METRICS, relevance_level)
+ logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())]))
+ summary_writer.add_scalar("ndcg_cut_20", metrics["ndcg_cut_20"], niter)
+ summary_writer.add_scalar("map", metrics["map"], niter)
+ summary_writer.add_scalar("P_20", metrics["P_20"], niter)
+ # write best dev weights to file
+ if metrics[metric] > dev_best_metric:
+ reranker.save_weights(dev_best_weight_fn, self.optimizer)
+
+ # write train_loss to file
+ loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss)))
+
+ summary_writer.add_scalar("training_loss", iter_loss_tensor.item(), niter)
+ reranker.add_summary(summary_writer, niter)
+ summary_writer.flush()
+ logger.info("training loss: %s", train_loss)
+ logger.info("Training took {}".format(time.time() - train_start_time))
+ summary_writer.close()
+
+ # TODO should we write a /done so that training can be skipped if possible when fastforward=False? or in Task?
+
+ def load_best_model(self, reranker, train_output_path):
+ self.optimizer = torch.optim.Adam(
+ filter(lambda param: param.requires_grad, reranker.model.parameters()), lr=self.config["lr"]
+ )
+
+ dev_best_weight_fn = train_output_path / "dev.best"
+ reranker.load_weights(dev_best_weight_fn, self.optimizer)
+
+ def predict(self, reranker, pred_data, pred_fn):
+ """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn`
+
+ Args:
+ model (Reranker): a PyTorch Reranker
+ pred_data (IterableDataset): data to predict on
+ pred_fn (Path): path to write the prediction run file to
+
+ Returns:
+ TREC Run
+
+ """
+
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ # save to pred_fn
+ model = reranker.model.to(self.device)
+ model.eval()
+
+ preds = {}
+ pred_dataloader = torch.utils.data.DataLoader(pred_data, batch_size=self.config["batch"], pin_memory=True, num_workers=0)
+ with torch.autograd.no_grad():
+ for batch in tqdm(pred_dataloader, desc="Predicting on dev"):
+ if len(batch["qid"]) != self.config["batch"]:
+ batch = self.fill_incomplete_batch(batch)
+
+ batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()}
+ scores = reranker.test(batch)
+ scores = scores.view(-1).cpu().numpy()
+ for qid, docid, score in zip(batch["qid"], batch["posdocid"], scores):
+ # Need to use float16 because pytrec_eval's c function call crashes with higher precision floats
+ preds.setdefault(qid, {})[docid] = score.astype(np.float16).item()
+
+ os.makedirs(os.path.dirname(pred_fn), exist_ok=True)
+ Searcher.write_trec_run(preds, pred_fn)
+
+ return preds
+
+ def fill_incomplete_batch(self, batch):
+ """
+ If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data.
+ How the data is chosen:
+ If the values are just a simple list, use the first element of the list to pad the batch
+ If the values are tensors/numpy arrays, use repeat() along the batch dimension
+ """
+ # logger.debug("filling in an incomplete batch")
+ repeat_times = math.ceil(self.config["batch"] / len(batch["qid"]))
+ diff = self.config["batch"] - len(batch["qid"])
+
+ def pad(v):
+ if isinstance(v, np.ndarray) or torch.is_tensor(v):
+ _v = v.repeat((repeat_times,) + tuple([1 for x in range(len(v.shape) - 1)]))
+ else:
+ _v = v + [v[0]] * diff
+
+ return _v[: self.config["batch"]]
+
+ batch = {k: pad(v) for k, v in batch.items()}
+ return batch
diff --git a/capreolus/trainer/tensorflow.py b/capreolus/trainer/tensorflow.py
new file mode 100644
index 000000000..301e83c89
--- /dev/null
+++ b/capreolus/trainer/tensorflow.py
@@ -0,0 +1,352 @@
+import hashlib
+import os
+import time
+import uuid
+from collections import defaultdict
+
+import tensorflow as tf
+import tensorflow_ranking as tfr
+import numpy as np
+from tqdm import tqdm
+
+from . import Trainer
+from capreolus import ModuleBase, Dependency, ConfigOption, Searcher, constants, evaluator, get_logger
+from capreolus.utils.common import plot_metrics, plot_loss
+
+logger = get_logger(__name__) # pylint: disable=invalid-name
+RESULTS_BASE_PATH = constants["RESULTS_BASE_PATH"]
+
+
+class TrecCheckpointCallback(tf.keras.callbacks.Callback):
+ """
+ A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset.
+ See TensorflowTrainer.train() for the invocation
+ Also saves the best model to disk
+ """
+
+ def __init__(self, qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs):
+ super(TrecCheckpointCallback, self).__init__(*args, **kwargs)
+ """
+ qrels - a qrels dict
+ dev_data - a torch.utils.IterableDataset
+ dev_records - a BatchedDataset instance
+ """
+ self.best_metric = -np.inf
+ self.qrels = qrels
+ self.dev_data = dev_data
+ self.dev_records = dev_records
+ self.output_path = output_path
+ self.iter_start_time = time.time()
+ self.metric = metric
+ self.validate_freq = validate_freq
+ self.relevance_level = relevance_level
+
+ def save_model(self):
+ self.model.save_weights("{0}/dev.best".format(self.output_path))
+
+ def on_epoch_begin(self, epoch, logs=None):
+ self.iter_start_time = time.time()
+
+ def on_epoch_end(self, epoch, logs=None):
+ logger.debug("Epoch {} took {}".format(epoch, time.time() - self.iter_start_time))
+ if (epoch + 1) % self.validate_freq == 0:
+ predictions = self.model.predict(self.dev_records, verbose=1, workers=8, use_multiprocessing=True)
+ trec_preds = self.get_preds_in_trec_format(predictions, self.dev_data)
+ metrics = evaluator.eval_runs(trec_preds, dict(self.qrels), evaluator.DEFAULT_METRICS, self.relevance_level)
+ logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())]))
+
+ if metrics[self.metric] > self.best_metric:
+ self.best_metric = metrics[self.metric]
+ # TODO: Prevent the embedding layer weights from being saved
+ self.save_model()
+
+ @staticmethod
+ def get_preds_in_trec_format(predictions, dev_data):
+ """
+ Takes in a list of predictions and returns a dict that can be fed into pytrec_eval
+ As a side effect, also writes the predictions into a file in the trec format
+ """
+ pred_dict = defaultdict(lambda: dict())
+
+ for i, (qid, docid) in enumerate(dev_data.get_qid_docid_pairs()):
+ # Pytrec_eval has problems with high precision floats
+ pred_dict[qid][docid] = predictions[i][0].astype(np.float16).item()
+
+ return dict(pred_dict)
+
+
+@Trainer.register
+class TensorFlowTrainer(Trainer):
+ module_name = "tensorflow"
+
+ config_spec = [
+ ConfigOption("batch", 32, "batch size"),
+ ConfigOption("niters", 20, "number of iterations to train for"),
+ ConfigOption("itersize", 512, "number of training instances in one iteration"),
+ # ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"),
+ ConfigOption("lr", 0.001, "learning rate"),
+ ConfigOption("loss", "pairwise_hinge_loss", "must be one of tfr.losses.RankingLossKey"),
+ # ConfigOption("fastforward", False),
+ ConfigOption("validatefreq", 1),
+ ConfigOption("boardname", "default"),
+ ConfigOption("usecache", False),
+ ConfigOption("tpuname", None),
+ ConfigOption("tpuzone", None),
+ ConfigOption("storage", None),
+ ]
+ config_keys_not_in_path = ["fastforward", "boardname", "usecache", "tpuname", "tpuzone", "storage"]
+
+ def build(self):
+ tf.random.set_seed(self.config["seed"])
+
+ # Use TPU if available, otherwise resort to GPU/CPU
+ try:
+ self.tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=self.config["tpuname"], zone=self.config["tpuzone"])
+ except ValueError:
+ self.tpu = None
+ logger.info("Could not find the tpu")
+
+ # TPUStrategy for distributed training
+ if self.tpu:
+ logger.info("Utilizing TPUs")
+ tf.config.experimental_connect_to_cluster(self.tpu)
+ tf.tpu.experimental.initialize_tpu_system(self.tpu)
+ self.strategy = tf.distribute.experimental.TPUStrategy(self.tpu)
+ else: # default strategy that works on CPU and single GPU
+ self.strategy = tf.distribute.get_strategy()
+
+ # Defining some props that we will later initialize
+ self.optimizer = None
+ self.loss = None
+ self.validate()
+
+ def validate(self):
+ if self.tpu and any([self.config["storage"] is None, self.config["tpuname"] is None, self.config["tpuzone"] is None]):
+ raise ValueError("storage, tpuname and tpuzone configs must be provided when training on TPU")
+ if self.tpu and self.config["storage"] and not self.config["storage"].startswith("gs://"):
+ raise ValueError("For TPU utilization, the storage config should start with 'gs://'")
+
+ def get_optimizer(self):
+ return tf.keras.optimizers.Adam(learning_rate=self.config["lr"])
+
+ def fastforward_training(self, reranker, weights_path, loss_fn):
+ # TODO: Fix fast forwarding
+ return 0
+
+ def load_best_model(self, reranker, train_output_path):
+ # TODO: Do the train_output_path modification at one place?
+ if self.tpu:
+ train_output_path = "{0}/{1}/{2}".format(
+ self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest()
+ )
+
+ reranker.model.load_weights("{0}/dev.best".format(train_output_path))
+
+ def apply_gradients(self, weights, grads):
+ self.optimizer.apply_gradients(zip(grads, weights))
+
+ def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1):
+ # summary_writer = tf.summary.create_file_writer("{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"]))
+
+ # Because TPUs can't work with local files
+ if self.tpu:
+ train_output_path = "{0}/{1}/{2}".format(
+ self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest()
+ )
+
+ os.makedirs(dev_output_path, exist_ok=True)
+ initial_iter = self.fastforward_training(reranker, dev_output_path, None)
+ logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"])
+
+ strategy_scope = self.strategy.scope()
+ with strategy_scope:
+ train_records = self.get_tf_train_records(reranker, train_dataset)
+ dev_records = self.get_tf_dev_records(reranker, dev_data)
+ trec_callback = TrecCheckpointCallback(
+ qrels,
+ dev_data,
+ dev_records,
+ train_output_path,
+ metric,
+ self.config["validatefreq"],
+ relevance_level=relevance_level,
+ )
+ tensorboard_callback = tf.keras.callbacks.TensorBoard(
+ log_dir="{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"])
+ )
+ reranker.build_model() # TODO needed here?
+
+ self.optimizer = self.get_optimizer()
+ loss = tfr.keras.losses.get(self.config["loss"])
+ reranker.model.compile(optimizer=self.optimizer, loss=loss)
+
+ train_start_time = time.time()
+ reranker.model.fit(
+ train_records.prefetch(tf.data.experimental.AUTOTUNE),
+ epochs=self.config["niters"],
+ steps_per_epoch=self.config["itersize"],
+ callbacks=[tensorboard_callback, trec_callback],
+ workers=8,
+ use_multiprocessing=True,
+ )
+ logger.info("Training took {}".format(time.time() - train_start_time))
+
+ # Skipping dumping metrics and plotting loss since that should be done through tensorboard
+
+ def create_tf_feature(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc):
+ """
+ Creates a single tf.train.Feature instance (i.e, a single sample)
+ """
+ feature = {
+ "qid": tf.train.Feature(bytes_list=tf.train.BytesList(value=[qid.encode("utf-8")])),
+ "query": tf.train.Feature(float_list=tf.train.FloatList(value=query)),
+ "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)),
+ "posdoc_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[posdoc_id.encode("utf-8")])),
+ "posdoc": tf.train.Feature(float_list=tf.train.FloatList(value=posdoc)),
+ }
+
+ if negdoc_id:
+ feature["negdoc_id"] = (tf.train.Feature(bytes_list=tf.train.BytesList(value=[negdoc_id.encode("utf-8")])),)
+ feature["negdoc"] = tf.train.Feature(float_list=tf.train.FloatList(value=negdoc))
+
+ return feature
+
+ def write_tf_record_to_file(self, dir_name, tf_features):
+ """
+ Actually write the tf record to file. The destination can also be a gcs bucket.
+ TODO: Use generators to optimize memory usage
+ """
+ filename = "{0}/{1}.tfrecord".format(dir_name, str(uuid.uuid4()))
+ examples = [tf.train.Example(features=tf.train.Features(feature=feature)) for feature in tf_features]
+
+ if not os.path.isdir(dir_name):
+ os.makedirs(dir_name, exist_ok=True)
+
+ examples = [example.SerializeToString() for example in examples]
+ with tf.io.TFRecordWriter(filename) as writer:
+ for example in examples:
+ writer.write(example)
+
+ logger.info("Wrote tf record file: {}".format(filename))
+
+ return str(filename)
+
+ def convert_to_tf_dev_record(self, reranker, dataset):
+ """
+ Similar to self.convert_to_tf_train_record(), but won't result in multiple files
+ """
+ dir_name = self.get_tf_record_cache_path(dataset)
+
+ tf_features = [reranker.extractor.create_tf_feature(sample) for sample in dataset]
+
+ return [self.write_tf_record_to_file(dir_name, tf_features)]
+
+ def convert_to_tf_train_record(self, reranker, dataset):
+ """
+ Tensorflow works better if the input data is fed in as tfrecords
+ Takes in a dataset, iterates through it, and creates multiple tf records from it.
+ The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature()
+ """
+ dir_name = self.get_tf_record_cache_path(dataset)
+
+ total_samples = dataset.get_total_samples()
+ tf_features = []
+ tf_record_filenames = []
+
+ for niter in tqdm(range(0, self.config["niters"]), desc="Converting data to tf records"):
+ for sample_idx, sample in enumerate(dataset):
+ tf_features.append(reranker.extractor.create_tf_feature(sample))
+
+ if len(tf_features) > 20000:
+ tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features))
+ tf_features = []
+
+ if sample_idx + 1 >= self.config["itersize"] * self.config["batch"]:
+ break
+
+ if len(tf_features):
+ tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features))
+
+ return tf_record_filenames
+
+ def get_tf_record_cache_path(self, dataset):
+ """
+ Get the path to the directory where tf records are written to.
+ If using TPUs, this will be a gcs path.
+ """
+ if self.tpu:
+ return "{0}/capreolus_tfrecords/{1}".format(self.config["storage"], dataset.get_hash())
+ else:
+ base_path = self.get_cache_path()
+ return "{0}/{1}".format(base_path, dataset.get_hash())
+
+ def cache_exists(self, dataset):
+ # TODO: Add checks to make sure that the number of files in the directory is correct
+ cache_dir = self.get_tf_record_cache_path(dataset)
+ logger.info("The cache path is {0} and does it exist? : {1}".format(cache_dir, tf.io.gfile.exists(cache_dir)))
+
+ return tf.io.gfile.isdir(cache_dir)
+
+ def load_tf_records_from_file(self, reranker, filenames, batch_size):
+ raw_dataset = tf.data.TFRecordDataset(filenames)
+ tf_records_dataset = raw_dataset.batch(batch_size, drop_remainder=True).map(
+ reranker.extractor.parse_tf_example, num_parallel_calls=tf.data.experimental.AUTOTUNE
+ )
+
+ return tf_records_dataset
+
+ def load_cached_tf_records(self, reranker, dataset, batch_size):
+ logger.info("Loading TF records from cache")
+ cache_dir = self.get_tf_record_cache_path(dataset)
+ filenames = tf.io.gfile.listdir(cache_dir)
+ filenames = ["{0}/{1}".format(cache_dir, name) for name in filenames]
+
+ return self.load_tf_records_from_file(reranker, filenames, batch_size)
+
+ def get_tf_dev_records(self, reranker, dataset):
+ """
+ 1. Returns tf records from cache (disk) if applicable
+ 2. Else, converts the dataset into tf records, writes them to disk, and returns them
+ """
+ if self.config["usecache"] and self.cache_exists(dataset):
+ return self.load_cached_tf_records(reranker, dataset, 1)
+ else:
+ tf_record_filenames = self.convert_to_tf_dev_record(reranker, dataset)
+ # TODO use actual batch size here. see issue #52
+ return self.load_tf_records_from_file(reranker, tf_record_filenames, 1) # self.config["batch"])
+
+ def get_tf_train_records(self, reranker, dataset):
+ """
+ 1. Returns tf records from cache (disk) if applicable
+ 2. Else, converts the dataset into tf records, writes them to disk, and returns them
+ """
+
+ if self.config["usecache"] and self.cache_exists(dataset):
+ return self.load_cached_tf_records(reranker, dataset, self.config["batch"])
+ else:
+ tf_record_filenames = self.convert_to_tf_train_record(reranker, dataset)
+ return self.load_tf_records_from_file(reranker, tf_record_filenames, self.config["batch"])
+
+ def predict(self, reranker, pred_data, pred_fn):
+ """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn`
+
+ Args:
+ model (Reranker): a PyTorch Reranker
+ pred_data (IterableDataset): data to predict on
+ pred_fn (Path): path to write the prediction run file to
+
+ Returns:
+ TREC Run
+
+ """
+
+ strategy_scope = self.strategy.scope()
+ with strategy_scope:
+ pred_records = self.get_tf_dev_records(reranker, pred_data)
+ predictions = reranker.model.predict(pred_records)
+ trec_preds = TrecCheckpointCallback.get_preds_in_trec_format(predictions, pred_data)
+
+ os.makedirs(os.path.dirname(pred_fn), exist_ok=True)
+ Searcher.write_trec_run(trec_preds, pred_fn)
+
+ return trec_preds
diff --git a/capreolus/trainer/test_trainer.py b/capreolus/trainer/tests/test_trainer.py
similarity index 97%
rename from capreolus/trainer/test_trainer.py
rename to capreolus/trainer/tests/test_trainer.py
index cea4f3443..e22723de0 100644
--- a/capreolus/trainer/test_trainer.py
+++ b/capreolus/trainer/tests/test_trainer.py
@@ -3,7 +3,7 @@
import os
import tensorflow as tf
from capreolus.benchmark import DummyBenchmark
-from capreolus.extractor import EmbedText
+from capreolus.extractor.embedtext import EmbedText
from capreolus.sampler import TrainDataset
from capreolus.trainer import TensorFlowTrainer
diff --git a/docs/cli.md b/docs/cli.md
index d3df61d53..65262cc99 100644
--- a/docs/cli.md
+++ b/docs/cli.md
@@ -1,4 +1,4 @@
-# Command Line Interface
+# Running Pipelines with the CLI
Capreolus provides a command line interface for running experiments using pipelines that are described by `Task` modules. To create a new pipeline, you'll need to create a new `Task` before using the CLI.
Capreolus takes a functional approach to describing an experiment. An experiment is simply a pipeline plus a set of configuration options specifying both classes to use for the pipeline's modules and configuration options associated with each module.
@@ -22,6 +22,7 @@ module type=benchmark
name=antique
name=dummy
name=robust04.yang19
+ name=nf
...
module type=reranker
name=CDSSM
@@ -37,30 +38,33 @@ module type=reranker
.. note:: Results and cached objects are stored in ``~/.capreolus/results/`` and ``~/.capreolus/cache/`` by default. Set the ``CAPREOLUS_RESULTS`` and ``CAPREOLUS_CACHE`` environment variables to change these locations. For example: ``export CAPREOLUS_CACHE=/data/capreolus/cache``
```
-- Use `RankTask` to search for the *robust04* topics in a robust04 index (which will be downloaded if it does not automatically exist), and then evaluate the results. The `Benchmark` specifies a dependency on `collection.name=robust04` and provides the corresponding topics and relevance judgments.
+### RankTask
+- Use `RankTask` with the [NFCorpus](https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/) `Benchmark`. This will download the collection, create an index, and search for the NFCorpus topics. The `Benchmark` specifies a dependency on `collection.name=nf` and provides the corresponding topics and relevance judgments.
```
$ capreolus rank.searcheval with searcher.name=BM25 \
- searcher.index.stemmer=porter benchmark.name=robust04.yang19
+ searcher.index.stemmer=porter benchmark.name=nf
```
-- Use a similar pipeline, but with RM3 query expansion and a small grid search over expansion parameters. The evaluation command will report cross-validated results using the folds specified by `robust04.yang19`.
+- Use a similar pipeline, but with RM3 query expansion and a small grid search over expansion parameters. The evaluation command will report cross-validated results using the folds specified by the `Benchmark`.
```
$ capreolus rank.searcheval with \
- searcher.index.stemmer=porter benchmark.name=robust04.yang19 \
+ searcher.index.stemmer=porter benchmark.name=nf \
searcher.name=BM25RM3 searcher.fbDocs=5-10-15 searcher.fbTerms=5-25-50
```
+### RerankTask
- Use `RerankTask` to run the same `RankTask` pipeline optimized for recall@1000, and then train a `Reranker` optimized for P@20 on the first fold provided by the `Benchmark`. We limit training to two iterations (`niters`) of size `itersize` to keep the training process from taking too long.
```
$ capreolus rerank.traineval with \
- rank.searcher.index.stemmer=porter benchmark.name=robust04.yang19 \
+ rank.searcher.index.stemmer=porter benchmark.name=nf \
rank.searcher.name=BM25RM3 rank.searcher.fbDocs=5-10-15 rank.searcher.fbTerms=5-25-50 \
rank.optimize=recall_1000 reranker.name=KNRM reranker.trainer.niters=2 optimize=P_20
```
-- The `ReRerankTask` demonstrates pipeline flexibility by adding a second reranking step on top of the output from `RerankTask`. Run `capreolus rererank.traineval` to see the configuration options it expects. *(Hint: it consists of a `RankTask` name `rank` as before, followed by a `RerankTask` named `rerank1`, followed by another `RerankTask` named `rerank2`.)*
+### ReRerankTask
+- The `ReRerankTask` demonstrates pipeline flexibility by adding a second reranking step on top of the output from `RerankTask`. Run `capreolus rererank.print_config` to see the configuration options it expects. *(Hint: it consists of a `RankTask` name `rank` as before, followed by a `RerankTask` named `rerank1`, followed by another `RerankTask` named `rerank2`.)*
diff --git a/docs/conf.py b/docs/conf.py
index 97db7c46d..1e2707f0f 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -37,7 +37,7 @@ def get_version(rel_path):
# -- Project information -----------------------------------------------------
project = "Capreolus"
-copyright = "2020"
+copyright = "2020 Andrew Yates"
author = "Andrew Yates"
# The full version, including alpha/beta/rc tags
@@ -189,6 +189,7 @@ def get_version(rel_path):
napoleon_google_docstring = True
+# autoapi_keep_files = True
autoapi_type = "python"
autoapi_dirs = ["../capreolus"]
autoapi_ignore = ["*tests/*", "flycheck_*"]
diff --git a/docs/index.rst b/docs/index.rst
index 315780ebd..8e0b28a00 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -17,6 +17,7 @@ Looking for the code? `Find Capreolus on GitHub. `_
diff --git a/docs/modules.md b/docs/modules.md
new file mode 100644
index 000000000..3a03819b6
--- /dev/null
+++ b/docs/modules.md
@@ -0,0 +1,329 @@
+# Available Modules
+
+The `Benchmark`, `Reranker`, and `Searcher` module types are most often configured by the end user.
+For a complete list of modules, run the command `capreolus modules` or see the API Reference.
+
+```eval_rst
+.. important:: When using Capreolus' configuration system, modules are selected by specifying their ``module_name``.
+ For example, the ``NF`` benchmark can be selected with the ``benchmark.name=nf`` config string or the equivalent config dictionary ``{"benchmark": {"name": "nf"}}``.
+
+ The corresponding class can be created as ``benchmark.nf.NF(config=..., provide=...)`` or created by name with ``Benchmark.create("nf", config=..., provide=...)``.
+```
+
+## Benchmarks
+
+### ANTIQUE
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.antique.ANTIQUE
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+
+### CodeSearchNet
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.codesearchnet.CodeSearchNetCorpus
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.codesearchnet.CodeSearchNetChallenge
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### (TREC) COVID
+
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.covid.COVID
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+
+### Dummy
+
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.dummy.DummyBenchmark
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### NF Corpus
+
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.nf.NF
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### (TREC) Robust04
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.robust04.Robust04
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+```eval_rst
+.. autoapiclass:: capreolus.benchmark.robust04.Robust04Yang19
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+## Searchers
+
+```eval_rst
+.. note:: Some searchers (e.g., BM25) automatically perform a cross-validated grid search when their parameters are provided as lists. For example, ``searcher.b=0.4,0.6,0.8 searcher.k1=1.0,1.5``.
+```
+
+### BM25
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.BM25
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### BM25 with Axiomatic expansion
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.AxiomaticSemanticMatching
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### BM25 with RM3 expansion
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.BM25RM3
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+
+### BM25 PRF
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.BM25PRF
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+
+### F2Exp
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.F2Exp
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### F2Log
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.F2Log
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### I(n)L2
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.INL2
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### QL with Dirichlet smoothing
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.DirichletQL
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### QL with J-M smoothing
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.QLJM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### SDM
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.SDM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### SPL
+```eval_rst
+.. autoapiclass:: capreolus.searcher.anserini.SPL
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+## Rerankers
+
+```eval_rst
+.. note:: Rerankers are implemented in PyTorch or TensorFlow. Rerankers with TensorFlow implementations can run on both GPUs and TPUs.
+```
+
+
+### CDSSM
+```eval_rst
+.. autoapiclass:: capreolus.reranker.CDSSM.CDSSM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### ConvKNRM
+```eval_rst
+.. autoapiclass:: capreolus.reranker.ConvKNRM.ConvKNRM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### DRMM
+```eval_rst
+.. autoapiclass:: capreolus.reranker.DRMM.DRMM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### DRMMTKS
+```eval_rst
+.. autoapiclass:: capreolus.reranker.DRMMTKS.DRMMTKS
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### DSSM
+```eval_rst
+.. autoapiclass:: capreolus.reranker.DSSM.DSSM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### DUET
+```eval_rst
+.. autoapiclass:: capreolus.reranker.DUET.DUET
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### DeepTileBars
+```eval_rst
+.. autoapiclass:: capreolus.reranker.DeepTileBar.DeepTileBar
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### HiNT
+```eval_rst
+.. autoapiclass:: capreolus.reranker.HINT.HINT
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### KNRM
+```eval_rst
+.. autoapiclass:: capreolus.reranker.KNRM.KNRM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### PACRR
+```eval_rst
+.. autoapiclass:: capreolus.reranker.PACRR.PACRR
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### POSITDRMM
+```eval_rst
+.. autoapiclass:: capreolus.reranker.POSITDRMM.POSITDRMM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### TK
+```eval_rst
+.. autoapiclass:: capreolus.reranker.TK.TK
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### TensorFlow KNRM
+```eval_rst
+.. autoapiclass:: capreolus.reranker.TFKNRM.TFKNRM
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
+### TensorFlow VanillaBERT
+```eval_rst
+.. autoapiclass:: capreolus.reranker.TFVanillaBert.TFVanillaBERT
+ :noindex:
+
+ .. autoapiattribute:: module_name
+ :noindex:
+```
+
diff --git a/docs/quick.md b/docs/quick.md
index 793f3d3ca..5a6fa1081 100644
--- a/docs/quick.md
+++ b/docs/quick.md
@@ -9,27 +9,44 @@
## Command Line Interface
-Use the `RankTask` pipeline to rank documents using a `Searcher` on an [Anserini](https://anserini.io) `Index` built on robust04. (The index will be automatically downloaded if `benchmark.collection.path` is invalid.)
+Use the `RankTask` pipeline to rank documents using a `Searcher` on an [Anserini](https://anserini.io) `Index` built on [NFCorpus](https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/), which contains biomedical documents and queries. NFCorpus was published by Boteva et al. in ECIR 2016. This dataset is publicly available and will be automatically downloaded by Capreolus.
+
+```bash
+$ capreolus rank.searcheval with benchmark.name=nf \
+ searcher.name=BM25 searcher.index.stemmer=porter searcher.b=0.8
```
-$ capreolus rank.searcheval with searcher.name=BM25 \
- searcher.index.stemmer=porter searcher.b=0.8 \
- benchmark.name=robust04.yang19 benchmark.collection.path=/path/to/trec45
+
+The `searcheval` command instructs `RankTask` to query NFCorpus and evaluate the Searcher's performance on NFCorpus' test queries. The command will output results like this:
+```bash
+INFO - capreolus.task.rank.evaluate - rank: fold=s1 best run: ...searcher-BM25_b-0.8_fields-title_hits-1000_k1-0.9/task-rank_filter-False/searcher
+INFO - capreolus.task.rank.evaluate - rank: cross-validated results when optimizing for 'map':
+INFO - capreolus.task.rank.evaluate - map: 0.1361
+INFO - capreolus.task.rank.evaluate - ndcg_cut_10: 0.2906
+...
+```
+
+These results are comparable with the *all titles* results in the [NFCorpus paper](https://www.cl.uni-heidelberg.de/~riezler/publications/papers/ECIR2016.pdf), which reports a MAP of 0.1251 for BM25 (Table 2). The Benchmark's ``fields`` config option can be used to issue other types of queries as well (e.g., ``benchmark.fields=all_fields``).
+
+```eval_rst
+.. important:: Capreolus Benchmarks define *folds* to use; each fold specifies training, dev (validation), and test queries.
+ Tasks respect these folds when calculating metrics.
+ NFCorpus defines a fixed test set, which corresponds to having a *single fold* in Capreolus.
+ When running a benchmark that uses multiple folds with cross-validation, like *robust04*, the results reported are averaged over the benchmark's test sets.
```
## Python API
Let's run the same pipeline using the Python API:
```python
-from capreolus.task.rank import RankTask
+from capreolus.task import RankTask
task = RankTask({'searcher': {'name': 'BM25', 'index': {'stemmer': 'porter'}, 'b': '0.8'},
- 'benchmark': {'name': 'robust04.yang19',
- 'collection': {'path': '/path/to/trec45'}}})
+ 'benchmark': {'name': 'nf'}})
task.searcheval()
```
```eval_rst
-.. note:: The ``capreolus.parse_config_string`` convenience method can transform a config string like ``searcher.name=BM25 benchmark.name=robust04.yang`` into a config dict as shown above.
+.. note:: The ``capreolus.parse_config_string`` convenience method can transform a config string like ``searcher.name=BM25 benchmark.name=nf`` into a config dict as shown above.
```
@@ -43,34 +60,38 @@ Capreolus pipelines are composed of self-contained modules corresponding to "IR
RankTask
declares dependencies on a Searcher
module and a Benchmark
module, which it uses to query a document collection and to obtain experimental data (i.e., topics, relevance judgments, and folds), respectively. The Searcher
depends on an Index
. Both the Index
and Benchmark
depend on a Collection
. In this example, RankTask
requires that the same Collection
be provided to both.
+```python
+from capreolus import Benchmark, Collection, Index, Searcher
+```
+
Let's construct this graph one module at a time.
```python
-# Previously, the Benchmark specified a dependency on the 'robust04' collection specifically.
-# Now we specify "robust04" ourselves.
->>> collection = Collection.create("robust04", config={'path': '/path/to/trec45'})
+# Previously, the Benchmark specified a dependency on the 'nf' collection specifically.
+# Now we create this Collection directly.
+>>> collection = Collection.create("nf")
>>> collection.get_path_and_types()
- ("/path/to/trec45", "TrecCollection", "DefaultLuceneDocumentGenerator")
+ ("/path/to/collection-nf/documents", "TrecCollection", "DefaultLuceneDocumentGenerator")
# Next, create a Benchmark and pass it the collection object directly.
# This is an alternative to automatically creating the collection as a dependency.
->>> benchmark = Benchmark.create("robust04.yang19", provide={'collection': collection})
->>> benchmark.topics
- {'title': {'301': 'International Organized Crime', '302': 'Poliomyelitis and Post-Polio', ... }
+>>> benchmark = Benchmark.create("nf", provide={'collection': collection})
+>>> benchmark.topics["title"]
+ {'56': 'foods for glaucoma', '68': 'what is actually in chicken nuggets', ... }
```
Next, we can build `Index` and `Searcher`. These module types do more than just pointing to data.
```python
>>> index = Index.create("anserini", {"stemmer": "porter"}, provide={"collection": collection})
>>> index.create_index() # returns immediately if the index already exists
->>> index.get_df("organized")
+>>> index.get_df("foods")
0
->>> index.get_df("organiz")
-3048
+>>> index.get_df("food")
+1011
# Next, a Searcher to query the index
>>> searcher = Searcher.create("BM25", {"hits": 3}, provide={"index": index})
->>> searcher.query("organized")
-OrderedDict([('FBIS4-2046', 4.867800235748291),
- ('FBIS3-2553', 4.822000026702881),
- ('FBIS3-23578', 4.754199981689453)])
+>>> searcher.query("foods")
+OrderedDict([('MED-1761', 1.213),
+ ('MED-2742', 1.212),
+ ('MED-1046', 1.2058)])
```
Finally, we can emulate the `RankTask.search()` method we called earlier:
@@ -82,6 +103,7 @@ Finally, we can emulate the `RankTask.search()` method we called earlier:
To get metrics, we could then pass `results` to `capreolus.evaluator.eval_runs()`:
```eval_rst
.. autoapifunction:: capreolus.evaluator.eval_runs
+ :noindex:
```
@@ -90,9 +112,9 @@ To get metrics, we could then pass `results` to `capreolus.evaluator.eval_runs()
Capreolus modules implement the Capreolus module API plus an API specific to the module type.
The module API consists of four attributes:
- `module_type`: a string indicating the module's type, like "index" or "benchmark"
-- `module_name`: a string indicating the module's name, like "anserini" or "robust04.yang19"
-- `config_spec`: a list of `ConfigOption` objects, for example, `ConfigOption("stemmer", default_value="none", description="stemmer to use")`
-- `dependencies` a list of `Dependency` objects; for example, `Dependency(key="collection", module="collection", name="robust04")`
+- `module_name`: a string indicating the module's name, like "anserini" or "nf"
+- `config_spec`: a list of `ConfigOption` objects. For example, `[ConfigOption("stemmer", default_value="none", description="stemmer to use")]`
+- `dependencies` a list of `Dependency` objects. For example, `[Dependency(key="collection", module="collection", name="nf")]`
When the module is created, any dependencies that are not explicitly passed with `provide={key: object}` are automatically created. The module's config options in `config_spec` and those of its dependencies are exposed as Capreolus configuration options.
@@ -104,10 +126,8 @@ The `Task` module API specifies two additional class attributes: `commands` and
Let's create a new task that mirrors the graph we constructed manually, except with two separate `Searcher` objects. We'll save the results from both searchers and measure their effectiveness on the validation queries to decide which searcher to report test set results on.
```python
-from capreolus import evaluator, Dependency, ConfigOption
-from capreolus.searcher import Searcher
+from capreolus import evaluator, get_logger, Dependency, ConfigOption
from capreolus.task import Task
-from capreolus.utils.loginit import get_logger
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -118,7 +138,7 @@ class TutorialTask(Task):
config_spec = [ConfigOption("optimize", "map", "metric to maximize on the validation set")]
dependencies = [
Dependency(
- key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"]
+ key="benchmark", module="benchmark", name="nf", provide_this=True, provide_children=["collection"]
),
Dependency(key="searcher1", module="searcher", name="BM25RM3"),
Dependency(key="searcher2", module="searcher", name="SDM"),
@@ -153,4 +173,34 @@ class TutorialTask(Task):
return best_results
```
-
+
+```eval_rst
+.. note:: The module needs to be registered in order for Capreolus to find it. Registration happens when the ``@Task.register`` decorator is applied, so no additional steps are needed to use the new Task via the Python API. When using the Task via the CLI, the ``tutorial.py`` file containing it needs to be imported in order for the Task to be registered. This can be accomplished by placing the file inside the ``capreolus.tasks`` package (see ``capreolus.task.__path__``). However, in this case, the above Task is already provided with Capreolus as ``tasks/tutorial.py``.
+```
+
+Let's try running the Task we just declared via the Python API.
+
+```python
+>>> task = TutorialTask()
+>>> results = task.run()
+>>> results['score']['map']
+0.14798308699242727
+```
+
+### Module APIs
+Each module type's base class describes the module API that should be implemented to create new modules of that type.
+Check out the API documentation to learn more:
+Benchmark,
+Collection,
+Extractor,
+Index,
+Reranker,
+Searcher,
+Task,
+Tokenizer, and
+Trainer.
+
+
+## Next Steps
+- Learn more about [running pipelines using the command line interface](cli.md)
+- View what [Capreolus modules](modules.md) are available