Source code for capreolus.benchmark

import os
import json
from copy import deepcopy
from collections import defaultdict

import ir_datasets

from capreolus import ModuleBase
from capreolus.utils.caching import cached_file, TargetFileExists
from capreolus.utils.trec import write_qrels, load_qrels, load_trec_topics
from capreolus.utils.loginit import get_logger


[docs]logger = get_logger(__name__)
[docs]def validate(build_f): def validate_folds_file(self): if not hasattr(self, "fold_file") or self.fold_file is None: logger.warning(f"Folds file is not found for Module {self.module_name}") return if self.fold_file.suffix != ".json": raise ValueError(f"Expect folds file to be in .json format.") raw_folds = json.load(open(self.fold_file)) # we actually don't need to verify the name of folds right? for fold_name, fold_sets in raw_folds.items(): if set(fold_sets) != {"train_qids", "predict"}: raise ValueError(f"Expect each fold to contain ['train_qids', 'predict'] fields.") if set(fold_sets["predict"]) != {"dev", "test"}: raise ValueError(f"Expect each fold to contain ['dev', 'test'] fields under 'predict'.") logger.info("Folds file validation finishes.") def validate_qrels_file(self): if not hasattr(self, "qrel_file") or self.qrel_file is None: logger.warning(f"Qrel file is not found for Module {self.module_name}") return n_dup, qrels = 0, defaultdict(dict) with open(self.qrel_file) as f: for line in f: qid, _, docid, label = line.strip().split() if docid in qrels[qid]: n_dup += 1 if int(label) != qrels[qid][docid]: raise ValueError(f"Found conflicting label in {self.qrel_file} for query {qid} and document {docid}.") qrels[qid][docid] = int(label) if n_dup > 0: qrel_file_no_ext, ext = os.path.splitext(self.qrel_file) dup_qrel_file = qrel_file_no_ext + "-contain-dup-entries" + ext os.rename(self.qrel_file, dup_qrel_file) write_qrels(qrels, self.qrel_file) logger.warning( f"Removed {n_dup} entries from the file {self.qrel_file}. The original version could be found in {dup_qrel_file}." ) logger.info("Qrel file validation finishes.") def validate_query_alignment(self): topic_qids = set(self.topics[self.query_type]) qrels_qids = set(self.qrels) for fold_name, fold_sets in self.folds.items(): # check if there are overlap between training, dev, and test set train_qids, dev_qids, test_qids = ( set(fold_sets["train_qids"]), set(fold_sets["predict"]["dev"]), set(fold_sets["predict"]["test"]), ) if len(train_qids & dev_qids) > 0: logger.warning( f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}." ) if len(train_qids & test_qids) > 0: logger.warning( f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}." ) if len(dev_qids & test_qids) > 0: logger.warning( f"Found {len(train_qids & dev_qids)} overlap queries between training and dev set in fold {fold_name}." ) # check if the topics, qrels, and folds file share a reasonable set (if not all) of queries folds_qids = train_qids | dev_qids | test_qids n_overlap = len(set(topic_qids) & set(qrels_qids) & set(folds_qids)) if not len(topic_qids) == len(qrels_qids) == len(folds_qids) == n_overlap: logger.warning( f"Number of queries are not aligned across topics, qrels and folds in fold {fold_name}: {len(topic_qids)} queries in topics file, {len(qrels_qids)} queries in qrels file, {len(folds_qids)} queries in folds file; {n_overlap} overlap queries found among the three." ) # check if any topic in folds cannot be found in topics file for set_name, set_qids in zip(["training", "dev", "test"], [train_qids, dev_qids, test_qids]): if len(set_qids - topic_qids) > 0: raise ValueError( f"{len(set_qids - topic_qids)} queries in {set_name} set of fold {fold_name} cannot be found in topic file." ) logger.info("Query Alignment validation finishes.") def _validate(self): """Rewrite the files that contain invalid (duplicate) entries, and remove the currently loaded variables""" build_f(self) validate_folds_file(self) validate_qrels_file(self) validate_query_alignment(self) return _validate
[docs]class Benchmark(ModuleBase): """Base class for Benchmark modules. The purpose of a Benchmark is to provide the data needed to run an experiment, such as queries, folds, and relevance judgments. Modules should provide: - a ``topics`` dict mapping query ids (*qids*) to *queries* - a ``qrels`` dict mapping *qids* to *docids* and *relevance labels* - a ``folds`` dict mapping a fold name to *training*, *dev* (validation), and *testing* qids - if these can be loaded from files in standard formats, they can be specified by setting the ``topic_file``, ``qrel_file``, and ``fold_file``, respectively, rather than by setting the above attributes directly """
[docs] module_type = "benchmark"
[docs] qrel_file = None
[docs] topic_file = None
[docs] fold_file = None
[docs] query_type = None
[docs] relevance_level = 1
""" Documents with a relevance label >= relevance_level will be considered relevant. This corresponds to trec_eval's --level_for_rel (and is passed to pytrec_eval as relevance_level). """
[docs] use_train_as_dev = True
""" Whether to use training set as validate set when there is no training needed, e.g. for traditional IR algorithms like BM25 """ @property
[docs] def qrels(self): if not hasattr(self, "_qrels"): self._qrels = load_qrels(self.qrel_file) return self._qrels
@property
[docs] def topics(self): if not hasattr(self, "_topics"): self._topics = load_trec_topics(self.topic_file) return self._topics
@property
[docs] def folds(self): if not hasattr(self, "_folds"): self._folds = json.load(open(self.fold_file, "rt"), parse_int=str) return self._folds
@property
[docs] def non_nn_dev(self): dev_per_fold = {fold_name: deepcopy(folds["predict"]["dev"]) for fold_name, folds in self.folds.items()} if self.use_train_as_dev: for fold_name, folds in self.folds.items(): dev_per_fold[fold_name].extend(folds["train_qids"]) return dev_per_fold
[docs] def get_topics_file(self, query_sets=None): """Returns path to a topics file in TSV format containing queries from query_sets. query_sets may contain any combination of 'train', 'dev', and 'test'. All are returned if query_sets is None.""" if query_sets: query_sets = set(query_sets) invalid = query_sets - {"train", "test", "dev"} if invalid: raise ValueError(f"query_sets contains invalid fold names: {invalid}") query_sets = "_".join(sorted(query_sets)) valid_qids = set() if "train" in query_sets: valid_qids.update(self.folds["train_qids"]) if "dev" in query_sets: valid_qids.update(self.folds["predict"]["dev"]) if "test" in query_sets: valid_qids.update(self.folds["predict"]["test"]) else: query_sets = "all" valid_qids = None fn = self.get_cache_path() / f"topics-{query_sets}.tsv" try: with cached_file(fn) as tmp_fn: with open(tmp_fn, "wt") as outf: for qid, query in self.topics[self.query_type].items(): if query_sets == "all" or qid in valid_qids: print(f"{qid}\t{query}", file=outf) except TargetFileExists as e: pass return fn
@validate
[docs] def build(self): return
[docs]class IRDBenchmark(Benchmark):
[docs] ird_dataset_names = []
@property
[docs] def qrels(self): if not hasattr(self, "_qrels"): self._qrels = self.ird_load_qrels() return self._qrels
@property
[docs] def topics(self): if not hasattr(self, "_topics"): self._topics = self.ird_load_topics() return self._topics
[docs] def ird_load_qrels(self): qrels = {} for name in self.ird_dataset_names: dataset = ir_datasets.load(name) for qrel in dataset.qrels_iter(): qrels.setdefault(qrel.query_id, {}) qrels[qrel.query_id][qrel.doc_id] = max(qrel.relevance, qrels[qrel.query_id].get(qrel.doc_id, -1)) return qrels
[docs] def ird_load_topics(self): topics = {} field = "description" if self.query_type == "desc" else self.query_type for name in self.ird_dataset_names: dataset = ir_datasets.load(name) for query in dataset.queries_iter(): topics[query.query_id] = getattr(query, field).replace("\n", " ") return {self.query_type: topics}
from profane import import_all_modules from .dummy import DummyBenchmark import_all_modules(__file__, __package__)