Source code for capreolus.utils.trec

import gzip
import os
import xml.etree.ElementTree as ET
from collections import defaultdict


[docs]def threshold_trec_run(run, fold, k): """ Take a trec run, and keep only the top-k docs """ filtered_run = defaultdict(dict) # This is possible because best_search_run is an OrderedDict for qid, docs in run.items(): if qid in fold["predict"]["test"]: for idx, (docid, score) in enumerate(docs.items()): if idx >= k: break filtered_run[qid][docid] = score return filtered_run
[docs]def load_ntcir_topics(fn): topics = {} tree = ET.parse(fn) for child in tree.getroot(): qid = child.find("qid").text.strip() query = child.find("content").text.strip() assert qid not in topics assert len(qid) > 0 and len(query) > 0 topics[qid] = query return {"content": topics}
[docs]def load_trec_topics(queryfn): title, desc, narr = defaultdict(list), defaultdict(list), defaultdict(list) def clean_line(line, tag_name, unwanted_tokens=None): if unwanted_tokens is None: unwanted_tokens = [] elif isinstance(unwanted_tokens, str): unwanted_tokens = [unwanted_tokens] assert isinstance(unwanted_tokens, list) or isinstance(unwanted_tokens, set) line = line.replace(f"<{tag_name}>", "").replace(f"</{tag_name}>", "").strip().split() # remove_tag line = [token for token in line if token not in unwanted_tokens] return line block = None if str(queryfn).endswith(".gz"): openf = gzip.open else: openf = open with openf(queryfn, "rt") as f: for line in f: line = line.strip() if line.startswith("<num>"): # <num> Number: 700, or # <num>700 # <num>700</num> qid = line.split()[-1].replace("<num>", "").replace("</num>", "") # no longer an int # assert qid > 0 block = None elif line.startswith("<title>"): # <title> query here, or # <title>query here</title> block = "title" line = clean_line(line, tag_name=block, unwanted_tokens="Topic:") title[qid].extend(line) # TODO does this sometimes start with Topic: ? assert "Topic:" not in line elif line.startswith("<desc>"): # <desc> description \n description, or # <desc>description</desc> block = "desc" line = clean_line(line, tag_name=block, unwanted_tokens="Description:") desc[qid].extend(line) elif line.startswith("<narr>"): # same format as <desc> block = "narr" line = clean_line(line, tag_name=block, unwanted_tokens="Narrative:") narr[qid].extend(line) elif line.startswith("</top>") or line.startswith("<top>"): block = None elif block == "title": title[qid].extend(line.strip().split()) elif block == "desc": desc[qid].extend(line.strip().split()) elif block == "narr": narr[qid].extend(line.strip().split()) out = {} if len(title) > 0: out["title"] = {qid: " ".join(terms) for qid, terms in title.items()} if len(desc) > 0: out["desc"] = {qid: " ".join(terms).replace("Description: ", "") for qid, terms in desc.items()} if len(narr) > 0: out["narr"] = {qid: " ".join(terms) for qid, terms in narr.items()} return out
[docs]def load_qrels(qrelfile, qids=None, include_spam=True): labels = defaultdict(dict) with open(qrelfile, "rt") as f: for line in f: line = line.strip() if len(line) == 0: continue cols = line.split() qid, docid, label = cols[0], cols[2], int(cols[3]) if qids is not None and qid not in qids: continue if label < 0 and not include_spam: continue labels[qid][docid] = label # remove qids with no relevant docs for qid in list(labels.keys()): if max(labels[qid].values()) <= 0: del labels[qid] labels.default_factory = None # behave like normal dict return labels
[docs]def write_qrels(labels, qrelfile): qreldir = os.path.dirname(qrelfile) if qreldir != "": os.makedirs(qreldir, exist_ok=True) with open(qrelfile, "w") as fout: for qid in labels: for docid in labels[qid]: fout.write(f"{qid} Q0 {docid} {labels[qid][docid]}\n")
[docs]def document_to_trectxt(docno, txt): s = f"<DOC>\n<DOCNO> {docno} </DOCNO>\n" s += f"<TEXT>\n{txt}\n</TEXT>\n</DOC>\n" return s
[docs]def topic_to_trectxt(qno, title, desc=None, narr=None): return ( f"<top>\n\n"
f"<num> Number: {qno}\n" f"<title> {title}\n\n" f"<desc> Description:\n{desc or title}\n\n" f"<narr> Narrative:\n{narr or title}\n\n" f"</top>\n\n\n" )
[docs]def anserini_index_to_trec_docs(index_dir, output_dir, expected_doc_count): from jnius import autoclass JFile = autoclass("java.io.File") JFSDirectory = autoclass("org.apache.lucene.store.FSDirectory") JIndexReaderUtils = autoclass("io.anserini.index.IndexReaderUtils") RAW = autoclass("io.anserini.index.IndexArgs").RAW index_reader_utils = JIndexReaderUtils() fsdir = JFSDirectory.open(JFile(index_dir).toPath()) reader = autoclass("org.apache.lucene.index.DirectoryReader").open(fsdir) docids = set() for i in range(expected_doc_count): try: docid = index_reader_utils.convertLuceneDocidToDocid(reader, i) docids.add(docid) except: # lgtm [py/catch-base-exception] # we reached the end? pass if len(docids) != expected_doc_count: raise ValueError( f"we expected to retrieve {expected_doc_count} documents from the index, but actually found {len(docids)}" ) output_handles = [gzip.open(os.path.join(output_dir, f"{i}.gz"), "wt", encoding="utf-8") for i in range(100, 200)] for docidx, docid in enumerate(sorted(docids)): # parse documents according to here: https://github.com/castorini/anserini/blob/anserini-0.9.3/src/main/java/io/anserini/index/IndexUtils.java#L345-L352 doc = index_reader_utils.document(reader, docid).getField(RAW) if doc is None: raise ValueError(f"{RAW} documents cannot be found in the index.") doc = doc.stringValue().lstrip("<TEXT>").rstrip("</TEXT>").strip() txt = document_to_trectxt(docid, doc) handleidx = docidx % len(output_handles) print(txt, file=output_handles[handleidx]) for handle in output_handles: handle.close()