diff --git a/capreolus/__init__.py b/capreolus/__init__.py index 9056e08db..ac644fe35 100644 --- a/capreolus/__init__.py +++ b/capreolus/__init__.py @@ -2,9 +2,9 @@ import os from pathlib import Path -from profane import ConfigOption, Dependency, constants, config_list_to_dict +from profane import ConfigOption, Dependency, ModuleBase, constants, config_list_to_dict, module_registry -__version__ = "0.2.1" +__version__ = "0.2.2" # specify a base package that we should look for modules under (e.g., .task) # constants must be specified before importing Task (or any other modules!) @@ -21,16 +21,18 @@ jnius_config.set_classpath(Anserini.get_fat_jar()) -# import capreolus.evaluator as evaluator -# from capreolus.benchmark import Benchmark -# from capreolus.collection import Collection -# from capreolus.extractor import Extractor -# from capreolus.index import Index -# from capreolus.reranker.base import Reranker -# from capreolus.searcher import Searcher -# from capreolus.task.base import Task -# from capreolus.tokenizer import Tokenizer -# from capreolus.trainer import Trainer +# note: order is important to avoid circular imports +from capreolus.utils.loginit import get_logger +import capreolus.evaluator as evaluator +from capreolus.benchmark import Benchmark +from capreolus.collection import Collection +from capreolus.index import Index +from capreolus.searcher import Searcher +from capreolus.extractor import Extractor +from capreolus.reranker import Reranker +from capreolus.tokenizer import Tokenizer +from capreolus.trainer import Trainer +from capreolus.task import Task def parse_config_string(s): diff --git a/capreolus/benchmark/__init__.py b/capreolus/benchmark/__init__.py index b7db5a9ef..20f2dac29 100644 --- a/capreolus/benchmark/__init__.py +++ b/capreolus/benchmark/__init__.py @@ -1,37 +1,27 @@ -from profane import import_all_modules - import json -import re -import os -import gzip -import pickle - -from tqdm import tqdm -from zipfile import ZipFile -from pathlib import Path -from collections import defaultdict -from bs4 import BeautifulSoup -from profane import ModuleBase, Dependency, ConfigOption, constants - -from capreolus.utils.loginit import get_logger -from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt -from capreolus.utils.common import download_file, remove_newline, get_udel_query_expander -logger = get_logger(__name__) -PACKAGE_PATH = constants["PACKAGE_PATH"] +from capreolus import ModuleBase +from capreolus.utils.trec import load_qrels, load_trec_topics class Benchmark(ModuleBase): - """the module base class""" + """Base class for Benchmark modules. The purpose of a Benchmark is to provide the data needed to run an experiment, such as queries, folds, and relevance judgments. + + Modules should provide: + - a ``topics`` dict mapping query ids (*qids*) to *queries* + - a ``qrels`` dict mapping *qids* to *docids* and *relevance labels* + - a ``folds`` dict mapping a fold name to *training*, *dev* (validation), and *testing* qids + - if these can be loaded from files in standard formats, they can be specified by setting the ``topic_file``, ``qrel_file``, and ``fold_file``, respectively, rather than by setting the above attributes directly + """ module_type = "benchmark" qrel_file = None topic_file = None fold_file = None query_type = None - # documents with a relevance label >= relevance_level will be considered relevant - # corresponds to trec_eval's --level_for_rel (and passed to pytrec_eval as relevance_level) relevance_level = 1 + """ Documents with a relevance label >= relevance_level will be considered relevant. + This corresponds to trec_eval's --level_for_rel (and is passed to pytrec_eval as relevance_level). """ @property def qrels(self): @@ -52,546 +42,8 @@ def folds(self): return self._folds -@Benchmark.register -class DummyBenchmark(Benchmark): - module_name = "dummy" - dependencies = [Dependency(key="collection", module="collection", name="dummy")] - qrel_file = PACKAGE_PATH / "data" / "qrels.dummy.txt" - topic_file = PACKAGE_PATH / "data" / "topics.dummy.txt" - fold_file = PACKAGE_PATH / "data" / "dummy_folds.json" - query_type = "title" - - -@Benchmark.register -class WSDM20Demo(Benchmark): - """ Robust04 benchmark equivalent to robust04.yang19 """ - - module_name = "wsdm20demo" - dependencies = [Dependency(key="collection", module="collection", name="robust04")] - qrel_file = PACKAGE_PATH / "data" / "qrels.robust2004.txt" - topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt" - fold_file = PACKAGE_PATH / "data" / "rob04_yang19_folds.json" - query_type = "title" - - -@Benchmark.register -class Robust04Yang19(Benchmark): - """Robust04 benchmark using the folds from Yang et al. [1] - - [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. 2019. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019. - """ - - module_name = "robust04.yang19" - dependencies = [Dependency(key="collection", module="collection", name="robust04")] - qrel_file = PACKAGE_PATH / "data" / "qrels.robust2004.txt" - topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt" - fold_file = PACKAGE_PATH / "data" / "rob04_yang19_folds.json" - query_type = "title" - - -@Benchmark.register -class NF(Benchmark): - """ A Full-Text Learning to Rank Dataset for Medical Information Retrieval [1] - - [1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval Proceedings of the 38th European Conference on Information Retrieval (ECIR), Padova, Italy, 2016 - """ - - module_name = "nf" - dependencies = [Dependency(key="collection", module="collection", name="nf")] - config_spec = [ - ConfigOption(key="labelrange", default_value="0-2", description="range of dataset qrels, options: 0-2, 1-3"), - ConfigOption( - key="fields", - default_value="all_fields", - description="query fields included in topic file, " - "options: 'all_fields', 'all_titles', 'nontopics', 'vid_title', 'vid_desc'", - ), - ] - - fold_file = PACKAGE_PATH / "data" / "nf.json" - - query_type = "title" - - def __init__(self, config, provide, share_dependency_objects): - super().__init__(config, provide, share_dependency_objects) - fields, label_range = self.config["fields"], self.config["labelrange"] - self.field2kws = { - "all_fields": ["all"], - "nontopics": ["nontopic-titles"], - "vid_title": ["vid-titles"], - "vid_desc": ["vid-desc"], - "all_titles": ["nontopic-titles", "vid-titles", "nontopic-titles"], - } - self.labelrange2kw = {"0-2": "2-1-0", "1-3": "3-2-1"} - - if fields not in self.field2kws: - raise ValueError(f"Unexpected fields value: {fields}, expect: {', '.join(self.field2kws.keys())}") - if label_range not in self.labelrange2kw: - raise ValueError(f"Unexpected label range: {label_range}, expect: {', '.join(self.field2kws.keys())}") - - self.qrel_file = PACKAGE_PATH / "data" / f"qrels.nf.{label_range}.txt" - self.test_qrel_file = PACKAGE_PATH / "data" / f"test.qrels.nf.{label_range}.txt" - self.topic_file = PACKAGE_PATH / "data" / f"topics.nf.{fields}.txt" - self.download_if_missing() - - def _transform_qid(self, raw): - """ NFCorpus dataset specific, remove prefix in query id since anserini convert all qid to integer """ - return raw.replace("PLAIN-", "") - - def download_if_missing(self): - if all([f.exists() for f in [self.topic_file, self.fold_file, self.qrel_file]]): - return - - tmp_corpus_dir = self.collection.download_raw() - topic_f = open(self.topic_file, "w", encoding="utf-8") - qrel_f = open(self.qrel_file, "w", encoding="utf-8") - test_qrel_f = open(self.test_qrel_file, "w", encoding="utf-8") - - set_names = ["train", "dev", "test"] - folds = {s: set() for s in set_names} - qrel_kw = self.labelrange2kw[self.config["labelrange"]] - for set_name in set_names: - with open(tmp_corpus_dir / f"{set_name}.{qrel_kw}.qrel") as f: - for line in f: - line = self._transform_qid(line) - qid = line.strip().split()[0] - folds[set_name].add(qid) - if set_name == "test": - test_qrel_f.write(line) - qrel_f.write(line) - - files = [tmp_corpus_dir / f"{set_name}.{keyword}.queries" for keyword in self.field2kws[self.config["fields"]]] - qids2topics = self._align_queries(files, "title") - - for qid, txts in qids2topics.items(): - topic_f.write(topic_to_trectxt(qid, txts["title"])) - - json.dump( - {"s1": {"train_qids": list(folds["train"]), "predict": {"dev": list(folds["dev"]), "test": list(folds["test"])}}}, - open(self.fold_file, "w"), - ) - - topic_f.close() - qrel_f.close() - test_qrel_f.close() - logger.info(f"nf benchmark prepared") - - def _align_queries(self, files, field, qid2queries=None): - if not qid2queries: - qid2queries = {} - for fn in files: - with open(fn, "r", encoding="utf-8") as f: - for line in f: - qid, txt = line.strip().split("\t") - qid = self._transform_qid(qid) - txt = " ".join(re.sub("[^A-Za-z]", " ", txt).split()[:1020]) - if qid not in qid2queries: - qid2queries[qid] = {field: txt} - else: - if field in qid2queries[qid]: - logger.warning(f"Overwriting title for query {qid}") - qid2queries[qid][field] = txt - return qid2queries - - -@Benchmark.register -class ANTIQUE(Benchmark): - """A Non-factoid Question Answering Benchmark from Hashemi et al. [1] - - [1] Helia Hashemi, Mohammad Aliannejadi, Hamed Zamani, and W. Bruce Croft. 2020. ANTIQUE: A non-factoid question answering benchmark. ECIR 2020. - """ - - module_name = "antique" - dependencies = [Dependency(key="collection", module="collection", name="antique")] - qrel_file = PACKAGE_PATH / "data" / "qrels.antique.txt" - topic_file = PACKAGE_PATH / "data" / "topics.antique.txt" - fold_file = PACKAGE_PATH / "data" / "antique.json" - query_type = "title" - relevance_level = 2 - - -@Benchmark.register -class MSMarcoPassage(Benchmark): - module_name = "msmarcopassage" - dependencies = [Dependency(key="collection", module="collection", name="msmarco")] - qrel_file = PACKAGE_PATH / "data" / "qrels.msmarcopassage.txt" - topic_file = PACKAGE_PATH / "data" / "topics.msmarcopassage.txt" - fold_file = PACKAGE_PATH / "data" / "msmarcopassage.folds.json" - query_type = "title" - - -@Benchmark.register -class CodeSearchNetCorpus(Benchmark): - module_name = "codesearchnet_corpus" - dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")] - url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2" - query_type = "title" - - file_fn = PACKAGE_PATH / "data" / "csn_corpus" - - qrel_dir = file_fn / "qrels" - topic_dir = file_fn / "topics" - fold_dir = file_fn / "folds" - - qidmap_dir = file_fn / "qidmap" - docidmap_dir = file_fn / "docidmap" - - config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")] - - def build(self): - lang = self.config["lang"] - - self.qid_map_file = self.qidmap_dir / f"{lang}.json" - self.docid_map_file = self.docidmap_dir / f"{lang}.json" - - self.qrel_file = self.qrel_dir / f"{lang}.txt" - self.topic_file = self.topic_dir / f"{lang}.txt" - self.fold_file = self.fold_dir / f"{lang}.json" - - for file in [var for var in vars(self) if var.endswith("file")]: - getattr(self, file).parent.mkdir(exist_ok=True, parents=True) - - self.download_if_missing() - - @property - def qid_map(self): - if not hasattr(self, "_qid_map"): - if not self.qid_map_file.exists(): - self.download_if_missing() - - self._qid_map = json.load(open(self.qid_map_file, "r")) - return self._qid_map - - @property - def docid_map(self): - if not hasattr(self, "_docid_map"): - if not self.docid_map_file.exists(): - self.download_if_missing() - - self._docid_map = json.load(open(self.docid_map_file, "r")) - return self._docid_map - - def download_if_missing(self): - files = [self.qid_map_file, self.docid_map_file, self.qrel_file, self.topic_file, self.fold_file] - if all([f.exists() for f in files]): - return - - lang = self.config["lang"] - - tmp_dir = Path("/tmp") - zip_fn = tmp_dir / f"{lang}.zip" - if not zip_fn.exists(): - download_file(f"{self.url}/{lang}.zip", zip_fn) - - with ZipFile(zip_fn, "r") as zipobj: - zipobj.extractall(tmp_dir) - - # prepare docid-url mapping from dedup.pkl - pkl_fn = tmp_dir / f"{lang}_dedupe_definitions_v2.pkl" - doc_objs = pickle.load(open(pkl_fn, "rb")) - self._docid_map = self._prep_docid_map(doc_objs) - assert self._get_n_docid() == len(doc_objs) - - # prepare folds, qrels, topics, docstring2qid # TODO: shall we add negative samples? - qrels, self._qid_map = defaultdict(dict), {} - qids = {s: [] for s in ["train", "valid", "test"]} - - topic_file = open(self.topic_file, "w", encoding="utf-8") - qrel_file = open(self.qrel_file, "w", encoding="utf-8") - - def gen_doc_from_gzdir(dir): - """ generate parsed dict-format doc from all jsonl.gz files under given directory """ - for fn in sorted(dir.glob("*.jsonl.gz")): - f = gzip.open(fn, "rb") - for doc in f: - yield json.loads(doc) - - for set_name in qids: - set_path = tmp_dir / lang / "final" / "jsonl" / set_name - for doc in gen_doc_from_gzdir(set_path): - code = remove_newline(" ".join(doc["code_tokens"])) - docstring = remove_newline(" ".join(doc["docstring_tokens"])) - n_words_in_docstring = len(docstring.split()) - if n_words_in_docstring >= 1024: - logger.warning( - f"chunk query to first 1000 words otherwise TooManyClause would be triggered " - f"at lucene at search stage, " - ) - docstring = " ".join(docstring.split()[:1020]) # for TooManyClause - - docid = self.get_docid(doc["url"], code) - qid = self._qid_map.get(docstring, str(len(self._qid_map))) - qrel_file.write(f"{qid} Q0 {docid} 1\n") - - if docstring not in self._qid_map: - self._qid_map[docstring] = qid - qids[set_name].append(qid) - topic_file.write(topic_to_trectxt(qid, docstring)) - - topic_file.close() - qrel_file.close() - - # write to qid_map.json, docid_map, fold.json - json.dump(self._qid_map, open(self.qid_map_file, "w")) - json.dump(self._docid_map, open(self.docid_map_file, "w")) - json.dump( - {"s1": {"train_qids": qids["train"], "predict": {"dev": qids["valid"], "test": qids["test"]}}}, - open(self.fold_file, "w"), - ) - - def _prep_docid_map(self, doc_objs): - """ - construct a nested dict to map each doc into a unique docid - which follows the structure: {url: {" ".join(code_tokens): docid, ...}} - - For all the lanugage datasets the url uniquely maps to a code_tokens yet it's not the case for but js and php - which requires a second-level mapping from raw_doc to docid - - :param doc_objs: a list of dict having keys ["nwo", "url", "sha", "identifier", "arguments" - "function", "function_tokens", "docstring", "doctring_tokens",], - :return: - """ - # TODO: any way to avoid the twice traversal of all url and make the return dict structure consistent - lang = self.config["lang"] - url2docid = defaultdict(dict) - for i, doc in tqdm(enumerate(doc_objs), desc=f"Preparing the {lang} docid_map"): - url, code_tokens = doc["url"], remove_newline(" ".join(doc["function_tokens"])) - url2docid[url][code_tokens] = f"{lang}-FUNCTION-{i}" - - # remove the code_tokens for the unique url-docid mapping - for url, docids in tqdm(url2docid.items(), desc=f"Compressing the {lang} docid_map"): - url2docid[url] = list(docids.values()) if len(docids) == 1 else docids # {code_tokens: docid} -> [docid] - return url2docid - - def _get_n_docid(self): - """ calculate the number of document ids contained in the nested docid map """ - lens = [len(docs) for url, docs in self._docid_map.items()] - return sum(lens) - - def get_docid(self, url, code_tokens): - """ retrieve the doc id according to the doc dict """ - docids = self.docid_map[url] - return docids[0] if len(docids) == 1 else docids[code_tokens] - - -@Benchmark.register -class CodeSearchNetChallenge(Benchmark): - """ - CodeSearchNetChallenge can only be used for training but not for evaluation since qrels is not provided - """ - - module_name = "codesearchnet_challenge" - dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")] - config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")] - - url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/queries.csv" - query_type = "title" - - file_fn = PACKAGE_PATH / "data" / "csn_challenge" - topic_file = file_fn / "topics.txt" - qid_map_file = file_fn / "qidmap.json" - - def download_if_missing(self): - """ download query.csv and prepare queryid - query mapping file """ - if self.topic_file.exists() and self.qid_map_file.exists(): - return - - tmp_dir = Path("/tmp") - tmp_dir.mkdir(exist_ok=True, parents=True) - self.file_fn.mkdir(exist_ok=True, parents=True) - - query_fn = tmp_dir / f"query.csv" - if not query_fn.exists(): - download_file(self.url, query_fn) - - # prepare qid - query - qid_map = {} - topic_file = open(self.topic_file, "w", encoding="utf-8") - query_file = open(query_fn) - for qid, line in enumerate(query_file): - if qid != 0: # ignore the first line "query" - topic_file.write(topic_to_trectxt(qid, line.strip())) - qid_map[qid] = line - topic_file.close() - json.dump(qid_map, open(self.qid_map_file, "w")) - - -@Benchmark.register -class COVID(Benchmark): - """ Ongoing TREC-COVID bechmark from https://ir.nist.gov/covidSubmit """ - - module_name = "covid" - dependencies = [Dependency(key="collection", module="collection", name="covid")] - data_dir = PACKAGE_PATH / "data" / "covid" - topic_url = "https://ir.nist.gov/covidSubmit/data/topics-rnd%d.xml" - qrel_url = "https://ir.nist.gov/covidSubmit/data/qrels-rnd%d.txt" - lastest_round = 3 - - config_spec = [ - ConfigOption("round", 3, "TREC-COVID round to use"), - ConfigOption("udelqexpand", False), - ConfigOption("excludeknown", True), - ] - - def build(self): - if self.config["round"] == self.lastest_round and not self.config["excludeknown"]: - logger.warning(f"No evaluation can be done for the lastest round in exclude-known mode") - - data_dir = self.get_cache_path() / "documents" - data_dir.mkdir(exist_ok=True, parents=True) - - self.qrel_ignore = f"{data_dir}/ignore.qrel.txt" - self.qrel_file = f"{data_dir}/qrel.txt" - self.topic_file = f"{data_dir}/topic.txt" - self.fold_file = f"{data_dir}/fold.json" - - self.download_if_missing() - - def download_if_missing(self): - if all([os.path.exists(fn) for fn in [self.qrel_file, self.qrel_ignore, self.topic_file, self.fold_file]]): - return - - rnd_i, excludeknown = self.config["round"], self.config["excludeknown"] - if rnd_i > self.lastest_round: - raise ValueError(f"round {rnd_i} is unavailable") - - logger.info(f"Preparing files for covid round-{rnd_i}") - - topic_url = self.topic_url % rnd_i - qrel_ignore_urls = [self.qrel_url % i for i in range(1, rnd_i)] # download all the qrels before current run - - # topic file - tmp_dir = Path("/tmp") - topic_tmp = tmp_dir / f"topic.round.{rnd_i}.xml" - if not os.path.exists(topic_tmp): - download_file(topic_url, topic_tmp) - all_qids = self.xml2trectopic(topic_tmp) # will update self.topic_file - - if excludeknown: - qrel_fn = open(self.qrel_file, "w") - for i, qrel_url in enumerate(qrel_ignore_urls): - qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1) - if not os.path.exists(qrel_tmp): - download_file(qrel_url, qrel_tmp) - with open(qrel_tmp) as f: - for line in f: - qrel_fn.write(line) - qrel_fn.close() - - f = open(self.qrel_ignore, "w") # empty ignore file - f.close() - else: - qrel_fn = open(self.qrel_ignore, "w") - for i, qrel_url in enumerate(qrel_ignore_urls): - qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1) - if not os.path.exists(qrel_tmp): - download_file(qrel_url, qrel_tmp) - with open(qrel_tmp) as f: - for line in f: - qrel_fn.write(line) - qrel_fn.close() - - if rnd_i == self.lastest_round: - f = open(self.qrel_file, "w") - f.close() - else: - with open(tmp_dir / f"qrel-{rnd_i}") as fin, open(self.qrel_file, "w") as fout: - for line in fin: - fout.write(line) - - # folds: use all labeled query for train, valid, and use all of them for test set - labeled_qids = list(load_qrels(self.qrel_ignore).keys()) - folds = {"s1": {"train_qids": labeled_qids, "predict": {"dev": labeled_qids, "test": all_qids}}} - json.dump(folds, open(self.fold_file, "w")) - - def xml2trectopic(self, xmlfile): - with open(xmlfile, "r") as f: - topic = f.read() - - all_qids = [] - soup = BeautifulSoup(topic, "lxml") - topics = soup.find_all("topic") - expand_query = get_udel_query_expander() - - with open(self.topic_file, "w") as fout: - for topic in topics: - qid = topic["number"] - title = topic.find_all("query")[0].text.strip() - desc = topic.find_all("question")[0].text.strip() - narr = topic.find_all("narrative")[0].text.strip() - - if self.config["udelqexpand"]: - title = expand_query(title, rm_sw=True) - desc = expand_query(desc, rm_sw=False) - - title = title + " " + desc - desc = " " - - topic_line = topic_to_trectxt(qid, title, desc=desc, narr=narr) - fout.write(topic_line) - all_qids.append(qid) - return all_qids - - -@Benchmark.register -class CovidQA(Benchmark): - module_name = "covidqa" - dependencies = [Dependency(key="collection", module="collection", name="covidqa")] - url = "https://raw.githubusercontent.com/castorini/pygaggle/master/data/kaggle-lit-review-%s.json" - available_versions = ["0.1", "0.2"] - - datadir = PACKAGE_PATH / "data" / "covidqa" - - config_spec = [ConfigOption("version", "0.1+0.2")] - - def build(self): - os.makedirs(self.datadir, exist_ok=True) - - version = self.config["version"] - self.qrel_file = self.datadir / f"qrels.v{version}.txt" - self.topic_file = self.datadir / f"topics.v{version}.txt" - self.fold_file = self.datadir / f"v{version}.json" # HOW TO SPLIT THE FOLD HERE? - - self.download_if_missing() - - def download_if_missing(self): - if all([os.path.exists(f) for f in [self.qrel_file, self.topic_file, self.fold_file]]): - return - - tmp_dir = Path("/tmp") - topic_f = open(self.topic_file, "w", encoding="utf-8") - qrel_f = open(self.qrel_file, "w", encoding="utf-8") - - all_qids = [] - qid = 2001 # to distingsuish queries here from queries in TREC-covid - versions = self.config["version"].split("+") if isinstance(self.config["version"], str) else str(self.config["version"]) - for v in versions: - if v not in self.available_versions: - vs = " ".join(self.available_versions) - logger.warning(f"Invalid version {v}, should be one of {vs}") - continue - - url = self.url % v - target_fn = tmp_dir / f"covidqa-v{v}.json" - if not os.path.exists(target_fn): - download_file(url, target_fn) - qa = json.load(open(target_fn)) - for subcate in qa["categories"]: - name = subcate["name"] - - for qa in subcate["sub_categories"]: - nq_name, kq_name = qa["nq_name"], qa["kq_name"] - query_line = topic_to_trectxt(qid, kq_name, nq_name) # kq_name == "query", nq_name == "question" - topic_f.write(query_line) - for ans in qa["answers"]: - docid = ans["id"] - qrel_f.write(f"{qid} Q0 {docid} 1\n") - all_qids.append(qid) - qid += 1 - - json.dump({"s1": {"train_qids": all_qids, "predict": {"dev": all_qids, "test": all_qids}}}, open(self.fold_file, "w")) - topic_f.close() - qrel_f.close() +from profane import import_all_modules +from .dummy import DummyBenchmark import_all_modules(__file__, __package__) diff --git a/capreolus/benchmark/antique.py b/capreolus/benchmark/antique.py new file mode 100644 index 000000000..c6601b0bd --- /dev/null +++ b/capreolus/benchmark/antique.py @@ -0,0 +1,22 @@ +from . import Benchmark +from capreolus import constants, ConfigOption, Dependency +from capreolus.utils.loginit import get_logger + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Benchmark.register +class ANTIQUE(Benchmark): + """A Non-factoid Question Answering Benchmark from Hashemi et al. [1] + + [1] Helia Hashemi, Mohammad Aliannejadi, Hamed Zamani, and W. Bruce Croft. 2020. ANTIQUE: A non-factoid question answering benchmark. ECIR 2020. + """ + + module_name = "antique" + dependencies = [Dependency(key="collection", module="collection", name="antique")] + qrel_file = PACKAGE_PATH / "data" / "qrels.antique.txt" + topic_file = PACKAGE_PATH / "data" / "topics.antique.txt" + fold_file = PACKAGE_PATH / "data" / "antique.json" + query_type = "title" + relevance_level = 2 diff --git a/capreolus/benchmark/codesearchnet.py b/capreolus/benchmark/codesearchnet.py new file mode 100644 index 000000000..662a031ac --- /dev/null +++ b/capreolus/benchmark/codesearchnet.py @@ -0,0 +1,220 @@ +import gzip +import pickle +import json +from collections import defaultdict +from pathlib import Path +from zipfile import ZipFile + +from tqdm import tqdm + +from . import Benchmark +from capreolus import constants, ConfigOption, Dependency +from capreolus.utils.loginit import get_logger +from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt +from capreolus.utils.common import download_file, remove_newline + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Benchmark.register +class CodeSearchNetCorpus(Benchmark): + """CodeSearchNet Corpus. [1] + + [1] Hamel Husain, Ho-Hsiang Wu, Tiferet Gazit, Miltiadis Allamanis, and Marc Brockschmidt. 2019. CodeSearchNet Challenge: Evaluating the State of Semantic Code Search. arXiv 2019. + """ + + module_name = "codesearchnet_corpus" + dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")] + url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2" + query_type = "title" + + file_fn = PACKAGE_PATH / "data" / "csn_corpus" + + qrel_dir = file_fn / "qrels" + topic_dir = file_fn / "topics" + fold_dir = file_fn / "folds" + + qidmap_dir = file_fn / "qidmap" + docidmap_dir = file_fn / "docidmap" + + config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")] + + def build(self): + lang = self.config["lang"] + + self.qid_map_file = self.qidmap_dir / f"{lang}.json" + self.docid_map_file = self.docidmap_dir / f"{lang}.json" + + self.qrel_file = self.qrel_dir / f"{lang}.txt" + self.topic_file = self.topic_dir / f"{lang}.txt" + self.fold_file = self.fold_dir / f"{lang}.json" + + for file in [var for var in vars(self) if var.endswith("file")]: + getattr(self, file).parent.mkdir(exist_ok=True, parents=True) + + self.download_if_missing() + + @property + def qid_map(self): + if not hasattr(self, "_qid_map"): + if not self.qid_map_file.exists(): + self.download_if_missing() + + self._qid_map = json.load(open(self.qid_map_file, "r")) + return self._qid_map + + @property + def docid_map(self): + if not hasattr(self, "_docid_map"): + if not self.docid_map_file.exists(): + self.download_if_missing() + + self._docid_map = json.load(open(self.docid_map_file, "r")) + return self._docid_map + + def download_if_missing(self): + files = [self.qid_map_file, self.docid_map_file, self.qrel_file, self.topic_file, self.fold_file] + if all([f.exists() for f in files]): + return + + lang = self.config["lang"] + + tmp_dir = Path("/tmp") + zip_fn = tmp_dir / f"{lang}.zip" + if not zip_fn.exists(): + download_file(f"{self.url}/{lang}.zip", zip_fn) + + with ZipFile(zip_fn, "r") as zipobj: + zipobj.extractall(tmp_dir) + + # prepare docid-url mapping from dedup.pkl + pkl_fn = tmp_dir / f"{lang}_dedupe_definitions_v2.pkl" + doc_objs = pickle.load(open(pkl_fn, "rb")) + self._docid_map = self._prep_docid_map(doc_objs) + assert self._get_n_docid() == len(doc_objs) + + # prepare folds, qrels, topics, docstring2qid # TODO: shall we add negative samples? + qrels, self._qid_map = defaultdict(dict), {} + qids = {s: [] for s in ["train", "valid", "test"]} + + topic_file = open(self.topic_file, "w", encoding="utf-8") + qrel_file = open(self.qrel_file, "w", encoding="utf-8") + + def gen_doc_from_gzdir(dir): + """ generate parsed dict-format doc from all jsonl.gz files under given directory """ + for fn in sorted(dir.glob("*.jsonl.gz")): + f = gzip.open(fn, "rb") + for doc in f: + yield json.loads(doc) + + for set_name in qids: + set_path = tmp_dir / lang / "final" / "jsonl" / set_name + for doc in gen_doc_from_gzdir(set_path): + code = remove_newline(" ".join(doc["code_tokens"])) + docstring = remove_newline(" ".join(doc["docstring_tokens"])) + n_words_in_docstring = len(docstring.split()) + if n_words_in_docstring >= 1024: + logger.warning( + f"chunk query to first 1000 words otherwise TooManyClause would be triggered " + f"at lucene at search stage, " + ) + docstring = " ".join(docstring.split()[:1020]) # for TooManyClause + + docid = self.get_docid(doc["url"], code) + qid = self._qid_map.get(docstring, str(len(self._qid_map))) + qrel_file.write(f"{qid} Q0 {docid} 1\n") + + if docstring not in self._qid_map: + self._qid_map[docstring] = qid + qids[set_name].append(qid) + topic_file.write(topic_to_trectxt(qid, docstring)) + + topic_file.close() + qrel_file.close() + + # write to qid_map.json, docid_map, fold.json + json.dump(self._qid_map, open(self.qid_map_file, "w")) + json.dump(self._docid_map, open(self.docid_map_file, "w")) + json.dump( + {"s1": {"train_qids": qids["train"], "predict": {"dev": qids["valid"], "test": qids["test"]}}}, + open(self.fold_file, "w"), + ) + + def _prep_docid_map(self, doc_objs): + """ + construct a nested dict to map each doc into a unique docid + which follows the structure: {url: {" ".join(code_tokens): docid, ...}} + + For all the lanugage datasets the url uniquely maps to a code_tokens yet it's not the case for but js and php + which requires a second-level mapping from raw_doc to docid + + :param doc_objs: a list of dict having keys ["nwo", "url", "sha", "identifier", "arguments" + "function", "function_tokens", "docstring", "doctring_tokens",], + :return: + """ + # TODO: any way to avoid the twice traversal of all url and make the return dict structure consistent + lang = self.config["lang"] + url2docid = defaultdict(dict) + for i, doc in tqdm(enumerate(doc_objs), desc=f"Preparing the {lang} docid_map"): + url, code_tokens = doc["url"], remove_newline(" ".join(doc["function_tokens"])) + url2docid[url][code_tokens] = f"{lang}-FUNCTION-{i}" + + # remove the code_tokens for the unique url-docid mapping + for url, docids in tqdm(url2docid.items(), desc=f"Compressing the {lang} docid_map"): + url2docid[url] = list(docids.values()) if len(docids) == 1 else docids # {code_tokens: docid} -> [docid] + return url2docid + + def _get_n_docid(self): + """ calculate the number of document ids contained in the nested docid map """ + lens = [len(docs) for url, docs in self._docid_map.items()] + return sum(lens) + + def get_docid(self, url, code_tokens): + """ retrieve the doc id according to the doc dict """ + docids = self.docid_map[url] + return docids[0] if len(docids) == 1 else docids[code_tokens] + + +@Benchmark.register +class CodeSearchNetChallenge(Benchmark): + """CodeSearchNet Challenge. [1] + This benchmark can only be used for training (and challenge submissions) because no qrels are provided. + + [1] Hamel Husain, Ho-Hsiang Wu, Tiferet Gazit, Miltiadis Allamanis, and Marc Brockschmidt. 2019. CodeSearchNet Challenge: Evaluating the State of Semantic Code Search. arXiv 2019. + """ + + module_name = "codesearchnet_challenge" + dependencies = [Dependency(key="collection", module="collection", name="codesearchnet")] + config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")] + + url = "https://raw.githubusercontent.com/github/CodeSearchNet/master/resources/queries.csv" + query_type = "title" + + file_fn = PACKAGE_PATH / "data" / "csn_challenge" + topic_file = file_fn / "topics.txt" + qid_map_file = file_fn / "qidmap.json" + + def download_if_missing(self): + """ download query.csv and prepare queryid - query mapping file """ + if self.topic_file.exists() and self.qid_map_file.exists(): + return + + tmp_dir = Path("/tmp") + tmp_dir.mkdir(exist_ok=True, parents=True) + self.file_fn.mkdir(exist_ok=True, parents=True) + + query_fn = tmp_dir / f"query.csv" + if not query_fn.exists(): + download_file(self.url, query_fn) + + # prepare qid - query + qid_map = {} + topic_file = open(self.topic_file, "w", encoding="utf-8") + query_file = open(query_fn) + for qid, line in enumerate(query_file): + if qid != 0: # ignore the first line "query" + topic_file.write(topic_to_trectxt(qid, line.strip())) + qid_map[qid] = line + topic_file.close() + json.dump(qid_map, open(self.qid_map_file, "w")) diff --git a/capreolus/benchmark/covid.py b/capreolus/benchmark/covid.py new file mode 100644 index 000000000..97941166c --- /dev/null +++ b/capreolus/benchmark/covid.py @@ -0,0 +1,192 @@ +import os +import json +from pathlib import Path + +from bs4 import BeautifulSoup + +from . import Benchmark +from capreolus import constants, ConfigOption, Dependency +from capreolus.utils.loginit import get_logger +from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt +from capreolus.utils.common import download_file, remove_newline, get_udel_query_expander + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Benchmark.register +class COVID(Benchmark): + """ Ongoing TREC-COVID bechmark from https://ir.nist.gov/covidSubmit that uses documents from CORD, the COVID-19 Open Research Dataset (https://www.semanticscholar.org/cord19). """ + + module_name = "covid" + dependencies = [Dependency(key="collection", module="collection", name="covid")] + data_dir = PACKAGE_PATH / "data" / "covid" + topic_url = "https://ir.nist.gov/covidSubmit/data/topics-rnd%d.xml" + qrel_url = "https://ir.nist.gov/covidSubmit/data/qrels-rnd%d.txt" + lastest_round = 3 + + config_spec = [ + ConfigOption("round", 3, "TREC-COVID round to use"), + ConfigOption("udelqexpand", False), + ConfigOption("excludeknown", True), + ] + + def build(self): + if self.config["round"] == self.lastest_round and not self.config["excludeknown"]: + logger.warning(f"No evaluation can be done for the lastest round in exclude-known mode") + + data_dir = self.get_cache_path() / "documents" + data_dir.mkdir(exist_ok=True, parents=True) + + self.qrel_ignore = f"{data_dir}/ignore.qrel.txt" + self.qrel_file = f"{data_dir}/qrel.txt" + self.topic_file = f"{data_dir}/topic.txt" + self.fold_file = f"{data_dir}/fold.json" + + self.download_if_missing() + + def download_if_missing(self): + if all([os.path.exists(fn) for fn in [self.qrel_file, self.qrel_ignore, self.topic_file, self.fold_file]]): + return + + rnd_i, excludeknown = self.config["round"], self.config["excludeknown"] + if rnd_i > self.lastest_round: + raise ValueError(f"round {rnd_i} is unavailable") + + logger.info(f"Preparing files for covid round-{rnd_i}") + + topic_url = self.topic_url % rnd_i + qrel_ignore_urls = [self.qrel_url % i for i in range(1, rnd_i)] # download all the qrels before current run + + # topic file + tmp_dir = Path("/tmp") + topic_tmp = tmp_dir / f"topic.round.{rnd_i}.xml" + if not os.path.exists(topic_tmp): + download_file(topic_url, topic_tmp) + all_qids = self.xml2trectopic(topic_tmp) # will update self.topic_file + + if excludeknown: + qrel_fn = open(self.qrel_file, "w") + for i, qrel_url in enumerate(qrel_ignore_urls): + qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1) + if not os.path.exists(qrel_tmp): + download_file(qrel_url, qrel_tmp) + with open(qrel_tmp) as f: + for line in f: + qrel_fn.write(line) + qrel_fn.close() + + f = open(self.qrel_ignore, "w") # empty ignore file + f.close() + else: + qrel_fn = open(self.qrel_ignore, "w") + for i, qrel_url in enumerate(qrel_ignore_urls): + qrel_tmp = tmp_dir / f"qrel-{i+1}" # round_id = (i + 1) + if not os.path.exists(qrel_tmp): + download_file(qrel_url, qrel_tmp) + with open(qrel_tmp) as f: + for line in f: + qrel_fn.write(line) + qrel_fn.close() + + if rnd_i == self.lastest_round: + f = open(self.qrel_file, "w") + f.close() + else: + with open(tmp_dir / f"qrel-{rnd_i}") as fin, open(self.qrel_file, "w") as fout: + for line in fin: + fout.write(line) + + # folds: use all labeled query for train, valid, and use all of them for test set + labeled_qids = list(load_qrels(self.qrel_ignore).keys()) + folds = {"s1": {"train_qids": labeled_qids, "predict": {"dev": labeled_qids, "test": all_qids}}} + json.dump(folds, open(self.fold_file, "w")) + + def xml2trectopic(self, xmlfile): + with open(xmlfile, "r") as f: + topic = f.read() + + all_qids = [] + soup = BeautifulSoup(topic, "lxml") + topics = soup.find_all("topic") + expand_query = get_udel_query_expander() + + with open(self.topic_file, "w") as fout: + for topic in topics: + qid = topic["number"] + title = topic.find_all("query")[0].text.strip() + desc = topic.find_all("question")[0].text.strip() + narr = topic.find_all("narrative")[0].text.strip() + + if self.config["udelqexpand"]: + title = expand_query(title, rm_sw=True) + desc = expand_query(desc, rm_sw=False) + + title = title + " " + desc + desc = " " + + topic_line = topic_to_trectxt(qid, title, desc=desc, narr=narr) + fout.write(topic_line) + all_qids.append(qid) + return all_qids + + +@Benchmark.register +class CovidQA(Benchmark): + module_name = "covidqa" + dependencies = [Dependency(key="collection", module="collection", name="covid")] + url = "https://raw.githubusercontent.com/castorini/pygaggle/master/data/kaggle-lit-review-%s.json" + available_versions = ["0.1", "0.2"] + + datadir = PACKAGE_PATH / "data" / "covidqa" + + config_spec = [ConfigOption("version", "0.1+0.2")] + + def build(self): + os.makedirs(self.datadir, exist_ok=True) + + version = self.config["version"] + self.qrel_file = self.datadir / f"qrels.v{version}.txt" + self.topic_file = self.datadir / f"topics.v{version}.txt" + self.fold_file = self.datadir / f"v{version}.json" # HOW TO SPLIT THE FOLD HERE? + + self.download_if_missing() + + def download_if_missing(self): + if all([os.path.exists(f) for f in [self.qrel_file, self.topic_file, self.fold_file]]): + return + + tmp_dir = Path("/tmp") + topic_f = open(self.topic_file, "w", encoding="utf-8") + qrel_f = open(self.qrel_file, "w", encoding="utf-8") + + all_qids = [] + qid = 2001 # to distingsuish queries here from queries in TREC-covid + versions = self.config["version"].split("+") if isinstance(self.config["version"], str) else str(self.config["version"]) + for v in versions: + if v not in self.available_versions: + vs = " ".join(self.available_versions) + logger.warning(f"Invalid version {v}, should be one of {vs}") + continue + + url = self.url % v + target_fn = tmp_dir / f"covidqa-v{v}.json" + if not os.path.exists(target_fn): + download_file(url, target_fn) + qa = json.load(open(target_fn)) + for subcate in qa["categories"]: + name = subcate["name"] + + for qa in subcate["sub_categories"]: + nq_name, kq_name = qa["nq_name"], qa["kq_name"] + query_line = topic_to_trectxt(qid, kq_name, nq_name) # kq_name == "query", nq_name == "question" + topic_f.write(query_line) + for ans in qa["answers"]: + docid = ans["id"] + qrel_f.write(f"{qid} Q0 {docid} 1\n") + all_qids.append(qid) + qid += 1 + + json.dump({"s1": {"train_qids": all_qids, "predict": {"dev": all_qids, "test": all_qids}}}, open(self.fold_file, "w")) + topic_f.close() + qrel_f.close() diff --git a/capreolus/benchmark/dummy.py b/capreolus/benchmark/dummy.py new file mode 100644 index 000000000..18d777283 --- /dev/null +++ b/capreolus/benchmark/dummy.py @@ -0,0 +1,16 @@ +from . import Benchmark +from capreolus import constants, ConfigOption, Dependency + +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Benchmark.register +class DummyBenchmark(Benchmark): + """ Tiny benchmark for testing """ + + module_name = "dummy" + dependencies = [Dependency(key="collection", module="collection", name="dummy")] + qrel_file = PACKAGE_PATH / "data" / "qrels.dummy.txt" + topic_file = PACKAGE_PATH / "data" / "topics.dummy.txt" + fold_file = PACKAGE_PATH / "data" / "dummy_folds.json" + query_type = "title" diff --git a/capreolus/benchmark/msmarco.py b/capreolus/benchmark/msmarco.py new file mode 100644 index 000000000..39019abae --- /dev/null +++ b/capreolus/benchmark/msmarco.py @@ -0,0 +1,15 @@ +from capreolus import constants, ConfigOption, Dependency +from . import Benchmark + +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +# TODO add download_if_missing and re-enable +# @Benchmark.register +# class MSMarcoPassage(Benchmark): +# module_name = "msmarcopassage" +# dependencies = [Dependency(key="collection", module="collection", name="msmarco")] +# qrel_file = PACKAGE_PATH / "data" / "qrels.msmarcopassage.txt" +# topic_file = PACKAGE_PATH / "data" / "topics.msmarcopassage.txt" +# fold_file = PACKAGE_PATH / "data" / "msmarcopassage.folds.json" +# query_type = "title" diff --git a/capreolus/benchmark/nf.py b/capreolus/benchmark/nf.py new file mode 100644 index 000000000..f8129b876 --- /dev/null +++ b/capreolus/benchmark/nf.py @@ -0,0 +1,114 @@ +import json +import re + +from . import Benchmark +from capreolus import constants, ConfigOption, Dependency +from capreolus.utils.loginit import get_logger +from capreolus.utils.trec import load_qrels, load_trec_topics, topic_to_trectxt + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Benchmark.register +class NF(Benchmark): + """ NFCorpus: A Full-Text Learning to Rank Dataset for Medical Information Retrieval [1] + + [1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval Proceedings of the 38th European Conference on Information Retrieval (ECIR), Padova, Italy, 2016. https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/ + """ + + module_name = "nf" + dependencies = [Dependency(key="collection", module="collection", name="nf")] + config_spec = [ + ConfigOption(key="labelrange", default_value="0-2", description="range of dataset qrels, options: 0-2, 1-3"), + ConfigOption( + key="fields", + default_value="all_titles", + description="query fields included in topic file, " + "options: 'all_fields', 'all_titles', 'nontopics', 'vid_title', 'vid_desc'", + ), + ] + + fold_file = PACKAGE_PATH / "data" / "nf.json" + + query_type = "title" + + def build(self): + fields, label_range = self.config["fields"], self.config["labelrange"] + self.field2kws = { + "all_fields": ["all"], + "nontopics": ["nontopic-titles"], + "vid_title": ["vid-titles"], + "vid_desc": ["vid-desc"], + "all_titles": ["nontopic-titles", "vid-titles", "nontopic-titles"], + } + self.labelrange2kw = {"0-2": "2-1-0", "1-3": "3-2-1"} + + if fields not in self.field2kws: + raise ValueError(f"Unexpected fields value: {fields}, expect: {', '.join(self.field2kws.keys())}") + if label_range not in self.labelrange2kw: + raise ValueError(f"Unexpected label range: {label_range}, expect: {', '.join(self.field2kws.keys())}") + + self.qrel_file = PACKAGE_PATH / "data" / f"qrels.nf.{label_range}.txt" + self.test_qrel_file = PACKAGE_PATH / "data" / f"test.qrels.nf.{label_range}.txt" + self.topic_file = PACKAGE_PATH / "data" / f"topics.nf.{fields}.txt" + self.download_if_missing() + + def _transform_qid(self, raw): + """ NFCorpus dataset specific, remove prefix in query id since anserini convert all qid to integer """ + return raw.replace("PLAIN-", "") + + def download_if_missing(self): + if all([f.exists() for f in [self.topic_file, self.fold_file, self.qrel_file]]): + return + + tmp_corpus_dir = self.collection.download_raw() + topic_f = open(self.topic_file, "w", encoding="utf-8") + qrel_f = open(self.qrel_file, "w", encoding="utf-8") + test_qrel_f = open(self.test_qrel_file, "w", encoding="utf-8") + + set_names = ["train", "dev", "test"] + folds = {s: set() for s in set_names} + qrel_kw = self.labelrange2kw[self.config["labelrange"]] + for set_name in set_names: + with open(tmp_corpus_dir / f"{set_name}.{qrel_kw}.qrel") as f: + for line in f: + line = self._transform_qid(line) + qid = line.strip().split()[0] + folds[set_name].add(qid) + if set_name == "test": + test_qrel_f.write(line) + qrel_f.write(line) + + files = [tmp_corpus_dir / f"{set_name}.{keyword}.queries" for keyword in self.field2kws[self.config["fields"]]] + qids2topics = self._align_queries(files, "title") + + for qid, txts in qids2topics.items(): + topic_f.write(topic_to_trectxt(qid, txts["title"])) + + json.dump( + {"s1": {"train_qids": list(folds["train"]), "predict": {"dev": list(folds["dev"]), "test": list(folds["test"])}}}, + open(self.fold_file, "w"), + ) + + topic_f.close() + qrel_f.close() + test_qrel_f.close() + logger.info(f"nf benchmark prepared") + + def _align_queries(self, files, field, qid2queries=None): + if not qid2queries: + qid2queries = {} + for fn in files: + with open(fn, "r", encoding="utf-8") as f: + for line in f: + qid, txt = line.strip().split("\t") + qid = self._transform_qid(qid) + txt = " ".join(re.sub("[^A-Za-z]", " ", txt).split()[:1020]) + if qid not in qid2queries: + qid2queries[qid] = {field: txt} + else: + if field in qid2queries[qid]: + logger.warning(f"Overwriting title for query {qid}") + qid2queries[qid][field] = txt + return qid2queries diff --git a/capreolus/benchmark/robust04.py b/capreolus/benchmark/robust04.py index d3545aaec..e13254782 100644 --- a/capreolus/benchmark/robust04.py +++ b/capreolus/benchmark/robust04.py @@ -1,4 +1,4 @@ -from profane import ModuleBase, Dependency, ConfigOption, constants +from capreolus import constants, ConfigOption, Dependency from . import Benchmark PACKAGE_PATH = constants["PACKAGE_PATH"] @@ -10,6 +10,7 @@ class Robust04(Benchmark): Given the remaining four folds, we split them into the same train and dev sets used in recent work. [2] [1] Samuel Huston and W. Bruce Croft. 2014. Parameters learned in the comparison of retrieval models using term dependencies. Technical Report. + [2] Sean MacAvaney, Andrew Yates, Arman Cohan, Nazli Goharian. 2019. CEDR: Contextualized Embeddings for Document Ranking. SIGIR 2019. """ @@ -19,3 +20,18 @@ class Robust04(Benchmark): topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt" fold_file = PACKAGE_PATH / "data" / "rob04_cedr_folds.json" query_type = "title" + + +@Benchmark.register +class Robust04Yang19(Benchmark): + """Robust04 benchmark using the folds from Yang et al. [1] + + [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. 2019. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019. + """ + + module_name = "robust04.yang19" + dependencies = [Dependency(key="collection", module="collection", name="robust04")] + qrel_file = PACKAGE_PATH / "data" / "qrels.robust2004.txt" + topic_file = PACKAGE_PATH / "data" / "topics.robust04.301-450.601-700.txt" + fold_file = PACKAGE_PATH / "data" / "rob04_yang19_folds.json" + query_type = "title" diff --git a/capreolus/collection/__init__.py b/capreolus/collection/__init__.py index f559c9643..d5bc5e048 100644 --- a/capreolus/collection/__init__.py +++ b/capreolus/collection/__init__.py @@ -1,44 +1,38 @@ -from profane import import_all_modules - -# import_all_modules(__file__, __package__) - import os -import math -import shutil -import pickle -import tarfile -import filecmp -from tqdm import tqdm -from zipfile import ZipFile -from pathlib import Path -import pandas as pd -from profane import ModuleBase, Dependency, ConfigOption, constants +from capreolus import ModuleBase, Dependency, ConfigOption, constants -from capreolus.utils.common import download_file, hash_file, remove_newline -from capreolus.utils.loginit import get_logger -from capreolus.utils.trec import anserini_index_to_trec_docs, document_to_trectxt -logger = get_logger(__name__) -PACKAGE_PATH = constants["PACKAGE_PATH"] +class Collection(ModuleBase): + """Base class for Collection modules. The purpose of a Collection is to describe a document collection's location and its format. + Determining the document collection's location on disk: + - The *path* config option will be used if it contains a valid loation. + - If not, the ``_path`` attribute is used if it is valid. This is primarily used with :class:`~.dummy.DummyCollection`. + - If not, the class' ``download_if_missing`` method will be called. + + Modules should provide: + - the ``collection_type`` and ``generator_type`` class attributes, corresponding to Anserini types + - a ``download_if_missing`` method, if the collection is publicly available + - a ``_validate_document_path`` method. See :func:`~capreolus.collection.Collection.validate_document_path`. + """ -class Collection(ModuleBase): module_type = "collection" is_large_collection = False _path = None def get_path_and_types(self): + """ Returns a ``(path, collection_type, generator_type)`` tuple. """ if not self.validate_document_path(self._path): self._path = self.find_document_path() return self._path, self.collection_type, self.generator_type def validate_document_path(self, path): - """ Attempt to validate the document collection at `path`. + """ Attempt to validate the document collection at ``path``. - By default, this will only check whether `path` exists. Subclasses should override - `_validate_document_path(path)` with their own logic to perform more detailed checks. + By default, this will only check whether ``path`` exists. Subclasses should override + ``_validate_document_path(path)`` with their own logic to perform more detailed checks. Returns: True if the path is valid following the logic described above, or False if it is not @@ -61,11 +55,11 @@ def _validate_document_path(self, path): def find_document_path(self): """ Find the location of this collection's documents (i.e., the raw document collection). - We first check the collection's config for a path key. If found, `self.validate_document_path` checks - whether the path is valid. Subclasses should override the private method `self._validate_document_path` - with custom logic for performing checks further than existence of the directory. See `Robust04`. + We first check the collection's config for a path key. If found, ``self.validate_document_path`` checks + whether the path is valid. Subclasses should override the private method ``self._validate_document_path`` + with custom logic for performing checks further than existence of the directory. - If a valid path was not found, call `download_if_missing`. + If a valid path was not found, call ``download_if_missing``. Subclasses should override this method if downloading the needed documents is possible. If a valid document path cannot be found, an exception is thrown. @@ -78,451 +72,25 @@ def find_document_path(self): if "path" in self.config and self.validate_document_path(self.config["path"]): return self.config["path"] + # see if the path is hardcoded (e.g., for the dummy collection") + if self._path and self.validate_document_path(self._path): + return self._path + # if not, see if the collection can be obtained through its download_if_missing method return self.download_if_missing() def download_if_missing(self): + """ Download the collection and return its path. Subclasses should override this. """ raise IOError( f"a download URL is not configured for collection={self.module_name} and the collection path does not exist; you must manually place the document collection at this path in order to use this collection" ) -@Collection.register -class Robust04(Collection): - module_name = "robust04" - collection_type = "TrecCollection" - generator_type = "DefaultLuceneDocumentGenerator" - config_keys_not_in_path = ["path"] - config_spec = [ConfigOption("path", "Aquaint-TREC-3-4", "path to corpus")] - - def download_if_missing(self): - return self.download_index( - url="https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-robust04-20191213.tar.gz", - sha256="dddb81f16d70ea6b9b0f94d6d6b888ed2ef827109a14ca21fd82b2acd6cbd450", - index_directory_inside="index-robust04-20191213/", - # this string should match how the index was built (i.e., Anserini, stopwords removed, Porter stemming) - index_cache_path_string="index-anserini_indexstops-False_stemmer-porter", - index_expected_document_count=528_030, - cachedir=self.get_cache_path(), - ) - - def _validate_document_path(self, path): - """ Validate that the document path appears to contain robust04's documents (Aquaint-TREC-3-4). - - Validation is performed by looking for four directories (case-insensitive): `FBIS`, `FR94`, `FT`, and `LATIMES`. - These directories may either be at the root of `path` or they may be in `path/NEWS_data` (case-insensitive). - - Returns: - True if the Aquaint-TREC-3-4 document directories are found or False if not - """ - - if not os.path.isdir(path): - return False - - contents = {fn.lower(): fn for fn in os.listdir(path)} - if "news_data" in contents: - contents = {fn.lower(): fn for fn in os.listdir(os.path.join(path, contents["news_data"]))} - - if "fbis" in contents and "fr94" in contents and "ft" in contents and "latimes" in contents: - return True - - return False - - def download_index( - self, cachedir, url, sha256, index_directory_inside, index_cache_path_string, index_expected_document_count - ): - # Download the collection from URL and extract into a path in the cache directory. - # To avoid re-downloading every call, we create an empty '/done' file in this directory on success. - done_file = os.path.join(cachedir, "done") - document_dir = os.path.join(cachedir, "documents") - - # already downloaded? - if os.path.exists(done_file): - return document_dir - - # 1. Download and extract Anserini index to a temporary location - tmp_dir = os.path.join(cachedir, "tmp_download") - archive_file = os.path.join(tmp_dir, "archive_file") - os.makedirs(document_dir, exist_ok=True) - os.makedirs(tmp_dir, exist_ok=True) - logger.info("downloading index for missing collection %s to temporary file %s", self.module_name, archive_file) - download_file(url, archive_file, expected_hash=sha256) - - logger.info("extracting index to %s (before moving to correct cache path)", tmp_dir) - with tarfile.open(archive_file) as tar: - tar.extractall(path=tmp_dir) - - extracted_dir = os.path.join(tmp_dir, index_directory_inside) - if not (os.path.exists(extracted_dir) and os.path.isdir(extracted_dir)): - raise ValueError(f"could not find expected index directory {extracted_dir} in {tmp_dir}") - - # 2. Move index to its correct location in the cache - index_dir = os.path.join(cachedir, index_cache_path_string, "index") - if not os.path.exists(os.path.join("index_dir", "done")): - if os.path.exists(index_dir): - shutil.rmtree(index_dir) - shutil.move(extracted_dir, index_dir) - - # 3. Extract raw documents from the Anserini index to document_dir - anserini_index_to_trec_docs(index_dir, document_dir, index_expected_document_count) - - # remove temporary files and create a /done we can use to verify extraction was successful - shutil.rmtree(tmp_dir) - with open(done_file, "wt") as outf: - print("", file=outf) - - return document_dir - - -@Collection.register -class DummyCollection(Collection): - module_name = "dummy" - _path = PACKAGE_PATH / "data" / "dummy" / "data" - collection_type = "TrecCollection" - generator_type = "DefaultLuceneDocumentGenerator" - - def _validate_document_path(self, path): - """ Validate that the document path contains `dummy_trec_doc` """ - return "dummy_trec_doc" in os.listdir(path) - - -@Collection.register -class NF(Collection): - module_name = "nf" - _path = PACKAGE_PATH / "data" / "nf-collection" - url = "http://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/nfcorpus.tar.gz" - - collection_type = "TrecCollection" - generator_type = "DefaultLuceneDocumentGenerator" - - def download_raw(self): - cachedir = self.get_cache_path() - tmp_dir = cachedir / "tmp" - tmp_tar_fn, tmp_corpus_dir = tmp_dir / "nfcorpus.tar.gz", tmp_dir / "nfcorpus" - - os.makedirs(tmp_dir, exist_ok=True) - - if not tmp_tar_fn.exists(): - download_file(self.url, tmp_tar_fn) - - with tarfile.open(tmp_tar_fn) as f: - f.extractall(tmp_dir) - return tmp_corpus_dir - - def download_if_missing(self): - cachedir = self.get_cache_path() - document_dir = os.path.join(cachedir, "documents") - coll_filename = os.path.join(document_dir, "nf-collection.txt") - if os.path.exists(coll_filename): - return document_dir - - os.makedirs(document_dir, exist_ok=True) - tmp_corpus_dir = self.download_raw() - - inp_fns = [tmp_corpus_dir / f"{set_name}.docs" for set_name in ["train", "dev", "test"]] - print(inp_fns) - with open(coll_filename, "w", encoding="utf-8") as outp_file: - self._convert_to_trec(inp_fns, outp_file) - logger.info(f"nf collection file prepared, stored at {coll_filename}") - - return document_dir - - def _convert_to_trec(self, inp_fns, outp_file): - for inp_fn in inp_fns: - assert os.path.exists(inp_fn) - - with open(inp_fn, "rt", encoding="utf-8") as f: - for line in f: - docid, doc = line.strip().split("\t") - outp_file.write(f"\n{docid}\n\n{doc}\n\n\n") - - -@Collection.register -class ANTIQUE(Collection): - module_name = "antique" - _path = PACKAGE_PATH / "data" / "antique-collection" - - collection_type = "TrecCollection" - generator_type = "DefaultLuceneDocumentGenerator" - - def download_if_missing(self): - url = "http://ciir.cs.umass.edu/downloads/Antique/antique-collection.txt" - cachedir = self.get_cache_path() - document_dir = os.path.join(cachedir, "documents") - coll_filename = os.path.join(document_dir, "antique-collection.txt") - - if os.path.exists(coll_filename): - return document_dir - - tmp_dir = cachedir / "tmp" - tmp_filename = os.path.join(tmp_dir, "tmp.anqique.file") - - os.makedirs(tmp_dir, exist_ok=True) - os.makedirs(document_dir, exist_ok=True) - - download_file(url, tmp_filename, expected_hash="68b6688f5f2668c93f0e8e43384f66def768c4da46da4e9f7e2629c1c47a0c36") - self._convert_to_trec(inp_path=tmp_filename, outp_path=coll_filename) - logger.info(f"antique collection file prepared, stored at {coll_filename}") - - for file in os.listdir(tmp_dir): # in case there are legacy files - os.remove(os.path.join(tmp_dir, file)) - shutil.rmtree(tmp_dir) - - return document_dir - - def _convert_to_trec(self, inp_path, outp_path): - assert os.path.exists(inp_path) - - fout = open(outp_path, "wt", encoding="utf-8") - with open(inp_path, "rt", encoding="utf-8") as f: - for line in f: - docid, doc = line.strip().split("\t") - fout.write(f"\n{docid}\n\n{doc}\n\n\n") - fout.close() - logger.debug(f"Converted file {os.path.basename(inp_path)} to TREC format, output to: {outp_path}") - - def _validate_document_path(self, path): - """ Checks that the sha256sum is correct """ - return ( - hash_file(os.path.join(path, "antique-collection.txt")) - == "409e0960f918970977ceab9e5b1d372f45395af25d53b95644bdc9ccbbf973da" - ) - - -@Collection.register -class MSMarco(Collection): - module_name = "msmarco" - config_keys_not_in_path = ["path"] - collection_type = "TrecCollection" - generator_type = "DefaultLuceneDocumentGenerator" - config_spec = [ConfigOption("path", "/GW/NeuralIR/nobackup/msmarco/trec_format", "path to corpus")] - - -@Collection.register -class CodeSearchNet(Collection): - module_name = "codesearchnet" - url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2" - collection_type = "TrecCollection" # TODO: any other supported type? - generator_type = "DefaultLuceneDocumentGenerator" - config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")] - - def download_if_missing(self): - cachedir = self.get_cache_path() - document_dir = cachedir / "documents" - coll_filename = document_dir / ("csn-" + self.config["lang"] + "-collection.txt") - - if coll_filename.exists(): - return document_dir - - zipfile = self.config["lang"] + ".zip" - lang_url = f"{self.url}/{zipfile}" - tmp_dir = cachedir / "tmp" - zip_path = tmp_dir / zipfile - - if zip_path.exists(): - logger.info(f"{zipfile} already exist under directory {tmp_dir}, skip downloaded") - else: - tmp_dir.mkdir(exist_ok=True, parents=True) - download_file(lang_url, zip_path) - - document_dir.mkdir(exist_ok=True, parents=True) # tmp - with ZipFile(zip_path, "r") as zipobj: - zipobj.extractall(tmp_dir) - - pkl_path = tmp_dir / (self.config["lang"] + "_dedupe_definitions_v2.pkl") - self._pkl2trec(pkl_path, coll_filename) - return document_dir - - def _pkl2trec(self, pkl_path, trec_path): - lang = self.config["lang"] - with open(pkl_path, "rb") as f: - codes = pickle.load(f) - - fout = open(trec_path, "w", encoding="utf-8") - for i, code in tqdm(enumerate(codes), desc=f"Preparing the {lang} collection file"): - docno = f"{lang}-FUNCTION-{i}" - doc = remove_newline(" ".join(code["function_tokens"])) - fout.write(document_to_trectxt(docno, doc)) - fout.close() - - -@Collection.register -class COVID(Collection): - module_name = "covid" - url = "https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/historical_releases/cord-19_%s.tar.gz" - generator_type = "Cord19Generator" - config_spec = [ConfigOption("coll_type", "abstract", "one of: abstract, fulltext, paragraph"), ConfigOption("round", 3)] - - def build(self): - coll_type, round = self.config["coll_type"], self.config["round"] - type2coll = { - "abstract": "Cord19AbstractCollection", - "fulltext": "Cord19FullTextCollection", - "paragraph": "Cord19ParagraphCollection", - } - dates = ["2020-04-10", "2020-05-01", "2020-05-19"] - - if coll_type not in type2coll: - raise ValueError(f"Unexpected coll_type: {coll_type}; expeced one of: {' '.join(type2coll.keys())}") - if round > len(dates): - raise ValueError(f"Unexpected round number: {round}; only {len(dates)} number of rounds are provided") - - self.collection_type = type2coll[coll_type] - self.date = dates[round - 1] - - def download_if_missing(self): - cachedir = self.get_cache_path() - tmp_dir, document_dir = Path("/tmp"), cachedir / "documents" - expected_fns = [document_dir / "metadata.csv", document_dir / "document_parses"] - if all([os.path.exists(f) for f in expected_fns]): - return document_dir - - url = self.url % self.date - tar_file = tmp_dir / f"covid-19-{self.date}.tar.gz" - if not tar_file.exists(): - download_file(url, tar_file) - - with tarfile.open(tar_file) as f: - f.extractall(path=cachedir) # emb.tar.gz, metadata.csv, doc.tar.gz, changelog - os.rename(cachedir / self.date, document_dir) - - doc_fn = "document_parses" - if f"{doc_fn}.tar.gz" in os.listdir(document_dir): - with tarfile.open(document_dir / f"{doc_fn}.tar.gz") as f: - f.extractall(path=document_dir) - else: - self.transform_metadata(document_dir) - - # only document_parses and metadata.csv are expected - for fn in os.listdir(document_dir): - if (document_dir / fn) not in expected_fns: - os.remove(document_dir / fn) - return document_dir - - def transform_metadata(self, root_path): - """ - the transformation is necessary for dataset round 1 and 2 according to - https://discourse.cord-19.semanticscholar.org/t/faqs-about-cord-19-dataset/94 - - the assumed directory under root_path: - ./root_path - ./metadata.csv - ./comm_use_subset - ./noncomm_use_subset - ./custom_license - ./biorxiv_medrxiv - ./archive - - In a nutshell: - 1. renaming: - Microsoft Academic Paper ID -> mag_id; - WHO #Covidence -> who_covidence_id - 2. update: - has_pdf_parse -> pdf_json_files # e.g. document_parses/pmc_json/PMC125340.xml.json - has_pmc_xml_parse -> pmc_json_files - """ - metadata_csv = str(root_path / "metadata.csv") - orifiles = ["arxiv", "custom_license", "biorxiv_medrxiv", "comm_use_subset", "noncomm_use_subset"] - for fn in orifiles: - if (root_path / fn).exists(): - continue - - tar_fn = root_path / f"{fn}.tar.gz" - if not tar_fn.exists(): - continue - - with tarfile.open(str(tar_fn)) as f: - f.extractall(path=root_path) - os.remove(tar_fn) - - metadata = pd.read_csv(metadata_csv, header=0) - columns = metadata.columns.values - cols_before = [ - "cord_uid", - "sha", - "source_x", - "title", - "doi", - "pmcid", - "pubmed_id", - "license", - "abstract", - "publish_time", - "authors", - "journal", - "Microsoft Academic Paper ID", - "WHO #Covidence", - "arxiv_id", - "has_pdf_parse", - "has_pmc_xml_parse", - "full_text_file", - "url", - ] - assert all(columns == cols_before) - - # step 1: rename column - cols_to_rename = {"Microsoft Academic Paper ID": "mag_id", "WHO #Covidence": "who_covidence_id"} - metadata.columns = [cols_to_rename.get(c, c) for c in columns] - - # step 2: parse path & move json file - doc_outp = root_path / "document_parses" - pdf_dir, pmc_dir = doc_outp / "pdf_json", doc_outp / "pmc_json" - pdf_dir.mkdir(exist_ok=True, parents=True) - pmc_dir.mkdir(exist_ok=True, parents=True) - - new_cols = ["pdf_json_files", "pmc_json_files"] - for col in new_cols: - metadata[col] = "" - metadata["s2_id"] = math.nan # tmp, what's this column?? - - iterbar = tqdm(desc="transforming data", total=len(metadata)) - for i, row in metadata.iterrows(): - dir = row["full_text_file"] - - if row["has_pmc_xml_parse"]: - name = row["pmcid"] + ".xml.json" - ori_fn = root_path / dir / "pmc_json" / name - pmc_fn = f"document_parses/pmc_json/{name}" - metadata.at[i, "pmc_json_files"] = pmc_fn - pmc_fn = root_path / pmc_fn - if not pmc_fn.exists(): - os.rename(ori_fn, pmc_fn) # check - else: - metadata.at[i, "pmc_json_files"] = math.nan - - if row["has_pdf_parse"]: - shas = str(row["sha"]).split(";") - pdf_fn_final = "" - for sha in shas: - name = sha.strip() + ".json" - ori_fn = root_path / dir / "pdf_json" / name - pdf_fn = f"document_parses/pdf_json/{name}" - pdf_fn_final = f"{pdf_fn_final};{pdf_fn}" if pdf_fn_final else pdf_fn - pdf_fn = root_path / pdf_fn - if not pdf_fn.exists(): - os.rename(ori_fn, pdf_fn) # check - else: - if ori_fn.exists(): - assert filecmp.cmp(ori_fn, pdf_fn) - os.remove(ori_fn) - - metadata.at[i, "pdf_json_files"] = pdf_fn_final - else: - metadata.at[i, "pdf_json_files"] = math.nan - - iterbar.update() - - # step 3: remove deprecated columns, remove unwanted directories - cols_to_remove = ["has_pdf_parse", "has_pmc_xml_parse", "full_text_file"] - metadata.drop(columns=cols_to_remove) +from profane import import_all_modules - dir_to_remove = ["comm_use_subset", "noncomm_use_subset", "custom_license", "biorxiv_medrxiv", "arxiv"] - for dir in dir_to_remove: - dir = root_path / dir - for subdir in os.listdir(dir): - os.rmdir(dir / subdir) # since we are supposed to move away all the files - os.rmdir(dir) +from .dummy import DummyCollection +from .antique import ANTIQUE +from .nf import NF +from .robust04 import Robust04 - # assert len(metadata.columns) == 19 - # step 4: save back - metadata.to_csv(metadata_csv, index=False) +import_all_modules(__file__, __package__) diff --git a/capreolus/collection/antique.py b/capreolus/collection/antique.py new file mode 100644 index 000000000..a8eca5be3 --- /dev/null +++ b/capreolus/collection/antique.py @@ -0,0 +1,67 @@ +import os +import shutil + +from . import Collection +from capreolus import ModuleBase, Dependency, ConfigOption, constants +from capreolus.utils.common import download_file, hash_file +from capreolus.utils.loginit import get_logger + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Collection.register +class ANTIQUE(Collection): + """A Non-factoid Question Answering Benchmark from Hashemi et al. [1] + + [1] Helia Hashemi, Mohammad Aliannejadi, Hamed Zamani, and W. Bruce Croft. 2020. ANTIQUE: A non-factoid question answering benchmark. ECIR 2020. + """ + + module_name = "antique" + _path = PACKAGE_PATH / "data" / "antique-collection" + + collection_type = "TrecCollection" + generator_type = "DefaultLuceneDocumentGenerator" + + def download_if_missing(self): + url = "http://ciir.cs.umass.edu/downloads/Antique/antique-collection.txt" + cachedir = self.get_cache_path() + document_dir = os.path.join(cachedir, "documents") + coll_filename = os.path.join(document_dir, "antique-collection.txt") + + if os.path.exists(coll_filename): + return document_dir + + tmp_dir = cachedir / "tmp" + tmp_filename = os.path.join(tmp_dir, "tmp.anqique.file") + + os.makedirs(tmp_dir, exist_ok=True) + os.makedirs(document_dir, exist_ok=True) + + download_file(url, tmp_filename, expected_hash="68b6688f5f2668c93f0e8e43384f66def768c4da46da4e9f7e2629c1c47a0c36") + self._convert_to_trec(inp_path=tmp_filename, outp_path=coll_filename) + logger.info(f"antique collection file prepared, stored at {coll_filename}") + + for file in os.listdir(tmp_dir): # in case there are legacy files + os.remove(os.path.join(tmp_dir, file)) + shutil.rmtree(tmp_dir) + + return document_dir + + def _convert_to_trec(self, inp_path, outp_path): + assert os.path.exists(inp_path) + + fout = open(outp_path, "wt", encoding="utf-8") + with open(inp_path, "rt", encoding="utf-8") as f: + for line in f: + docid, doc = line.strip().split("\t") + fout.write(f"\n{docid}\n\n{doc}\n\n\n") + fout.close() + logger.debug(f"Converted file {os.path.basename(inp_path)} to TREC format, output to: {outp_path}") + + def _validate_document_path(self, path): + """ Checks that the sha256sum is correct """ + return ( + hash_file(os.path.join(path, "antique-collection.txt")) + == "409e0960f918970977ceab9e5b1d372f45395af25d53b95644bdc9ccbbf973da" + ) diff --git a/capreolus/collection/codesearchnet.py b/capreolus/collection/codesearchnet.py new file mode 100644 index 000000000..48c0f3675 --- /dev/null +++ b/capreolus/collection/codesearchnet.py @@ -0,0 +1,66 @@ +import pickle + +from tqdm import tqdm +from zipfile import ZipFile + +from . import Collection +from capreolus import ModuleBase, Dependency, ConfigOption, constants +from capreolus.utils.common import download_file, hash_file, remove_newline +from capreolus.utils.loginit import get_logger +from capreolus.utils.trec import document_to_trectxt + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Collection.register +class CodeSearchNet(Collection): + """CodeSearchNet Corpus. [1] + + [1] Hamel Husain, Ho-Hsiang Wu, Tiferet Gazit, Miltiadis Allamanis, and Marc Brockschmidt. 2019. CodeSearchNet Challenge: Evaluating the State of Semantic Code Search. arXiv 2019. + """ + + module_name = "codesearchnet" + url = "https://s3.amazonaws.com/code-search-net/CodeSearchNet/v2" + collection_type = "TrecCollection" # TODO: any other supported type? + generator_type = "DefaultLuceneDocumentGenerator" + config_spec = [ConfigOption("lang", "ruby", "CSN language dataset to use")] + + def download_if_missing(self): + cachedir = self.get_cache_path() + document_dir = cachedir / "documents" + coll_filename = document_dir / ("csn-" + self.config["lang"] + "-collection.txt") + + if coll_filename.exists(): + return document_dir + + zipfile = self.config["lang"] + ".zip" + lang_url = f"{self.url}/{zipfile}" + tmp_dir = cachedir / "tmp" + zip_path = tmp_dir / zipfile + + if zip_path.exists(): + logger.info(f"{zipfile} already exist under directory {tmp_dir}, skip downloaded") + else: + tmp_dir.mkdir(exist_ok=True, parents=True) + download_file(lang_url, zip_path) + + document_dir.mkdir(exist_ok=True, parents=True) # tmp + with ZipFile(zip_path, "r") as zipobj: + zipobj.extractall(tmp_dir) + + pkl_path = tmp_dir / (self.config["lang"] + "_dedupe_definitions_v2.pkl") + self._pkl2trec(pkl_path, coll_filename) + return document_dir + + def _pkl2trec(self, pkl_path, trec_path): + lang = self.config["lang"] + with open(pkl_path, "rb") as f: + codes = pickle.load(f) + + fout = open(trec_path, "w", encoding="utf-8") + for i, code in tqdm(enumerate(codes), desc=f"Preparing the {lang} collection file"): + docno = f"{lang}-FUNCTION-{i}" + doc = remove_newline(" ".join(code["function_tokens"])) + fout.write(document_to_trectxt(docno, doc)) + fout.close() diff --git a/capreolus/collection/covid.py b/capreolus/collection/covid.py new file mode 100644 index 000000000..324d3c174 --- /dev/null +++ b/capreolus/collection/covid.py @@ -0,0 +1,200 @@ +import os +import math +import tarfile +import filecmp + +from tqdm import tqdm +from pathlib import Path +import pandas as pd + +from . import Collection +from capreolus import ModuleBase, Dependency, ConfigOption, constants +from capreolus.utils.common import download_file, hash_file, remove_newline +from capreolus.utils.loginit import get_logger + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Collection.register +class COVID(Collection): + """ The COVID-19 Open Research Dataset (https://www.semanticscholar.org/cord19) """ + + module_name = "covid" + url = "https://ai2-semanticscholar-cord-19.s3-us-west-2.amazonaws.com/historical_releases/cord-19_%s.tar.gz" + generator_type = "Cord19Generator" + config_spec = [ConfigOption("coll_type", "abstract", "one of: abstract, fulltext, paragraph"), ConfigOption("round", 3)] + + def build(self): + coll_type, round = self.config["coll_type"], self.config["round"] + type2coll = { + "abstract": "Cord19AbstractCollection", + "fulltext": "Cord19FullTextCollection", + "paragraph": "Cord19ParagraphCollection", + } + dates = ["2020-04-10", "2020-05-01", "2020-05-19"] + + if coll_type not in type2coll: + raise ValueError(f"Unexpected coll_type: {coll_type}; expeced one of: {' '.join(type2coll.keys())}") + if round > len(dates): + raise ValueError(f"Unexpected round number: {round}; only {len(dates)} number of rounds are provided") + + self.collection_type = type2coll[coll_type] + self.date = dates[round - 1] + + def download_if_missing(self): + cachedir = self.get_cache_path() + tmp_dir, document_dir = Path("/tmp"), cachedir / "documents" + expected_fns = [document_dir / "metadata.csv", document_dir / "document_parses"] + if all([os.path.exists(f) for f in expected_fns]): + return document_dir + + url = self.url % self.date + tar_file = tmp_dir / f"covid-19-{self.date}.tar.gz" + if not tar_file.exists(): + download_file(url, tar_file) + + with tarfile.open(tar_file) as f: + f.extractall(path=cachedir) # emb.tar.gz, metadata.csv, doc.tar.gz, changelog + os.rename(cachedir / self.date, document_dir) + + doc_fn = "document_parses" + if f"{doc_fn}.tar.gz" in os.listdir(document_dir): + with tarfile.open(document_dir / f"{doc_fn}.tar.gz") as f: + f.extractall(path=document_dir) + else: + self.transform_metadata(document_dir) + + # only document_parses and metadata.csv are expected + for fn in os.listdir(document_dir): + if (document_dir / fn) not in expected_fns: + os.remove(document_dir / fn) + return document_dir + + def transform_metadata(self, root_path): + """ + the transformation is necessary for dataset round 1 and 2 according to + https://discourse.cord-19.semanticscholar.org/t/faqs-about-cord-19-dataset/94 + + the assumed directory under root_path: + ./root_path + ./metadata.csv + ./comm_use_subset + ./noncomm_use_subset + ./custom_license + ./biorxiv_medrxiv + ./archive + + In a nutshell: + 1. renaming: + Microsoft Academic Paper ID -> mag_id; + WHO #Covidence -> who_covidence_id + 2. update: + has_pdf_parse -> pdf_json_files # e.g. document_parses/pmc_json/PMC125340.xml.json + has_pmc_xml_parse -> pmc_json_files + """ + metadata_csv = str(root_path / "metadata.csv") + orifiles = ["arxiv", "custom_license", "biorxiv_medrxiv", "comm_use_subset", "noncomm_use_subset"] + for fn in orifiles: + if (root_path / fn).exists(): + continue + + tar_fn = root_path / f"{fn}.tar.gz" + if not tar_fn.exists(): + continue + + with tarfile.open(str(tar_fn)) as f: + f.extractall(path=root_path) + os.remove(tar_fn) + + metadata = pd.read_csv(metadata_csv, header=0) + columns = metadata.columns.values + cols_before = [ + "cord_uid", + "sha", + "source_x", + "title", + "doi", + "pmcid", + "pubmed_id", + "license", + "abstract", + "publish_time", + "authors", + "journal", + "Microsoft Academic Paper ID", + "WHO #Covidence", + "arxiv_id", + "has_pdf_parse", + "has_pmc_xml_parse", + "full_text_file", + "url", + ] + assert all(columns == cols_before) + + # step 1: rename column + cols_to_rename = {"Microsoft Academic Paper ID": "mag_id", "WHO #Covidence": "who_covidence_id"} + metadata.columns = [cols_to_rename.get(c, c) for c in columns] + + # step 2: parse path & move json file + doc_outp = root_path / "document_parses" + pdf_dir, pmc_dir = doc_outp / "pdf_json", doc_outp / "pmc_json" + pdf_dir.mkdir(exist_ok=True, parents=True) + pmc_dir.mkdir(exist_ok=True, parents=True) + + new_cols = ["pdf_json_files", "pmc_json_files"] + for col in new_cols: + metadata[col] = "" + metadata["s2_id"] = math.nan # tmp, what's this column?? + + iterbar = tqdm(desc="transforming data", total=len(metadata)) + for i, row in metadata.iterrows(): + dir = row["full_text_file"] + + if row["has_pmc_xml_parse"]: + name = row["pmcid"] + ".xml.json" + ori_fn = root_path / dir / "pmc_json" / name + pmc_fn = f"document_parses/pmc_json/{name}" + metadata.at[i, "pmc_json_files"] = pmc_fn + pmc_fn = root_path / pmc_fn + if not pmc_fn.exists(): + os.rename(ori_fn, pmc_fn) # check + else: + metadata.at[i, "pmc_json_files"] = math.nan + + if row["has_pdf_parse"]: + shas = str(row["sha"]).split(";") + pdf_fn_final = "" + for sha in shas: + name = sha.strip() + ".json" + ori_fn = root_path / dir / "pdf_json" / name + pdf_fn = f"document_parses/pdf_json/{name}" + pdf_fn_final = f"{pdf_fn_final};{pdf_fn}" if pdf_fn_final else pdf_fn + pdf_fn = root_path / pdf_fn + if not pdf_fn.exists(): + os.rename(ori_fn, pdf_fn) # check + else: + if ori_fn.exists(): + assert filecmp.cmp(ori_fn, pdf_fn) + os.remove(ori_fn) + + metadata.at[i, "pdf_json_files"] = pdf_fn_final + else: + metadata.at[i, "pdf_json_files"] = math.nan + + iterbar.update() + + # step 3: remove deprecated columns, remove unwanted directories + cols_to_remove = ["has_pdf_parse", "has_pmc_xml_parse", "full_text_file"] + metadata.drop(columns=cols_to_remove) + + dir_to_remove = ["comm_use_subset", "noncomm_use_subset", "custom_license", "biorxiv_medrxiv", "arxiv"] + for dir in dir_to_remove: + dir = root_path / dir + for subdir in os.listdir(dir): + os.rmdir(dir / subdir) # since we are supposed to move away all the files + os.rmdir(dir) + + # assert len(metadata.columns) == 19 + # step 4: save back + metadata.to_csv(metadata_csv, index=False) diff --git a/capreolus/collection/dummy.py b/capreolus/collection/dummy.py new file mode 100644 index 000000000..baa9d41ca --- /dev/null +++ b/capreolus/collection/dummy.py @@ -0,0 +1,21 @@ +import os + +from . import Collection +from capreolus import constants, get_logger + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Collection.register +class DummyCollection(Collection): + """ Tiny collection for testing """ + + module_name = "dummy" + _path = PACKAGE_PATH / "data" / "dummy" / "data" + collection_type = "TrecCollection" + generator_type = "DefaultLuceneDocumentGenerator" + + def _validate_document_path(self, path): + """ Validate that the document path contains `dummy_trec_doc` """ + return "dummy_trec_doc" in os.listdir(path) diff --git a/capreolus/collection/msmarco.py b/capreolus/collection/msmarco.py new file mode 100644 index 000000000..e1f8ecbe4 --- /dev/null +++ b/capreolus/collection/msmarco.py @@ -0,0 +1,10 @@ +from . import Collection +from capreolus import ConfigOption + +# @Collection.register +# class MSMarco(Collection): +# module_name = "msmarco" +# config_keys_not_in_path = ["path"] +# collection_type = "TrecCollection" +# generator_type = "DefaultLuceneDocumentGenerator" +# config_spec = [ConfigOption("path", "/GW/NeuralIR/nobackup/msmarco/trec_format", "path to corpus")] diff --git a/capreolus/collection/nf.py b/capreolus/collection/nf.py new file mode 100644 index 000000000..a1afc40cd --- /dev/null +++ b/capreolus/collection/nf.py @@ -0,0 +1,71 @@ +import os +import tarfile + +from . import Collection +from capreolus import constants, get_logger +from capreolus.utils.common import download_file + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Collection.register +class NF(Collection): + """ NFCorpus: A Full-Text Learning to Rank Dataset for Medical Information Retrieval [1] + + [1] Vera Boteva, Demian Gholipour, Artem Sokolov and Stefan Riezler. A Full-Text Learning to Rank Dataset for Medical Information Retrieval Proceedings of the 38th European Conference on Information Retrieval (ECIR), Padova, Italy, 2016. https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/ + """ + + module_name = "nf" + _path = PACKAGE_PATH / "data" / "nf-collection" + url = "http://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/nfcorpus.tar.gz" + + collection_type = "TrecCollection" + generator_type = "DefaultLuceneDocumentGenerator" + + def download_raw(self): + cachedir = self.get_cache_path() + tmp_dir = cachedir / "tmp" + tmp_tar_fn, tmp_corpus_dir = tmp_dir / "nfcorpus.tar.gz", tmp_dir / "nfcorpus" + + os.makedirs(tmp_dir, exist_ok=True) + + if not tmp_tar_fn.exists(): + download_file(self.url, tmp_tar_fn, "ebc026d4a8bef3f866148b727e945a2073eb4045ede9b7de95dd50fd086b4256") + + with tarfile.open(tmp_tar_fn) as f: + f.extractall(tmp_dir) + return tmp_corpus_dir + + def download_if_missing(self): + cachedir = self.get_cache_path() + document_dir = os.path.join(cachedir, "documents") + coll_filename = os.path.join(document_dir, "nf-collection.txt") + if os.path.exists(coll_filename): + return document_dir + + os.makedirs(document_dir, exist_ok=True) + tmp_corpus_dir = self.download_raw() + + inp_fns = [tmp_corpus_dir / f"{set_name}.docs" for set_name in ["train", "dev", "test"]] + print(inp_fns) + with open(coll_filename, "w", encoding="utf-8") as outp_file: + self._convert_to_trec(inp_fns, outp_file) + logger.info(f"nf collection file prepared, stored at {coll_filename}") + + return document_dir + + def _convert_to_trec(self, inp_fns, outp_file): + # train.docs, dev.docs, and test.docs have some overlap, so we check for duplicate docids + seen_docids = set() + + for inp_fn in inp_fns: + assert os.path.exists(inp_fn) + + with open(inp_fn, "rt", encoding="utf-8") as f: + for line in f: + docid, doc = line.strip().split("\t") + + if docid not in seen_docids: + outp_file.write(f"\n{docid}\n\n{doc}\n\n\n") + seen_docids.add(docid) diff --git a/capreolus/collection/robust04.py b/capreolus/collection/robust04.py new file mode 100644 index 000000000..b5ac211c0 --- /dev/null +++ b/capreolus/collection/robust04.py @@ -0,0 +1,101 @@ +import os +import shutil +import tarfile + +from . import Collection +from capreolus import ModuleBase, Dependency, ConfigOption, constants +from capreolus.utils.common import download_file, hash_file, remove_newline +from capreolus.utils.loginit import get_logger +from capreolus.utils.trec import anserini_index_to_trec_docs, document_to_trectxt + +logger = get_logger(__name__) +PACKAGE_PATH = constants["PACKAGE_PATH"] + + +@Collection.register +class Robust04(Collection): + """ TREC Robust04 (TREC disks 4 and 5 without the Congressional Record documents) """ + + module_name = "robust04" + collection_type = "TrecCollection" + generator_type = "DefaultLuceneDocumentGenerator" + config_keys_not_in_path = ["path"] + config_spec = [ConfigOption("path", "Aquaint-TREC-3-4", "path to corpus")] + + def download_if_missing(self): + return self.download_index( + url="https://git.uwaterloo.ca/jimmylin/anserini-indexes/raw/master/index-robust04-20191213.tar.gz", + sha256="dddb81f16d70ea6b9b0f94d6d6b888ed2ef827109a14ca21fd82b2acd6cbd450", + index_directory_inside="index-robust04-20191213/", + # this string should match how the index was built (i.e., Anserini, stopwords removed, Porter stemming) + index_cache_path_string="index-anserini_indexstops-False_stemmer-porter", + index_expected_document_count=528_030, + cachedir=self.get_cache_path(), + ) + + def _validate_document_path(self, path): + """ Validate that the document path appears to contain robust04's documents (Aquaint-TREC-3-4). + + Validation is performed by looking for four directories (case-insensitive): `FBIS`, `FR94`, `FT`, and `LATIMES`. + These directories may either be at the root of `path` or they may be in `path/NEWS_data` (case-insensitive). + + Returns: + True if the Aquaint-TREC-3-4 document directories are found or False if not + """ + + if not os.path.isdir(path): + return False + + contents = {fn.lower(): fn for fn in os.listdir(path)} + if "news_data" in contents: + contents = {fn.lower(): fn for fn in os.listdir(os.path.join(path, contents["news_data"]))} + + if "fbis" in contents and "fr94" in contents and "ft" in contents and "latimes" in contents: + return True + + return False + + def download_index( + self, cachedir, url, sha256, index_directory_inside, index_cache_path_string, index_expected_document_count + ): + # Download the collection from URL and extract into a path in the cache directory. + # To avoid re-downloading every call, we create an empty '/done' file in this directory on success. + done_file = os.path.join(cachedir, "done") + document_dir = os.path.join(cachedir, "documents") + + # already downloaded? + if os.path.exists(done_file): + return document_dir + + # 1. Download and extract Anserini index to a temporary location + tmp_dir = os.path.join(cachedir, "tmp_download") + archive_file = os.path.join(tmp_dir, "archive_file") + os.makedirs(document_dir, exist_ok=True) + os.makedirs(tmp_dir, exist_ok=True) + logger.info("downloading index for missing collection %s to temporary file %s", self.module_name, archive_file) + download_file(url, archive_file, expected_hash=sha256) + + logger.info("extracting index to %s (before moving to correct cache path)", tmp_dir) + with tarfile.open(archive_file) as tar: + tar.extractall(path=tmp_dir) + + extracted_dir = os.path.join(tmp_dir, index_directory_inside) + if not (os.path.exists(extracted_dir) and os.path.isdir(extracted_dir)): + raise ValueError(f"could not find expected index directory {extracted_dir} in {tmp_dir}") + + # 2. Move index to its correct location in the cache + index_dir = os.path.join(cachedir, index_cache_path_string, "index") + if not os.path.exists(os.path.join("index_dir", "done")): + if os.path.exists(index_dir): + shutil.rmtree(index_dir) + shutil.move(extracted_dir, index_dir) + + # 3. Extract raw documents from the Anserini index to document_dir + anserini_index_to_trec_docs(index_dir, document_dir, index_expected_document_count) + + # remove temporary files and create a /done we can use to verify extraction was successful + shutil.rmtree(tmp_dir) + with open(done_file, "wt") as outf: + print("", file=outf) + + return document_dir diff --git a/capreolus/extractor/__init__.py b/capreolus/extractor/__init__.py index 5590bf2ea..0f05df178 100644 --- a/capreolus/extractor/__init__.py +++ b/capreolus/extractor/__init__.py @@ -1,27 +1,18 @@ -from profane import import_all_modules - -# import_all_modules(__file__, __package__) - -import pickle -from collections import defaultdict -import tensorflow as tf - -import os -import numpy as np import hashlib -from pymagnitude import Magnitude, MagnitudeUtils -from tqdm import tqdm -from profane import ModuleBase, Dependency, ConfigOption, constants - +import os -from capreolus.utils.loginit import get_logger -from capreolus.utils.common import padlist -from capreolus.utils.exceptions import MissingDocError +from capreolus import ModuleBase, get_logger logger = get_logger(__name__) class Extractor(ModuleBase): + """Base class for Extractor modules. The purpose of an Extractor is to convert queries and documents to a representation suitable for use with a :class:`~capreolus.reranker.Reranker` module. + + Modules should provide: + - an ``id2vec(qid, posid, negid=None)`` method that converts the given query and document ids to an appropriate representation + """ + module_type = "extractor" def _extend_stoi(self, toks_list, calc_idf=False): @@ -75,349 +66,7 @@ def build_from_benchmark(self, *args, **kwargs): raise NotImplementedError -@Extractor.register -class EmbedText(Extractor): - module_name = "embedtext" - requires_random_seed = True - dependencies = [ - Dependency( - key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"} - ), - Dependency(key="tokenizer", module="tokenizer", name="anserini"), - ] - config_spec = [ - ConfigOption("embeddings", "glove6b"), - ConfigOption("zerounk", False), - ConfigOption("calcidf", True), - ConfigOption("maxqlen", 4), - ConfigOption("maxdoclen", 800), - ConfigOption("usecache", False), - ] - - pad = 0 - pad_tok = "" - embed_paths = { - "glove6b": "glove/light/glove.6B.300d", - "glove6b.50d": "glove/light/glove.6B.50d", - "w2vnews": "word2vec/light/GoogleNews-vectors-negative300", - "fasttext": "fasttext/light/wiki-news-300d-1M-subword", - } - - def _get_pretrained_emb(self): - magnitude_cache = constants["CACHE_BASE_PATH"] / "magnitude/" - return Magnitude(MagnitudeUtils.download_model(self.embed_paths[self.config["embeddings"]], download_dir=magnitude_cache)) - - def load_state(self, qids, docids): - with open(self.get_state_cache_file_path(qids, docids), "rb") as f: - state_dict = pickle.load(f) - self.qid2toks = state_dict["qid2toks"] - self.docid2toks = state_dict["docid2toks"] - self.stoi = state_dict["stoi"] - self.itos = state_dict["itos"] - - def cache_state(self, qids, docids): - os.makedirs(self.get_cache_path(), exist_ok=True) - with open(self.get_state_cache_file_path(qids, docids), "wb") as f: - state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "stoi": self.stoi, "itos": self.itos} - pickle.dump(state_dict, f, protocol=-1) - - def get_tf_feature_description(self): - feature_description = { - "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64), - "query_idf": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.float32), - "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), - "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), - "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)), - } - - return feature_description - - def create_tf_feature(self, sample): - """ - sample - output from self.id2vec() - return - a tensorflow feature - """ - query, query_idf, posdoc, negdoc = (sample["query"], sample["query_idf"], sample["posdoc"], sample["negdoc"]) - feature = { - "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)), - "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)), - "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)), - "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)), - } - - return feature - - def parse_tf_example(self, example_proto): - feature_description = self.get_tf_feature_description() - parsed_example = tf.io.parse_example(example_proto, feature_description) - posdoc = parsed_example["posdoc"] - negdoc = parsed_example["negdoc"] - query = parsed_example["query"] - query_idf = parsed_example["query_idf"] - label = parsed_example["label"] - - return (posdoc, negdoc, query, query_idf), label - - def _build_vocab(self, qids, docids, topics): - if self.is_state_cached(qids, docids) and self.config["usecache"]: - self.load_state(qids, docids) - logger.info("Vocabulary loaded from cache") - else: - tokenize = self.tokenizer.tokenize - self.qid2toks = {qid: tokenize(topics[qid]) for qid in qids} - self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in docids} - self._extend_stoi(self.qid2toks.values(), calc_idf=self.config["calcidf"]) - self._extend_stoi(self.docid2toks.values(), calc_idf=self.config["calcidf"]) - self.itos = {i: s for s, i in self.stoi.items()} - logger.info(f"vocabulary constructed, with {len(self.itos)} terms in total") - if self.config["usecache"]: - self.cache_state(qids, docids) - - def _get_idf(self, toks): - return [self.idf.get(tok, 0) for tok in toks] - - def _build_embedding_matrix(self): - assert len(self.stoi) > 1 # needs more vocab than self.pad_tok - - magnitude_emb = self._get_pretrained_emb() - emb_dim = magnitude_emb.dim - embed_vocab = set(term for term, _ in magnitude_emb) - embed_matrix = np.zeros((len(self.stoi), emb_dim), dtype=np.float32) - - n_missed = 0 - for term, idx in tqdm(self.stoi.items()): - if term in embed_vocab: - embed_matrix[idx] = magnitude_emb.query(term) - elif term == self.pad_tok: - embed_matrix[idx] = np.zeros(emb_dim) - else: - n_missed += 1 - embed_matrix[idx] = np.zeros(emb_dim) if self.config["zerounk"] else np.random.normal(scale=0.5, size=emb_dim) - - logger.info(f"embedding matrix {self.config['embeddings']} constructed, with shape {embed_matrix.shape}") - if n_missed > 0: - logger.warning(f"{n_missed}/{len(self.stoi)} (%.3f) term missed" % (n_missed / len(self.stoi))) - - self.embeddings = embed_matrix - - def exist(self): - return ( - hasattr(self, "embeddings") - and self.embeddings is not None - and isinstance(self.embeddings, np.ndarray) - and 0 < len(self.stoi) == self.embeddings.shape[0] - ) - - def preprocess(self, qids, docids, topics): - if self.exist(): - return - - self.index.create_index() - - self.itos = {self.pad: self.pad_tok} - self.stoi = {self.pad_tok: self.pad} - self.qid2toks = defaultdict(list) - self.docid2toks = defaultdict(list) - self.idf = defaultdict(lambda: 0) - self.embeddings = None - # self.cache = self.load_cache() # TODO - - self._build_vocab(qids, docids, topics) - self._build_embedding_matrix() - - def _tok2vec(self, toks): - # return [self.embeddings[self.stoi[tok]] for tok in toks] - return [self.stoi[tok] for tok in toks] - - def id2vec(self, qid, posid, negid=None): - query = self.qid2toks[qid] - - # TODO find a way to calculate qlen/doclen stats earlier, so we can log them and check sanity of our values - qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"] - posdoc = self.docid2toks.get(posid, None) - if not posdoc: - raise MissingDocError(qid, posid) - - idfs = padlist(self._get_idf(query), qlen, 0) - query = self._tok2vec(padlist(query, qlen, self.pad_tok)) - posdoc = self._tok2vec(padlist(posdoc, doclen, self.pad_tok)) - - # TODO determine whether pin_memory is happening. may not be because we don't place the strings in a np or torch object - data = { - "qid": qid, - "posdocid": posid, - "idfs": np.array(idfs, dtype=np.float32), - "query": np.array(query, dtype=np.long), - "posdoc": np.array(posdoc, dtype=np.long), - "query_idf": np.array(idfs, dtype=np.float32), - "negdocid": "", - "negdoc": np.zeros(self.config["maxdoclen"], dtype=np.long), - } - - if negid: - negdoc = self.docid2toks.get(negid, None) - if not negdoc: - raise MissingDocError(qid, negid) - - negdoc = self._tok2vec(padlist(negdoc, doclen, self.pad_tok)) - data["negdocid"] = negid - data["negdoc"] = np.array(negdoc, dtype=np.long) - - return data - - -@Extractor.register -class BertText(Extractor): - module_name = "berttext" - dependencies = [ - Dependency( - key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"} - ), - Dependency(key="tokenizer", module="tokenizer", name="berttokenizer"), - ] - config_spec = [ConfigOption("maxqlen", 4), ConfigOption("maxdoclen", 800), ConfigOption("usecache", False)] - - pad = 0 - pad_tok = "" - - @staticmethod - def config(): - maxqlen = 4 - maxdoclen = 800 - usecache = False - - def load_state(self, qids, docids): - with open(self.get_state_cache_file_path(qids, docids), "rb") as f: - state_dict = pickle.load(f) - self.qid2toks = state_dict["qid2toks"] - self.docid2toks = state_dict["docid2toks"] - self.clsidx = state_dict["clsidx"] - self.sepidx = state_dict["sepidx"] - - def cache_state(self, qids, docids): - os.makedirs(self.get_cache_path(), exist_ok=True) - with open(self.get_state_cache_file_path(qids, docids), "wb") as f: - state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "clsidx": self.clsidx, "sepidx": self.sepidx} - pickle.dump(state_dict, f, protocol=-1) - - def get_tf_feature_description(self): - feature_description = { - "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64), - "query_mask": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64), - "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), - "posdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), - "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), - "negdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), - "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)), - } - - return feature_description - - def create_tf_feature(self, sample): - """ - sample - output from self.id2vec() - return - a tensorflow feature - """ - query, posdoc, negdoc, negdoc_id = sample["query"], sample["posdoc"], sample["negdoc"], sample["negdocid"] - query_mask, posdoc_mask, negdoc_mask = sample["query_mask"], sample["posdoc_mask"], sample["negdoc_mask"] - - feature = { - "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)), - "query_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=query_mask)), - "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)), - "posdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc_mask)), - "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)), - "negdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc_mask)), - } - - return feature - - def parse_tf_example(self, example_proto): - feature_description = self.get_tf_feature_description() - parsed_example = tf.io.parse_example(example_proto, feature_description) - posdoc = parsed_example["posdoc"] - posdoc_mask = parsed_example["posdoc_mask"] - negdoc = parsed_example["negdoc"] - negdoc_mask = parsed_example["negdoc_mask"] - query = parsed_example["query"] - query_mask = parsed_example["query_mask"] - label = parsed_example["label"] - - return (posdoc, posdoc_mask, negdoc, negdoc_mask, query, query_mask), label - - def _build_vocab(self, qids, docids, topics): - if self.is_state_cached(qids, docids) and self.config["usecache"]: - self.load_state(qids, docids) - logger.info("Vocabulary loaded from cache") - else: - logger.info("Building bertext vocabulary") - tokenize = self.tokenizer.tokenize - self.qid2toks = {qid: tokenize(topics[qid]) for qid in tqdm(qids, desc="querytoks")} - self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in tqdm(docids, desc="doctoks")} - self.clsidx, self.sepidx = self.tokenizer.convert_tokens_to_ids(["CLS", "SEP"]) - - self.cache_state(qids, docids) - - def exist(self): - return hasattr(self, "docid2toks") and len(self.docid2toks) - - def preprocess(self, qids, docids, topics): - if self.exist(): - return - - self.index.create_index() - self.qid2toks = defaultdict(list) - self.docid2toks = defaultdict(list) - self.clsidx = None - self.sepidx = None - - self._build_vocab(qids, docids, topics) - - def id2vec(self, qid, posid, negid=None): - tokenizer = self.tokenizer - qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"] - - query_toks = tokenizer.convert_tokens_to_ids(self.qid2toks[qid]) - query_mask = self.get_mask(query_toks, qlen) - query = padlist(query_toks, qlen) - - posdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks[posid]) - posdoc_mask = self.get_mask(posdoc_toks, doclen) - posdoc = padlist(posdoc_toks, doclen) - - data = { - "qid": qid, - "posdocid": posid, - "idfs": np.zeros(qlen, dtype=np.float32), - "query": np.array(query, dtype=np.long), - "query_mask": np.array(query_mask, dtype=np.long), - "posdoc": np.array(posdoc, dtype=np.long), - "posdoc_mask": np.array(posdoc_mask, dtype=np.long), - "query_idf": np.array(query, dtype=np.float32), - "negdocid": "", - "negdoc": np.zeros(doclen, dtype=np.long), - "negdoc_mask": np.zeros(doclen, dtype=np.long), - } - - if negid: - negdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks.get(negid, None)) - negdoc_mask = self.get_mask(negdoc_toks, doclen) - negdoc = padlist(negdoc_toks, doclen) - - if not negdoc: - raise MissingDocError(qid, negid) - - data["negdocid"] = negid - data["negdoc"] = np.array(negdoc, dtype=np.long) - data["negdoc_mask"] = np.array(negdoc_mask, dtype=np.long) +from profane import import_all_modules - return data - def get_mask(self, doc, to_len): - """ - Returns a mask where it is 1 for actual toks and 0 for pad toks - """ - s = doc[:to_len] - padlen = to_len - len(s) - mask = [1 for _ in s] + [0 for _ in range(padlen)] - return mask +import_all_modules(__file__, __package__) diff --git a/capreolus/extractor/bagofwords.py b/capreolus/extractor/bagofwords.py index 53ebceaf9..94fd8c7df 100644 --- a/capreolus/extractor/bagofwords.py +++ b/capreolus/extractor/bagofwords.py @@ -1,9 +1,9 @@ import pickle import os import time -from profane import Dependency, ConfigOption -from capreolus.extractor import Extractor +from capreolus import Dependency, ConfigOption +from . import Extractor from capreolus.tokenizer import Tokenizer from capreolus.utils.loginit import get_logger from tqdm import tqdm diff --git a/capreolus/extractor/berttext.py b/capreolus/extractor/berttext.py new file mode 100644 index 000000000..6b464663d --- /dev/null +++ b/capreolus/extractor/berttext.py @@ -0,0 +1,165 @@ +import os +import pickle +from collections import defaultdict + +import numpy as np +import tensorflow as tf +from tqdm import tqdm + +from . import Extractor +from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger +from capreolus.utils.common import padlist +from capreolus.utils.exceptions import MissingDocError + +logger = get_logger(__name__) + + +@Extractor.register +class BertText(Extractor): + module_name = "berttext" + dependencies = [ + Dependency( + key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"} + ), + Dependency(key="tokenizer", module="tokenizer", name="berttokenizer"), + ] + config_spec = [ConfigOption("maxqlen", 4), ConfigOption("maxdoclen", 800), ConfigOption("usecache", False)] + + pad = 0 + pad_tok = "" + + def load_state(self, qids, docids): + with open(self.get_state_cache_file_path(qids, docids), "rb") as f: + state_dict = pickle.load(f) + self.qid2toks = state_dict["qid2toks"] + self.docid2toks = state_dict["docid2toks"] + self.clsidx = state_dict["clsidx"] + self.sepidx = state_dict["sepidx"] + + def cache_state(self, qids, docids): + os.makedirs(self.get_cache_path(), exist_ok=True) + with open(self.get_state_cache_file_path(qids, docids), "wb") as f: + state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "clsidx": self.clsidx, "sepidx": self.sepidx} + pickle.dump(state_dict, f, protocol=-1) + + def get_tf_feature_description(self): + feature_description = { + "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64), + "query_mask": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64), + "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), + "posdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), + "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), + "negdoc_mask": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), + "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)), + } + + return feature_description + + def create_tf_feature(self, sample): + """ + sample - output from self.id2vec() + return - a tensorflow feature + """ + query, posdoc, negdoc, negdoc_id = sample["query"], sample["posdoc"], sample["negdoc"], sample["negdocid"] + query_mask, posdoc_mask, negdoc_mask = sample["query_mask"], sample["posdoc_mask"], sample["negdoc_mask"] + + feature = { + "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)), + "query_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=query_mask)), + "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)), + "posdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc_mask)), + "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)), + "negdoc_mask": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc_mask)), + } + + return feature + + def parse_tf_example(self, example_proto): + feature_description = self.get_tf_feature_description() + parsed_example = tf.io.parse_example(example_proto, feature_description) + posdoc = parsed_example["posdoc"] + posdoc_mask = parsed_example["posdoc_mask"] + negdoc = parsed_example["negdoc"] + negdoc_mask = parsed_example["negdoc_mask"] + query = parsed_example["query"] + query_mask = parsed_example["query_mask"] + label = parsed_example["label"] + + return (posdoc, posdoc_mask, negdoc, negdoc_mask, query, query_mask), label + + def _build_vocab(self, qids, docids, topics): + if self.is_state_cached(qids, docids) and self.config["usecache"]: + self.load_state(qids, docids) + logger.info("Vocabulary loaded from cache") + else: + logger.info("Building bertext vocabulary") + tokenize = self.tokenizer.tokenize + self.qid2toks = {qid: tokenize(topics[qid]) for qid in tqdm(qids, desc="querytoks")} + self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in tqdm(docids, desc="doctoks")} + self.clsidx, self.sepidx = self.tokenizer.convert_tokens_to_ids(["CLS", "SEP"]) + + self.cache_state(qids, docids) + + def exist(self): + return hasattr(self, "docid2toks") and len(self.docid2toks) + + def preprocess(self, qids, docids, topics): + if self.exist(): + return + + self.index.create_index() + self.qid2toks = defaultdict(list) + self.docid2toks = defaultdict(list) + self.clsidx = None + self.sepidx = None + + self._build_vocab(qids, docids, topics) + + def id2vec(self, qid, posid, negid=None): + tokenizer = self.tokenizer + qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"] + + query_toks = tokenizer.convert_tokens_to_ids(self.qid2toks[qid]) + query_mask = self.get_mask(query_toks, qlen) + query = padlist(query_toks, qlen) + + posdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks[posid]) + posdoc_mask = self.get_mask(posdoc_toks, doclen) + posdoc = padlist(posdoc_toks, doclen) + + data = { + "qid": qid, + "posdocid": posid, + "idfs": np.zeros(qlen, dtype=np.float32), + "query": np.array(query, dtype=np.long), + "query_mask": np.array(query_mask, dtype=np.long), + "posdoc": np.array(posdoc, dtype=np.long), + "posdoc_mask": np.array(posdoc_mask, dtype=np.long), + "query_idf": np.array(query, dtype=np.float32), + "negdocid": "", + "negdoc": np.zeros(doclen, dtype=np.long), + "negdoc_mask": np.zeros(doclen, dtype=np.long), + } + + if negid: + negdoc_toks = tokenizer.convert_tokens_to_ids(self.docid2toks.get(negid, None)) + negdoc_mask = self.get_mask(negdoc_toks, doclen) + negdoc = padlist(negdoc_toks, doclen) + + if not negdoc: + raise MissingDocError(qid, negid) + + data["negdocid"] = negid + data["negdoc"] = np.array(negdoc, dtype=np.long) + data["negdoc_mask"] = np.array(negdoc_mask, dtype=np.long) + + return data + + def get_mask(self, doc, to_len): + """ + Returns a mask where it is 1 for actual toks and 0 for pad toks + """ + s = doc[:to_len] + padlen = to_len - len(s) + mask = [1 for _ in s] + [0 for _ in range(padlen)] + return mask diff --git a/capreolus/extractor/deeptileextractor.py b/capreolus/extractor/deeptileextractor.py index 0eb0ce469..a947eb576 100644 --- a/capreolus/extractor/deeptileextractor.py +++ b/capreolus/extractor/deeptileextractor.py @@ -13,7 +13,7 @@ from tqdm import tqdm from profane import ConfigOption, Dependency, constants -from capreolus.extractor import Extractor +from . import Extractor from capreolus.utils.common import padlist from capreolus.utils.loginit import get_logger diff --git a/capreolus/extractor/embedtext.py b/capreolus/extractor/embedtext.py new file mode 100644 index 000000000..0e90198c5 --- /dev/null +++ b/capreolus/extractor/embedtext.py @@ -0,0 +1,206 @@ +import os +import pickle +from collections import defaultdict + +import numpy as np +import tensorflow as tf +from pymagnitude import Magnitude, MagnitudeUtils +from tqdm import tqdm + +from . import Extractor +from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger +from capreolus.utils.common import padlist +from capreolus.utils.exceptions import MissingDocError + +logger = get_logger(__name__) + + +@Extractor.register +class EmbedText(Extractor): + module_name = "embedtext" + requires_random_seed = True + dependencies = [ + Dependency( + key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"} + ), + Dependency(key="tokenizer", module="tokenizer", name="anserini"), + ] + config_spec = [ + ConfigOption("embeddings", "glove6b"), + ConfigOption("zerounk", False), + ConfigOption("calcidf", True), + ConfigOption("maxqlen", 4), + ConfigOption("maxdoclen", 800), + ConfigOption("usecache", False), + ] + + pad = 0 + pad_tok = "" + embed_paths = { + "glove6b": "glove/light/glove.6B.300d", + "glove6b.50d": "glove/light/glove.6B.50d", + "w2vnews": "word2vec/light/GoogleNews-vectors-negative300", + "fasttext": "fasttext/light/wiki-news-300d-1M-subword", + } + + def _get_pretrained_emb(self): + magnitude_cache = constants["CACHE_BASE_PATH"] / "magnitude/" + return Magnitude(MagnitudeUtils.download_model(self.embed_paths[self.config["embeddings"]], download_dir=magnitude_cache)) + + def load_state(self, qids, docids): + with open(self.get_state_cache_file_path(qids, docids), "rb") as f: + state_dict = pickle.load(f) + self.qid2toks = state_dict["qid2toks"] + self.docid2toks = state_dict["docid2toks"] + self.stoi = state_dict["stoi"] + self.itos = state_dict["itos"] + + def cache_state(self, qids, docids): + os.makedirs(self.get_cache_path(), exist_ok=True) + with open(self.get_state_cache_file_path(qids, docids), "wb") as f: + state_dict = {"qid2toks": self.qid2toks, "docid2toks": self.docid2toks, "stoi": self.stoi, "itos": self.itos} + pickle.dump(state_dict, f, protocol=-1) + + def get_tf_feature_description(self): + feature_description = { + "query": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.int64), + "query_idf": tf.io.FixedLenFeature([self.config["maxqlen"]], tf.float32), + "posdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), + "negdoc": tf.io.FixedLenFeature([self.config["maxdoclen"]], tf.int64), + "label": tf.io.FixedLenFeature([2], tf.float32, default_value=tf.convert_to_tensor([1, 0], dtype=tf.float32)), + } + + return feature_description + + def create_tf_feature(self, sample): + """ + sample - output from self.id2vec() + return - a tensorflow feature + """ + query, query_idf, posdoc, negdoc = (sample["query"], sample["query_idf"], sample["posdoc"], sample["negdoc"]) + feature = { + "query": tf.train.Feature(int64_list=tf.train.Int64List(value=query)), + "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)), + "posdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=posdoc)), + "negdoc": tf.train.Feature(int64_list=tf.train.Int64List(value=negdoc)), + } + + return feature + + def parse_tf_example(self, example_proto): + feature_description = self.get_tf_feature_description() + parsed_example = tf.io.parse_example(example_proto, feature_description) + posdoc = parsed_example["posdoc"] + negdoc = parsed_example["negdoc"] + query = parsed_example["query"] + query_idf = parsed_example["query_idf"] + label = parsed_example["label"] + + return (posdoc, negdoc, query, query_idf), label + + def _build_vocab(self, qids, docids, topics): + if self.is_state_cached(qids, docids) and self.config["usecache"]: + self.load_state(qids, docids) + logger.info("Vocabulary loaded from cache") + else: + tokenize = self.tokenizer.tokenize + self.qid2toks = {qid: tokenize(topics[qid]) for qid in qids} + self.docid2toks = {docid: tokenize(self.index.get_doc(docid)) for docid in docids} + self._extend_stoi(self.qid2toks.values(), calc_idf=self.config["calcidf"]) + self._extend_stoi(self.docid2toks.values(), calc_idf=self.config["calcidf"]) + self.itos = {i: s for s, i in self.stoi.items()} + logger.info(f"vocabulary constructed, with {len(self.itos)} terms in total") + if self.config["usecache"]: + self.cache_state(qids, docids) + + def _get_idf(self, toks): + return [self.idf.get(tok, 0) for tok in toks] + + def _build_embedding_matrix(self): + assert len(self.stoi) > 1 # needs more vocab than self.pad_tok + + magnitude_emb = self._get_pretrained_emb() + emb_dim = magnitude_emb.dim + embed_vocab = set(term for term, _ in magnitude_emb) + embed_matrix = np.zeros((len(self.stoi), emb_dim), dtype=np.float32) + + n_missed = 0 + for term, idx in tqdm(self.stoi.items()): + if term in embed_vocab: + embed_matrix[idx] = magnitude_emb.query(term) + elif term == self.pad_tok: + embed_matrix[idx] = np.zeros(emb_dim) + else: + n_missed += 1 + embed_matrix[idx] = np.zeros(emb_dim) if self.config["zerounk"] else np.random.normal(scale=0.5, size=emb_dim) + + logger.info(f"embedding matrix {self.config['embeddings']} constructed, with shape {embed_matrix.shape}") + if n_missed > 0: + logger.warning(f"{n_missed}/{len(self.stoi)} (%.3f) term missed" % (n_missed / len(self.stoi))) + + self.embeddings = embed_matrix + + def exist(self): + return ( + hasattr(self, "embeddings") + and self.embeddings is not None + and isinstance(self.embeddings, np.ndarray) + and 0 < len(self.stoi) == self.embeddings.shape[0] + ) + + def preprocess(self, qids, docids, topics): + if self.exist(): + return + + self.index.create_index() + + self.itos = {self.pad: self.pad_tok} + self.stoi = {self.pad_tok: self.pad} + self.qid2toks = defaultdict(list) + self.docid2toks = defaultdict(list) + self.idf = defaultdict(lambda: 0) + self.embeddings = None + # self.cache = self.load_cache() # TODO + + self._build_vocab(qids, docids, topics) + self._build_embedding_matrix() + + def _tok2vec(self, toks): + # return [self.embeddings[self.stoi[tok]] for tok in toks] + return [self.stoi[tok] for tok in toks] + + def id2vec(self, qid, posid, negid=None): + query = self.qid2toks[qid] + + # TODO find a way to calculate qlen/doclen stats earlier, so we can log them and check sanity of our values + qlen, doclen = self.config["maxqlen"], self.config["maxdoclen"] + posdoc = self.docid2toks.get(posid, None) + if not posdoc: + raise MissingDocError(qid, posid) + + idfs = padlist(self._get_idf(query), qlen, 0) + query = self._tok2vec(padlist(query, qlen, self.pad_tok)) + posdoc = self._tok2vec(padlist(posdoc, doclen, self.pad_tok)) + + # TODO determine whether pin_memory is happening. may not be because we don't place the strings in a np or torch object + data = { + "qid": qid, + "posdocid": posid, + "idfs": np.array(idfs, dtype=np.float32), + "query": np.array(query, dtype=np.long), + "posdoc": np.array(posdoc, dtype=np.long), + "query_idf": np.array(idfs, dtype=np.float32), + "negdocid": "", + "negdoc": np.zeros(self.config["maxdoclen"], dtype=np.long), + } + + if negid: + negdoc = self.docid2toks.get(negid, None) + if not negdoc: + raise MissingDocError(qid, negid) + + negdoc = self._tok2vec(padlist(negdoc, doclen, self.pad_tok)) + data["negdocid"] = negid + data["negdoc"] = np.array(negdoc, dtype=np.long) + + return data diff --git a/capreolus/index/__init__.py b/capreolus/index/__init__.py index 8ed070c6b..be0ca3df5 100644 --- a/capreolus/index/__init__.py +++ b/capreolus/index/__init__.py @@ -1,22 +1,18 @@ -from profane import import_all_modules - -# import_all_modules(__file__, __package__) - -import logging -import math -import os -import subprocess - -from profane import ModuleBase, Dependency, ConfigOption, constants +from capreolus import ModuleBase, Dependency, ConfigOption, get_logger -from capreolus.utils.common import Anserini -from capreolus.utils.loginit import get_logger logger = get_logger(__name__) # pylint: disable=invalid-name -MAX_THREADS = constants["MAX_THREADS"] class Index(ModuleBase): + """Base class for Index modules. The purpose of an Index module is to represent an inverted index that can be queried with a :class:`~capreolus.searcher.Searcher` module and used to obtain documents and collection statistics. + + Modules should provide: + - a ``_create_index`` method that creates an index on the ``Collection`` dependency + - a ``get_doc(docid)`` and a ``get_docs(docid)`` method + - a ``get_df(term)`` method + """ + module_type = "index" dependencies = [Dependency(key="collection", module="collection")] @@ -46,82 +42,8 @@ def get_docs(self, doc_ids): raise NotImplementedError() -@Index.register -class AnseriniIndex(Index): - module_name = "anserini" - config_spec = [ - ConfigOption("indexstops", False, "should stopwords be indexed? (if False, stopwords are removed)"), - ConfigOption("stemmer", "porter", "stemmer: porter, krovetz, or none"), - ] - - def _create_index(self): - outdir = self.get_index_path() - stops = "-keepStopwords" if self.config["indexstops"] else "" - stemmer = "none" if self.config["stemmer"] is None else self.config["stemmer"] - - collection_path, document_type, generator_type = self.collection.get_path_and_types() - - anserini_fat_jar = Anserini.get_fat_jar() - if self.collection.is_large_collection: - cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -stemmer {stemmer} {stops}" - else: - cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -storePositions -storeDocvectors -storeContents -stemmer {stemmer} {stops}" - - logger.info("building index %s", outdir) - logger.debug(cmd) - os.makedirs(os.path.basename(outdir), exist_ok=True) - - app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) - - # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger - for line in app.stdout: - Anserini.filter_and_log_anserini_output(line, logger) +from profane import import_all_modules - app.wait() - if app.returncode != 0: - raise RuntimeError("command failed") +from .anserini import AnseriniIndex - def get_docs(self, doc_ids): - # if self.collection.is_large_collection: - # return self.get_documents_from_disk(doc_ids) - return [self.get_doc(doc_id) for doc_id in doc_ids] - - def get_doc(self, docid): - try: - if not hasattr(self, "index_utils") or self.index_utils is None: - self.open() - return self.index_reader_utils.documentContents(self.reader, self.JString(docid)) - except Exception as e: - raise - - def get_df(self, term): - # returns 0 for missing terms - if not hasattr(self, "reader") or self.reader is None: - self.open() - jterm = self.JTerm("contents", term) - return self.reader.docFreq(jterm) - - def get_idf(self, term): - """ BM25's IDF with a floor of 0 """ - df = self.get_df(term) - idf = (self.numdocs - df + 0.5) / (df + 0.5) - idf = math.log(1 + idf) - return max(idf, 0) - - def open(self): - from jnius import autoclass - - index_path = self.get_index_path().as_posix() - - JIndexUtils = autoclass("io.anserini.index.IndexUtils") - JIndexReaderUtils = autoclass("io.anserini.index.IndexReaderUtils") - self.index_utils = JIndexUtils(index_path) - self.index_reader_utils = JIndexReaderUtils() - - JFile = autoclass("java.io.File") - JFSDirectory = autoclass("org.apache.lucene.store.FSDirectory") - fsdir = JFSDirectory.open(JFile(index_path).toPath()) - self.reader = autoclass("org.apache.lucene.index.DirectoryReader").open(fsdir) - self.numdocs = self.reader.numDocs() - self.JTerm = autoclass("org.apache.lucene.index.Term") - self.JString = autoclass("java.lang.String") +import_all_modules(__file__, __package__) diff --git a/capreolus/index/anserini.py b/capreolus/index/anserini.py new file mode 100644 index 000000000..ca3e3db39 --- /dev/null +++ b/capreolus/index/anserini.py @@ -0,0 +1,91 @@ +import math +import os +import subprocess + +from . import Index +from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger +from capreolus.utils.common import Anserini + +logger = get_logger(__name__) # pylint: disable=invalid-name +MAX_THREADS = constants["MAX_THREADS"] + + +@Index.register +class AnseriniIndex(Index): + module_name = "anserini" + config_spec = [ + ConfigOption("indexstops", False, "should stopwords be indexed? (if False, stopwords are removed)"), + ConfigOption("stemmer", "porter", "stemmer: porter, krovetz, or none"), + ] + + def _create_index(self): + outdir = self.get_index_path() + stops = "-keepStopwords" if self.config["indexstops"] else "" + stemmer = "none" if self.config["stemmer"] is None else self.config["stemmer"] + + collection_path, document_type, generator_type = self.collection.get_path_and_types() + + anserini_fat_jar = Anserini.get_fat_jar() + if self.collection.is_large_collection: + cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -stemmer {stemmer} {stops}" + else: + cmd = f"java -classpath {anserini_fat_jar} -Xms512M -Xmx31G -Dapp.name='IndexCollection' io.anserini.index.IndexCollection -collection {document_type} -generator {generator_type} -threads {MAX_THREADS} -input {collection_path} -index {outdir} -storePositions -storeDocvectors -storeContents -stemmer {stemmer} {stops}" + + logger.info("building index %s", outdir) + logger.debug(cmd) + os.makedirs(os.path.basename(outdir), exist_ok=True) + + app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) + + # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger + for line in app.stdout: + Anserini.filter_and_log_anserini_output(line, logger) + + app.wait() + if app.returncode != 0: + raise RuntimeError("command failed") + + def get_docs(self, doc_ids): + # if self.collection.is_large_collection: + # return self.get_documents_from_disk(doc_ids) + return [self.get_doc(doc_id) for doc_id in doc_ids] + + def get_doc(self, docid): + try: + if not hasattr(self, "index_utils") or self.index_utils is None: + self.open() + return self.index_reader_utils.documentContents(self.reader, self.JString(docid)) + except Exception as e: + raise + + def get_df(self, term): + # returns 0 for missing terms + if not hasattr(self, "reader") or self.reader is None: + self.open() + jterm = self.JTerm("contents", term) + return self.reader.docFreq(jterm) + + def get_idf(self, term): + """ BM25's IDF with a floor of 0 """ + df = self.get_df(term) + idf = (self.numdocs - df + 0.5) / (df + 0.5) + idf = math.log(1 + idf) + return max(idf, 0) + + def open(self): + from jnius import autoclass + + index_path = self.get_index_path().as_posix() + + JIndexUtils = autoclass("io.anserini.index.IndexUtils") + JIndexReaderUtils = autoclass("io.anserini.index.IndexReaderUtils") + self.index_utils = JIndexUtils(index_path) + self.index_reader_utils = JIndexReaderUtils() + + JFile = autoclass("java.io.File") + JFSDirectory = autoclass("org.apache.lucene.store.FSDirectory") + fsdir = JFSDirectory.open(JFile(index_path).toPath()) + self.reader = autoclass("org.apache.lucene.index.DirectoryReader").open(fsdir) + self.numdocs = self.reader.numDocs() + self.JTerm = autoclass("org.apache.lucene.index.Term") + self.JString = autoclass("java.lang.String") diff --git a/capreolus/index/tests/test_index.py b/capreolus/index/tests/test_index.py index 7228d3c12..d42dd66ef 100644 --- a/capreolus/index/tests/test_index.py +++ b/capreolus/index/tests/test_index.py @@ -1,13 +1,18 @@ import pytest +from capreolus import module_registry from capreolus.collection import Collection, DummyCollection -from capreolus.index import Index -from capreolus.index import AnseriniIndex +from capreolus.index import Index, AnseriniIndex from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index -def test_anserini_create_index(tmpdir_as_cache): - index = AnseriniIndex({"name": "anserini", "indexstops": False, "stemmer": "porter", "collection": {"name": "dummy"}}) +indexs = set(module_registry.get_module_names("index")) + + +@pytest.mark.parametrize("index_name", indexs) +def test_create_index(tmpdir_as_cache, index_name): + provide = {"collection": DummyCollection()} + index = Index.create(index_name, provide=provide) assert not index.exists() index.create_index() assert index.exists() diff --git a/capreolus/reranker/CDSSM.py b/capreolus/reranker/CDSSM.py index 33e4cbe75..4e4dcfa5d 100644 --- a/capreolus/reranker/CDSSM.py +++ b/capreolus/reranker/CDSSM.py @@ -1,7 +1,7 @@ import torch -from profane import ConfigOption, Dependency from torch import nn +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import create_emb_layer @@ -74,7 +74,8 @@ def forward(self, sentence, query): @Reranker.register class CDSSM(Reranker): - description = """Yelong Shen, Xiaodong He, Jianfeng Gao, Li Deng, and Grégoire Mesnil. 2014. A Latent Semantic Model with Convolutional-Pooling Structure for Information Retrieval. In CIKM'14.""" + """Yelong Shen, Xiaodong He, Jianfeng Gao, Li Deng, and Grégoire Mesnil. 2014. A Latent Semantic Model with Convolutional-Pooling Structure for Information Retrieval. In CIKM'14.""" + module_name = "CDSSM" config_spec = [ diff --git a/capreolus/reranker/ConvKNRM.py b/capreolus/reranker/ConvKNRM.py index b7a7c4968..936022131 100644 --- a/capreolus/reranker/ConvKNRM.py +++ b/capreolus/reranker/ConvKNRM.py @@ -1,7 +1,7 @@ import torch -from profane import ConfigOption, Dependency from torch import nn +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import RbfKernelBank, SimilarityMatrix, create_emb_layer from capreolus.utils.loginit import get_logger @@ -79,8 +79,9 @@ def forward(self, sentence, query_sentence, query_idf): @Reranker.register class ConvKNRM(Reranker): + """Zhuyun Dai, Chenyan Xiong, Jamie Callan, and Zhiyuan Liu. 2018. Convolutional Neural Networks for Soft-Matching N-Grams in Ad-hoc Search. In WSDM'18.""" + module_name = "ConvKNRM" - description = """Zhuyun Dai, Chenyan Xiong, Jamie Callan, and Zhiyuan Liu. 2018. Convolutional Neural Networks for Soft-Matching N-Grams in Ad-hoc Search. In WSDM'18.""" config_spec = [ ConfigOption("gradkernels", True, "backprop through mus and sigmas"), diff --git a/capreolus/reranker/DRMM.py b/capreolus/reranker/DRMM.py index a079b54f9..e93f79da3 100644 --- a/capreolus/reranker/DRMM.py +++ b/capreolus/reranker/DRMM.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from profane import ConfigOption, Dependency +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import create_emb_layer from capreolus.utils.loginit import get_logger @@ -122,8 +122,9 @@ def forward(self, sentence, query_sentence, query_idf): @Reranker.register class DRMM(Reranker): + """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16.""" + module_name = "DRMM" - description = """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16.""" config_spec = [ ConfigOption("nbins", 29, "number of bins in matching histogram"), diff --git a/capreolus/reranker/DRMMTKS.py b/capreolus/reranker/DRMMTKS.py index e9bf70517..eddbe7e74 100644 --- a/capreolus/reranker/DRMMTKS.py +++ b/capreolus/reranker/DRMMTKS.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from profane import ConfigOption, Dependency +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import create_emb_layer from capreolus.utils.loginit import get_logger @@ -86,9 +86,10 @@ def forward(self, doc, query, query_idf): @Reranker.register class DRMMTKS(Reranker): - # refernce: https://github.com/NTMC-Community/MatchZoo-py/blob/master/matchzoo/models/drmmtks.py + """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16.""" + + # reference: https://github.com/NTMC-Community/MatchZoo-py/blob/master/matchzoo/models/drmmtks.py module_name = "DRMMTKS" - description = """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16.""" config_spec = [ ConfigOption("topk", 10, "number of bins in matching histogram"), diff --git a/capreolus/reranker/DSSM.py b/capreolus/reranker/DSSM.py index b53aa4f85..60a5be00c 100644 --- a/capreolus/reranker/DSSM.py +++ b/capreolus/reranker/DSSM.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from profane import ConfigOption, Dependency +from capreolus import ConfigOption, Dependency from capreolus.extractor.bagofwords import BagOfWords from capreolus.reranker import Reranker from capreolus.utils.loginit import get_logger @@ -46,7 +46,8 @@ def forward(self, sentence, query, query_idf): @Reranker.register class DSSM(Reranker): - description = """Po-Sen Huang, Xiaodong He, Jianfeng Gao, Li Deng, Alex Acero, and Larry Heck. 2013. Learning deep structured semantic models for web search using clickthrough data. In CIKM'13.""" + """Po-Sen Huang, Xiaodong He, Jianfeng Gao, Li Deng, Alex Acero, and Larry Heck. 2013. Learning deep structured semantic models for web search using clickthrough data. In CIKM'13.""" + module_name = "DSSM" dependencies = [ Dependency(key="extractor", module="extractor", name="bagofwords"), diff --git a/capreolus/reranker/DUET.py b/capreolus/reranker/DUET.py index 0a5f4ef4e..95103f221 100644 --- a/capreolus/reranker/DUET.py +++ b/capreolus/reranker/DUET.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn -from profane import ConfigOption, Dependency +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import create_emb_layer from capreolus.utils.loginit import get_logger @@ -131,8 +131,9 @@ def forward(self, documents, queries, query_idf): @Reranker.register class DUET(Reranker): + """Bhaskar Mitra, Fernando Diaz, and Nick Craswell. 2017. Learning to Match using Local and Distributed Representations of Text for Web Search. In WWW'17.""" + module_name = "DUET" - description = """Bhaskar Mitra, Fernando Diaz, and Nick Craswell. 2017. Learning to Match using Local and Distributed Representations of Text for Web Search. In WWW'17.""" config_spec = [ ConfigOption("nfilter", 10, "number of filters for both local and distrbuted model"), diff --git a/capreolus/reranker/DeepTileBar.py b/capreolus/reranker/DeepTileBar.py index 3fd141b57..af49aed54 100644 --- a/capreolus/reranker/DeepTileBar.py +++ b/capreolus/reranker/DeepTileBar.py @@ -2,10 +2,10 @@ import torch import torch.nn.functional as F -from profane import ConfigOption, Dependency from torch import nn from torch.autograd import Variable +from capreolus import ConfigOption, Dependency from capreolus.extractor.deeptileextractor import DeepTileExtractor from capreolus.reranker import Reranker from capreolus.utils.loginit import get_logger @@ -172,7 +172,8 @@ def test_forward(self, pos_tile_matrix): @Reranker.register class DeepTileBar(Reranker): - description = """Zhiwen Tang and Grace Hui Yang. 2019. DeepTileBars: Visualizing Term Distribution for Neural Information Retrieval. In AAAI'19.""" + """Zhiwen Tang and Grace Hui Yang. 2019. DeepTileBars: Visualizing Term Distribution for Neural Information Retrieval. In AAAI'19.""" + module_name = "DeepTileBar" dependencies = [ diff --git a/capreolus/reranker/HINT.py b/capreolus/reranker/HINT.py index 888f36a3d..9e568fa28 100644 --- a/capreolus/reranker/HINT.py +++ b/capreolus/reranker/HINT.py @@ -2,10 +2,10 @@ import torch import torch.nn.functional as F -from profane import ConfigOption, Dependency from torch import nn from torch.autograd import Variable +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import create_emb_layer @@ -323,8 +323,9 @@ def test_forward(self, query_sentence, query_idf, pos_sentence): @Reranker.register class HINT(Reranker): + """Yixing Fan, Jiafeng Guo, Yanyan Lan, Jun Xu, Chengxiang Zhai, and Xueqi Cheng. 2018. Modeling Diverse Relevance Patterns in Ad-hoc Retrieval. In SIGIR'18.""" + module_name = "HINT" - description = """Yixing Fan, Jiafeng Guo, Yanyan Lan, Jun Xu, Chengxiang Zhai, and Xueqi Cheng. 2018. Modeling Diverse Relevance Patterns in Ad-hoc Retrieval. In SIGIR'18.""" config_spec = [ConfigOption("spatialGRU", 2), ConfigOption("LSTMdim", 6), ConfigOption("kmax", 10)] diff --git a/capreolus/reranker/KNRM.py b/capreolus/reranker/KNRM.py index 29924284d..153ac0d8a 100644 --- a/capreolus/reranker/KNRM.py +++ b/capreolus/reranker/KNRM.py @@ -1,8 +1,8 @@ import matplotlib.pyplot as plt import torch -from profane import ConfigOption, Dependency from torch import nn +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import RbfKernelBank, SimilarityMatrix, create_emb_layer from capreolus.utils.loginit import get_logger @@ -60,9 +60,9 @@ def forward(self, doctoks, querytoks, query_idf): @Reranker.register class KNRM(Reranker): + """Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017. End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17.""" + module_name = "KNRM" - description = """Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017. - End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17.""" config_spec = [ ConfigOption("gradkernels", True, "backprop through mus and sigmas"), diff --git a/capreolus/reranker/PACRR.py b/capreolus/reranker/PACRR.py index ad8d97d74..3ee47f601 100644 --- a/capreolus/reranker/PACRR.py +++ b/capreolus/reranker/PACRR.py @@ -1,8 +1,8 @@ import torch -from profane import ConfigOption, Dependency from torch import nn from torch.nn import functional as F +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker # TODO add shuffle, cascade, disambig? @@ -85,9 +85,9 @@ def forward(self, simmat): @Reranker.register class PACRR(Reranker): + """Kai Hui, Andrew Yates, Klaus Berberich, and Gerard de Melo. 2017. PACRR: A Position-Aware Neural IR Model for Relevance Matching. EMNLP 2017. """ + module_name = "PACRR" - description = """Kai Hui, Andrew Yates, Klaus Berberich, and Gerard de Melo. EMNLP 2017. - PACRR: A Position-Aware Neural IR Model for Relevance Matching. """ config_spec = [ ConfigOption("mingram", 1, "minimum length of ngram used"), diff --git a/capreolus/reranker/POSITDRMM.py b/capreolus/reranker/POSITDRMM.py index 1359eeba0..6839fe550 100644 --- a/capreolus/reranker/POSITDRMM.py +++ b/capreolus/reranker/POSITDRMM.py @@ -1,9 +1,9 @@ import torch import torch.nn.functional as F -from profane import ConfigOption, Dependency from torch import nn from torch.autograd import Variable +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import create_emb_layer from capreolus.utils.loginit import get_logger @@ -127,8 +127,8 @@ def test_forward(self, query_sentence, query_idf, pos_sentence, extras): @Reranker.register class POSITDRMM(Reranker): - description = """Ryan McDonald, George Brokos, and Ion Androutsopoulos. 2018. - Deep Relevance Ranking Using Enhanced Document-Query Interactions. In EMNLP'18.""" + """Ryan McDonald, George Brokos, and Ion Androutsopoulos. 2018. Deep Relevance Ranking Using Enhanced Document-Query Interactions. In EMNLP'18.""" + module_name = "POSITDRMM" def build_model(self): diff --git a/capreolus/reranker/TFKNRM.py b/capreolus/reranker/TFKNRM.py index 71e392f01..8fe08050f 100644 --- a/capreolus/reranker/TFKNRM.py +++ b/capreolus/reranker/TFKNRM.py @@ -1,6 +1,6 @@ import tensorflow as tf -from profane import ConfigOption, Dependency +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import RbfKernelBankTF, similarity_matrix_tf @@ -57,7 +57,13 @@ def call(self, x, **kwargs): @Reranker.register class TFKNRM(Reranker): + """TensorFlow implementation of KNRM. + + Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017. End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17. + """ + module_name = "TFKNRM" + dependencies = [ Dependency(key="extractor", module="extractor", name="embedtext"), Dependency(key="trainer", module="trainer", name="tensorflow"), diff --git a/capreolus/reranker/TFVanillaBert.py b/capreolus/reranker/TFVanillaBert.py index 2eea4b194..c67412bd6 100644 --- a/capreolus/reranker/TFVanillaBert.py +++ b/capreolus/reranker/TFVanillaBert.py @@ -1,7 +1,7 @@ import tensorflow as tf -from profane import ConfigOption, Dependency from transformers import TFBertForSequenceClassification +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.utils.loginit import get_logger @@ -53,7 +53,10 @@ def call(self, x, **kwargs): @Reranker.register class TFVanillaBERT(Reranker): + """TensorFlow implementation of Vanilla BERT.""" + module_name = "TFVanillaBERT" + dependencies = [ Dependency(key="extractor", module="extractor", name="berttext"), Dependency(key="trainer", module="trainer", name="tensorflow"), diff --git a/capreolus/reranker/TK.py b/capreolus/reranker/TK.py index 16f5c1da0..f0476022e 100644 --- a/capreolus/reranker/TK.py +++ b/capreolus/reranker/TK.py @@ -1,11 +1,11 @@ import math import torch -from profane import ConfigOption, Dependency from torch import nn from torch.nn import TransformerEncoder, TransformerEncoderLayer - from allennlp.modules.matrix_attention import CosineMatrixAttention + +from capreolus import ConfigOption, Dependency from capreolus.reranker import Reranker from capreolus.reranker.common import SimilarityMatrix, create_emb_layer from capreolus.utils.loginit import get_logger @@ -146,9 +146,9 @@ def forward(self, doctoks, querytoks, query_idf): @Reranker.register class TK(Reranker): + """Sebastian Hofstätter, Markus Zlabinger, and Allan Hanbury. 2019. TU Wien @ TREC Deep Learning '19 -- Simple Contextualization for Re-ranking. In TREC '19.""" + module_name = "TK" - description = """Sebastian Hofstätter, Markus Zlabinger, and Allan Hanbury. 2019. - TU Wien @ TREC Deep Learning '19 -- Simple Contextualization for Re-ranking. In TREC '19.""" config_spec = [ ConfigOption("gradkernels", True, "backprop through mus and sigmas"), diff --git a/capreolus/reranker/__init__.py b/capreolus/reranker/__init__.py index 77645a0dc..9248dcc3f 100644 --- a/capreolus/reranker/__init__.py +++ b/capreolus/reranker/__init__.py @@ -1,5 +1,61 @@ +import os +import pickle + +from capreolus import ConfigOption, Dependency, ModuleBase + + +class Reranker(ModuleBase): + """Base class for Reranker modules. The purpose of a Reranker is to predict relevance scores for input documents. Rerankers are generally supervised methods implemented in PyTorch or TensorFlow. + + Modules should provide: + - a ``build_model`` method that initializes the model used + - a ``score`` and a ``test`` method that take a representation created by an :class:`~capreolus.extractor.Extractor` module as input and return document scores + - a ``load_weights`` and a ``save_weights`` method, if the base class' PyTorch methods cannot be used + """ + + module_type = "reranker" + dependencies = [ + Dependency(key="extractor", module="extractor", name="embedtext"), + Dependency(key="trainer", module="trainer", name="pytorch"), + ] + + def add_summary(self, summary_writer, niter): + """ + Write to the summay_writer custom visualizations/data specific to this reranker + """ + for name, weight in self.model.named_parameters(): + summary_writer.add_histogram(name, weight.data.cpu(), niter) + # summary_writer.add_histogram(f'{name}.grad', weight.grad, niter) + + def save_weights(self, weights_fn, optimizer): + if not os.path.exists(os.path.dirname(weights_fn)): + os.makedirs(os.path.dirname(weights_fn)) + + d = {k: v for k, v in self.model.state_dict().items() if ("embedding.weight" not in k and "_nosave_" not in k)} + with open(weights_fn, "wb") as outf: + pickle.dump(d, outf, protocol=-1) + + optimizer_fn = weights_fn.as_posix() + ".optimizer" + with open(optimizer_fn, "wb") as outf: + pickle.dump(optimizer.state_dict(), outf, protocol=-1) + + def load_weights(self, weights_fn, optimizer): + with open(weights_fn, "rb") as f: + d = pickle.load(f) + + cur_keys = set(k for k in self.model.state_dict().keys() if not ("embedding.weight" in k or "_nosave_" in k)) + missing = cur_keys - set(d.keys()) + if len(missing) > 0: + raise RuntimeError("loading state_dict with keys that do not match current model: %s" % missing) + + self.model.load_state_dict(d, strict=False) + + optimizer_fn = weights_fn.as_posix() + ".optimizer" + with open(optimizer_fn, "rb") as f: + optimizer.load_state_dict(pickle.load(f)) + + from profane import import_all_modules -from .base import Reranker import_all_modules(__file__, __package__) diff --git a/capreolus/reranker/base.py b/capreolus/reranker/base.py deleted file mode 100644 index 832714e1e..000000000 --- a/capreolus/reranker/base.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -import pickle - -from profane import ConfigOption, Dependency, ModuleBase - - -class Reranker(ModuleBase): - module_type = "reranker" - dependencies = [ - Dependency(key="extractor", module="extractor", name="embedtext"), - Dependency(key="trainer", module="trainer", name="pytorch"), - ] - - def add_summary(self, summary_writer, niter): - """ - Write to the summay_writer custom visualizations/data specific to this reranker - """ - for name, weight in self.model.named_parameters(): - summary_writer.add_histogram(name, weight.data.cpu(), niter) - # summary_writer.add_histogram(f'{name}.grad', weight.grad, niter) - - def save_weights(self, weights_fn, optimizer): - if not os.path.exists(os.path.dirname(weights_fn)): - os.makedirs(os.path.dirname(weights_fn)) - - d = {k: v for k, v in self.model.state_dict().items() if ("embedding.weight" not in k and "_nosave_" not in k)} - with open(weights_fn, "wb") as outf: - pickle.dump(d, outf, protocol=-1) - - optimizer_fn = weights_fn.as_posix() + ".optimizer" - with open(optimizer_fn, "wb") as outf: - pickle.dump(optimizer.state_dict(), outf, protocol=-1) - - def load_weights(self, weights_fn, optimizer): - with open(weights_fn, "rb") as f: - d = pickle.load(f) - - cur_keys = set(k for k in self.model.state_dict().keys() if not ("embedding.weight" in k or "_nosave_" in k)) - missing = cur_keys - set(d.keys()) - if len(missing) > 0: - raise RuntimeError("loading state_dict with keys that do not match current model: %s" % missing) - - self.model.load_state_dict(d, strict=False) - - optimizer_fn = weights_fn.as_posix() + ".optimizer" - with open(optimizer_fn, "rb") as f: - optimizer.load_state_dict(pickle.load(f)) diff --git a/capreolus/reranker/tests/test_rerankers.py b/capreolus/reranker/tests/test_rerankers.py index b8b0ce41b..682503492 100644 --- a/capreolus/reranker/tests/test_rerankers.py +++ b/capreolus/reranker/tests/test_rerankers.py @@ -5,8 +5,9 @@ import torch from pymagnitude import Magnitude +from capreolus import Reranker, module_registry from capreolus.benchmark import DummyBenchmark -from capreolus.extractor import EmbedText +from capreolus.extractor.embedtext import EmbedText from capreolus.extractor.bagofwords import BagOfWords from capreolus.extractor.deeptileextractor import DeepTileExtractor from capreolus.reranker.CDSSM import CDSSM @@ -24,6 +25,15 @@ from capreolus.trainer import PytorchTrainer, TensorFlowTrainer +rerankers = set(module_registry.get_module_names("reranker")) + + +@pytest.mark.parametrize("reranker_name", rerankers) +def test_reranker_creatable(tmpdir_as_cache, dummy_index, reranker_name): + provide = {"collection": dummy_index.collection, "index": dummy_index} + reranker = Reranker.create(reranker_name, provide=provide) + + def test_knrm_pytorch(dummy_index, tmpdir, tmpdir_as_cache, monkeypatch): def fake_magnitude_embedding(*args, **kwargs): return Magnitude(None) diff --git a/capreolus/run.py b/capreolus/run.py index f716cd610..39ef2048c 100644 --- a/capreolus/run.py +++ b/capreolus/run.py @@ -39,28 +39,42 @@ def prepare_task(fullcommand, config): help = """ - Usage: - run.py COMMAND [(with CONFIG...)] [options] - run.py help [COMMAND] - run.py (-h | --help) +Usage: + capreolus COMMAND [(with CONFIG...)] [options] + capreolus help [COMMAND] + capreolus (-h | --help) - Options: - -h --help Print this help message and exit. - -l VALUE --loglevel=VALUE Set the log level: DEBUG, INFO, WARNING, ERROR, or CRITICAL. - -p VALUE --priority=VALUE Sets the priority for a queued up experiment. No effect without -q flag. - -q --queue Only queue this run, do not start it. + Options: + -h --help Print this help message and exit. + -l VALUE --loglevel=VALUE Set the log level: DEBUG, INFO, WARNING, ERROR, or CRITICAL. + -p VALUE --priority=VALUE Sets the priority for a queued up experiment. No effect without -q flag. + -q --queue Queue this run, and do not start it. - Arguments: - COMMAND Name of command to run (see below for list of commands) - CONFIG Configuration assignments of the form foo.bar=17 + Arguments: + PIPELINE Name of pipeline to run, which consists of a Task and a command (see below for a list) + CONFIG Configuration assignments of the form foo.bar=17 - Commands: (TODO expand/generate) - rank.run ...description here... - rank.describe ...description here... - """ + Tasks and their commands: + rank.search search a collection using queries from a benchmark + rank.evaluate evaluate the result of rank.search + rank.searcheval run rank.search followed by rank.evaluate + + rerank.train run rank.search and train a model to rerank the results + rerank.evaluate evaluate the result of rerank.train + rerank.traineval run rerank.train followed by rerank.evaluate + + rererank.train run rerank.train and train a (second) model to rerank the results + rererank.evaluate evaluate the result of rererank.train + rererank.traineval run rererank.train followed by rererank.evaluate + + tutorial.run task from the "Getting Started" tutorial + + All tasks additionally support the following help commands: describe, print_config, print_pipeline + e.g., capreolus rank.print_config with searcher=BM25 +""" if __name__ == "__main__": # hack to make docopt print full help message if no arguments are give diff --git a/capreolus/sampler/tests/test_sampler.py b/capreolus/sampler/tests/test_sampler.py index 93f6b9179..59a159b6c 100644 --- a/capreolus/sampler/tests/test_sampler.py +++ b/capreolus/sampler/tests/test_sampler.py @@ -4,7 +4,7 @@ import numpy as np from capreolus.benchmark import DummyBenchmark -from capreolus.extractor import EmbedText +from capreolus.extractor.embedtext import EmbedText from capreolus.sampler import TrainDataset, PredDataset from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index diff --git a/capreolus/searcher/__init__.py b/capreolus/searcher/__init__.py index c4b1735a2..34da19ef0 100644 --- a/capreolus/searcher/__init__.py +++ b/capreolus/searcher/__init__.py @@ -1,17 +1,7 @@ -from profane import import_all_modules - - -# import_all_modules(__file__, __package__) - import os -import math -import subprocess from collections import defaultdict, OrderedDict -import numpy as np -from profane import ModuleBase, Dependency, ConfigOption, constants - -from capreolus.utils.common import Anserini +from capreolus import ModuleBase, Dependency, ConfigOption, constants from capreolus.utils.loginit import get_logger from capreolus.utils.trec import topic_to_trectxt @@ -24,6 +14,15 @@ def list2str(l, delimiter="-"): class Searcher(ModuleBase): + """Base class for Searcher modules. The purpose of a Searcher is to query a collection via an :class:`~capreolus.index.Index` module. + + Similar to Rerankers, Searchers return a list of documents and their relevance scores for a given query. + Searchers are unsupervised and efficient, whereas Rerankers are supervised and do not use an inverted index directly. + + Modules should provide: + - a ``query(string)`` and a ``query_from_file(path)`` method that return document scores + """ + module_type = "searcher" @staticmethod @@ -83,495 +82,8 @@ def query(self, query, **kwargs): return config2runs["searcher"] if len(config2runs) == 1 else config2runs -class AnseriniSearcherMixIn: - """ MixIn for searchers that use Anserini's SearchCollection script """ - - def _anserini_query_from_file(self, topicsfn, anserini_param_str, output_base_path, topicfield): - if not os.path.exists(topicsfn): - raise IOError(f"could not find topics file: {topicsfn}") - - # for covid: - field2querytype = {"query": "title", "question": "description", "narrative": "narrative"} - for k, v in field2querytype.items(): - topicfield = topicfield.replace(k, v) - - donefn = os.path.join(output_base_path, "done") - if os.path.exists(donefn): - logger.debug(f"skipping Anserini SearchCollection call because path already exists: {donefn}") - return - - # create index if it does not exist. the call returns immediately if the index does exist. - self.index.create_index() - - os.makedirs(output_base_path, exist_ok=True) - output_path = os.path.join(output_base_path, "searcher") - - # add stemmer and stop options to match underlying index - indexopts = "-stemmer " - indexopts += "none" if self.index.config["stemmer"] is None else self.index.config["stemmer"] - if self.index.config["indexstops"]: - indexopts += " -keepstopwords" - - index_path = self.index.get_index_path() - anserini_fat_jar = Anserini.get_fat_jar() - cmd = ( - f"java -classpath {anserini_fat_jar} " - f"-Xms512M -Xmx31G -Dapp.name=SearchCollection io.anserini.search.SearchCollection " - f"-topicreader Trec -index {index_path} {indexopts} -topics {topicsfn} -output {output_path} " - f"-topicfield {topicfield} -inmem -threads {MAX_THREADS} {anserini_param_str}" - ) - logger.info("Anserini writing runs to %s", output_path) - logger.debug(cmd) - - app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) - - # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger - for line in app.stdout: - Anserini.filter_and_log_anserini_output(line, logger) - - app.wait() - if app.returncode != 0: - raise RuntimeError("command failed") - - with open(donefn, "wt") as donef: - print("done", file=donef) - - -class PostprocessMixin: - def _keep_topn(self, runs, topn): - queries = sorted(list(runs.keys()), key=lambda k: int(k)) - for q in queries: - docs = runs[q] - if len(docs) <= topn: - continue - docs = sorted(docs.items(), key=lambda kv: kv[1], reverse=True)[:topn] - runs[q] = {k: v for k, v in docs} - return runs - - def filter(self, run_dir, docs_to_remove=None, docs_to_keep=None, topn=None): - if (not docs_to_keep) and (not docs_to_remove): - raise - - for fn in os.listdir(run_dir): - if fn == "done": - continue - - run_fn = os.path.join(run_dir, fn) - self._filter(run_fn, docs_to_remove, docs_to_keep, topn) - return run_dir - - def _filter(self, runfile, docs_to_remove, docs_to_keep, topn): - runs = Searcher.load_trec_run(runfile) - - # filtering - if docs_to_remove: # prioritize docs_to_remove - if isinstance(docs_to_remove, list): - docs_to_remove = {q: docs_to_remove for q in runs} - runs = {q: {d: v for d, v in docs.items() if d not in docs_to_remove.get(q, [])} for q, docs in runs.items()} - elif docs_to_keep: - if isinstance(docs_to_keep, list): - docs_to_keep = {q: docs_to_keep for q in runs} - runs = {q: {d: v for d, v in docs.items() if d in docs_to_keep[q]} for q, docs in runs.items()} - - if topn: - runs = self._keep_topn(runs, topn) - Searcher.write_trec_run(runs, runfile) # overwrite runfile - - def dedup(self, run_dir, topn=None): - for fn in os.listdir(run_dir): - if fn == "done": - continue - run_fn = os.path.join(run_dir, fn) - self._dedup(run_fn, topn) - return run_dir - - def _dedup(self, runfile, topn): - runs = Searcher.load_trec_run(runfile) - new_runs = {q: {} for q in runs} - - # use the sum of each passage score as the document score, no sorting is done here - for q, psg in runs.items(): - for pid, score in psg.items(): - docid = pid.split(".")[0] - new_runs[q][docid] = max(new_runs[q].get(docid, -math.inf), score) - runs = new_runs - - if topn: - runs = self._keep_topn(runs, topn) - Searcher.write_trec_run(runs, runfile) - - -@Searcher.register -class BM25(Searcher, AnseriniSearcherMixIn): - """ BM25 with fixed k1 and b. """ - - module_name = "BM25" - - dependencies = [Dependency(key="index", module="index", name="anserini")] - config_spec = [ - ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), - ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - """ - Runs BM25 search. Takes a query from the topic files, and fires it against the index - Args: - topicsfn: Path to a topics file - output_path: Path where the results of the search (i.e the run file) should be stored - - Returns: Path to the run file where the results of the search are stored - - """ - bstr, k1str = list2str(config["b"], delimiter=" "), list2str(config["k1"], delimiter=" ") - hits = config["hits"] - anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}" - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class BM25Grid(Searcher, AnseriniSearcherMixIn): - """ BM25 with a grid search for k1 and b. Search is from 0.1 to bmax/k1max in 0.1 increments """ - - module_name = "BM25Grid" - dependencies = [Dependency(key="index", module="index", name="anserini")] - config_spec = [ - ConfigOption("k1max", 1.0, "maximum k1 value to include in grid search (starting at 0.1)"), - ConfigOption("bmax", 1.0, "maximum b value to include in grid search (starting at 0.1)"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - bs = np.around(np.arange(0.1, config["bmax"] + 0.1, 0.1), 1) - k1s = np.around(np.arange(0.1, config["k1max"] + 0.1, 0.1), 1) - bstr = " ".join(str(x) for x in bs) - k1str = " ".join(str(x) for x in k1s) - hits = config["hits"] - anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}" - - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class BM25RM3(Searcher, AnseriniSearcherMixIn): - - module_name = "BM25RM3" - dependencies = [Dependency(key="index", module="index", name="anserini")] - config_spec = [ - ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"), - ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"), - ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"), - ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"), - ConfigOption("originalQueryWeight", [0.5], "the weight of unexpended query", value_type="floatlist"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - hits = str(config["hits"]) - - anserini_param_str = ( - "-rm3 " - + " ".join(f"-rm3.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "originalQueryWeight"]) - + " -bm25 " - + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"]) - + f" -hits {hits}" - ) - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class BM25PostProcess(BM25, PostprocessMixin): - module_name = "BM25Postprocess" - - config_spec = [ - ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), - ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("topn", 1000), - ConfigOption("fields", "title"), - ConfigOption("dedep", False), - ] - - def query_from_file(self, topicsfn, output_path, docs_to_remove=None): - output_path = super().query_from_file(topicsfn, output_path) # will call _query_from_file() from BM25 - - if docs_to_remove: - output_path = self.filter(output_path, docs_to_remove=docs_to_remove, topn=self.config["topn"]) - if self.config["dedup"]: - output_path = self.dedup(output_path, topn=self.config["topn"]) - - return output_path - - -@Searcher.register -class StaticBM25RM3Rob04Yang19(Searcher): - """ Tuned BM25+RM3 run used by Yang et al. in [1]. This should be used only with a benchmark using the same folds and queries. - - [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019. - """ - - module_name = "bm25staticrob04yang19" - - def _query_from_file(self, topicsfn, output_path, config): - import shutil - - outfn = os.path.join(output_path, "static.run") - os.makedirs(output_path, exist_ok=True) - shutil.copy2(constants["PACKAGE_PATH"] / "data" / "rob04_yang19_rm3.run", outfn) - - return output_path - - def query(self, *args, **kwargs): - raise NotImplementedError("this searcher uses a static run file, so it cannot handle new queries") - - -@Searcher.register -class BM25PRF(Searcher, AnseriniSearcherMixIn): - """ - BM25 with PRF - """ - - module_name = "BM25PRF" - - dependencies = [Dependency(key="index", module="index", name="anserini")] - config_spec = [ - ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"), - ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"), - ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"), - ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"), - ConfigOption("newTermWeight", [0.2, 0.25], value_type="floatlist"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - hits = str(config["hits"]) - - anserini_param_str = ( - "-bm25prf " - + " ".join(f"-bm25prf.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "newTermWeight", "k1", "b"]) - + " -bm25 " - + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"]) - + f" -hits {hits}" - ) - print(output_path) - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class AxiomaticSemanticMatching(Searcher, AnseriniSearcherMixIn): - """ - TODO: Add more info on retrieval method - Also, BM25 is hard-coded to be the scoring model - """ - - module_name = "axiomatic" - dependencies = [Dependency(key="index", module="index", name="anserini")] - config_spec = [ - ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), - ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), - ConfigOption("r", 20, value_type="intlist"), - ConfigOption("n", 30, value_type="intlist"), - ConfigOption("beta", 0.4, value_type="floatlist"), - ConfigOption("top", 20, value_type="intlist"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - hits = str(config["hits"]) - conditionals = "" - - anserini_param_str = "-axiom -axiom.deterministic -axiom.r {0} -axiom.n {1} -axiom.beta {2} -axiom.top {3}".format( - *[list2str(config[k], " ") for k in ["r", "n", "beta", "top"]] - ) - anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1} ".format(*[list2str(config[k], " ") for k in ["k1", "b"]]) - anserini_param_str += f" -hits {hits}" - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class DirichletQL(Searcher, AnseriniSearcherMixIn): - """ Dirichlet QL with a fixed mu """ - - module_name = "DirichletQL" - dependencies = [Dependency(key="index", module="index", name="anserini")] - - config_spec = [ - ConfigOption("mu", 1000, "smoothing parameter", value_type="intlist"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - """ - Runs Dirichlet QL search. Takes a query from the topic files, and fires it against the index - Args: - topicsfn: Path to a topics file - output_path: Path where the results of the search (i.e the run file) should be stored - - Returns: Path to the run file where the results of the search are stored - - """ - mustr = list2str(config["mu"], delimiter=" ") - hits = config["hits"] - anserini_param_str = f"-qld -qld.mu {mustr} -hits {hits}" - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class QLJM(Searcher, AnseriniSearcherMixIn): - """ - QL with Jelinek-Mercer smoothing - """ - - module_name = "QLJM" - dependencies = [Dependency(key="index", module="index", name="anserini")] - config_spec = [ - ConfigOption("lam", 0.1, value_type="floatlist"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - anserini_param_str = "-qljm -qljm.lambda {0} -hits {1}".format(list2str(config["lam"], delimiter=" "), config["hits"]) - - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class INL2(Searcher, AnseriniSearcherMixIn): - """ - I(n)L2 scoring model - """ - - module_name = "INL2" - dependencies = [Dependency(key="index", module="index", name="anserini")] - config_spec = [ - ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - anserini_param_str = "-inl2 -inl2.c {0} -hits {1}".format(config["c"], config["hits"]) - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - return output_path - - -@Searcher.register -class SPL(Searcher, AnseriniSearcherMixIn): - """ - SPL scoring model - """ - - module_name = "SPL" - dependencies = [Dependency(key="index", module="index", name="anserini")] - - config_spec = [ - ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - anserini_param_str = "-spl -spl.c {0} -hits {1}".format(config["c"], config["hits"]) - - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class F2Exp(Searcher, AnseriniSearcherMixIn): - """ - F2Exp scoring model - """ - - module_name = "F2Exp" - dependencies = [Dependency(key="index", module="index", name="anserini")] - - config_spec = [ - ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - anserini_param_str = "-f2exp -f2exp.s {0} -hits {1}".format(config["s"], config["hits"]) - - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class F2Log(Searcher, AnseriniSearcherMixIn): - """ - F2Log scoring model - """ - - module_name = "F2Log" - dependencies = [Dependency(key="index", module="index", name="anserini")] - - config_spec = [ - ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] - - def _query_from_file(self, topicsfn, output_path, config): - anserini_param_str = "-f2log -f2log.s {0} -hits {1}".format(config["s"], config["hits"]) - - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) - - return output_path - - -@Searcher.register -class SDM(Searcher, AnseriniSearcherMixIn): - """ - Sequential Dependency Model - The scoring model is hardcoded to be BM25 (TODO: Make it configurable?) - """ - - module_name = "SDM" - dependencies = [Dependency(key="index", module="index", name="anserini")] - - # array input of (tw, ow, uw) is not support by anserini.SearchCollection - config_spec = [ - ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), - ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), - ConfigOption("tw", 0.85, "term weight"), - ConfigOption("ow", 0.15, "ordered window weight"), - ConfigOption("uw", 0.05, "unordered window weight"), - ConfigOption("hits", 1000, "number of results to return"), - ConfigOption("fields", "title"), - ] +from profane import import_all_modules - def _query_from_file(self, topicsfn, output_path, config): - hits = config["hits"] - anserini_param_str = "-sdm -sdm.tw {0} -sdm.ow {1} -sdm.uw {2}".format(*[config[k] for k in ["tw", "ow", "uw"]]) - anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1}".format(*[list2str(config[k], " ") for k in ["k1", "b"]]) - anserini_param_str += f" -hits {hits}" - self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) +from .anserini import BM25, BM25RM3, SDM - return output_path +import_all_modules(__file__, __package__) diff --git a/capreolus/searcher/anserini.py b/capreolus/searcher/anserini.py new file mode 100644 index 000000000..25ff6c780 --- /dev/null +++ b/capreolus/searcher/anserini.py @@ -0,0 +1,504 @@ +import os +import math +import subprocess +from collections import defaultdict, OrderedDict + +import numpy as np + +from . import Searcher +from capreolus import ModuleBase, Dependency, ConfigOption, constants +from capreolus.utils.common import Anserini +from capreolus.utils.loginit import get_logger +from capreolus.utils.trec import topic_to_trectxt + +logger = get_logger(__name__) # pylint: disable=invalid-name +MAX_THREADS = constants["MAX_THREADS"] + + +def list2str(l, delimiter="-"): + return delimiter.join(str(x) for x in l) + + +class AnseriniSearcherMixIn: + """ MixIn for searchers that use Anserini's SearchCollection script """ + + def _anserini_query_from_file(self, topicsfn, anserini_param_str, output_base_path, topicfield): + if not os.path.exists(topicsfn): + raise IOError(f"could not find topics file: {topicsfn}") + + # for covid: + field2querytype = {"query": "title", "question": "description", "narrative": "narrative"} + for k, v in field2querytype.items(): + topicfield = topicfield.replace(k, v) + + donefn = os.path.join(output_base_path, "done") + if os.path.exists(donefn): + logger.debug(f"skipping Anserini SearchCollection call because path already exists: {donefn}") + return + + # create index if it does not exist. the call returns immediately if the index does exist. + self.index.create_index() + + os.makedirs(output_base_path, exist_ok=True) + output_path = os.path.join(output_base_path, "searcher") + + # add stemmer and stop options to match underlying index + indexopts = "-stemmer " + indexopts += "none" if self.index.config["stemmer"] is None else self.index.config["stemmer"] + if self.index.config["indexstops"]: + indexopts += " -keepstopwords" + + index_path = self.index.get_index_path() + anserini_fat_jar = Anserini.get_fat_jar() + cmd = ( + f"java -classpath {anserini_fat_jar} " + f"-Xms512M -Xmx31G -Dapp.name=SearchCollection io.anserini.search.SearchCollection " + f"-topicreader Trec -index {index_path} {indexopts} -topics {topicsfn} -output {output_path} " + f"-topicfield {topicfield} -inmem -threads {MAX_THREADS} {anserini_param_str}" + ) + logger.info("Anserini writing runs to %s", output_path) + logger.debug(cmd) + + app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True) + + # Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger + for line in app.stdout: + Anserini.filter_and_log_anserini_output(line, logger) + + app.wait() + if app.returncode != 0: + raise RuntimeError("command failed") + + with open(donefn, "wt") as donef: + print("done", file=donef) + + +class PostprocessMixin: + def _keep_topn(self, runs, topn): + queries = sorted(list(runs.keys()), key=lambda k: int(k)) + for q in queries: + docs = runs[q] + if len(docs) <= topn: + continue + docs = sorted(docs.items(), key=lambda kv: kv[1], reverse=True)[:topn] + runs[q] = {k: v for k, v in docs} + return runs + + def filter(self, run_dir, docs_to_remove=None, docs_to_keep=None, topn=None): + if (not docs_to_keep) and (not docs_to_remove): + raise + + for fn in os.listdir(run_dir): + if fn == "done": + continue + + run_fn = os.path.join(run_dir, fn) + self._filter(run_fn, docs_to_remove, docs_to_keep, topn) + return run_dir + + def _filter(self, runfile, docs_to_remove, docs_to_keep, topn): + runs = Searcher.load_trec_run(runfile) + + # filtering + if docs_to_remove: # prioritize docs_to_remove + if isinstance(docs_to_remove, list): + docs_to_remove = {q: docs_to_remove for q in runs} + runs = {q: {d: v for d, v in docs.items() if d not in docs_to_remove.get(q, [])} for q, docs in runs.items()} + elif docs_to_keep: + if isinstance(docs_to_keep, list): + docs_to_keep = {q: docs_to_keep for q in runs} + runs = {q: {d: v for d, v in docs.items() if d in docs_to_keep[q]} for q, docs in runs.items()} + + if topn: + runs = self._keep_topn(runs, topn) + Searcher.write_trec_run(runs, runfile) # overwrite runfile + + def dedup(self, run_dir, topn=None): + for fn in os.listdir(run_dir): + if fn == "done": + continue + run_fn = os.path.join(run_dir, fn) + self._dedup(run_fn, topn) + return run_dir + + def _dedup(self, runfile, topn): + runs = Searcher.load_trec_run(runfile) + new_runs = {q: {} for q in runs} + + # use the sum of each passage score as the document score, no sorting is done here + for q, psg in runs.items(): + for pid, score in psg.items(): + docid = pid.split(".")[0] + new_runs[q][docid] = max(new_runs[q].get(docid, -math.inf), score) + runs = new_runs + + if topn: + runs = self._keep_topn(runs, topn) + Searcher.write_trec_run(runs, runfile) + + +@Searcher.register +class BM25(Searcher, AnseriniSearcherMixIn): + """ Anserini BM25. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """ + + module_name = "BM25" + + dependencies = [Dependency(key="index", module="index", name="anserini")] + config_spec = [ + ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), + ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + """ + Runs BM25 search. Takes a query from the topic files, and fires it against the index + Args: + topicsfn: Path to a topics file + output_path: Path where the results of the search (i.e the run file) should be stored + + Returns: Path to the run file where the results of the search are stored + + """ + bstr, k1str = list2str(config["b"], delimiter=" "), list2str(config["k1"], delimiter=" ") + hits = config["hits"] + anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}" + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class BM25Grid(Searcher, AnseriniSearcherMixIn): + """ Deprecated. BM25 with a grid search for k1 and b. Search is from 0.1 to bmax/k1max in 0.1 increments """ + + module_name = "BM25Grid" + dependencies = [Dependency(key="index", module="index", name="anserini")] + config_spec = [ + ConfigOption("k1max", 1.0, "maximum k1 value to include in grid search (starting at 0.1)"), + ConfigOption("bmax", 1.0, "maximum b value to include in grid search (starting at 0.1)"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + bs = np.around(np.arange(0.1, config["bmax"] + 0.1, 0.1), 1) + k1s = np.around(np.arange(0.1, config["k1max"] + 0.1, 0.1), 1) + bstr = " ".join(str(x) for x in bs) + k1str = " ".join(str(x) for x in k1s) + hits = config["hits"] + anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}" + + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class BM25RM3(Searcher, AnseriniSearcherMixIn): + """ Anserini BM25 with RM3 expansion. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """ + + module_name = "BM25RM3" + dependencies = [Dependency(key="index", module="index", name="anserini")] + config_spec = [ + ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"), + ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"), + ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"), + ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"), + ConfigOption("originalQueryWeight", [0.5], "the weight of unexpended query", value_type="floatlist"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + hits = str(config["hits"]) + + anserini_param_str = ( + "-rm3 " + + " ".join(f"-rm3.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "originalQueryWeight"]) + + " -bm25 " + + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"]) + + f" -hits {hits}" + ) + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class BM25PostProcess(BM25, PostprocessMixin): + module_name = "BM25Postprocess" + + config_spec = [ + ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), + ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("topn", 1000), + ConfigOption("fields", "title"), + ConfigOption("dedep", False), + ] + + def query_from_file(self, topicsfn, output_path, docs_to_remove=None): + output_path = super().query_from_file(topicsfn, output_path) # will call _query_from_file() from BM25 + + if docs_to_remove: + output_path = self.filter(output_path, docs_to_remove=docs_to_remove, topn=self.config["topn"]) + if self.config["dedup"]: + output_path = self.dedup(output_path, topn=self.config["topn"]) + + return output_path + + +@Searcher.register +class StaticBM25RM3Rob04Yang19(Searcher): + """ Tuned BM25+RM3 run used by Yang et al. in [1]. This should be used only with a benchmark using the same folds and queries. + + [1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019. + """ + + module_name = "bm25staticrob04yang19" + + def _query_from_file(self, topicsfn, output_path, config): + import shutil + + outfn = os.path.join(output_path, "static.run") + os.makedirs(output_path, exist_ok=True) + shutil.copy2(constants["PACKAGE_PATH"] / "data" / "rob04_yang19_rm3.run", outfn) + + return output_path + + def query(self, *args, **kwargs): + raise NotImplementedError("this searcher uses a static run file, so it cannot handle new queries") + + +@Searcher.register +class BM25PRF(Searcher, AnseriniSearcherMixIn): + """ Anserini BM25 PRF. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """ + + module_name = "BM25PRF" + + dependencies = [Dependency(key="index", module="index", name="anserini")] + config_spec = [ + ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"), + ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"), + ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"), + ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"), + ConfigOption("newTermWeight", [0.2, 0.25], value_type="floatlist"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + hits = str(config["hits"]) + + anserini_param_str = ( + "-bm25prf " + + " ".join(f"-bm25prf.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "newTermWeight", "k1", "b"]) + + " -bm25 " + + " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"]) + + f" -hits {hits}" + ) + print(output_path) + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class AxiomaticSemanticMatching(Searcher, AnseriniSearcherMixIn): + """ Anserini BM25 with Axiomatic query expansion. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """ + + module_name = "axiomatic" + dependencies = [Dependency(key="index", module="index", name="anserini")] + config_spec = [ + ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), + ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), + ConfigOption("r", 20, value_type="intlist"), + ConfigOption("n", 30, value_type="intlist"), + ConfigOption("beta", 0.4, value_type="floatlist"), + ConfigOption("top", 20, value_type="intlist"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + hits = str(config["hits"]) + conditionals = "" + + anserini_param_str = "-axiom -axiom.deterministic -axiom.r {0} -axiom.n {1} -axiom.beta {2} -axiom.top {3}".format( + *[list2str(config[k], " ") for k in ["r", "n", "beta", "top"]] + ) + anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1} ".format(*[list2str(config[k], " ") for k in ["k1", "b"]]) + anserini_param_str += f" -hits {hits}" + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class DirichletQL(Searcher, AnseriniSearcherMixIn): + """ Anserini QL with Dirichlet smoothing. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """ + + module_name = "DirichletQL" + dependencies = [Dependency(key="index", module="index", name="anserini")] + + config_spec = [ + ConfigOption("mu", 1000, "smoothing parameter", value_type="intlist"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + """ + Runs Dirichlet QL search. Takes a query from the topic files, and fires it against the index + Args: + topicsfn: Path to a topics file + output_path: Path where the results of the search (i.e the run file) should be stored + + Returns: Path to the run file where the results of the search are stored + + """ + mustr = list2str(config["mu"], delimiter=" ") + hits = config["hits"] + anserini_param_str = f"-qld -qld.mu {mustr} -hits {hits}" + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class QLJM(Searcher, AnseriniSearcherMixIn): + """ Anserini QL with Jelinek-Mercer smoothing. This searcher's parameters can also be specified as lists indicating parameters to grid search (e.g., ``"0.4,0.6,0.8"`` or ``"0.4..1,0.2"``). """ + + module_name = "QLJM" + dependencies = [Dependency(key="index", module="index", name="anserini")] + config_spec = [ + ConfigOption("lam", 0.1, value_type="floatlist"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + anserini_param_str = "-qljm -qljm.lambda {0} -hits {1}".format(list2str(config["lam"], delimiter=" "), config["hits"]) + + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class INL2(Searcher, AnseriniSearcherMixIn): + """ Anserini I(n)L2 scoring model. This searcher does not support list parameters. """ + + module_name = "INL2" + dependencies = [Dependency(key="index", module="index", name="anserini")] + config_spec = [ + ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + anserini_param_str = "-inl2 -inl2.c {0} -hits {1}".format(config["c"], config["hits"]) + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + return output_path + + +@Searcher.register +class SPL(Searcher, AnseriniSearcherMixIn): + """ + Anserini SPL scoring model. This searcher does not support list parameters. + """ + + module_name = "SPL" + dependencies = [Dependency(key="index", module="index", name="anserini")] + + config_spec = [ + ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + anserini_param_str = "-spl -spl.c {0} -hits {1}".format(config["c"], config["hits"]) + + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class F2Exp(Searcher, AnseriniSearcherMixIn): + """ + F2Exp scoring model. This searcher does not support list parameters. + """ + + module_name = "F2Exp" + dependencies = [Dependency(key="index", module="index", name="anserini")] + + config_spec = [ + ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + anserini_param_str = "-f2exp -f2exp.s {0} -hits {1}".format(config["s"], config["hits"]) + + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class F2Log(Searcher, AnseriniSearcherMixIn): + """ + F2Log scoring model. This searcher does not support list parameters. + """ + + module_name = "F2Log" + dependencies = [Dependency(key="index", module="index", name="anserini")] + + config_spec = [ + ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + anserini_param_str = "-f2log -f2log.s {0} -hits {1}".format(config["s"], config["hits"]) + + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path + + +@Searcher.register +class SDM(Searcher, AnseriniSearcherMixIn): + """ + Anserini BM25 with the Sequential Dependency Model. This searcher supports list parameters for only k1 and b. + """ + + module_name = "SDM" + dependencies = [Dependency(key="index", module="index", name="anserini")] + + # array input of (tw, ow, uw) is not support by anserini.SearchCollection + config_spec = [ + ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"), + ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"), + ConfigOption("tw", 0.85, "term weight"), + ConfigOption("ow", 0.15, "ordered window weight"), + ConfigOption("uw", 0.05, "unordered window weight"), + ConfigOption("hits", 1000, "number of results to return"), + ConfigOption("fields", "title"), + ] + + def _query_from_file(self, topicsfn, output_path, config): + hits = config["hits"] + anserini_param_str = "-sdm -sdm.tw {0} -sdm.ow {1} -sdm.uw {2}".format(*[config[k] for k in ["tw", "ow", "uw"]]) + anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1}".format(*[list2str(config[k], " ") for k in ["k1", "b"]]) + anserini_param_str += f" -hits {hits}" + self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"]) + + return output_path diff --git a/capreolus/searcher/tests/test_searcher.py b/capreolus/searcher/tests/test_searcher.py index 469f5dc62..0a87d61f8 100644 --- a/capreolus/searcher/tests/test_searcher.py +++ b/capreolus/searcher/tests/test_searcher.py @@ -1,11 +1,11 @@ import os import numpy as np import pytest -from profane import module_registry +from capreolus import module_registry from capreolus.utils.trec import load_trec_topics from capreolus.benchmark import DummyBenchmark -from capreolus.searcher import Searcher, BM25, BM25Grid +from capreolus.searcher.anserini import Searcher, BM25, BM25Grid from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index skip_searchers = {"bm25staticrob04yang19", "BM25Grid", "BM25Postprocess", "axiomatic"} diff --git a/capreolus/task/__init__.py b/capreolus/task/__init__.py index 6412f8b0e..b64a361d2 100644 --- a/capreolus/task/__init__.py +++ b/capreolus/task/__init__.py @@ -1,5 +1,62 @@ +from capreolus import ModuleBase, Dependency, ConfigOption, constants, module_registry + + +class Task(ModuleBase): + """Base class for Task modules. The purpose of a Task is to describe a Capreolus pipeline and serve as the pipeline's entry point. Tasks provide one or more commands that provide entry points while sharing the Task's configuration options and dependencies. + + Modules should provide: + - a ``commands`` attribute containing the names of methods that can serve as pipeline entry points (*Task commands*). Each command will be accessible via the CLI using the syntax ``capreolus . ...`` + - a ``default_command`` attribute containing the name of a command to run if none is given + - methods (taking only the *self* argument) that correspond to each command defined + """ + + module_type = "task" + commands = [] + help_commands = ["describe", "print_config", "print_paths", "print_pipeline"] + default_command = "describe" + requires_random_seed = True + + def print_config(self): + print("Configuration:") + self.print_module_config(prefix=" ") + + def print_paths(self): # TODO + pass + + def print_pipeline(self): + print(f"Module graph:") + self.print_module_graph(prefix=" ") + + def describe(self): + self.print_pipeline() + print("\n") + self.print_config() + + def get_results_path(self): + """ Return an absolute path that can be used for storing results. + The path is a function of the module's config and the configs of its dependencies. + """ + + return constants["RESULTS_BASE_PATH"] / self.get_module_path() + + +@Task.register +class ModulesTask(Task): + module_name = "modules" + commands = ["list_modules"] + default_command = "list_modules" + + def list_modules(self): + for module_type in module_registry.get_module_types(): + print(f"module type={module_type}") + + for module_name in module_registry.get_module_names(module_type): + print(f" name={module_name}") + + from profane import import_all_modules -from .base import Task +from .rank import RankTask +from .rerank import RerankTask import_all_modules(__file__, __package__) diff --git a/capreolus/task/base.py b/capreolus/task/base.py deleted file mode 100644 index 694713063..000000000 --- a/capreolus/task/base.py +++ /dev/null @@ -1,46 +0,0 @@ -from profane import ModuleBase, Dependency, ConfigOption, constants, module_registry - - -class Task(ModuleBase): - module_type = "task" - commands = [] - help_commands = ["describe", "print_config", "print_paths", "print_pipeline"] - default_command = "describe" - requires_random_seed = True - - def print_config(self): - print("Configuration:") - self.print_module_config(prefix=" ") - - def print_paths(self): # TODO - pass - - def print_pipeline(self): - print(f"Module graph:") - self.print_module_graph(prefix=" ") - - def describe(self): - self.print_pipeline() - print("\n") - self.print_config() - - def get_results_path(self): - """ Return an absolute path that can be used for storing results. - The path is a function of the module's config and the configs of its dependencies. - """ - - return constants["RESULTS_BASE_PATH"] / self.get_module_path() - - -@Task.register -class ModulesTask(Task): - module_name = "modules" - commands = ["list_modules"] - default_command = "list_modules" - - def list_modules(self): - for module_type in module_registry.get_module_types(): - print(f"module type={module_type}") - - for module_name in module_registry.get_module_names(module_type): - print(f" name={module_name}") diff --git a/capreolus/task/rank.py b/capreolus/task/rank.py index ff664fe4d..526aa0899 100644 --- a/capreolus/task/rank.py +++ b/capreolus/task/rank.py @@ -20,7 +20,9 @@ class RankTask(Task): config_keys_not_in_path = ["optimize", "metrics"] # affect only evaluation but not search() dependencies = [ - Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]), + Dependency( + key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"] + ), Dependency(key="searcher", module="searcher", name="BM25"), ] diff --git a/capreolus/task/rerank.py b/capreolus/task/rerank.py index 980eee019..963464dcc 100644 --- a/capreolus/task/rerank.py +++ b/capreolus/task/rerank.py @@ -23,7 +23,9 @@ class RerankTask(Task): ConfigOption("optimize", "map", "metric to maximize on the dev set"), # affects train() because we check to save weights ] dependencies = [ - Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]), + Dependency( + key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"] + ), Dependency(key="rank", module="task", name="rank"), Dependency(key="reranker", module="reranker", name="KNRM"), ] diff --git a/capreolus/task/rererank.py b/capreolus/task/rererank.py index 5928eb075..e8acd04a0 100644 --- a/capreolus/task/rererank.py +++ b/capreolus/task/rererank.py @@ -24,7 +24,9 @@ class ReRerankTask(Task): ConfigOption("topn", 100, "number of stage two results to rerank"), ] dependencies = [ - Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]), + Dependency( + key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"] + ), Dependency(key="rank", module="task", name="rank", provide_this=True), Dependency(key="rerank1", module="task", name="rerank"), Dependency(key="rerank2", module="task", name="rerank"), diff --git a/capreolus/task/tests/test_task.py b/capreolus/task/tests/test_task.py new file mode 100644 index 000000000..b9e7e0df2 --- /dev/null +++ b/capreolus/task/tests/test_task.py @@ -0,0 +1,13 @@ +import pytest + +from capreolus import Benchmark, Task, module_registry +from capreolus.tests.common_fixtures import dummy_index, tmpdir_as_cache + + +tasks = set(module_registry.get_module_names("task")) + + +@pytest.mark.parametrize("task_name", tasks) +def test_task_creatable(tmpdir_as_cache, dummy_index, task_name): + provide = {"index": dummy_index, "benchmark": Benchmark.create("dummy"), "collection": dummy_index.collection} + task = Task.create(task_name, provide=provide) diff --git a/capreolus/task/tutorial.py b/capreolus/task/tutorial.py index 1df31c4e6..0b3488af8 100644 --- a/capreolus/task/tutorial.py +++ b/capreolus/task/tutorial.py @@ -11,9 +11,7 @@ class TutorialTask(Task): module_name = "tutorial" config_spec = [ConfigOption("optimize", "map", "metric to maximize on the validation set")] dependencies = [ - Dependency( - key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"] - ), + Dependency(key="benchmark", module="benchmark", name="nf", provide_this=True, provide_children=["collection"]), Dependency(key="searcher1", module="searcher", name="BM25RM3"), Dependency(key="searcher2", module="searcher", name="SDM"), ] @@ -37,7 +35,7 @@ def run(self): ) for fold, path in best_results["path"].items(): - shortpath = "..." + path[:-20] + shortpath = "..." + path[-40:] logger.info("fold=%s best run: %s", fold, shortpath) logger.info("cross-validated results when optimizing for '%s':", self.config["optimize"]) diff --git a/capreolus/tests/test_benchmark.py b/capreolus/tests/test_benchmark.py index 0ade5c188..6a177be25 100644 --- a/capreolus/tests/test_benchmark.py +++ b/capreolus/tests/test_benchmark.py @@ -3,14 +3,26 @@ import pytest from tqdm import tqdm +from capreolus import Benchmark, module_registry from capreolus.utils.loginit import get_logger from capreolus.utils.common import remove_newline -from capreolus.benchmark import CodeSearchNetCorpus as CodeSearchNetCodeSearchNetCorpusBenchmark -from capreolus.benchmark import CodeSearchNetChallenge as CodeSearchNetCodeSearchNetChallengeBenchmark -from capreolus.collection import CodeSearchNet as CodeSearchNetCollection +from capreolus.benchmark.codesearchnet import CodeSearchNetCorpus as CodeSearchNetCodeSearchNetCorpusBenchmark +from capreolus.benchmark.codesearchnet import CodeSearchNetChallenge as CodeSearchNetCodeSearchNetChallengeBenchmark +from capreolus.collection.codesearchnet import CodeSearchNet as CodeSearchNetCollection +from capreolus.tests.common_fixtures import tmpdir_as_cache logger = get_logger(__name__) +benchmarks = set(module_registry.get_module_names("benchmark")) + + +@pytest.mark.parametrize("benchmark_name", benchmarks) +@pytest.mark.download +def test_benchmark_creatable(tmpdir_as_cache, benchmark_name): + benchmark = Benchmark.create(benchmark_name) + if hasattr(benchmark, "download_if_missing"): + benchmark.download_if_missing() + @pytest.mark.download def test_csn_corpus_benchmark_downloadifmissing(): diff --git a/capreolus/tests/test_collection.py b/capreolus/tests/test_collection.py index 34e455ecf..54a348107 100644 --- a/capreolus/tests/test_collection.py +++ b/capreolus/tests/test_collection.py @@ -3,8 +3,25 @@ import pytest +from capreolus import Collection, module_registry from capreolus.index import AnseriniIndex -from capreolus.collection import ANTIQUE, CodeSearchNet +from capreolus.collection.antique import ANTIQUE +from capreolus.collection.codesearchnet import CodeSearchNet +from capreolus.tests.common_fixtures import tmpdir_as_cache + +collections = set(module_registry.get_module_names("collection")) + + +@pytest.mark.parametrize("collection_name", collections) +def test_collection_creatable(tmpdir_as_cache, collection_name): + collection = Collection.create(collection_name) + + +@pytest.mark.parametrize("collection_name", collections) +@pytest.mark.download +def test_collection_downloadable(tmpdir_as_cache, collection_name): + collection = Collection.create(collection_name) + collection.find_document_path() @pytest.mark.download diff --git a/capreolus/tests/test_extractor.py b/capreolus/tests/test_extractor.py index 87f0d8333..e0ef28f26 100644 --- a/capreolus/tests/test_extractor.py +++ b/capreolus/tests/test_extractor.py @@ -4,12 +4,14 @@ from nltk import TextTilingTokenizer from pymagnitude import Magnitude, MagnitudeUtils import numpy as np +import pytest +from capreolus import Extractor, module_registry from capreolus.collection import DummyCollection from capreolus.index import AnseriniIndex from capreolus.tokenizer import AnseriniTokenizer from capreolus.benchmark import DummyBenchmark -from capreolus.extractor import EmbedText +from capreolus.extractor.embedtext import EmbedText from capreolus.tests.common_fixtures import tmpdir_as_cache, dummy_index from capreolus.utils.exceptions import MissingDocError @@ -19,6 +21,14 @@ MAXQLEN = 8 MAXDOCLEN = 7 +extractors = set(module_registry.get_module_names("extractor")) + + +@pytest.mark.parametrize("extractor_name", extractors) +def test_extractor_creatable(tmpdir_as_cache, dummy_index, extractor_name): + provide = {"index": dummy_index, "collection": dummy_index.collection} + extractor = Extractor.create(extractor_name, provide=provide) + def test_embedtext_creation(monkeypatch): def fake_magnitude_embedding(*args, **kwargs): diff --git a/capreolus/tokenizer/__init__.py b/capreolus/tokenizer/__init__.py index ccd471b54..0ce810e08 100644 --- a/capreolus/tokenizer/__init__.py +++ b/capreolus/tokenizer/__init__.py @@ -1,70 +1,19 @@ -from profane import import_all_modules - -# import_all_modules(__file__, __package__) - -from profane import ModuleBase, Dependency, ConfigOption - -from transformers import BertTokenizer as HFBertTokenizer +from capreolus import ModuleBase, Dependency, ConfigOption class Tokenizer(ModuleBase): - module_type = "tokenizer" - - -@Tokenizer.register -class AnseriniTokenizer(Tokenizer): - module_name = "anserini" - config_spec = [ - ConfigOption("keepstops", True, "keep stopwords if True"), - ConfigOption("stemmer", "none", "stemmer: porter, krovetz, or none"), - ] - - def build(self): - self._tokenize = self._get_tokenize_fn() - - def _get_tokenize_fn(self): - from jnius import autoclass - - stemmer, keepstops = self.config["stemmer"], self.config["keepstops"] - if stemmer is None: - stemmer = "none" + """Base class for Tokenizer modules. The purpose of a Tokenizer is to tokenize strings of text (e.g., as required by an :class:`~capreolus.extractor.Extractor`). - emptyjchar = autoclass("org.apache.lucene.analysis.CharArraySet").EMPTY_SET - Analyzer = autoclass("io.anserini.analysis.DefaultEnglishAnalyzer") - analyzer = Analyzer.newStemmingInstance(stemmer, emptyjchar) if keepstops else Analyzer.newStemmingInstance(stemmer) - tokenizefn = autoclass("io.anserini.analysis.AnalyzerUtils").analyze + Modules should provide: + - a ``tokenize(strings)`` method that takes a list of strings and returns tokenized versions + """ - def _tokenize(sentence): - return tokenizefn(analyzer, sentence).toArray() - - return _tokenize - - def tokenize(self, sentences): - if not sentences or len(sentences) == 0: # either "" or [] - return [] - - if isinstance(sentences, str): - return self._tokenize(sentences) - - return [self._tokenize(s) for s in sentences] - - -@Tokenizer.register -class BertTokenizer(Tokenizer): - module_name = "berttokenizer" - config_spec = [ConfigOption("pretrained", "bert-base-uncased", "pretrained model to load vocab from")] - - def build(self): - self.bert_tokenizer = HFBertTokenizer.from_pretrained(self.config["pretrained"]) + module_type = "tokenizer" - def convert_tokens_to_ids(self, tokens): - return self.bert_tokenizer.convert_tokens_to_ids(tokens) - def tokenize(self, sentences): - if not sentences or len(sentences) == 0: # either "" or [] - return [] +from profane import import_all_modules - if isinstance(sentences, str): - return self.bert_tokenizer.tokenize(sentences) +from .anserini import AnseriniTokenizer +from .bert import BertTokenizer - return [self.bert_tokenizer.tokenize(s) for s in sentences] +import_all_modules(__file__, __package__) diff --git a/capreolus/tokenizer/anserini.py b/capreolus/tokenizer/anserini.py new file mode 100644 index 000000000..13eb83a73 --- /dev/null +++ b/capreolus/tokenizer/anserini.py @@ -0,0 +1,40 @@ +from . import Tokenizer +from capreolus import ModuleBase, Dependency, ConfigOption + + +@Tokenizer.register +class AnseriniTokenizer(Tokenizer): + module_name = "anserini" + config_spec = [ + ConfigOption("keepstops", True, "keep stopwords if True"), + ConfigOption("stemmer", "none", "stemmer: porter, krovetz, or none"), + ] + + def build(self): + self._tokenize = self._get_tokenize_fn() + + def _get_tokenize_fn(self): + from jnius import autoclass + + stemmer, keepstops = self.config["stemmer"], self.config["keepstops"] + if stemmer is None: + stemmer = "none" + + emptyjchar = autoclass("org.apache.lucene.analysis.CharArraySet").EMPTY_SET + Analyzer = autoclass("io.anserini.analysis.DefaultEnglishAnalyzer") + analyzer = Analyzer.newStemmingInstance(stemmer, emptyjchar) if keepstops else Analyzer.newStemmingInstance(stemmer) + tokenizefn = autoclass("io.anserini.analysis.AnalyzerUtils").analyze + + def _tokenize(sentence): + return tokenizefn(analyzer, sentence).toArray() + + return _tokenize + + def tokenize(self, sentences): + if not sentences or len(sentences) == 0: # either "" or [] + return [] + + if isinstance(sentences, str): + return self._tokenize(sentences) + + return [self._tokenize(s) for s in sentences] diff --git a/capreolus/tokenizer/bert.py b/capreolus/tokenizer/bert.py new file mode 100644 index 000000000..29f698ec7 --- /dev/null +++ b/capreolus/tokenizer/bert.py @@ -0,0 +1,25 @@ +from . import Tokenizer +from capreolus import ModuleBase, Dependency, ConfigOption + +from transformers import BertTokenizer as HFBertTokenizer + + +@Tokenizer.register +class BertTokenizer(Tokenizer): + module_name = "berttokenizer" + config_spec = [ConfigOption("pretrained", "bert-base-uncased", "pretrained model to load vocab from")] + + def build(self): + self.bert_tokenizer = HFBertTokenizer.from_pretrained(self.config["pretrained"]) + + def convert_tokens_to_ids(self, tokens): + return self.bert_tokenizer.convert_tokens_to_ids(tokens) + + def tokenize(self, sentences): + if not sentences or len(sentences) == 0: # either "" or [] + return [] + + if isinstance(sentences, str): + return self.bert_tokenizer.tokenize(sentences) + + return [self.bert_tokenizer.tokenize(s) for s in sentences] diff --git a/capreolus/trainer/__init__.py b/capreolus/trainer/__init__.py index 2ca2011bf..a0bc0fde2 100644 --- a/capreolus/trainer/__init__.py +++ b/capreolus/trainer/__init__.py @@ -1,37 +1,18 @@ -from profane import import_all_modules - -# import_all_modules(__file__, __package__) - -import hashlib -import math import os -import sys -import time -import uuid -from collections import defaultdict -from copy import copy -from profane import ModuleBase, Dependency, ConfigOption, constants -import tensorflow as tf -import tensorflow_ranking as tfr -import numpy as np -import torch -from keras import Sequential, layers -from keras.layers import Dense -from torch.utils.tensorboard import SummaryWriter -from tqdm import tqdm - -from capreolus.reranker.common import pair_hinge_loss, pair_softmax_loss -from capreolus.searcher import Searcher -from capreolus.utils.loginit import get_logger -from capreolus.utils.common import plot_metrics, plot_loss -from capreolus import evaluator +from capreolus import ModuleBase, Dependency, ConfigOption, constants, get_logger logger = get_logger(__name__) # pylint: disable=invalid-name -RESULTS_BASE_PATH = constants["RESULTS_BASE_PATH"] class Trainer(ModuleBase): + """Base class for Trainer modules. The purpose of a Trainer is to train a :class:`~capreolus.reranker.Reranker` module and use it to make predictions. Capreolus provides two trainers: :class:`~capreolus.trainer.pytorch.PytorchTrainer` and :class:`~capreolus.trainer.tensorflow.TensorFlowTrainer` + + Modules should provide: + - a ``train`` method that trains a reranker on training and dev (validation) data + - a ``predict`` method that uses a reranker to make predictions on data + """ + module_type = "trainer" requires_random_seed = True @@ -49,649 +30,9 @@ def get_paths_for_early_stopping(self, train_output_path, dev_output_path): return dev_best_weight_fn, weights_output_path, info_output_path, loss_fn -@Trainer.register -class PytorchTrainer(Trainer): - module_name = "pytorch" - config_spec = [ - ConfigOption("batch", 32, "batch size"), - ConfigOption("niters", 20, "number of iterations to train for"), - ConfigOption("itersize", 512, "number of training instances in one iteration"), - ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"), - ConfigOption("lr", 0.001, "learning rate"), - ConfigOption("softmaxloss", False, "True to use softmax loss (over pairs) or False to use hinge loss"), - ConfigOption("fastforward", False), - ConfigOption("validatefreq", 1), - ConfigOption("boardname", "default"), - ] - config_keys_not_in_path = ["fastforward", "boardname"] - - def build(self): - # sanity checks - if self.config["batch"] < 1: - raise ValueError("batch must be >= 1") - - if self.config["niters"] <= 0: - raise ValueError("niters must be > 0") - - if self.config["itersize"] < self.config["batch"]: - raise ValueError("itersize must be >= batch") - - if self.config["gradacc"] < 1 or not float(self.config["gradacc"]).is_integer(): - raise ValueError("gradacc must be an integer >= 1") - - if self.config["lr"] <= 0: - raise ValueError("lr must be > 0") - - torch.manual_seed(self.config["seed"]) - torch.cuda.manual_seed_all(self.config["seed"]) - - def single_train_iteration(self, reranker, train_dataloader): - """Train model for one iteration using instances from train_dataloader. - - Args: - model (Reranker): a PyTorch Reranker - train_dataloader (DataLoader): a PyTorch DataLoader that iterates over training instances - - Returns: - float: average loss over the iteration - - """ - - iter_loss = [] - batches_since_update = 0 - batches_per_epoch = (self.config["itersize"] // self.config["batch"]) or 1 - batches_per_step = self.config["gradacc"] - - for bi, batch in tqdm(enumerate(train_dataloader), desc="Iter progression"): - # TODO make sure _prepare_batch_with_strings equivalent is happening inside the sampler - batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()} - doc_scores = reranker.score(batch) - loss = self.loss(doc_scores) - iter_loss.append(loss) - loss.backward() - - batches_since_update += 1 - if batches_since_update == batches_per_step: - batches_since_update = 0 - self.optimizer.step() - self.optimizer.zero_grad() - - if (bi + 1) % batches_per_epoch == 0: - break - - return torch.stack(iter_loss).mean() - - def load_loss_file(self, fn): - """Loads loss history from fn - - Args: - fn (Path): path to a loss.txt file - - Returns: - a list of losses ordered by iterations - - """ - - loss = [] - with fn.open(mode="rt") as f: - for lineidx, line in enumerate(f): - line = line.strip() - if not line: - continue - - iteridx, iterloss = line.rstrip().split() - - if int(iteridx) != lineidx: - raise IOError(f"malformed loss file {fn} ... did two processes write to it?") - - loss.append(float(iterloss)) - - return loss - - def fastforward_training(self, reranker, weights_path, loss_fn): - """Skip to the last training iteration whose weights were saved. - - If saved model and optimizer weights are available, this method will load those weights into model - and optimizer, and then return the next iteration to be run. For example, if weights are available for - iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and - this method will return 11. - - If an error or inconsistency is encountered when checking for weights, this method returns 0. - - This method checks several files to determine if weights "are available". First, loss_fn is read to - determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) - Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. - If this is successful, the method returns `1 + last recorded iteration`. If not, it returns 0. - (We consider loss_fn because it is written at the end of every training iteration.) - - Args: - model (Reranker): a PyTorch Reranker whose state should be loaded - weights_path (Path): directory containing model and optimizer weights - loss_fn (Path): file containing loss history - - Returns: - int: the next training iteration after fastforwarding. If successful, this is > 0. - If no weights are available or they cannot be loaded, 0 is returned. - - """ - - if not (weights_path.exists() and loss_fn.exists()): - return 0 - - try: - loss = self.load_loss_file(loss_fn) - except IOError: - return 0 - - last_loss_iteration = len(loss) - 1 - weights_fn = weights_path / f"{last_loss_iteration}.p" - - try: - reranker.load_weights(weights_fn, self.optimizer) - return last_loss_iteration + 1 - except: - logger.info("attempted to load weights from %s but failed, starting at iteration 0", weights_fn) - return 0 - - def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1): - """Train a model following the trainer's config (specifying batch size, number of iterations, etc). - - Args: - train_dataset (IterableDataset): training dataset - train_output_path (Path): directory under which train_dataset runs and training loss will be saved - dev_data (IterableDataset): dev dataset - dev_output_path (Path): directory where dev_data runs and metrics will be saved - - """ - # Set up logging - # TODO why not put this under train_output_path? - summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" / self.config["boardname"], comment=train_output_path) - - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - model = reranker.model.to(self.device) - self.optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=self.config["lr"]) - - if self.config["softmaxloss"]: - self.loss = pair_softmax_loss - else: - self.loss = pair_hinge_loss - - dev_best_weight_fn, weights_output_path, info_output_path, loss_fn = self.get_paths_for_early_stopping( - train_output_path, dev_output_path - ) - - initial_iter = self.fastforward_training(reranker, weights_output_path, loss_fn) if self.config["fastforward"] else 0 - logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"]) - - train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=self.config["batch"], pin_memory=True, num_workers=0 - ) - # dataiter = iter(train_dataloader) - # sample_input = dataiter.next() - # summary_writer.add_graph( - # reranker.model, - # [ - # sample_input["query"].to(self.device), - # sample_input["posdoc"].to(self.device), - # sample_input["negdoc"].to(self.device), - # ], - # ) - - train_loss = [] - # are we resuming training? - if initial_iter > 0: - train_loss = self.load_loss_file(loss_fn) - - # are we done training? - if initial_iter < self.config["niters"]: - logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter) - batches_per_epoch = self.config["itersize"] // self.config["batch"] - for niter in range(initial_iter): - for bi, batch in enumerate(train_dataloader): - if (bi + 1) % batches_per_epoch == 0: - break - - dev_best_metric = -np.inf - validation_frequency = self.config["validatefreq"] - train_start_time = time.time() - for niter in range(initial_iter, self.config["niters"]): - model.train() - - iter_start_time = time.time() - iter_loss_tensor = self.single_train_iteration(reranker, train_dataloader) - logger.info("A single iteration takes {}".format(time.time() - iter_start_time)) - train_loss.append(iter_loss_tensor.item()) - logger.info("iter = %d loss = %f", niter, train_loss[-1]) - - # write model weights to file - weights_fn = weights_output_path / f"{niter}.p" - reranker.save_weights(weights_fn, self.optimizer) - # predict performance on dev set - - if niter % validation_frequency == 0: - pred_fn = dev_output_path / f"{niter}.run" - preds = self.predict(reranker, dev_data, pred_fn) - - # log dev metrics - metrics = evaluator.eval_runs(preds, qrels, evaluator.DEFAULT_METRICS, relevance_level) - logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())])) - summary_writer.add_scalar("ndcg_cut_20", metrics["ndcg_cut_20"], niter) - summary_writer.add_scalar("map", metrics["map"], niter) - summary_writer.add_scalar("P_20", metrics["P_20"], niter) - # write best dev weights to file - if metrics[metric] > dev_best_metric: - reranker.save_weights(dev_best_weight_fn, self.optimizer) - - # write train_loss to file - loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss))) - - summary_writer.add_scalar("training_loss", iter_loss_tensor.item(), niter) - reranker.add_summary(summary_writer, niter) - summary_writer.flush() - logger.info("training loss: %s", train_loss) - logger.info("Training took {}".format(time.time() - train_start_time)) - summary_writer.close() - - # TODO should we write a /done so that training can be skipped if possible when fastforward=False? or in Task? - - def load_best_model(self, reranker, train_output_path): - self.optimizer = torch.optim.Adam( - filter(lambda param: param.requires_grad, reranker.model.parameters()), lr=self.config["lr"] - ) - - dev_best_weight_fn = train_output_path / "dev.best" - reranker.load_weights(dev_best_weight_fn, self.optimizer) - - def predict(self, reranker, pred_data, pred_fn): - """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn` - - Args: - model (Reranker): a PyTorch Reranker - pred_data (IterableDataset): data to predict on - pred_fn (Path): path to write the prediction run file to - - Returns: - TREC Run - - """ - - self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # save to pred_fn - model = reranker.model.to(self.device) - model.eval() - - preds = {} - pred_dataloader = torch.utils.data.DataLoader(pred_data, batch_size=self.config["batch"], pin_memory=True, num_workers=0) - with torch.autograd.no_grad(): - for batch in tqdm(pred_dataloader, desc="Predicting on dev"): - if len(batch["qid"]) != self.config["batch"]: - batch = self.fill_incomplete_batch(batch) - - batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()} - scores = reranker.test(batch) - scores = scores.view(-1).cpu().numpy() - for qid, docid, score in zip(batch["qid"], batch["posdocid"], scores): - # Need to use float16 because pytrec_eval's c function call crashes with higher precision floats - preds.setdefault(qid, {})[docid] = score.astype(np.float16).item() - - os.makedirs(os.path.dirname(pred_fn), exist_ok=True) - Searcher.write_trec_run(preds, pred_fn) - - return preds - - def fill_incomplete_batch(self, batch): - """ - If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data. - How the data is chosen: - If the values are just a simple list, use the first element of the list to pad the batch - If the values are tensors/numpy arrays, use repeat() along the batch dimension - """ - # logger.debug("filling in an incomplete batch") - repeat_times = math.ceil(self.config["batch"] / len(batch["qid"])) - diff = self.config["batch"] - len(batch["qid"]) - - def pad(v): - if isinstance(v, np.ndarray) or torch.is_tensor(v): - _v = v.repeat((repeat_times,) + tuple([1 for x in range(len(v.shape) - 1)])) - else: - _v = v + [v[0]] * diff - - return _v[: self.config["batch"]] - - batch = {k: pad(v) for k, v in batch.items()} - return batch - - -class TrecCheckpointCallback(tf.keras.callbacks.Callback): - """ - A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. - See TensorflowTrainer.train() for the invocation - Also saves the best model to disk - """ - - def __init__(self, qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs): - super(TrecCheckpointCallback, self).__init__(*args, **kwargs) - """ - qrels - a qrels dict - dev_data - a torch.utils.IterableDataset - dev_records - a BatchedDataset instance - """ - self.best_metric = -np.inf - self.qrels = qrels - self.dev_data = dev_data - self.dev_records = dev_records - self.output_path = output_path - self.iter_start_time = time.time() - self.metric = metric - self.validate_freq = validate_freq - self.relevance_level = relevance_level - - def save_model(self): - self.model.save_weights("{0}/dev.best".format(self.output_path)) - - def on_epoch_begin(self, epoch, logs=None): - self.iter_start_time = time.time() - - def on_epoch_end(self, epoch, logs=None): - logger.debug("Epoch {} took {}".format(epoch, time.time() - self.iter_start_time)) - if (epoch + 1) % self.validate_freq == 0: - predictions = self.model.predict(self.dev_records, verbose=1, workers=8, use_multiprocessing=True) - trec_preds = self.get_preds_in_trec_format(predictions, self.dev_data) - metrics = evaluator.eval_runs(trec_preds, dict(self.qrels), evaluator.DEFAULT_METRICS, self.relevance_level) - logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())])) - - if metrics[self.metric] > self.best_metric: - self.best_metric = metrics[self.metric] - # TODO: Prevent the embedding layer weights from being saved - self.save_model() - - @staticmethod - def get_preds_in_trec_format(predictions, dev_data): - """ - Takes in a list of predictions and returns a dict that can be fed into pytrec_eval - As a side effect, also writes the predictions into a file in the trec format - """ - pred_dict = defaultdict(lambda: dict()) - - for i, (qid, docid) in enumerate(dev_data.get_qid_docid_pairs()): - # Pytrec_eval has problems with high precision floats - pred_dict[qid][docid] = predictions[i][0].astype(np.float16).item() - - return dict(pred_dict) - - -@Trainer.register -class TensorFlowTrainer(Trainer): - module_name = "tensorflow" - - config_spec = [ - ConfigOption("batch", 32, "batch size"), - ConfigOption("niters", 20, "number of iterations to train for"), - ConfigOption("itersize", 512, "number of training instances in one iteration"), - # ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"), - ConfigOption("lr", 0.001, "learning rate"), - ConfigOption("loss", "pairwise_hinge_loss", "must be one of tfr.losses.RankingLossKey"), - # ConfigOption("fastforward", False), - ConfigOption("validatefreq", 1), - ConfigOption("boardname", "default"), - ConfigOption("usecache", False), - ConfigOption("tpuname", None), - ConfigOption("tpuzone", None), - ConfigOption("storage", None), - ] - config_keys_not_in_path = ["fastforward", "boardname", "usecache", "tpuname", "tpuzone", "storage"] - - def build(self): - tf.random.set_seed(self.config["seed"]) - - # Use TPU if available, otherwise resort to GPU/CPU - try: - self.tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=self.config["tpuname"], zone=self.config["tpuzone"]) - except ValueError: - self.tpu = None - logger.info("Could not find the tpu") - - # TPUStrategy for distributed training - if self.tpu: - logger.info("Utilizing TPUs") - tf.config.experimental_connect_to_cluster(self.tpu) - tf.tpu.experimental.initialize_tpu_system(self.tpu) - self.strategy = tf.distribute.experimental.TPUStrategy(self.tpu) - else: # default strategy that works on CPU and single GPU - self.strategy = tf.distribute.get_strategy() - - # Defining some props that we will later initialize - self.optimizer = None - self.loss = None - self.validate() - - def validate(self): - if self.tpu and any([self.config["storage"] is None, self.config["tpuname"] is None, self.config["tpuzone"] is None]): - raise ValueError("storage, tpuname and tpuzone configs must be provided when training on TPU") - if self.tpu and self.config["storage"] and not self.config["storage"].startswith("gs://"): - raise ValueError("For TPU utilization, the storage config should start with 'gs://'") - - def get_optimizer(self): - return tf.keras.optimizers.Adam(learning_rate=self.config["lr"]) - - def fastforward_training(self, reranker, weights_path, loss_fn): - # TODO: Fix fast forwarding - return 0 - - def load_best_model(self, reranker, train_output_path): - # TODO: Do the train_output_path modification at one place? - if self.tpu: - train_output_path = "{0}/{1}/{2}".format( - self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest() - ) - - reranker.model.load_weights("{0}/dev.best".format(train_output_path)) - - def apply_gradients(self, weights, grads): - self.optimizer.apply_gradients(zip(grads, weights)) - - def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1): - # summary_writer = tf.summary.create_file_writer("{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"])) - - # Because TPUs can't work with local files - if self.tpu: - train_output_path = "{0}/{1}/{2}".format( - self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest() - ) - - os.makedirs(dev_output_path, exist_ok=True) - initial_iter = self.fastforward_training(reranker, dev_output_path, None) - logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"]) - - strategy_scope = self.strategy.scope() - with strategy_scope: - train_records = self.get_tf_train_records(reranker, train_dataset) - dev_records = self.get_tf_dev_records(reranker, dev_data) - trec_callback = TrecCheckpointCallback( - qrels, - dev_data, - dev_records, - train_output_path, - metric, - self.config["validatefreq"], - relevance_level=relevance_level, - ) - tensorboard_callback = tf.keras.callbacks.TensorBoard( - log_dir="{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"]) - ) - reranker.build_model() # TODO needed here? - - self.optimizer = self.get_optimizer() - loss = tfr.keras.losses.get(self.config["loss"]) - reranker.model.compile(optimizer=self.optimizer, loss=loss) - - train_start_time = time.time() - reranker.model.fit( - train_records.prefetch(tf.data.experimental.AUTOTUNE), - epochs=self.config["niters"], - steps_per_epoch=self.config["itersize"], - callbacks=[tensorboard_callback, trec_callback], - workers=8, - use_multiprocessing=True, - ) - logger.info("Training took {}".format(time.time() - train_start_time)) - - # Skipping dumping metrics and plotting loss since that should be done through tensorboard - - def create_tf_feature(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc): - """ - Creates a single tf.train.Feature instance (i.e, a single sample) - """ - feature = { - "qid": tf.train.Feature(bytes_list=tf.train.BytesList(value=[qid.encode("utf-8")])), - "query": tf.train.Feature(float_list=tf.train.FloatList(value=query)), - "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)), - "posdoc_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[posdoc_id.encode("utf-8")])), - "posdoc": tf.train.Feature(float_list=tf.train.FloatList(value=posdoc)), - } - - if negdoc_id: - feature["negdoc_id"] = (tf.train.Feature(bytes_list=tf.train.BytesList(value=[negdoc_id.encode("utf-8")])),) - feature["negdoc"] = tf.train.Feature(float_list=tf.train.FloatList(value=negdoc)) - - return feature - - def write_tf_record_to_file(self, dir_name, tf_features): - """ - Actually write the tf record to file. The destination can also be a gcs bucket. - TODO: Use generators to optimize memory usage - """ - filename = "{0}/{1}.tfrecord".format(dir_name, str(uuid.uuid4())) - examples = [tf.train.Example(features=tf.train.Features(feature=feature)) for feature in tf_features] - - if not os.path.isdir(dir_name): - os.makedirs(dir_name, exist_ok=True) - - examples = [example.SerializeToString() for example in examples] - with tf.io.TFRecordWriter(filename) as writer: - for example in examples: - writer.write(example) - - logger.info("Wrote tf record file: {}".format(filename)) - - return str(filename) - - def convert_to_tf_dev_record(self, reranker, dataset): - """ - Similar to self.convert_to_tf_train_record(), but won't result in multiple files - """ - dir_name = self.get_tf_record_cache_path(dataset) - - tf_features = [reranker.extractor.create_tf_feature(sample) for sample in dataset] - - return [self.write_tf_record_to_file(dir_name, tf_features)] - - def convert_to_tf_train_record(self, reranker, dataset): - """ - Tensorflow works better if the input data is fed in as tfrecords - Takes in a dataset, iterates through it, and creates multiple tf records from it. - The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature() - """ - dir_name = self.get_tf_record_cache_path(dataset) - - total_samples = dataset.get_total_samples() - tf_features = [] - tf_record_filenames = [] - - for niter in tqdm(range(0, self.config["niters"]), desc="Converting data to tf records"): - for sample_idx, sample in enumerate(dataset): - tf_features.append(reranker.extractor.create_tf_feature(sample)) - - if len(tf_features) > 20000: - tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features)) - tf_features = [] - - if sample_idx + 1 >= self.config["itersize"] * self.config["batch"]: - break - - if len(tf_features): - tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features)) - - return tf_record_filenames - - def get_tf_record_cache_path(self, dataset): - """ - Get the path to the directory where tf records are written to. - If using TPUs, this will be a gcs path. - """ - if self.tpu: - return "{0}/capreolus_tfrecords/{1}".format(self.config["storage"], dataset.get_hash()) - else: - base_path = self.get_cache_path() - return "{0}/{1}".format(base_path, dataset.get_hash()) - - def cache_exists(self, dataset): - # TODO: Add checks to make sure that the number of files in the directory is correct - cache_dir = self.get_tf_record_cache_path(dataset) - logger.info("The cache path is {0} and does it exist? : {1}".format(cache_dir, tf.io.gfile.exists(cache_dir))) - - return tf.io.gfile.isdir(cache_dir) - - def load_tf_records_from_file(self, reranker, filenames, batch_size): - raw_dataset = tf.data.TFRecordDataset(filenames) - tf_records_dataset = raw_dataset.batch(batch_size, drop_remainder=True).map( - reranker.extractor.parse_tf_example, num_parallel_calls=tf.data.experimental.AUTOTUNE - ) - - return tf_records_dataset - - def load_cached_tf_records(self, reranker, dataset, batch_size): - logger.info("Loading TF records from cache") - cache_dir = self.get_tf_record_cache_path(dataset) - filenames = tf.io.gfile.listdir(cache_dir) - filenames = ["{0}/{1}".format(cache_dir, name) for name in filenames] - - return self.load_tf_records_from_file(reranker, filenames, batch_size) - - def get_tf_dev_records(self, reranker, dataset): - """ - 1. Returns tf records from cache (disk) if applicable - 2. Else, converts the dataset into tf records, writes them to disk, and returns them - """ - if self.config["usecache"] and self.cache_exists(dataset): - return self.load_cached_tf_records(reranker, dataset, 1) - else: - tf_record_filenames = self.convert_to_tf_dev_record(reranker, dataset) - # TODO use actual batch size here. see issue #52 - return self.load_tf_records_from_file(reranker, tf_record_filenames, 1) # self.config["batch"]) - - def get_tf_train_records(self, reranker, dataset): - """ - 1. Returns tf records from cache (disk) if applicable - 2. Else, converts the dataset into tf records, writes them to disk, and returns them - """ - - if self.config["usecache"] and self.cache_exists(dataset): - return self.load_cached_tf_records(reranker, dataset, self.config["batch"]) - else: - tf_record_filenames = self.convert_to_tf_train_record(reranker, dataset) - return self.load_tf_records_from_file(reranker, tf_record_filenames, self.config["batch"]) - - def predict(self, reranker, pred_data, pred_fn): - """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn` - - Args: - model (Reranker): a PyTorch Reranker - pred_data (IterableDataset): data to predict on - pred_fn (Path): path to write the prediction run file to - - Returns: - TREC Run - - """ - - strategy_scope = self.strategy.scope() - with strategy_scope: - pred_records = self.get_tf_dev_records(reranker, pred_data) - predictions = reranker.model.predict(pred_records) - trec_preds = TrecCheckpointCallback.get_preds_in_trec_format(predictions, pred_data) +from profane import import_all_modules - os.makedirs(os.path.dirname(pred_fn), exist_ok=True) - Searcher.write_trec_run(trec_preds, pred_fn) +from .pytorch import PytorchTrainer +from .tensorflow import TensorFlowTrainer - return trec_preds +import_all_modules(__file__, __package__) diff --git a/capreolus/trainer/pytorch.py b/capreolus/trainer/pytorch.py new file mode 100644 index 000000000..78938bace --- /dev/null +++ b/capreolus/trainer/pytorch.py @@ -0,0 +1,329 @@ +import math +import os +import time + +import numpy as np +import torch +from torch.utils.tensorboard import SummaryWriter +from tqdm import tqdm + +from . import Trainer +from capreolus import ModuleBase, Dependency, ConfigOption, Searcher, constants, evaluator, get_logger +from capreolus.reranker.common import pair_hinge_loss, pair_softmax_loss +from capreolus.utils.common import plot_metrics, plot_loss + +logger = get_logger(__name__) # pylint: disable=invalid-name +RESULTS_BASE_PATH = constants["RESULTS_BASE_PATH"] + + +@Trainer.register +class PytorchTrainer(Trainer): + module_name = "pytorch" + config_spec = [ + ConfigOption("batch", 32, "batch size"), + ConfigOption("niters", 20, "number of iterations to train for"), + ConfigOption("itersize", 512, "number of training instances in one iteration"), + ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"), + ConfigOption("lr", 0.001, "learning rate"), + ConfigOption("softmaxloss", False, "True to use softmax loss (over pairs) or False to use hinge loss"), + ConfigOption("fastforward", False), + ConfigOption("validatefreq", 1), + ConfigOption("boardname", "default"), + ] + config_keys_not_in_path = ["fastforward", "boardname"] + + def build(self): + # sanity checks + if self.config["batch"] < 1: + raise ValueError("batch must be >= 1") + + if self.config["niters"] <= 0: + raise ValueError("niters must be > 0") + + if self.config["itersize"] < self.config["batch"]: + raise ValueError("itersize must be >= batch") + + if self.config["gradacc"] < 1 or not float(self.config["gradacc"]).is_integer(): + raise ValueError("gradacc must be an integer >= 1") + + if self.config["lr"] <= 0: + raise ValueError("lr must be > 0") + + torch.manual_seed(self.config["seed"]) + torch.cuda.manual_seed_all(self.config["seed"]) + + def single_train_iteration(self, reranker, train_dataloader): + """Train model for one iteration using instances from train_dataloader. + + Args: + model (Reranker): a PyTorch Reranker + train_dataloader (DataLoader): a PyTorch DataLoader that iterates over training instances + + Returns: + float: average loss over the iteration + + """ + + iter_loss = [] + batches_since_update = 0 + batches_per_epoch = (self.config["itersize"] // self.config["batch"]) or 1 + batches_per_step = self.config["gradacc"] + + for bi, batch in tqdm(enumerate(train_dataloader), desc="Iter progression"): + # TODO make sure _prepare_batch_with_strings equivalent is happening inside the sampler + batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()} + doc_scores = reranker.score(batch) + loss = self.loss(doc_scores) + iter_loss.append(loss) + loss.backward() + + batches_since_update += 1 + if batches_since_update == batches_per_step: + batches_since_update = 0 + self.optimizer.step() + self.optimizer.zero_grad() + + if (bi + 1) % batches_per_epoch == 0: + break + + return torch.stack(iter_loss).mean() + + def load_loss_file(self, fn): + """Loads loss history from fn + + Args: + fn (Path): path to a loss.txt file + + Returns: + a list of losses ordered by iterations + + """ + + loss = [] + with fn.open(mode="rt") as f: + for lineidx, line in enumerate(f): + line = line.strip() + if not line: + continue + + iteridx, iterloss = line.rstrip().split() + + if int(iteridx) != lineidx: + raise IOError(f"malformed loss file {fn} ... did two processes write to it?") + + loss.append(float(iterloss)) + + return loss + + def fastforward_training(self, reranker, weights_path, loss_fn): + """Skip to the last training iteration whose weights were saved. + + If saved model and optimizer weights are available, this method will load those weights into model + and optimizer, and then return the next iteration to be run. For example, if weights are available for + iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and + this method will return 11. + + If an error or inconsistency is encountered when checking for weights, this method returns 0. + + This method checks several files to determine if weights "are available". First, loss_fn is read to + determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) + Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. + If this is successful, the method returns `1 + last recorded iteration`. If not, it returns 0. + (We consider loss_fn because it is written at the end of every training iteration.) + + Args: + model (Reranker): a PyTorch Reranker whose state should be loaded + weights_path (Path): directory containing model and optimizer weights + loss_fn (Path): file containing loss history + + Returns: + int: the next training iteration after fastforwarding. If successful, this is > 0. + If no weights are available or they cannot be loaded, 0 is returned. + + """ + + if not (weights_path.exists() and loss_fn.exists()): + return 0 + + try: + loss = self.load_loss_file(loss_fn) + except IOError: + return 0 + + last_loss_iteration = len(loss) - 1 + weights_fn = weights_path / f"{last_loss_iteration}.p" + + try: + reranker.load_weights(weights_fn, self.optimizer) + return last_loss_iteration + 1 + except: + logger.info("attempted to load weights from %s but failed, starting at iteration 0", weights_fn) + return 0 + + def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1): + """Train a model following the trainer's config (specifying batch size, number of iterations, etc). + + Args: + train_dataset (IterableDataset): training dataset + train_output_path (Path): directory under which train_dataset runs and training loss will be saved + dev_data (IterableDataset): dev dataset + dev_output_path (Path): directory where dev_data runs and metrics will be saved + + """ + # Set up logging + # TODO why not put this under train_output_path? + summary_writer = SummaryWriter(RESULTS_BASE_PATH / "runs" / self.config["boardname"], comment=train_output_path) + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model = reranker.model.to(self.device) + self.optimizer = torch.optim.Adam(filter(lambda param: param.requires_grad, model.parameters()), lr=self.config["lr"]) + + if self.config["softmaxloss"]: + self.loss = pair_softmax_loss + else: + self.loss = pair_hinge_loss + + dev_best_weight_fn, weights_output_path, info_output_path, loss_fn = self.get_paths_for_early_stopping( + train_output_path, dev_output_path + ) + + initial_iter = self.fastforward_training(reranker, weights_output_path, loss_fn) if self.config["fastforward"] else 0 + logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"]) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=self.config["batch"], pin_memory=True, num_workers=0 + ) + # dataiter = iter(train_dataloader) + # sample_input = dataiter.next() + # summary_writer.add_graph( + # reranker.model, + # [ + # sample_input["query"].to(self.device), + # sample_input["posdoc"].to(self.device), + # sample_input["negdoc"].to(self.device), + # ], + # ) + + train_loss = [] + # are we resuming training? + if initial_iter > 0: + train_loss = self.load_loss_file(loss_fn) + + # are we done training? + if initial_iter < self.config["niters"]: + logger.debug("fastforwarding train_dataloader to iteration %s", initial_iter) + batches_per_epoch = self.config["itersize"] // self.config["batch"] + for niter in range(initial_iter): + for bi, batch in enumerate(train_dataloader): + if (bi + 1) % batches_per_epoch == 0: + break + + dev_best_metric = -np.inf + validation_frequency = self.config["validatefreq"] + train_start_time = time.time() + for niter in range(initial_iter, self.config["niters"]): + model.train() + + iter_start_time = time.time() + iter_loss_tensor = self.single_train_iteration(reranker, train_dataloader) + logger.info("A single iteration takes {}".format(time.time() - iter_start_time)) + train_loss.append(iter_loss_tensor.item()) + logger.info("iter = %d loss = %f", niter, train_loss[-1]) + + # write model weights to file + weights_fn = weights_output_path / f"{niter}.p" + reranker.save_weights(weights_fn, self.optimizer) + # predict performance on dev set + + if niter % validation_frequency == 0: + pred_fn = dev_output_path / f"{niter}.run" + preds = self.predict(reranker, dev_data, pred_fn) + + # log dev metrics + metrics = evaluator.eval_runs(preds, qrels, evaluator.DEFAULT_METRICS, relevance_level) + logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())])) + summary_writer.add_scalar("ndcg_cut_20", metrics["ndcg_cut_20"], niter) + summary_writer.add_scalar("map", metrics["map"], niter) + summary_writer.add_scalar("P_20", metrics["P_20"], niter) + # write best dev weights to file + if metrics[metric] > dev_best_metric: + reranker.save_weights(dev_best_weight_fn, self.optimizer) + + # write train_loss to file + loss_fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(train_loss))) + + summary_writer.add_scalar("training_loss", iter_loss_tensor.item(), niter) + reranker.add_summary(summary_writer, niter) + summary_writer.flush() + logger.info("training loss: %s", train_loss) + logger.info("Training took {}".format(time.time() - train_start_time)) + summary_writer.close() + + # TODO should we write a /done so that training can be skipped if possible when fastforward=False? or in Task? + + def load_best_model(self, reranker, train_output_path): + self.optimizer = torch.optim.Adam( + filter(lambda param: param.requires_grad, reranker.model.parameters()), lr=self.config["lr"] + ) + + dev_best_weight_fn = train_output_path / "dev.best" + reranker.load_weights(dev_best_weight_fn, self.optimizer) + + def predict(self, reranker, pred_data, pred_fn): + """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn` + + Args: + model (Reranker): a PyTorch Reranker + pred_data (IterableDataset): data to predict on + pred_fn (Path): path to write the prediction run file to + + Returns: + TREC Run + + """ + + self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # save to pred_fn + model = reranker.model.to(self.device) + model.eval() + + preds = {} + pred_dataloader = torch.utils.data.DataLoader(pred_data, batch_size=self.config["batch"], pin_memory=True, num_workers=0) + with torch.autograd.no_grad(): + for batch in tqdm(pred_dataloader, desc="Predicting on dev"): + if len(batch["qid"]) != self.config["batch"]: + batch = self.fill_incomplete_batch(batch) + + batch = {k: v.to(self.device) if not isinstance(v, list) else v for k, v in batch.items()} + scores = reranker.test(batch) + scores = scores.view(-1).cpu().numpy() + for qid, docid, score in zip(batch["qid"], batch["posdocid"], scores): + # Need to use float16 because pytrec_eval's c function call crashes with higher precision floats + preds.setdefault(qid, {})[docid] = score.astype(np.float16).item() + + os.makedirs(os.path.dirname(pred_fn), exist_ok=True) + Searcher.write_trec_run(preds, pred_fn) + + return preds + + def fill_incomplete_batch(self, batch): + """ + If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data. + How the data is chosen: + If the values are just a simple list, use the first element of the list to pad the batch + If the values are tensors/numpy arrays, use repeat() along the batch dimension + """ + # logger.debug("filling in an incomplete batch") + repeat_times = math.ceil(self.config["batch"] / len(batch["qid"])) + diff = self.config["batch"] - len(batch["qid"]) + + def pad(v): + if isinstance(v, np.ndarray) or torch.is_tensor(v): + _v = v.repeat((repeat_times,) + tuple([1 for x in range(len(v.shape) - 1)])) + else: + _v = v + [v[0]] * diff + + return _v[: self.config["batch"]] + + batch = {k: pad(v) for k, v in batch.items()} + return batch diff --git a/capreolus/trainer/tensorflow.py b/capreolus/trainer/tensorflow.py new file mode 100644 index 000000000..301e83c89 --- /dev/null +++ b/capreolus/trainer/tensorflow.py @@ -0,0 +1,352 @@ +import hashlib +import os +import time +import uuid +from collections import defaultdict + +import tensorflow as tf +import tensorflow_ranking as tfr +import numpy as np +from tqdm import tqdm + +from . import Trainer +from capreolus import ModuleBase, Dependency, ConfigOption, Searcher, constants, evaluator, get_logger +from capreolus.utils.common import plot_metrics, plot_loss + +logger = get_logger(__name__) # pylint: disable=invalid-name +RESULTS_BASE_PATH = constants["RESULTS_BASE_PATH"] + + +class TrecCheckpointCallback(tf.keras.callbacks.Callback): + """ + A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. + See TensorflowTrainer.train() for the invocation + Also saves the best model to disk + """ + + def __init__(self, qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs): + super(TrecCheckpointCallback, self).__init__(*args, **kwargs) + """ + qrels - a qrels dict + dev_data - a torch.utils.IterableDataset + dev_records - a BatchedDataset instance + """ + self.best_metric = -np.inf + self.qrels = qrels + self.dev_data = dev_data + self.dev_records = dev_records + self.output_path = output_path + self.iter_start_time = time.time() + self.metric = metric + self.validate_freq = validate_freq + self.relevance_level = relevance_level + + def save_model(self): + self.model.save_weights("{0}/dev.best".format(self.output_path)) + + def on_epoch_begin(self, epoch, logs=None): + self.iter_start_time = time.time() + + def on_epoch_end(self, epoch, logs=None): + logger.debug("Epoch {} took {}".format(epoch, time.time() - self.iter_start_time)) + if (epoch + 1) % self.validate_freq == 0: + predictions = self.model.predict(self.dev_records, verbose=1, workers=8, use_multiprocessing=True) + trec_preds = self.get_preds_in_trec_format(predictions, self.dev_data) + metrics = evaluator.eval_runs(trec_preds, dict(self.qrels), evaluator.DEFAULT_METRICS, self.relevance_level) + logger.info("dev metrics: %s", " ".join([f"{metric}={v:0.3f}" for metric, v in sorted(metrics.items())])) + + if metrics[self.metric] > self.best_metric: + self.best_metric = metrics[self.metric] + # TODO: Prevent the embedding layer weights from being saved + self.save_model() + + @staticmethod + def get_preds_in_trec_format(predictions, dev_data): + """ + Takes in a list of predictions and returns a dict that can be fed into pytrec_eval + As a side effect, also writes the predictions into a file in the trec format + """ + pred_dict = defaultdict(lambda: dict()) + + for i, (qid, docid) in enumerate(dev_data.get_qid_docid_pairs()): + # Pytrec_eval has problems with high precision floats + pred_dict[qid][docid] = predictions[i][0].astype(np.float16).item() + + return dict(pred_dict) + + +@Trainer.register +class TensorFlowTrainer(Trainer): + module_name = "tensorflow" + + config_spec = [ + ConfigOption("batch", 32, "batch size"), + ConfigOption("niters", 20, "number of iterations to train for"), + ConfigOption("itersize", 512, "number of training instances in one iteration"), + # ConfigOption("gradacc", 1, "number of batches to accumulate over before updating weights"), + ConfigOption("lr", 0.001, "learning rate"), + ConfigOption("loss", "pairwise_hinge_loss", "must be one of tfr.losses.RankingLossKey"), + # ConfigOption("fastforward", False), + ConfigOption("validatefreq", 1), + ConfigOption("boardname", "default"), + ConfigOption("usecache", False), + ConfigOption("tpuname", None), + ConfigOption("tpuzone", None), + ConfigOption("storage", None), + ] + config_keys_not_in_path = ["fastforward", "boardname", "usecache", "tpuname", "tpuzone", "storage"] + + def build(self): + tf.random.set_seed(self.config["seed"]) + + # Use TPU if available, otherwise resort to GPU/CPU + try: + self.tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=self.config["tpuname"], zone=self.config["tpuzone"]) + except ValueError: + self.tpu = None + logger.info("Could not find the tpu") + + # TPUStrategy for distributed training + if self.tpu: + logger.info("Utilizing TPUs") + tf.config.experimental_connect_to_cluster(self.tpu) + tf.tpu.experimental.initialize_tpu_system(self.tpu) + self.strategy = tf.distribute.experimental.TPUStrategy(self.tpu) + else: # default strategy that works on CPU and single GPU + self.strategy = tf.distribute.get_strategy() + + # Defining some props that we will later initialize + self.optimizer = None + self.loss = None + self.validate() + + def validate(self): + if self.tpu and any([self.config["storage"] is None, self.config["tpuname"] is None, self.config["tpuzone"] is None]): + raise ValueError("storage, tpuname and tpuzone configs must be provided when training on TPU") + if self.tpu and self.config["storage"] and not self.config["storage"].startswith("gs://"): + raise ValueError("For TPU utilization, the storage config should start with 'gs://'") + + def get_optimizer(self): + return tf.keras.optimizers.Adam(learning_rate=self.config["lr"]) + + def fastforward_training(self, reranker, weights_path, loss_fn): + # TODO: Fix fast forwarding + return 0 + + def load_best_model(self, reranker, train_output_path): + # TODO: Do the train_output_path modification at one place? + if self.tpu: + train_output_path = "{0}/{1}/{2}".format( + self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest() + ) + + reranker.model.load_weights("{0}/dev.best".format(train_output_path)) + + def apply_gradients(self, weights, grads): + self.optimizer.apply_gradients(zip(grads, weights)) + + def train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1): + # summary_writer = tf.summary.create_file_writer("{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"])) + + # Because TPUs can't work with local files + if self.tpu: + train_output_path = "{0}/{1}/{2}".format( + self.config["storage"], "train_output", hashlib.md5(str(train_output_path).encode("utf-8")).hexdigest() + ) + + os.makedirs(dev_output_path, exist_ok=True) + initial_iter = self.fastforward_training(reranker, dev_output_path, None) + logger.info("starting training from iteration %s/%s", initial_iter, self.config["niters"]) + + strategy_scope = self.strategy.scope() + with strategy_scope: + train_records = self.get_tf_train_records(reranker, train_dataset) + dev_records = self.get_tf_dev_records(reranker, dev_data) + trec_callback = TrecCheckpointCallback( + qrels, + dev_data, + dev_records, + train_output_path, + metric, + self.config["validatefreq"], + relevance_level=relevance_level, + ) + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir="{0}/capreolus_tensorboard/{1}".format(self.config["storage"], self.config["boardname"]) + ) + reranker.build_model() # TODO needed here? + + self.optimizer = self.get_optimizer() + loss = tfr.keras.losses.get(self.config["loss"]) + reranker.model.compile(optimizer=self.optimizer, loss=loss) + + train_start_time = time.time() + reranker.model.fit( + train_records.prefetch(tf.data.experimental.AUTOTUNE), + epochs=self.config["niters"], + steps_per_epoch=self.config["itersize"], + callbacks=[tensorboard_callback, trec_callback], + workers=8, + use_multiprocessing=True, + ) + logger.info("Training took {}".format(time.time() - train_start_time)) + + # Skipping dumping metrics and plotting loss since that should be done through tensorboard + + def create_tf_feature(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc): + """ + Creates a single tf.train.Feature instance (i.e, a single sample) + """ + feature = { + "qid": tf.train.Feature(bytes_list=tf.train.BytesList(value=[qid.encode("utf-8")])), + "query": tf.train.Feature(float_list=tf.train.FloatList(value=query)), + "query_idf": tf.train.Feature(float_list=tf.train.FloatList(value=query_idf)), + "posdoc_id": tf.train.Feature(bytes_list=tf.train.BytesList(value=[posdoc_id.encode("utf-8")])), + "posdoc": tf.train.Feature(float_list=tf.train.FloatList(value=posdoc)), + } + + if negdoc_id: + feature["negdoc_id"] = (tf.train.Feature(bytes_list=tf.train.BytesList(value=[negdoc_id.encode("utf-8")])),) + feature["negdoc"] = tf.train.Feature(float_list=tf.train.FloatList(value=negdoc)) + + return feature + + def write_tf_record_to_file(self, dir_name, tf_features): + """ + Actually write the tf record to file. The destination can also be a gcs bucket. + TODO: Use generators to optimize memory usage + """ + filename = "{0}/{1}.tfrecord".format(dir_name, str(uuid.uuid4())) + examples = [tf.train.Example(features=tf.train.Features(feature=feature)) for feature in tf_features] + + if not os.path.isdir(dir_name): + os.makedirs(dir_name, exist_ok=True) + + examples = [example.SerializeToString() for example in examples] + with tf.io.TFRecordWriter(filename) as writer: + for example in examples: + writer.write(example) + + logger.info("Wrote tf record file: {}".format(filename)) + + return str(filename) + + def convert_to_tf_dev_record(self, reranker, dataset): + """ + Similar to self.convert_to_tf_train_record(), but won't result in multiple files + """ + dir_name = self.get_tf_record_cache_path(dataset) + + tf_features = [reranker.extractor.create_tf_feature(sample) for sample in dataset] + + return [self.write_tf_record_to_file(dir_name, tf_features)] + + def convert_to_tf_train_record(self, reranker, dataset): + """ + Tensorflow works better if the input data is fed in as tfrecords + Takes in a dataset, iterates through it, and creates multiple tf records from it. + The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature() + """ + dir_name = self.get_tf_record_cache_path(dataset) + + total_samples = dataset.get_total_samples() + tf_features = [] + tf_record_filenames = [] + + for niter in tqdm(range(0, self.config["niters"]), desc="Converting data to tf records"): + for sample_idx, sample in enumerate(dataset): + tf_features.append(reranker.extractor.create_tf_feature(sample)) + + if len(tf_features) > 20000: + tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features)) + tf_features = [] + + if sample_idx + 1 >= self.config["itersize"] * self.config["batch"]: + break + + if len(tf_features): + tf_record_filenames.append(self.write_tf_record_to_file(dir_name, tf_features)) + + return tf_record_filenames + + def get_tf_record_cache_path(self, dataset): + """ + Get the path to the directory where tf records are written to. + If using TPUs, this will be a gcs path. + """ + if self.tpu: + return "{0}/capreolus_tfrecords/{1}".format(self.config["storage"], dataset.get_hash()) + else: + base_path = self.get_cache_path() + return "{0}/{1}".format(base_path, dataset.get_hash()) + + def cache_exists(self, dataset): + # TODO: Add checks to make sure that the number of files in the directory is correct + cache_dir = self.get_tf_record_cache_path(dataset) + logger.info("The cache path is {0} and does it exist? : {1}".format(cache_dir, tf.io.gfile.exists(cache_dir))) + + return tf.io.gfile.isdir(cache_dir) + + def load_tf_records_from_file(self, reranker, filenames, batch_size): + raw_dataset = tf.data.TFRecordDataset(filenames) + tf_records_dataset = raw_dataset.batch(batch_size, drop_remainder=True).map( + reranker.extractor.parse_tf_example, num_parallel_calls=tf.data.experimental.AUTOTUNE + ) + + return tf_records_dataset + + def load_cached_tf_records(self, reranker, dataset, batch_size): + logger.info("Loading TF records from cache") + cache_dir = self.get_tf_record_cache_path(dataset) + filenames = tf.io.gfile.listdir(cache_dir) + filenames = ["{0}/{1}".format(cache_dir, name) for name in filenames] + + return self.load_tf_records_from_file(reranker, filenames, batch_size) + + def get_tf_dev_records(self, reranker, dataset): + """ + 1. Returns tf records from cache (disk) if applicable + 2. Else, converts the dataset into tf records, writes them to disk, and returns them + """ + if self.config["usecache"] and self.cache_exists(dataset): + return self.load_cached_tf_records(reranker, dataset, 1) + else: + tf_record_filenames = self.convert_to_tf_dev_record(reranker, dataset) + # TODO use actual batch size here. see issue #52 + return self.load_tf_records_from_file(reranker, tf_record_filenames, 1) # self.config["batch"]) + + def get_tf_train_records(self, reranker, dataset): + """ + 1. Returns tf records from cache (disk) if applicable + 2. Else, converts the dataset into tf records, writes them to disk, and returns them + """ + + if self.config["usecache"] and self.cache_exists(dataset): + return self.load_cached_tf_records(reranker, dataset, self.config["batch"]) + else: + tf_record_filenames = self.convert_to_tf_train_record(reranker, dataset) + return self.load_tf_records_from_file(reranker, tf_record_filenames, self.config["batch"]) + + def predict(self, reranker, pred_data, pred_fn): + """Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn` + + Args: + model (Reranker): a PyTorch Reranker + pred_data (IterableDataset): data to predict on + pred_fn (Path): path to write the prediction run file to + + Returns: + TREC Run + + """ + + strategy_scope = self.strategy.scope() + with strategy_scope: + pred_records = self.get_tf_dev_records(reranker, pred_data) + predictions = reranker.model.predict(pred_records) + trec_preds = TrecCheckpointCallback.get_preds_in_trec_format(predictions, pred_data) + + os.makedirs(os.path.dirname(pred_fn), exist_ok=True) + Searcher.write_trec_run(trec_preds, pred_fn) + + return trec_preds diff --git a/capreolus/trainer/test_trainer.py b/capreolus/trainer/tests/test_trainer.py similarity index 97% rename from capreolus/trainer/test_trainer.py rename to capreolus/trainer/tests/test_trainer.py index cea4f3443..e22723de0 100644 --- a/capreolus/trainer/test_trainer.py +++ b/capreolus/trainer/tests/test_trainer.py @@ -3,7 +3,7 @@ import os import tensorflow as tf from capreolus.benchmark import DummyBenchmark -from capreolus.extractor import EmbedText +from capreolus.extractor.embedtext import EmbedText from capreolus.sampler import TrainDataset from capreolus.trainer import TensorFlowTrainer diff --git a/docs/cli.md b/docs/cli.md index d3df61d53..65262cc99 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -1,4 +1,4 @@ -# Command Line Interface +# Running Pipelines with the CLI Capreolus provides a command line interface for running experiments using pipelines that are described by `Task` modules. To create a new pipeline, you'll need to create a new `Task` before using the CLI. Capreolus takes a functional approach to describing an experiment. An experiment is simply a pipeline plus a set of configuration options specifying both classes to use for the pipeline's modules and configuration options associated with each module. @@ -22,6 +22,7 @@ module type=benchmark name=antique name=dummy name=robust04.yang19 + name=nf ... module type=reranker name=CDSSM @@ -37,30 +38,33 @@ module type=reranker .. note:: Results and cached objects are stored in ``~/.capreolus/results/`` and ``~/.capreolus/cache/`` by default. Set the ``CAPREOLUS_RESULTS`` and ``CAPREOLUS_CACHE`` environment variables to change these locations. For example: ``export CAPREOLUS_CACHE=/data/capreolus/cache`` ``` -- Use `RankTask` to search for the *robust04* topics in a robust04 index (which will be downloaded if it does not automatically exist), and then evaluate the results. The `Benchmark` specifies a dependency on `collection.name=robust04` and provides the corresponding topics and relevance judgments. +### RankTask +- Use `RankTask` with the [NFCorpus](https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/) `Benchmark`. This will download the collection, create an index, and search for the NFCorpus topics. The `Benchmark` specifies a dependency on `collection.name=nf` and provides the corresponding topics and relevance judgments. ``` $ capreolus rank.searcheval with searcher.name=BM25 \ - searcher.index.stemmer=porter benchmark.name=robust04.yang19 + searcher.index.stemmer=porter benchmark.name=nf ``` -- Use a similar pipeline, but with RM3 query expansion and a small grid search over expansion parameters. The evaluation command will report cross-validated results using the folds specified by `robust04.yang19`. +- Use a similar pipeline, but with RM3 query expansion and a small grid search over expansion parameters. The evaluation command will report cross-validated results using the folds specified by the `Benchmark`. ``` $ capreolus rank.searcheval with \ - searcher.index.stemmer=porter benchmark.name=robust04.yang19 \ + searcher.index.stemmer=porter benchmark.name=nf \ searcher.name=BM25RM3 searcher.fbDocs=5-10-15 searcher.fbTerms=5-25-50 ``` +### RerankTask - Use `RerankTask` to run the same `RankTask` pipeline optimized for recall@1000, and then train a `Reranker` optimized for P@20 on the first fold provided by the `Benchmark`. We limit training to two iterations (`niters`) of size `itersize` to keep the training process from taking too long. ``` $ capreolus rerank.traineval with \ - rank.searcher.index.stemmer=porter benchmark.name=robust04.yang19 \ + rank.searcher.index.stemmer=porter benchmark.name=nf \ rank.searcher.name=BM25RM3 rank.searcher.fbDocs=5-10-15 rank.searcher.fbTerms=5-25-50 \ rank.optimize=recall_1000 reranker.name=KNRM reranker.trainer.niters=2 optimize=P_20 ``` -- The `ReRerankTask` demonstrates pipeline flexibility by adding a second reranking step on top of the output from `RerankTask`. Run `capreolus rererank.traineval` to see the configuration options it expects. *(Hint: it consists of a `RankTask` name `rank` as before, followed by a `RerankTask` named `rerank1`, followed by another `RerankTask` named `rerank2`.)* +### ReRerankTask +- The `ReRerankTask` demonstrates pipeline flexibility by adding a second reranking step on top of the output from `RerankTask`. Run `capreolus rererank.print_config` to see the configuration options it expects. *(Hint: it consists of a `RankTask` name `rank` as before, followed by a `RerankTask` named `rerank1`, followed by another `RerankTask` named `rerank2`.)* diff --git a/docs/conf.py b/docs/conf.py index 97db7c46d..1e2707f0f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,7 +37,7 @@ def get_version(rel_path): # -- Project information ----------------------------------------------------- project = "Capreolus" -copyright = "2020" +copyright = "2020 Andrew Yates" author = "Andrew Yates" # The full version, including alpha/beta/rc tags @@ -189,6 +189,7 @@ def get_version(rel_path): napoleon_google_docstring = True +# autoapi_keep_files = True autoapi_type = "python" autoapi_dirs = ["../capreolus"] autoapi_ignore = ["*tests/*", "flycheck_*"] diff --git a/docs/index.rst b/docs/index.rst index 315780ebd..8e0b28a00 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,7 @@ Looking for the code? `Find Capreolus on GitHub. `_ diff --git a/docs/modules.md b/docs/modules.md new file mode 100644 index 000000000..3a03819b6 --- /dev/null +++ b/docs/modules.md @@ -0,0 +1,329 @@ +# Available Modules + +The `Benchmark`, `Reranker`, and `Searcher` module types are most often configured by the end user. +For a complete list of modules, run the command `capreolus modules` or see the API Reference. + +```eval_rst +.. important:: When using Capreolus' configuration system, modules are selected by specifying their ``module_name``. + For example, the ``NF`` benchmark can be selected with the ``benchmark.name=nf`` config string or the equivalent config dictionary ``{"benchmark": {"name": "nf"}}``. + + The corresponding class can be created as ``benchmark.nf.NF(config=..., provide=...)`` or created by name with ``Benchmark.create("nf", config=..., provide=...)``. +``` + +## Benchmarks + +### ANTIQUE +```eval_rst +.. autoapiclass:: capreolus.benchmark.antique.ANTIQUE + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + + +### CodeSearchNet +```eval_rst +.. autoapiclass:: capreolus.benchmark.codesearchnet.CodeSearchNetCorpus + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +```eval_rst +.. autoapiclass:: capreolus.benchmark.codesearchnet.CodeSearchNetChallenge + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### (TREC) COVID + +```eval_rst +.. autoapiclass:: capreolus.benchmark.covid.COVID + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + + +### Dummy + +```eval_rst +.. autoapiclass:: capreolus.benchmark.dummy.DummyBenchmark + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### NF Corpus + +```eval_rst +.. autoapiclass:: capreolus.benchmark.nf.NF + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### (TREC) Robust04 +```eval_rst +.. autoapiclass:: capreolus.benchmark.robust04.Robust04 + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +```eval_rst +.. autoapiclass:: capreolus.benchmark.robust04.Robust04Yang19 + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +## Searchers + +```eval_rst +.. note:: Some searchers (e.g., BM25) automatically perform a cross-validated grid search when their parameters are provided as lists. For example, ``searcher.b=0.4,0.6,0.8 searcher.k1=1.0,1.5``. +``` + +### BM25 +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.BM25 + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### BM25 with Axiomatic expansion +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.AxiomaticSemanticMatching + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### BM25 with RM3 expansion +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.BM25RM3 + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + + +### BM25 PRF +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.BM25PRF + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + + +### F2Exp +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.F2Exp + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### F2Log +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.F2Log + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### I(n)L2 +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.INL2 + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### QL with Dirichlet smoothing +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.DirichletQL + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### QL with J-M smoothing +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.QLJM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### SDM +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.SDM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### SPL +```eval_rst +.. autoapiclass:: capreolus.searcher.anserini.SPL + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +## Rerankers + +```eval_rst +.. note:: Rerankers are implemented in PyTorch or TensorFlow. Rerankers with TensorFlow implementations can run on both GPUs and TPUs. +``` + + +### CDSSM +```eval_rst +.. autoapiclass:: capreolus.reranker.CDSSM.CDSSM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### ConvKNRM +```eval_rst +.. autoapiclass:: capreolus.reranker.ConvKNRM.ConvKNRM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### DRMM +```eval_rst +.. autoapiclass:: capreolus.reranker.DRMM.DRMM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### DRMMTKS +```eval_rst +.. autoapiclass:: capreolus.reranker.DRMMTKS.DRMMTKS + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### DSSM +```eval_rst +.. autoapiclass:: capreolus.reranker.DSSM.DSSM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### DUET +```eval_rst +.. autoapiclass:: capreolus.reranker.DUET.DUET + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### DeepTileBars +```eval_rst +.. autoapiclass:: capreolus.reranker.DeepTileBar.DeepTileBar + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### HiNT +```eval_rst +.. autoapiclass:: capreolus.reranker.HINT.HINT + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### KNRM +```eval_rst +.. autoapiclass:: capreolus.reranker.KNRM.KNRM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### PACRR +```eval_rst +.. autoapiclass:: capreolus.reranker.PACRR.PACRR + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### POSITDRMM +```eval_rst +.. autoapiclass:: capreolus.reranker.POSITDRMM.POSITDRMM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### TK +```eval_rst +.. autoapiclass:: capreolus.reranker.TK.TK + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### TensorFlow KNRM +```eval_rst +.. autoapiclass:: capreolus.reranker.TFKNRM.TFKNRM + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + +### TensorFlow VanillaBERT +```eval_rst +.. autoapiclass:: capreolus.reranker.TFVanillaBert.TFVanillaBERT + :noindex: + + .. autoapiattribute:: module_name + :noindex: +``` + diff --git a/docs/quick.md b/docs/quick.md index 793f3d3ca..5a6fa1081 100644 --- a/docs/quick.md +++ b/docs/quick.md @@ -9,27 +9,44 @@ ## Command Line Interface -Use the `RankTask` pipeline to rank documents using a `Searcher` on an [Anserini](https://anserini.io) `Index` built on robust04. (The index will be automatically downloaded if `benchmark.collection.path` is invalid.) +Use the `RankTask` pipeline to rank documents using a `Searcher` on an [Anserini](https://anserini.io) `Index` built on [NFCorpus](https://www.cl.uni-heidelberg.de/statnlpgroup/nfcorpus/), which contains biomedical documents and queries. NFCorpus was published by Boteva et al. in ECIR 2016. This dataset is publicly available and will be automatically downloaded by Capreolus. + +```bash +$ capreolus rank.searcheval with benchmark.name=nf \ + searcher.name=BM25 searcher.index.stemmer=porter searcher.b=0.8 ``` -$ capreolus rank.searcheval with searcher.name=BM25 \ - searcher.index.stemmer=porter searcher.b=0.8 \ - benchmark.name=robust04.yang19 benchmark.collection.path=/path/to/trec45 + +The `searcheval` command instructs `RankTask` to query NFCorpus and evaluate the Searcher's performance on NFCorpus' test queries. The command will output results like this: +```bash +INFO - capreolus.task.rank.evaluate - rank: fold=s1 best run: ...searcher-BM25_b-0.8_fields-title_hits-1000_k1-0.9/task-rank_filter-False/searcher +INFO - capreolus.task.rank.evaluate - rank: cross-validated results when optimizing for 'map': +INFO - capreolus.task.rank.evaluate - map: 0.1361 +INFO - capreolus.task.rank.evaluate - ndcg_cut_10: 0.2906 +... +``` + +These results are comparable with the *all titles* results in the [NFCorpus paper](https://www.cl.uni-heidelberg.de/~riezler/publications/papers/ECIR2016.pdf), which reports a MAP of 0.1251 for BM25 (Table 2). The Benchmark's ``fields`` config option can be used to issue other types of queries as well (e.g., ``benchmark.fields=all_fields``). + +```eval_rst +.. important:: Capreolus Benchmarks define *folds* to use; each fold specifies training, dev (validation), and test queries. + Tasks respect these folds when calculating metrics. + NFCorpus defines a fixed test set, which corresponds to having a *single fold* in Capreolus. + When running a benchmark that uses multiple folds with cross-validation, like *robust04*, the results reported are averaged over the benchmark's test sets. ``` ## Python API Let's run the same pipeline using the Python API: ```python -from capreolus.task.rank import RankTask +from capreolus.task import RankTask task = RankTask({'searcher': {'name': 'BM25', 'index': {'stemmer': 'porter'}, 'b': '0.8'}, - 'benchmark': {'name': 'robust04.yang19', - 'collection': {'path': '/path/to/trec45'}}}) + 'benchmark': {'name': 'nf'}}) task.searcheval() ``` ```eval_rst -.. note:: The ``capreolus.parse_config_string`` convenience method can transform a config string like ``searcher.name=BM25 benchmark.name=robust04.yang`` into a config dict as shown above. +.. note:: The ``capreolus.parse_config_string`` convenience method can transform a config string like ``searcher.name=BM25 benchmark.name=nf`` into a config dict as shown above. ``` @@ -43,34 +60,38 @@ Capreolus pipelines are composed of self-contained modules corresponding to "IR RankTask declares dependencies on a Searcher module and a Benchmark module, which it uses to query a document collection and to obtain experimental data (i.e., topics, relevance judgments, and folds), respectively. The Searcher depends on an Index. Both the Index and Benchmark depend on a Collection. In this example, RankTask requires that the same Collection be provided to both.

