Source code for capreolus.extractor.pooled_bertpassage

import tensorflow as tf
import numpy as np

from capreolus import Dependency, ConfigOption, get_logger
from capreolus.utils.exceptions import MissingDocError
from . import Extractor
from .bertpassage import BertPassage

[docs]logger = get_logger(__name__)
[docs]@Extractor.register class PooledBertPassage(BertPassage): """ Extracts passages from the document to be later consumed by a BERT based model. Different from BertPassage in the sense that all the passages from a document "stick together" during training - the resulting feature always have the shape (batch, num_passages, maxseqlen) - and this allows the reranker to pool over passages from the same document during training """
[docs] module_name = "pooledbertpassage"
[docs] dependencies = [ Dependency(key="benchmark", module="benchmark", name=None), Dependency( key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"} ), Dependency(key="tokenizer", module="tokenizer", name="berttokenizer"),
[docs] config_spec = [ ConfigOption("maxseqlen", 256, "Maximum input length (query+document)"), ConfigOption("maxqlen", 20, "Maximum query length"), ConfigOption("usecache", False, "Should the extracted features be cached?"), ConfigOption("passagelen", 150, "Length of the extracted passage"), ConfigOption("stride", 100, "Stride"), ConfigOption("sentences", False, "Use a sentence tokenizer to form passages"), ConfigOption("numpassages", 16, "Number of passages per document"), # TODO remove prob here. unused. ConfigOption( "prob", 0.1, "The probability that a passage from the document will be used for training (the first passage is always used)",
), ]
[docs] def create_tf_train_feature(self, sample): """ Returns a set of features from a doc. Of the num_passages passages that are present in a document, we use only a subset of it. params: sample - A dict where each entry has the shape [batch_size, num_passages, maxseqlen] Returns a list of features. Each feature is a dict, and each value in the dict has the shape [batch_size, maxseqlen]. Yes, the output shape is different to the input shape because we sample from the passages. """ return self.create_tf_dev_feature(sample)
[docs] def create_tf_dev_feature(self, sample): """ Unlike the train feature, the dev set uses all passages. Both the input and the output are dicts with the shape [batch_size, num_passages, maxseqlen] """ posdoc, negdoc, negdoc_id = sample["pos_bert_input"], sample["neg_bert_input"], sample["negdocid"] posdoc_mask, posdoc_seg, negdoc_mask, negdoc_seg = ( sample["pos_mask"], sample["pos_seg"], sample["neg_mask"], sample["neg_seg"], ) label = sample["label"] def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): # if value ist tensor value = value.numpy() # get value of tensor return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) feature = { "pos_bert_input": _bytes_feature(, "pos_mask": _bytes_feature(, "pos_seg": _bytes_feature(, "neg_bert_input": _bytes_feature(, "neg_mask": _bytes_feature(, "neg_seg": _bytes_feature(, "label": _bytes_feature(, } return [feature]
[docs] def parse_tf_train_example(self, example_proto): return self.parse_tf_dev_example(example_proto)
[docs] def parse_tf_dev_example(self, example_proto): feature_description = self.get_tf_feature_description() parsed_example =, feature_description) def parse_tensor_as_int(x): parsed_tensor =, tf.int64) parsed_tensor.set_shape([self.config["numpassages"], self.config["maxseqlen"]]) return parsed_tensor def parse_label_tensor(x): parsed_tensor =, tf.float32) parsed_tensor.set_shape([2]) return parsed_tensor pos_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["pos_bert_input"], dtype=tf.int64) pos_mask = tf.map_fn(parse_tensor_as_int, parsed_example["pos_mask"], dtype=tf.int64) pos_seg = tf.map_fn(parse_tensor_as_int, parsed_example["pos_seg"], dtype=tf.int64) neg_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["neg_bert_input"], dtype=tf.int64) neg_mask = tf.map_fn(parse_tensor_as_int, parsed_example["neg_mask"], dtype=tf.int64) neg_seg = tf.map_fn(parse_tensor_as_int, parsed_example["neg_seg"], dtype=tf.int64) label = tf.map_fn(parse_label_tensor, parsed_example["label"], dtype=tf.float32) return (pos_bert_input, pos_mask, pos_seg, neg_bert_input, neg_mask, neg_seg), label
[docs] def id2vec(self, qid, posid, negid=None, label=None): """ See parent class for docstring """ assert label is not None maxseqlen = self.config["maxseqlen"] numpassages = self.config["numpassages"] query_toks = self.qid2toks[qid] pos_bert_inputs = [] pos_bert_masks = [] pos_bert_segs = [] # N.B: The passages in self.docid2passages are not bert tokenized pos_passages = self._get_passages(posid) for tokenized_passage in pos_passages: inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage) pos_bert_inputs.append(inp) pos_bert_masks.append(mask) pos_bert_segs.append(seg) # TODO: Rename the posdoc key in the below dict to 'pos_bert_input' data = { "qid": qid, "posdocid": posid, "pos_bert_input": np.array(pos_bert_inputs, dtype=np.long), "pos_mask": np.array(pos_bert_masks, dtype=np.long), "pos_seg": np.array(pos_bert_segs, dtype=np.long), "negdocid": "", "neg_bert_input": np.zeros((numpassages, maxseqlen), dtype=np.long), "neg_mask": np.zeros((numpassages, maxseqlen), dtype=np.long), "neg_seg": np.zeros((numpassages, maxseqlen), dtype=np.long), "label": np.array(label, dtype=np.float32), } if not negid: return data neg_bert_inputs, neg_bert_masks, neg_bert_segs = [], [], [] neg_passages = self._get_passages(negid) for tokenized_passage in neg_passages: inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage) neg_bert_inputs.append(inp) neg_bert_masks.append(mask) neg_bert_segs.append(seg) if not neg_bert_inputs: raise MissingDocError(qid, negid) data["negdocid"] = negid data["neg_bert_input"] = np.array(neg_bert_inputs, dtype=np.long) data["neg_mask"] = np.array(neg_bert_masks, dtype=np.long) data["neg_seg"] = np.array(neg_bert_segs, dtype=np.long) return data