Source code for capreolus.task.rerank

import random
import os
from pathlib import Path

import numpy as np
import torch
from profane import ModuleBase, Dependency, ConfigOption, constants

from capreolus.sampler import TrainDataset, PredDataset
from capreolus.searcher import Searcher
from capreolus.task import Task
from capreolus import evaluator
from capreolus.utils.loginit import get_logger

[docs]logger = get_logger(__name__)
[docs]@Task.register class RerankTask(Task):
[docs] module_name = "rerank"
[docs] config_spec = [ ConfigOption("fold", "s1", "fold to run"), ConfigOption("optimize", "map", "metric to maximize on the dev set"), # affects train() because we check to save weights
]
[docs] dependencies = [ Dependency(key="benchmark", module="benchmark", name="wsdm20demo", provide_this=True, provide_children=["collection"]), Dependency(key="rank", module="task", name="rank"), Dependency(key="reranker", module="reranker", name="KNRM"),
]
[docs] commands = ["train", "evaluate", "traineval"] + Task.help_commands
[docs] default_command = "describe"
[docs] def traineval(self): self.train() self.evaluate()
[docs] def train(self): fold = self.config["fold"] self.rank.search() rank_results = self.rank.evaluate() best_search_run_path = rank_results["path"][fold] best_search_run = Searcher.load_trec_run(best_search_run_path) return self.rerank_run(best_search_run, self.get_results_path())
[docs] def rerank_run(self, best_search_run, train_output_path, include_train=False): if not isinstance(train_output_path, Path): train_output_path = Path(train_output_path) fold = self.config["fold"] dev_output_path = train_output_path / "pred" / "dev" logger.debug("results path: %s", train_output_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type] ) self.reranker.build_model() self.reranker.searcher_scores = best_search_run train_run = {qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["train_qids"]} dev_run = {qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["predict"]["dev"]} train_dataset = TrainDataset(qid_docid_to_rank=train_run, qrels=self.benchmark.qrels, extractor=self.reranker.extractor) dev_dataset = PredDataset(qid_docid_to_rank=dev_run, extractor=self.reranker.extractor) self.reranker.trainer.train( self.reranker, train_dataset, train_output_path, dev_dataset, dev_output_path, self.benchmark.qrels, self.config["optimize"], ) self.reranker.trainer.load_best_model(self.reranker, train_output_path) dev_output_path = train_output_path / "pred" / "dev" / "best" dev_preds = self.reranker.trainer.predict(self.reranker, dev_dataset, dev_output_path) test_run = {qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["predict"]["test"]} test_dataset = PredDataset(qid_docid_to_rank=test_run, extractor=self.reranker.extractor) test_output_path = train_output_path / "pred" / "test" / "best" test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) preds = {"dev": dev_preds, "test": test_preds} if include_train: train_dataset = PredDataset(qid_docid_to_rank=train_run, extractor=self.reranker.extractor) train_output_path = train_output_path / "pred" / "train" / "best" train_preds = self.reranker.trainer.predict(self.reranker, train_dataset, train_output_path) preds["train"] = train_preds return preds
[docs] def evaluate(self): fold = self.config["fold"] train_output_path = self.get_results_path() test_output_path = train_output_path / "pred" / "test" / "best" logger.debug("results path: %s", train_output_path) if os.path.exists(test_output_path): test_preds = Searcher.load_trec_run(test_output_path) else: self.rank.search() rank_results = self.rank.evaluate() best_search_run_path = rank_results["path"][fold] best_search_run = Searcher.load_trec_run(best_search_run_path) docids = set(docid for querydocs in best_search_run.values() for docid in querydocs) self.reranker.extractor.preprocess( qids=best_search_run.keys(), docids=docids, topics=self.benchmark.topics[self.benchmark.query_type] ) self.reranker.build_model() self.reranker.searcher_scores = best_search_run self.reranker.trainer.load_best_model(self.reranker, train_output_path) test_run = { qid: docs for qid, docs in best_search_run.items() if qid in self.benchmark.folds[fold]["predict"]["test"] } test_dataset = PredDataset(qid_docid_to_rank=test_run, extractor=self.reranker.extractor) test_preds = self.reranker.trainer.predict(self.reranker, test_dataset, test_output_path) metrics = evaluator.eval_runs(test_preds, self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level) logger.info("rerank: fold=%s test metrics: %s", fold, metrics) print("\ncomputing metrics across all folds") avg = {} found = 0 for fold in self.benchmark.folds: # TODO fix by using multiple Tasks from pathlib import Path pred_path = Path(test_output_path.as_posix().replace("fold-" + self.config["fold"], "fold-" + fold)) if not os.path.exists(pred_path): print("\tfold=%s results are missing and will not be included" % fold) continue found += 1 preds = Searcher.load_trec_run(pred_path) metrics = evaluator.eval_runs(preds, self.benchmark.qrels, evaluator.DEFAULT_METRICS, self.benchmark.relevance_level) for metric, val in metrics.items(): avg.setdefault(metric, []).append(val) avg = {k: np.mean(v) for k, v in avg.items()} logger.info("rerank: average cross-validated metrics when choosing iteration based on '%s':", self.config["optimize"]) for metric, score in sorted(avg.items()): logger.info("%15s: %0.4f", metric, score)