Source code for capreolus.extractor.pooled_bertpassage

import tensorflow as tf
import numpy as np

from capreolus import Dependency, ConfigOption, get_logger
from capreolus.utils.exceptions import MissingDocError
from . import Extractor
from .bertpassage import BertPassage

logger = get_logger(__name__)
class PooledBertPassage(BertPassage):
    """
    Extracts passages from the document to be later consumed by a BERT based model.
    Different from BertPassage in the sense that all the passages from a document "stick together" during training - the resulting feature always have the shape (batch, num_passages, maxseqlen) - and this allows the reranker to pool over passages from the same document during training
    """
module_name = "pooledbertpassage"
dependencies = [
        Dependency(key="benchmark", module="benchmark", name=None),
        Dependency(
            key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"}
        ),
        Dependency(key="tokenizer", module="tokenizer", name="berttokenizer"),
config_spec = [
        ConfigOption("maxseqlen", 256, "Maximum input length (query+document)"),
        ConfigOption("maxqlen", 20, "Maximum query length"),
        ConfigOption("padq", False, "Always pad queries to maxqlen"),
        ConfigOption("usecache", False, "Should the extracted features be cached?"),
        ConfigOption("passagelen", 150, "Length of the extracted passage"),
        ConfigOption("stride", 100, "Stride"),
        ConfigOption("sentences", False, "Use a sentence tokenizer to form passages"),
        ConfigOption("numpassages", 16, "Number of passages per document"),
        # TODO remove prob here. unused.
        ConfigOption(
            "prob",
            0.1,
            "The probability that a passage from the document will be used for training (the first passage is always used)",
), ]
def create_tf_train_feature(self, sample):
        """
        Returns a set of features from a doc. Of the num_passages passages that are present in a document, we use only a subset of it.
        params:
            sample - A dict where each entry has the shape [batch_size, num_passages, maxseqlen]
        Returns a list of features. Each feature is a dict, and each value in the dict has the shape [batch_size, maxseqlen].
        Yes, the output shape is different to the input shape because we sample from the passages.
        """
        return self.create_tf_dev_feature(sample)
def create_tf_dev_feature(self, sample):
        """
        Unlike the train feature, the dev set uses all passages. Both the input and the output are dicts with the shape [batch_size, num_passages, maxseqlen]
        """
        posdoc, negdoc, negdoc_id = sample["pos_bert_input"], sample["neg_bert_input"], sample["negdocid"]
        posdoc_mask, posdoc_seg, negdoc_mask, negdoc_seg = (
            sample["pos_mask"],
            sample["pos_seg"],
            sample["neg_mask"],
            sample["neg_seg"],
        )
        label = sample["label"]

        def _bytes_feature(value):
            """Returns a bytes_list from a string / byte."""
            if isinstance(value, type(tf.constant(0))):  # if value ist tensor
                value = value.numpy()  # get value of tensor
            return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

        feature = {
            "pos_bert_input": _bytes_feature(,
            "pos_mask": _bytes_feature(,
            "pos_seg": _bytes_feature(,
            "neg_bert_input": _bytes_feature(,
            "neg_mask": _bytes_feature(,
            "neg_seg": _bytes_feature(,
            "label": _bytes_feature(,
        }

        return [feature]
def parse_tf_train_example(self, example_proto):
        return self.parse_tf_dev_example(example_proto)
def parse_tf_dev_example(self, example_proto):
        feature_description = self.get_tf_feature_description()
        parsed_example =, feature_description)

        def parse_tensor_as_int(x):
            parsed_tensor =, tf.int64)
            parsed_tensor.set_shape([self.config["numpassages"], self.config["maxseqlen"]])
            return parsed_tensor

        def parse_label_tensor(x):
            parsed_tensor =, tf.float32)
            parsed_tensor.set_shape([2])
            return parsed_tensor

        pos_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["pos_bert_input"], dtype=tf.int64)
        pos_mask = tf.map_fn(parse_tensor_as_int, parsed_example["pos_mask"], dtype=tf.int64)
        pos_seg = tf.map_fn(parse_tensor_as_int, parsed_example["pos_seg"], dtype=tf.int64)

        neg_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["neg_bert_input"], dtype=tf.int64)
        neg_mask = tf.map_fn(parse_tensor_as_int, parsed_example["neg_mask"], dtype=tf.int64)
        neg_seg = tf.map_fn(parse_tensor_as_int, parsed_example["neg_seg"], dtype=tf.int64)

        label = tf.map_fn(parse_label_tensor, parsed_example["label"], dtype=tf.float32)

        return (pos_bert_input, pos_mask, pos_seg, neg_bert_input, neg_mask, neg_seg), label
def id2vec(self, qid, posid, negid=None, label=None, *args, **kwargs):
        """
        See parent class for docstring
        """
        assert label is not None
        maxseqlen = self.config["maxseqlen"]
        numpassages = self.config["numpassages"]

        query_toks = self.qid2toks[qid]

        pos_bert_inputs = []
        pos_bert_masks = []
        pos_bert_segs = []

        # N.B: The passages in self.docid2passages are not bert tokenized
        pos_passages = self._get_passages(posid)
        for tokenized_passage in pos_passages:
            inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage)
            pos_bert_inputs.append(inp)
            pos_bert_masks.append(mask)
            pos_bert_segs.append(seg)

        # TODO: Rename the posdoc key in the below dict to 'pos_bert_input'
        data = {
            "qid": qid,
            "posdocid": posid,
            "pos_bert_input": np.array(pos_bert_inputs, dtype=np.long),
            "pos_mask": np.array(pos_bert_masks, dtype=np.long),
            "pos_seg": np.array(pos_bert_segs, dtype=np.long),
            "negdocid": "",
            "neg_bert_input": np.zeros((numpassages, maxseqlen), dtype=np.long),
            "neg_mask": np.zeros((numpassages, maxseqlen), dtype=np.long),
            "neg_seg": np.zeros((numpassages, maxseqlen), dtype=np.long),
            "label": np.array(label, dtype=np.float32),
        }

        if not negid:
            return data

        neg_bert_inputs, neg_bert_masks, neg_bert_segs = [], [], []
        neg_passages = self._get_passages(negid)
        for tokenized_passage in neg_passages:
            inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage)
            neg_bert_inputs.append(inp)
            neg_bert_masks.append(mask)
            neg_bert_segs.append(seg)

        if not neg_bert_inputs:
            raise MissingDocError(qid, negid)

        data["negdocid"] = negid
        data["neg_bert_input"] = np.array(neg_bert_inputs, dtype=np.long)
        data["neg_mask"] = np.array(neg_bert_masks, dtype=np.long)
        data["neg_seg"] = np.array(neg_bert_segs, dtype=np.long)

        return data