import os
import gdown
from pathlib import Path
from collections import defaultdict
from capreolus import ConfigOption, Dependency
from capreolus.utils.loginit import get_logger
from . import Searcher
from .anserini import BM25
[docs]logger = get_logger(__name__)
[docs]SUPPORTED_TRIPLE_FILE = ["small", "large.v1", "large.v2"]
[docs]class MsmarcoPsgSearcherMixin:
@staticmethod
[docs] def convert_to_trec_runs(msmarco_top1k_fn, style="eval"):
logger.info(f"Converting file {msmarco_top1k_fn} (with style {style}) into trec format")
runs = defaultdict(dict)
with open(msmarco_top1k_fn, "r", encoding="utf-8") as f:
for line in f:
if style == "triple":
qid, pos_pid, neg_pid = line.strip().split("\t")
runs[qid][pos_pid] = len(runs.get(qid, {}))
runs[qid][neg_pid] = len(runs.get(qid, {}))
elif style == "eval":
qid, pid, _, _ = line.strip().split("\t")
runs[qid][pid] = len(runs.get(qid, []))
else:
raise ValueError(f"Unexpected style {style}, should be either 'triple' or 'eval'")
return runs
@staticmethod
[docs] def get_fn_from_url(url):
return url.split("/")[-1].replace(".gz", "").replace(".tar", "")
[docs] def get_url(self):
tripleversion = self.config["tripleversion"]
if tripleversion == "large.v1":
return "https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.tsv.gz"
if tripleversion == "large.v2":
return "https://msmarco.blob.core.windows.net/msmarcoranking/qidpidtriples.train.full.2.tsv.gz"
if tripleversion == "small":
return "https://drive.google.com/uc?id=1LCQ-85fx61_5gQgljyok8olf6GadZUeP"
raise ValueError("Unknown version for triplet large" % self.config["tripleversion"])
[docs] def download_and_prepare_train_set(self, tmp_dir):
tmp_dir.mkdir(exist_ok=True, parents=True)
triple_version = self.config["tripleversion"]
url = self.get_url()
if triple_version.startswith("large"):
extract_file_name = self.get_fn_from_url(url)
extract_dir = self.benchmark.collection.download_and_extract(url, tmp_dir, expected_fns=extract_file_name)
triple_fn = extract_dir / extract_file_name
elif triple_version == "small":
triple_fn = tmp_dir / "triples.train.small.idversion.tsv"
if not os.path.exists(triple_fn):
gdown.download(url, triple_fn.as_posix(), quiet=False)
else:
raise ValueError(f"Unknown version for triplet: {triple_version}")
return self.convert_to_trec_runs(triple_fn, style="triple")
@Searcher.register
[docs]class MsmarcoPsg(Searcher, MsmarcoPsgSearcherMixin):
"""
Skip the searching on training set by converting the official training triplet into a "fake" runfile.
Use the offical runfile for the development and the test set.
"""
[docs] module_name = "msmarcopsg"
[docs] dependencies = [Dependency(key="benchmark", module="benchmark", name="msmarcopsg")]
[docs] config_spec = [
ConfigOption("tripleversion", "small", "version of triplet.qid file, small, large.v1 or large.v2"),
]
def _query_from_file(self, topicsfn, output_path, cfg):
"""only query results in dev and test set are saved"""
final_runfn = Path(output_path) / "searcher"
final_donefn = Path(output_path) / "done"
if os.path.exists(final_donefn):
return output_path
tmp_dir = self.get_cache_path() / "tmp"
tmp_dir.mkdir(exist_ok=True, parents=True)
output_path.mkdir(exist_ok=True, parents=True)
# train
train_run = self.download_and_prepare_train_set(tmp_dir=tmp_dir)
self.write_trec_run(preds=train_run, outfn=final_runfn, mode="wt")
# dev and test
dev_test_urls = [
"https://msmarco.blob.core.windows.net/msmarcoranking/top1000.dev.tar.gz",
"https://msmarco.blob.core.windows.net/msmarcoranking/top1000.eval.tar.gz",
]
runs = {}
for url in dev_test_urls:
extract_file_name = self.get_fn_from_url(url)
extract_dir = self.benchmark.collection.download_and_extract(url, tmp_dir, expected_fns=extract_file_name)
runs.update(self.convert_to_trec_runs(extract_dir / extract_file_name, style="eval"))
self.write_trec_run(preds=runs, outfn=final_runfn, mode="a")
with open(final_donefn, "wt") as f:
print("done", file=f)
return output_path
@Searcher.register
[docs]class MsmarcoPsgBm25(BM25, MsmarcoPsgSearcherMixin):
"""
Skip the searching on training set by converting the official training triplet into a "fake" runfile.
Conduct configurable BM25 search on the development and the test set.
"""
[docs] module_name = "msmarcopsgbm25"
[docs] dependencies = [
Dependency(key="benchmark", module="benchmark", name="msmarcopsg"),
Dependency(key="index", module="index", name="anserini"),
]
[docs] config_spec = BM25.config_spec + [
ConfigOption("tripleversion", "small", "version of triplet.qid file, small, large.v1 or large.v2"),
]
def _query_from_file(self, topicsfn, output_path, config):
final_runfn = os.path.join(output_path, "searcher")
final_donefn = os.path.join(output_path, "done")
if os.path.exists(final_donefn):
return output_path
output_path.mkdir(exist_ok=True, parents=True)
tmp_dir = self.get_cache_path() / "tmp"
tmp_topicsfn = tmp_dir / os.path.basename(topicsfn)
tmp_output_dir = tmp_dir / "BM25_results"
tmp_output_dir.mkdir(exist_ok=True, parents=True)
train_runs = self.download_and_prepare_train_set(tmp_dir=tmp_dir)
if not os.path.exists(tmp_topicsfn):
with open(tmp_topicsfn, "wt") as fout:
with open(topicsfn) as f:
for line in f:
qid, title = line.strip().split("\t")
if qid not in self.benchmark.folds["s1"]["train_qids"]:
fout.write(line)
super()._query_from_file(topicsfn=tmp_topicsfn, output_path=tmp_output_dir, config=config)
dev_test_runfile = tmp_output_dir / "searcher"
assert os.path.exists(dev_test_runfile)
# write train and dev, test runs into final searcher file
Searcher.write_trec_run(train_runs, final_runfn)
with open(dev_test_runfile) as fin, open(final_runfn, "a") as fout:
for line in fin:
fout.write(line)
with open(final_donefn, "w") as f:
f.write("done")
return output_path
# todo: make this another type of "Module" (e.g. DPR Module)
@Searcher.register
[docs]class StaticTctColBertDev(Searcher, MsmarcoPsgSearcherMixin):
"""
Skip the searching on training set by converting the official training triplet into a "fake" runfile.
Use the runfile pre-prepared using TCT-ColBERT (https://cs.uwaterloo.ca/~jimmylin/publications/Lin_etal_2021_RepL4NLP.pdf)
"""
[docs] module_name = "static_tct_colbert"
[docs] dependencies = [Dependency(key="benchmark", module="benchmark", name="msmarcopsg")]
[docs] config_spec = [
ConfigOption("tripleversion", "small", "version of triplet.qid file, small, large.v1 or large.v2"),
]
def _query_from_file(self, topicsfn, output_path, cfg):
outfn = Path(output_path) / "static.run"
done_fn = Path(output_path) / "done"
if done_fn.exists():
return outfn
tmp_dir = self.get_cache_path() / "tmp"
output_path.mkdir(exist_ok=True, parents=True)
# train
train_runs = self.download_and_prepare_train_set(tmp_dir=tmp_dir)
self.write_trec_run(preds=train_runs, outfn=outfn, mode="wt")
logger.info(f"prepared runs from train set")
# dev
tmp_dev = tmp_dir / "tct_colbert_v1_wo_neg.tsv"
if not tmp_dev.exists():
tmp_dir.mkdir(exist_ok=True, parents=True)
url = "http://drive.google.com/uc?id=1jOVL3DIya6qDiwM_Dnqc81FT5ZB43csP"
gdown.download(url, tmp_dev.as_posix(), quiet=False)
assert tmp_dev.exists()
with open(tmp_dev, "rt") as f, open(outfn, "at") as fout:
for line in f:
qid, docid, rank, score = line.strip().split("\t")
fout.write(f"{qid} Q0 {docid} {rank} {score} tct_colbert\n")
with open(done_fn, "wt") as f:
print("done", file=f)
return outfn
@Searcher.register
[docs]class MsmarcoPsgTop200(Searcher, MsmarcoPsgSearcherMixin):
"""
Skip the searching on training set by converting the official training triplet into a "fake" runfile.
Use the runfile pre-prepared using TCT-ColBERT (https://cs.uwaterloo.ca/~jimmylin/publications/Lin_etal_2021_RepL4NLP.pdf)
"""
[docs] module_name = "msptop200"
[docs] dependencies = [Dependency(key="benchmark", module="benchmark", name="msmarcopsg")]
[docs] config_spec = [
ConfigOption(
"firststage",
"tct",
"Options: tct, bm25, tct>bm25, bm25>tct. where config before > stands for training set source, and that after > stands for dev and test source.",
)
]
[docs] def get_train_url(self):
train_first_stage = self.config["firststage"].split(">")[0]
url_template = "https://drive.google.com/uc?id="
assert train_first_stage in {"bm25", "tct"}
file_id = "10VjzcDUtZwJWoWUlVnjtyI4j5K6c-882" if train_first_stage == "tct" else "1ZgrxqdbV3-YbF9PnOVtSIx04RqG-YOMW"
return url_template + file_id
[docs] def get_dev_url(self):
dev_first_stage = self.config["firststage"]
if ">" in dev_first_stage:
dev_first_stage = dev_first_stage.split(">")[1]
url_template = "https://drive.google.com/uc?id="
assert dev_first_stage in {"bm25", "tct"}
file_id = "1WBUashNhtJKNsKYBzeR4IxcMzbjqiqg6" if dev_first_stage == "tct" else "1PWuDcr8c4EIB-mxdFY7-KkTezJ7aN0Fq"
return url_template + file_id
[docs] def get_test_url(self):
dev_first_stage = self.config["firststage"]
if ">" in dev_first_stage:
dev_first_stage = dev_first_stage.split(">")[1]
url_template = "https://drive.google.com/uc?id="
assert dev_first_stage in {"tct"}, "Only support inference on tct test set for now"
file_id = "1U4DBP_3HBXC8EJNbI_wFUVoZnt7FiPbe"
return url_template + file_id
def _query_from_file(self, topicsfn, output_path, cfg):
outfn = Path(output_path) / "static.run"
done_fn = Path(output_path) / "done"
if done_fn.exists():
assert outfn.exists()
return outfn
tmp_dir = self.get_cache_path() / "tmp"
os.makedirs(tmp_dir, exist_ok=True)
output_path.mkdir(exist_ok=True, parents=True)
tag = self.config["firststage"]
fout = open(outfn, "wt")
url_lists = [self.get_train_url(), self.get_dev_url()]
if tag == "tct":
url_lists.append(self.get_test_url())
for set_name, url in zip(["train", "dev", "test"], url_lists):
if set_name == "test":
assert tag == "tct"
# basename = self.get_fn_from_url(url)
basename = f"{tag}-{set_name}"
tmp_fn = tmp_dir / basename
# download the file
if not os.path.exists(tmp_fn):
gdown.download(url, tmp_fn.as_posix(), quiet=False)
# convert into trec and combine
with open(tmp_fn, "rt") as f:
for line in f:
try:
qid, docid, rank = line.strip().split()
except:
raise ValueError("This line cannot be parsed:" + line)
score = 1000 - int(rank)
fout.write(f"{qid} Q0 {docid} {rank} {score} {tag}\n")
with open(done_fn, "wt") as f:
print("done", file=f)
return outfn