import os
from collections import defaultdict, OrderedDict
from capreolus import ModuleBase, constants
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import topic_to_trectxt
from capreolus.utils.common import OrderedDefaultDict
[docs]logger = get_logger(__name__) # pylint: disable=invalid-name
[docs]MAX_THREADS = constants["MAX_THREADS"]
[docs]def list2str(l, delimiter="-"):
return delimiter.join(str(x) for x in l)
[docs]class Searcher(ModuleBase):
"""Base class for Searcher modules. The purpose of a Searcher is to query a collection via an :class:`~capreolus.index.Index` module.
Similar to Rerankers, Searchers return a list of documents and their relevance scores for a given query.
Searchers are unsupervised and efficient, whereas Rerankers are supervised and do not use an inverted index directly.
Modules should provide:
- a ``query(string)`` and a ``query_from_file(path)`` method that return document scores
"""
[docs] module_type = "searcher"
@staticmethod
[docs] def load_trec_run(fn):
# Docids in the run file appear according to decreasing score, hence it makes sense to preserve this order
run = OrderedDefaultDict()
with open(fn, "rt") as f:
for line in f:
line = line.strip()
if len(line) > 0:
qid, _, docid, rank, score, desc = line.split(" ")
run[qid][docid] = float(score)
return run
@staticmethod
[docs] def write_trec_run(preds, outfn):
count = 0
with open(outfn, "wt") as outf:
qids = sorted(preds.keys(), key=lambda k: int(k))
for qid in qids:
rank = 1
for docid, score in sorted(preds[qid].items(), key=lambda x: x[1], reverse=True):
print(f"{qid} Q0 {docid} {rank} {score} capreolus", file=outf)
rank += 1
count += 1
def _query_from_file(self, topicsfn, output_path, cfg):
raise NotImplementedError()
[docs] def query_from_file(self, topicsfn, output_path):
return self._query_from_file(topicsfn, output_path, self.config)
[docs] def query(self, query, **kwargs):
"""
search document based on given query, using parameters in config as default
"""
config = {k: kwargs.get(k, self.config[k]) for k in self.config}
cache_dir = self.get_cache_path()
cache_dir.mkdir(exist_ok=True)
topic_fn, runfile_dir = cache_dir / "topic.txt", cache_dir / "runfiles"
fake_qid = "1"
with open(topic_fn, "w", encoding="utf-8") as f:
f.write(topic_to_trectxt(fake_qid, query))
self._query_from_file(topic_fn, runfile_dir, config)
runfile_fns = [f for f in os.listdir(runfile_dir) if f != "done"]
config2runs = {}
for runfile in runfile_fns:
runfile_fn = runfile_dir / runfile
runs = self.load_trec_run(runfile_fn)
config2runs[runfile.replace("searcher_", "")] = OrderedDict(runs[fake_qid])
os.remove(runfile_fn) # remove it in case the file accumulate
os.remove(runfile_dir / "done")
return config2runs["searcher"] if len(config2runs) == 1 else config2runs
from profane import import_all_modules
from .anserini import BM25, BM25RM3, SDM
import_all_modules(__file__, __package__)