from profane import import_all_modules
# import_all_modules(__file__, __package__)
import os
import math
import subprocess
from collections import defaultdict, OrderedDict
import numpy as np
from profane import ModuleBase, Dependency, ConfigOption, constants
from pyserini.search import pysearch
from capreolus.utils.common import Anserini
from capreolus.utils.loginit import get_logger
from capreolus.utils.trec import topic_to_trectxt
[docs]logger = get_logger(__name__) # pylint: disable=invalid-name
[docs]MAX_THREADS = constants["MAX_THREADS"]
[docs]def list2str(l, delimiter="-"):
return delimiter.join(str(x) for x in l)
[docs]class Searcher(ModuleBase):
[docs] module_type = "searcher"
@staticmethod
[docs] def load_trec_run(fn):
run = defaultdict(dict)
with open(fn, "rt") as f:
for line in f:
line = line.strip()
if len(line) > 0:
qid, _, docid, rank, score, desc = line.split(" ")
run[qid][docid] = float(score)
return run
@staticmethod
[docs] def write_trec_run(preds, outfn):
count = 0
with open(outfn, "wt") as outf:
qids = sorted(preds.keys(), key=lambda k: int(k))
for qid in qids:
rank = 1
for docid, score in sorted(preds[qid].items(), key=lambda x: x[1], reverse=True):
print(f"{qid} Q0 {docid} {rank} {score} capreolus", file=outf)
rank += 1
count += 1
def _query_from_file(self, topicsfn, output_path, cfg):
raise NotImplementedError()
[docs] def query_from_file(self, topicsfn, output_path):
return self._query_from_file(topicsfn, output_path, self.config)
[docs] def query(self, query, **kwargs):
"""
search document based on given query, using parameters in config as default
"""
config = {k: kwargs.get(k, self.config[k]) for k in self.config}
cache_dir = self.get_cache_path()
cache_dir.mkdir(exist_ok=True)
topic_fn, runfile_dir = cache_dir / "topic.txt", cache_dir / "runfiles"
fake_qid = "1"
with open(topic_fn, "w", encoding="utf-8") as f:
f.write(topic_to_trectxt(fake_qid, query))
self._query_from_file(topic_fn, runfile_dir, config)
runfile_fns = [f for f in os.listdir(runfile_dir) if f != "done"]
config2runs = {}
for runfile in runfile_fns:
runfile_fn = runfile_dir / runfile
runs = self.load_trec_run(runfile_fn)
config2runs[runfile.replace("searcher_", "")] = OrderedDict(runs[fake_qid])
os.remove(runfile_fn) # remove it in case the file accumulate
os.remove(runfile_dir / "done")
return config2runs["searcher"] if len(config2runs) == 1 else config2runs
[docs]class AnseriniSearcherMixIn:
""" MixIn for searchers that use Anserini's SearchCollection script """
def _anserini_query_from_file(self, topicsfn, anserini_param_str, output_base_path, topicfield):
if not os.path.exists(topicsfn):
raise IOError(f"could not find topics file: {topicsfn}")
# for covid:
field2querytype = {"query": "title", "question": "description", "narrative": "narrative"}
for k, v in field2querytype.items():
topicfield = topicfield.replace(k, v)
donefn = os.path.join(output_base_path, "done")
if os.path.exists(donefn):
logger.debug(f"skipping Anserini SearchCollection call because path already exists: {donefn}")
return
# create index if it does not exist. the call returns immediately if the index does exist.
self.index.create_index()
os.makedirs(output_base_path, exist_ok=True)
output_path = os.path.join(output_base_path, "searcher")
# add stemmer and stop options to match underlying index
indexopts = "-stemmer "
indexopts += "none" if self.index.config["stemmer"] is None else self.index.config["stemmer"]
if self.index.config["indexstops"]:
indexopts += " -keepstopwords"
index_path = self.index.get_index_path()
anserini_fat_jar = Anserini.get_fat_jar()
cmd = (
f"java -classpath {anserini_fat_jar} "
f"-Xms512M -Xmx31G -Dapp.name=SearchCollection io.anserini.search.SearchCollection "
f"-topicreader Trec -index {index_path} {indexopts} -topics {topicsfn} -output {output_path} "
f"-topicfield {topicfield} -inmem -threads {MAX_THREADS} {anserini_param_str}"
)
logger.info("Anserini writing runs to %s", output_path)
logger.debug(cmd)
app = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, universal_newlines=True)
# Anserini output is verbose, so ignore DEBUG log lines and send other output through our logger
for line in app.stdout:
Anserini.filter_and_log_anserini_output(line, logger)
app.wait()
if app.returncode != 0:
raise RuntimeError("command failed")
with open(donefn, "wt") as donef:
print("done", file=donef)
[docs]class PostprocessMixin:
def _keep_topn(self, runs, topn):
queries = sorted(list(runs.keys()), key=lambda k: int(k))
for q in queries:
docs = runs[q]
if len(docs) <= topn:
continue
docs = sorted(docs.items(), key=lambda kv: kv[1], reverse=True)[:topn]
runs[q] = {k: v for k, v in docs}
return runs
[docs] def filter(self, run_dir, docs_to_remove=None, docs_to_keep=None, topn=None):
if (not docs_to_keep) and (not docs_to_remove):
raise
for fn in os.listdir(run_dir):
if fn == "done":
continue
run_fn = os.path.join(run_dir, fn)
self._filter(run_fn, docs_to_remove, docs_to_keep, topn)
return run_dir
def _filter(self, runfile, docs_to_remove, docs_to_keep, topn):
runs = Searcher.load_trec_run(runfile)
# filtering
if docs_to_remove: # prioritize docs_to_remove
if isinstance(docs_to_remove, list):
docs_to_remove = {q: docs_to_remove for q in runs}
runs = {q: {d: v for d, v in docs.items() if d not in docs_to_remove.get(q, [])} for q, docs in runs.items()}
elif docs_to_keep:
if isinstance(docs_to_keep, list):
docs_to_keep = {q: docs_to_keep for q in runs}
runs = {q: {d: v for d, v in docs.items() if d in docs_to_keep[q]} for q, docs in runs.items()}
if topn:
runs = self._keep_topn(runs, topn)
Searcher.write_trec_run(runs, runfile) # overwrite runfile
[docs] def dedup(self, run_dir, topn=None):
for fn in os.listdir(run_dir):
if fn == "done":
continue
run_fn = os.path.join(run_dir, fn)
self._dedup(run_fn, topn)
return run_dir
def _dedup(self, runfile, topn):
runs = Searcher.load_trec_run(runfile)
new_runs = {q: {} for q in runs}
# use the sum of each passage score as the document score, no sorting is done here
for q, psg in runs.items():
for pid, score in psg.items():
docid = pid.split(".")[0]
new_runs[q][docid] = max(new_runs[q].get(docid, -math.inf), score)
runs = new_runs
if topn:
runs = self._keep_topn(runs, topn)
Searcher.write_trec_run(runs, runfile)
[docs]@Searcher.register
class BM25(Searcher, AnseriniSearcherMixIn):
""" BM25 with fixed k1 and b. """
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
"""
Runs BM25 search. Takes a query from the topic files, and fires it against the index
Args:
topicsfn: Path to a topics file
output_path: Path where the results of the search (i.e the run file) should be stored
Returns: Path to the run file where the results of the search are stored
"""
bstr, k1str = list2str(config["b"], delimiter=" "), list2str(config["k1"], delimiter=" ")
hits = config["hits"]
anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}"
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class BM25Grid(Searcher, AnseriniSearcherMixIn):
""" BM25 with a grid search for k1 and b. Search is from 0.1 to bmax/k1max in 0.1 increments """
[docs] module_name = "BM25Grid"
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("k1max", 1.0, "maximum k1 value to include in grid search (starting at 0.1)"),
ConfigOption("bmax", 1.0, "maximum b value to include in grid search (starting at 0.1)"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
bs = np.around(np.arange(0.1, config["bmax"] + 0.1, 0.1), 1)
k1s = np.around(np.arange(0.1, config["k1max"] + 0.1, 0.1), 1)
bstr = " ".join(str(x) for x in bs)
k1str = " ".join(str(x) for x in k1s)
hits = config["hits"]
anserini_param_str = f"-bm25 -bm25.b {bstr} -bm25.k1 {k1str} -hits {hits}"
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class BM25RM3(Searcher, AnseriniSearcherMixIn):
[docs] module_name = "BM25RM3"
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"),
ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"),
ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"),
ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"),
ConfigOption("originalQueryWeight", [0.5], "the weight of unexpended query", value_type="floatlist"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
hits = str(config["hits"])
anserini_param_str = (
"-rm3 "
+ " ".join(f"-rm3.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "originalQueryWeight"])
+ " -bm25 "
+ " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"])
+ f" -hits {hits}"
)
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class BM25PostProcess(BM25, PostprocessMixin):
[docs] module_name = "BM25Postprocess"
[docs] config_spec = [
ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("topn", 1000),
ConfigOption("fields", "title"),
ConfigOption("dedep", False),
]
[docs] def query_from_file(self, topicsfn, output_path, docs_to_remove=None):
output_path = super().query_from_file(topicsfn, output_path) # will call _query_from_file() from BM25
if docs_to_remove:
output_path = self.filter(output_path, docs_to_remove=docs_to_remove, topn=self.config["topn"])
if self.config["dedup"]:
output_path = self.dedup(output_path, topn=self.config["topn"])
return output_path
[docs]@Searcher.register
class StaticBM25RM3Rob04Yang19(Searcher):
""" Tuned BM25+RM3 run used by Yang et al. in [1]. This should be used only with a benchmark using the same folds and queries.
[1] Wei Yang, Kuang Lu, Peilin Yang, and Jimmy Lin. Critically Examining the "Neural Hype": Weak Baselines and the Additivity of Effectiveness Gains from Neural Ranking Models. SIGIR 2019.
"""
[docs] module_name = "bm25staticrob04yang19"
def _query_from_file(self, topicsfn, output_path, config):
import shutil
outfn = os.path.join(output_path, "static.run")
os.makedirs(output_path, exist_ok=True)
shutil.copy2(constants["PACKAGE_PATH"] / "data" / "rob04_yang19_rm3.run", outfn)
return output_path
[docs] def query(self, *args, **kwargs):
raise NotImplementedError("this searcher uses a static run file, so it cannot handle new queries")
[docs]@Searcher.register
class BM25PRF(Searcher, AnseriniSearcherMixIn):
"""
BM25 with PRF
"""
[docs] module_name = "BM25PRF"
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("k1", [0.65, 0.70, 0.75], "controls term saturation", value_type="floatlist"),
ConfigOption("b", [0.60, 0.7], "controls document length normalization", value_type="floatlist"),
ConfigOption("fbTerms", [65, 70, 95, 100], "number of generated terms from feedback", value_type="intlist"),
ConfigOption("fbDocs", [5, 10, 15], "number of documents used for feedback", value_type="intlist"),
ConfigOption("newTermWeight", [0.2, 0.25], value_type="floatlist"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
hits = str(config["hits"])
anserini_param_str = (
"-bm25prf "
+ " ".join(f"-bm25prf.{k} {list2str(config[k], ' ')}" for k in ["fbTerms", "fbDocs", "newTermWeight", "k1", "b"])
+ " -bm25 "
+ " ".join(f"-bm25.{k} {list2str(config[k], ' ')}" for k in ["k1", "b"])
+ f" -hits {hits}"
)
print(output_path)
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class AxiomaticSemanticMatching(Searcher, AnseriniSearcherMixIn):
"""
TODO: Add more info on retrieval method
Also, BM25 is hard-coded to be the scoring model
"""
[docs] module_name = "axiomatic"
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
ConfigOption("r", 20, value_type="intlist"),
ConfigOption("n", 30, value_type="intlist"),
ConfigOption("beta", 0.4, value_type="floatlist"),
ConfigOption("top", 20, value_type="intlist"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
hits = str(config["hits"])
conditionals = ""
anserini_param_str = "-axiom -axiom.deterministic -axiom.r {0} -axiom.n {1} -axiom.beta {2} -axiom.top {3}".format(
*[list2str(config[k], " ") for k in ["r", "n", "beta", "top"]]
)
anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1} ".format(*[list2str(config[k], " ") for k in ["k1", "b"]])
anserini_param_str += f" -hits {hits}"
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class DirichletQL(Searcher, AnseriniSearcherMixIn):
""" Dirichlet QL with a fixed mu """
[docs] module_name = "DirichletQL"
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("mu", 1000, "smoothing parameter", value_type="intlist"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
"""
Runs Dirichlet QL search. Takes a query from the topic files, and fires it against the index
Args:
topicsfn: Path to a topics file
output_path: Path where the results of the search (i.e the run file) should be stored
Returns: Path to the run file where the results of the search are stored
"""
mustr = list2str(config["mu"], delimiter=" ")
hits = config["hits"]
anserini_param_str = f"-qld -qld.mu {mustr} -hits {hits}"
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class QLJM(Searcher, AnseriniSearcherMixIn):
"""
QL with Jelinek-Mercer smoothing
"""
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("lam", 0.1, value_type="floatlist"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
anserini_param_str = "-qljm -qljm.lambda {0} -hits {1}".format(list2str(config["lam"], delimiter=" "), config["hits"])
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class INL2(Searcher, AnseriniSearcherMixIn):
"""
I(n)L2 scoring model
"""
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
anserini_param_str = "-inl2 -inl2.c {0} -hits {1}".format(config["c"], config["hits"])
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class SPL(Searcher, AnseriniSearcherMixIn):
"""
SPL scoring model
"""
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("c", 0.1), # array input of this parameter is not support by anserini.SearchCollection
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
anserini_param_str = "-spl -spl.c {0} -hits {1}".format(config["c"], config["hits"])
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class F2Exp(Searcher, AnseriniSearcherMixIn):
"""
F2Exp scoring model
"""
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
anserini_param_str = "-f2exp -f2exp.s {0} -hits {1}".format(config["s"], config["hits"])
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class F2Log(Searcher, AnseriniSearcherMixIn):
"""
F2Log scoring model
"""
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
[docs] config_spec = [
ConfigOption("s", 0.5), # array input of this parameter is not support by anserini.SearchCollection
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
anserini_param_str = "-f2log -f2log.s {0} -hits {1}".format(config["s"], config["hits"])
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path
[docs]@Searcher.register
class SDM(Searcher, AnseriniSearcherMixIn):
"""
Sequential Dependency Model
The scoring model is hardcoded to be BM25 (TODO: Make it configurable?)
"""
[docs] dependencies = [Dependency(key="index", module="index", name="anserini")]
# array input of (tw, ow, uw) is not support by anserini.SearchCollection
[docs] config_spec = [
ConfigOption("k1", 0.9, "controls term saturation", value_type="floatlist"),
ConfigOption("b", 0.4, "controls document length normalization", value_type="floatlist"),
ConfigOption("tw", 0.85, "term weight"),
ConfigOption("ow", 0.15, "ordered window weight"),
ConfigOption("uw", 0.05, "unordered window weight"),
ConfigOption("hits", 1000, "number of results to return"),
ConfigOption("fields", "title"),
]
def _query_from_file(self, topicsfn, output_path, config):
hits = config["hits"]
anserini_param_str = "-sdm -sdm.tw {0} -sdm.ow {1} -sdm.uw {2}".format(*[config[k] for k in ["tw", "ow", "uw"]])
anserini_param_str += " -bm25 -bm25.k1 {0} -bm25.b {1}".format(*[list2str(config[k], " ") for k in ["k1", "b"]])
anserini_param_str += f" -hits {hits}"
self._anserini_query_from_file(topicsfn, anserini_param_str, output_path, config["fields"])
return output_path