import hashlib
import random
import torch.utils.data
from capreolus.utils.exceptions import MissingDocError
from capreolus.utils.loginit import get_logger
[docs]logger = get_logger(__name__)
[docs]class TrainDataset(torch.utils.data.IterableDataset):
"""
Samples training data. Intended to be used with a pytorch DataLoader
"""
def __init__(self, qid_docid_to_rank, qrels, extractor, relevance_level=1):
self.extractor = extractor
# remove qids from qid_docid_to_rank that do not have relevance labels in the qrels
qid_docid_to_rank = qid_docid_to_rank.copy()
for qid in list(qid_docid_to_rank.keys()):
if qid not in qrels:
logger.warning("skipping training qid=%s that was missing from the qrels", qid)
del qid_docid_to_rank[qid]
self.qid_docid_to_rank = qid_docid_to_rank
self.qid_to_reldocs = {
qid: [docid for docid in docids if qrels[qid].get(docid, 0) >= relevance_level]
for qid, docids in qid_docid_to_rank.items()
}
# TODO option to include only negdocs in a top k
self.qid_to_negdocs = {
qid: [docid for docid in docids if qrels[qid].get(docid, 0) < relevance_level]
for qid, docids in qid_docid_to_rank.items()
}
# remove any ids that do not have both relevant and non-relevant documents for training
total_samples = 1 # keep tracks of the total possible number of unique training triples for this dataset
for qid in qid_docid_to_rank:
posdocs = len(self.qid_to_reldocs[qid])
negdocs = len(self.qid_to_negdocs[qid])
total_samples += posdocs * negdocs
if posdocs == 0 or negdocs == 0:
logger.debug("removing training qid=%s with %s positive docs and %s negative docs", qid, posdocs, negdocs)
del self.qid_to_reldocs[qid]
del self.qid_to_negdocs[qid]
self.total_samples = total_samples
def __hash__(self):
return self.get_hash()
[docs] def get_hash(self):
sorted_rep = sorted([(qid, docids) for qid, docids in self.qid_docid_to_rank.items()])
key_content = "{0}{1}".format(self.extractor.module_name, str(sorted_rep))
key = hashlib.md5(key_content.encode("utf-8")).hexdigest()
return "train_{0}".format(key)
[docs] def get_total_samples(self):
return self.total_samples
[docs] def generator_func(self):
# Convert each query and doc id to the corresponding feature/embedding and yield
while True:
all_qids = sorted(self.qid_to_reldocs)
if len(all_qids) == 0:
raise RuntimeError("TrainDataset has no valid qids")
random.shuffle(all_qids)
for qid in all_qids:
posdocid = random.choice(self.qid_to_reldocs[qid])
negdocid = random.choice(self.qid_to_negdocs[qid])
try:
yield self.extractor.id2vec(qid, posdocid, negdocid)
except MissingDocError:
# at training time we warn but ignore on missing docs
logger.warning(
"skipping training pair with missing features: qid=%s posid=%s negid=%s", qid, posdocid, negdocid
)
def __iter__(self):
"""
Returns: Triplets of the form (query_feature, posdoc_feature, negdoc_feature)
"""
return iter(self.generator_func())
[docs]class PredDataset(torch.utils.data.IterableDataset):
"""
Creates a Dataset for evaluation (test) data to be used with a pytorch DataLoader
"""
def __init__(self, qid_docid_to_rank, extractor):
self.qid_docid_to_rank = qid_docid_to_rank
self.extractor = extractor
def genf():
for qid, docids in qid_docid_to_rank.items():
for docid in docids:
try:
yield extractor.id2vec(qid, docid)
except MissingDocError:
# when predictiong we raise an exception on missing docs, as this may invalidate results
logger.error("got none features for prediction: qid=%s posid=%s", qid, docid)
raise
self.generator_func = genf
def __hash__(self):
return self.get_hash()
[docs] def get_hash(self):
sorted_rep = sorted([(qid, docids) for qid, docids in self.qid_docid_to_rank.items()])
key_content = "{0}{1}".format(self.extractor.module_name, str(sorted_rep))
key = hashlib.md5(key_content.encode("utf-8")).hexdigest()
return "dev_{0}".format(key)
def __iter__(self):
"""
Returns: Tuples of the form (query_feature, posdoc_feature)
"""
return iter(self.generator_func())
[docs] def get_qid_docid_pairs(self):
"""
Returns a generator for the (qid, docid) pairs. Useful if you want to sequentially access the pred pairs without
extracting the actual content
"""
for qid in self.qid_docid_to_rank:
for docid in self.qid_docid_to_rank[qid]:
yield qid, docid