Source code for capreolus.eval.msmarco_eval

"""
This module computes evaluation metrics for MSMARCO dataset on the ranking task.

Creation Date : 06/12/2018
Last Modified : 1/21/2019
Authors : Daniel Campos <dacamp@microsoft.com>, Rutger van Haasteren <ruvanh@microsoft.com>
"""
import sys

from collections import Counter

[docs]MaxMRRRank = 10
[docs]def qrel_trec_to_msmarco(qrels): return {qid: list(qrels[qid]) for qid in qrels}
[docs]def runs_trec_to_msmarco(runs): return {qid: [docid for docid, score in sorted(runs[qid].items(), key=lambda kv: float(kv[1]), reverse=True)] for qid in runs}
[docs]def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): """Perform quality checks on the dictionaries Args: p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping Dict as read in with load_reference or load_reference_from_stream p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates Returns: bool,str: Boolean whether allowed, message to be shown in case of a problem """ message = "" allowed = True # Check that we do not have multiple passages per query for qid in qids_to_ranked_candidate_passages: # Remove all zeros from the candidates duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1]) if len(duplicate_pids - set([0])) > 0: message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format( qid=qid, pid=list(duplicate_pids)[0] ) allowed = False return allowed, message
[docs]def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages): """Compute MRR metric Args: p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping Dict as read in with load_reference or load_reference_from_stream p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates Returns: dict: dictionary of metrics {'MRR': <MRR Score>} """ all_scores = {} MRR = 0 ranking = [] for qid in qids_to_ranked_candidate_passages: if qid in qids_to_relevant_passageids: ranking.append(0) target_pid = qids_to_relevant_passageids[qid] candidate_pid = qids_to_ranked_candidate_passages[qid] for i in range(0, MaxMRRRank): if i >= len(candidate_pid): break if candidate_pid[i] in target_pid: MRR += 1 / (i + 1) ranking.pop() ranking.append(i + 1) break if len(ranking) == 0: raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?") MRR = MRR / len(qids_to_relevant_passageids) all_scores["MRR @10"] = MRR all_scores["QueriesRanked"] = len(qids_to_ranked_candidate_passages) return all_scores
[docs]def compute_metrics_from_files(trec_qrels, trec_runs, perform_checks=True): qids_to_relevant_passageids = qrel_trec_to_msmarco(trec_qrels) qids_to_ranked_candidate_passages = runs_trec_to_msmarco(trec_runs) if perform_checks: allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages) if message != "": print(message) return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
[docs]def main(): """Command line: python msmarco_eval_ranking.py <path_to_reference_file> <path_to_candidate_file> """ if len(sys.argv) == 3: path_to_reference = sys.argv[1] path_to_candidate = sys.argv[2] metrics = compute_metrics_from_files(path_to_reference, path_to_candidate) print("#####################") for metric in sorted(metrics): print("{}: {}".format(metric, metrics[metric])) print("#####################") else: print("Usage: msmarco_eval_ranking.py <reference ranking> <candidate ranking>") exit()