+```python +from capreolus import Benchmark, Collection, Index, Searcher +``` + Let's construct this graph one module at a time. ```python -# Previously, the Benchmark specified a dependency on the 'robust04' collection specifically. -# Now we specify "robust04" ourselves. ->>> collection = Collection.create("robust04", config={'path': '/path/to/trec45'}) +# Previously, the Benchmark specified a dependency on the 'nf' collection specifically. +# Now we create this Collection directly. +>>> collection = Collection.create("nf") >>> collection.get_path_and_types() - ("/path/to/trec45", "TrecCollection", "DefaultLuceneDocumentGenerator") + ("/path/to/collection-nf/documents", "TrecCollection", "DefaultLuceneDocumentGenerator") # Next, create a Benchmark and pass it the collection object directly. # This is an alternative to automatically creating the collection as a dependency. ->>> benchmark = Benchmark.create("robust04.yang19", provide={'collection': collection}) ->>> benchmark.topics - {'title': {'301': 'International Organized Crime', '302': 'Poliomyelitis and Post-Polio', ... } +>>> benchmark = Benchmark.create("nf", provide={'collection': collection}) +>>> benchmark.topics["title"] + {'56': 'foods for glaucoma', '68': 'what is actually in chicken nuggets', ... } ``` Next, we can build `Index` and `Searcher`. These module types do more than just pointing to data. ```python >>> index = Index.create("anserini", {"stemmer": "porter"}, provide={"collection": collection}) >>> index.create_index() # returns immediately if the index already exists ->>> index.get_df("organized") +>>> index.get_df("foods") 0 ->>> index.get_df("organiz") -3048 +>>> index.get_df("food") +1011 # Next, a Searcher to query the index >>> searcher = Searcher.create("BM25", {"hits": 3}, provide={"index": index}) ->>> searcher.query("organized") -OrderedDict([('FBIS4-2046', 4.867800235748291), - ('FBIS3-2553', 4.822000026702881), - ('FBIS3-23578', 4.754199981689453)]) +>>> searcher.query("foods") +OrderedDict([('MED-1761', 1.213), + ('MED-2742', 1.212), + ('MED-1046', 1.2058)]) ``` Finally, we can emulate the `RankTask.search()` method we called earlier: @@ -82,6 +103,7 @@ Finally, we can emulate the `RankTask.search()` method we called earlier: To get metrics, we could then pass `results` to `capreolus.evaluator.eval_runs()`: ```eval_rst .. autoapifunction:: capreolus.evaluator.eval_runs + :noindex: ``` @@ -90,9 +112,9 @@ To get metrics, we could then pass `results` to `capreolus.evaluator.eval_runs() Capreolus modules implement the Capreolus module API plus an API specific to the module type. The module API consists of four attributes: - `module_type`: a string indicating the module's type, like "index" or "benchmark" -- `module_name`: a string indicating the module's name, like "anserini" or "robust04.yang19" -- `config_spec`: a list of `ConfigOption` objects, for example, `ConfigOption("stemmer", default_value="none", description="stemmer to use")` -- `dependencies` a list of `Dependency` objects; for example, `Dependency(key="collection", module="collection", name="robust04")` +- `module_name`: a string indicating the module's name, like "anserini" or "nf" +- `config_spec`: a list of `ConfigOption` objects. For example, `[ConfigOption("stemmer", default_value="none", description="stemmer to use")]` +- `dependencies` a list of `Dependency` objects. For example, `[Dependency(key="collection", module="collection", name="nf")]` When the module is created, any dependencies that are not explicitly passed with `provide={key: object}` are automatically created. The module's config options in `config_spec` and those of its dependencies are exposed as Capreolus configuration options. @@ -104,10 +126,8 @@ The `Task` module API specifies two additional class attributes: `commands` and Let's create a new task that mirrors the graph we constructed manually, except with two separate `Searcher` objects. We'll save the results from both searchers and measure their effectiveness on the validation queries to decide which searcher to report test set results on. ```python -from capreolus import evaluator, Dependency, ConfigOption -from capreolus.searcher import Searcher +from capreolus import evaluator, get_logger, Dependency, ConfigOption from capreolus.task import Task -from capreolus.utils.loginit import get_logger logger = get_logger(__name__) # pylint: disable=invalid-name @@ -118,7 +138,7 @@ class TutorialTask(Task): config_spec = [ConfigOption("optimize", "map", "metric to maximize on the validation set")] dependencies = [ Dependency( - key="benchmark", module="benchmark", name="robust04.yang19", provide_this=True, provide_children=["collection"] + key="benchmark", module="benchmark", name="nf", provide_this=True, provide_children=["collection"] ), Dependency(key="searcher1", module="searcher", name="BM25RM3"), Dependency(key="searcher2", module="searcher", name="SDM"), @@ -153,4 +173,34 @@ class TutorialTask(Task): return best_results ``` - + +```eval_rst +.. note:: The module needs to be registered in order for Capreolus to find it. Registration happens when the ``@Task.register`` decorator is applied, so no additional steps are needed to use the new Task via the Python API. When using the Task via the CLI, the ``tutorial.py`` file containing it needs to be imported in order for the Task to be registered. This can be accomplished by placing the file inside the ``capreolus.tasks`` package (see ``capreolus.task.__path__``). However, in this case, the above Task is already provided with Capreolus as ``tasks/tutorial.py``. +``` + +Let's try running the Task we just declared via the Python API. + +```python +>>> task = TutorialTask() +>>> results = task.run() +>>> results['score']['map'] +0.14798308699242727 +``` + +### Module APIs +Each module type's base class describes the module API that should be implemented to create new modules of that type. +Check out the API documentation to learn more: +Benchmark, +Collection, +Extractor, +Index, +Reranker, +Searcher, +Task, +Tokenizer, and +Trainer. + + +## Next Steps +- Learn more about [running pipelines using the command line interface](cli.md) +- View what [Capreolus modules](modules.md) are available