"""
This module computes evaluation metrics for MSMARCO dataset on the ranking task.
Creation Date : 06/12/2018
Last Modified : 1/21/2019
Authors : Daniel Campos <dacamp@microsoft.com>, Rutger van Haasteren <ruvanh@microsoft.com>
"""
import sys
from collections import Counter
[docs]def qrel_trec_to_msmarco(qrels):
return {qid: list(qrels[qid]) for qid in qrels}
[docs]def runs_trec_to_msmarco(runs):
return {qid: [docid for docid, score in sorted(runs[qid].items(), key=lambda kv: float(kv[1]), reverse=True)] for qid in runs}
[docs]def quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
"""Perform quality checks on the dictionaries
Args:
p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
Dict as read in with load_reference or load_reference_from_stream
p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
Returns:
bool,str: Boolean whether allowed, message to be shown in case of a problem
"""
message = ""
allowed = True
# Check that we do not have multiple passages per query
for qid in qids_to_ranked_candidate_passages:
# Remove all zeros from the candidates
duplicate_pids = set([item for item, count in Counter(qids_to_ranked_candidate_passages[qid]).items() if count > 1])
if len(duplicate_pids - set([0])) > 0:
message = "Cannot rank a passage multiple times for a single query. QID={qid}, PID={pid}".format(
qid=qid, pid=list(duplicate_pids)[0]
)
allowed = False
return allowed, message
[docs]def compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages):
"""Compute MRR metric
Args:
p_qids_to_relevant_passageids (dict): dictionary of query-passage mapping
Dict as read in with load_reference or load_reference_from_stream
p_qids_to_ranked_candidate_passages (dict): dictionary of query-passage candidates
Returns:
dict: dictionary of metrics {'MRR': <MRR Score>}
"""
all_scores = {}
MRR = 0
ranking = []
for qid in qids_to_ranked_candidate_passages:
if qid in qids_to_relevant_passageids:
ranking.append(0)
target_pid = qids_to_relevant_passageids[qid]
candidate_pid = qids_to_ranked_candidate_passages[qid]
for i in range(0, MaxMRRRank):
if i >= len(candidate_pid):
break
if candidate_pid[i] in target_pid:
MRR += 1 / (i + 1)
ranking.pop()
ranking.append(i + 1)
break
if len(ranking) == 0:
raise IOError("No matching QIDs found. Are you sure you are scoring the evaluation set?")
MRR = MRR / len(qids_to_relevant_passageids)
all_scores["MRR @10"] = MRR
all_scores["QueriesRanked"] = len(qids_to_ranked_candidate_passages)
return all_scores
[docs]def compute_metrics_from_files(trec_qrels, trec_runs, perform_checks=True):
qids_to_relevant_passageids = qrel_trec_to_msmarco(trec_qrels)
qids_to_ranked_candidate_passages = runs_trec_to_msmarco(trec_runs)
if perform_checks:
allowed, message = quality_checks_qids(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
if message != "":
print(message)
return compute_metrics(qids_to_relevant_passageids, qids_to_ranked_candidate_passages)
[docs]def main():
"""Command line:
python msmarco_eval_ranking.py <path_to_reference_file> <path_to_candidate_file>
"""
if len(sys.argv) == 3:
path_to_reference = sys.argv[1]
path_to_candidate = sys.argv[2]
metrics = compute_metrics_from_files(path_to_reference, path_to_candidate)
print("#####################")
for metric in sorted(metrics):
print("{}: {}".format(metric, metrics[metric]))
print("#####################")
else:
print("Usage: msmarco_eval_ranking.py <reference ranking> <candidate ranking>")
exit()