Source code for capreolus.extractor.bertpassage

import pickle
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm


from capreolus.extractor import Extractor
from capreolus import Dependency, ConfigOption, get_logger
from capreolus.utils.common import padlist
from capreolus.utils.exceptions import MissingDocError
from capreolus.tokenizer.punkt import PunktTokenizer

from .common import SingleTrainingPassagesMixin

[docs]logger = get_logger(__name__)
@Extractor.register
[docs]class BertPassage(Extractor, SingleTrainingPassagesMixin): """ Extracts passages from the document to be later consumed by a BERT based model. Does NOT use all the passages. The first passages is always used. Use the `prob` config to control the probability of a passage being selected Gotcha: In Tensorflow the train tfrecords have shape (batch_size, maxseqlen) while dev tf records have the shape (batch_size, num_passages, maxseqlen). This is because during inference, we want to pool over the scores of the passages belonging to a doc """
[docs] module_name = "bertpassage"
[docs] dependencies = [ Dependency(key="benchmark", module="benchmark", name=None), Dependency( key="index", module="index", name="anserini", default_config_overrides={"indexstops": True, "stemmer": "none"} ), Dependency(key="tokenizer", module="tokenizer", name="berttokenizer"),
]
[docs] config_spec = [ ConfigOption("maxseqlen", 256, "Maximum input length (query+document)"), ConfigOption("maxqlen", 20, "Maximum query length"), ConfigOption("padq", False, "Always pad queries to maxqlen"), ConfigOption("usecache", False, "Should the extracted features be cached?"), ConfigOption("passagelen", 150, "Length of the extracted passage"), ConfigOption("stride", 100, "Stride"), ConfigOption("sentences", False, "Use a sentence tokenizer to form passages"), ConfigOption("numpassages", 16, "Number of passages per document"), ConfigOption( "prob", 0.1, "The probability that a passage from the document will be used for training " "(the first passage is always used)",
), ]
[docs] config_keys_not_in_path = ["usecache"]
[docs] def build(self): self.pad = self.tokenizer.bert_tokenizer.pad_token_id self.cls = self.tokenizer.bert_tokenizer.cls_token_id self.sep = self.tokenizer.bert_tokenizer.sep_token_id self.pad_tok = self.tokenizer.bert_tokenizer.pad_token self.cls_tok = self.tokenizer.bert_tokenizer.cls_token self.sep_tok = self.tokenizer.bert_tokenizer.sep_token
[docs] def load_state(self, qids, docids): cache_fn = self.get_state_cache_file_path(qids, docids) logger.debug("loading state from: %s", cache_fn) with open(cache_fn, "rb") as f: state_dict = pickle.load(f) self.qid2toks = state_dict["qid2toks"]
[docs] def cache_state(self, qids, docids): os.makedirs(self.get_cache_path(), exist_ok=True) with open(self.get_state_cache_file_path(qids, docids), "wb") as f: state_dict = {"qid2toks": self.qid2toks} pickle.dump(state_dict, f, protocol=-1)
[docs] def get_tf_feature_description(self): feature_description = { "pos_bert_input": tf.io.FixedLenFeature([], tf.string), "pos_mask": tf.io.FixedLenFeature([], tf.string), "pos_seg": tf.io.FixedLenFeature([], tf.string), "neg_bert_input": tf.io.FixedLenFeature([], tf.string), "neg_mask": tf.io.FixedLenFeature([], tf.string), "neg_seg": tf.io.FixedLenFeature([], tf.string), "label": tf.io.FixedLenFeature([], tf.string), } return feature_description
[docs] def create_tf_dev_feature(self, sample): """ Unlike the train feature, the dev set uses all passages. Both the input and the output are dicts with the shape [batch_size, num_passages, maxseqlen] """ posdoc, negdoc, negdoc_id = sample["pos_bert_input"], sample["neg_bert_input"], sample["negdocid"] posdoc_mask, posdoc_seg, negdoc_mask, negdoc_seg = ( sample["pos_mask"], sample["pos_seg"], sample["neg_mask"], sample["neg_seg"], ) label = sample["label"] def _bytes_feature(value): """Returns a bytes_list from a string / byte.""" if isinstance(value, type(tf.constant(0))): # if value ist tensor value = value.numpy() # get value of tensor return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) feature = { "pos_bert_input": _bytes_feature(tf.io.serialize_tensor(posdoc)), "pos_mask": _bytes_feature(tf.io.serialize_tensor(posdoc_mask)), "pos_seg": _bytes_feature(tf.io.serialize_tensor(posdoc_seg)), "neg_bert_input": _bytes_feature(tf.io.serialize_tensor(negdoc)), "neg_mask": _bytes_feature(tf.io.serialize_tensor(negdoc_mask)), "neg_seg": _bytes_feature(tf.io.serialize_tensor(negdoc_seg)), "label": _bytes_feature(tf.io.serialize_tensor(label)), } return [feature]
[docs] def parse_tf_dev_example(self, example_proto): feature_description = self.get_tf_feature_description() parsed_example = tf.io.parse_example(example_proto, feature_description) def parse_tensor_as_int(x): parsed_tensor = tf.io.parse_tensor(x, tf.int64) parsed_tensor.set_shape([self.config["numpassages"], self.config["maxseqlen"]]) return parsed_tensor def parse_label_tensor(x): parsed_tensor = tf.io.parse_tensor(x, tf.float32) parsed_tensor.set_shape([2]) return parsed_tensor pos_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["pos_bert_input"], dtype=tf.int64) pos_mask = tf.map_fn(parse_tensor_as_int, parsed_example["pos_mask"], dtype=tf.int64) pos_seg = tf.map_fn(parse_tensor_as_int, parsed_example["pos_seg"], dtype=tf.int64) neg_bert_input = tf.map_fn(parse_tensor_as_int, parsed_example["neg_bert_input"], dtype=tf.int64) neg_mask = tf.map_fn(parse_tensor_as_int, parsed_example["neg_mask"], dtype=tf.int64) neg_seg = tf.map_fn(parse_tensor_as_int, parsed_example["neg_seg"], dtype=tf.int64) label = tf.map_fn(parse_label_tensor, parsed_example["label"], dtype=tf.float32) return (pos_bert_input, pos_mask, pos_seg, neg_bert_input, neg_mask, neg_seg), label
def _filter_inputs(self, bert_inputs, bert_masks, bert_segs, n_valid_psg): """Preserve only one passage from all available passages.""" assert n_valid_psg <= len( bert_inputs ), f"Passages only have {len(bert_inputs)} entries, but got {n_valid_psg} valid passages." valid_indexes = list(range(0, n_valid_psg)) if len(valid_indexes) == 0: valid_indexes = [0] random_i = self.rng.choice(valid_indexes) return list(map(lambda arr: arr[random_i], [bert_inputs, bert_masks, bert_segs])) def _encode_inputs(self, query_toks, passages): """Convert the query and passages into BERT inputs, mask, segments.""" bert_inputs, bert_masks, bert_segs = [], [], [] n_valid_psg = 0 for tokenized_passage in passages: if tokenized_passage != [self.pad_tok]: # end of the passage n_valid_psg += 1 inp, mask, seg = self._prepare_bert_input(query_toks, tokenized_passage) bert_inputs.append(inp) bert_masks.append(mask) bert_segs.append(seg) return bert_inputs, bert_masks, bert_segs, n_valid_psg def _get_passages(self, docid): doc = self.index.get_doc(docid) if not self.config["sentences"]: return self._get_sliding_window_passages(doc) else: return self._get_sent_passages(doc) def _get_sent_passages(self, doc): passages = [] punkt = PunktTokenizer() numpassages = self.config["numpassages"] for sentence in punkt.tokenize(doc): if len(passages) >= numpassages: break passages.extend(self._chunk_sent(sentence, self.config["passagelen"])) if numpassages != 0: passages = passages[:numpassages] n_actual_passages = len(passages) for _ in range(numpassages - n_actual_passages): # randomly use one of previous passages when the document is exhausted # append empty passages passages.append([""]) assert len(passages) == numpassages or len(numpassages) == 0 return sorted(passages, key=len) def _get_sliding_window_passages(self, doc): """ Extract passages from the doc. If there are too many passages, keep the first and the last one and sample from the rest. If there are not enough packages, pad. """ passages = [] numpassages = self.config["numpassages"] doc = self.tokenizer.tokenize(doc) for i in range(0, len(doc), self.config["stride"]): if i >= len(doc): assert len(passages) > 0, f"no passage can be built from empty document {doc}" break passages.append(doc[i : i + self.config["passagelen"]]) n_actual_passages = len(passages) # If we have a more passages than required, keep the first and last, and sample from the rest if n_actual_passages > numpassages: if numpassages > 1: # passages = [passages[0]] + list(self.rng.choice(passages[1:-1], numpassages - 2, replace=False)) + [passages[-1]] passages = passages[:numpassages] else: passages = [passages[0]] else: # Pad until we have the required number of passages passages.extend([[self.pad_tok] for _ in range(numpassages - n_actual_passages)]) assert len(passages) == numpassages return passages # from https://github.com/castorini/birch/blob/2dd0401ebb388a1c96f8f3357a064164a5db3f0e/src/utils/doc_utils.py#L73 def _chunk_sent(self, sent, max_len): words = self.tokenizer.tokenize(sent) if len(words) <= max_len: return [words] chunked_sents = [] size = int(len(words) / max_len) for i in range(0, size): seq = words[i * max_len : (i + 1) * max_len] chunked_sents.append(seq) return chunked_sents def _build_vocab(self, qids, docids, topics): """only build vocab for queries as the size of docidid2document would be large for some of the document retrieval collection.""" if self.is_state_cached(qids, docids) and self.config["usecache"]: logger.info("Vocabulary loaded from cache") self.load_state(qids, docids) else: logger.info("Building BertPassage vocabulary") self.qid2toks = {qid: self.tokenizer.tokenize(topics[qid]) for qid in tqdm(qids, desc="querytoks")} self.cache_state(qids, docids)
[docs] def exist(self): return hasattr(self, "qid2toks") and len(self.qid2toks)
[docs] def preprocess(self, qids, docids, topics): if self.exist(): return self.index.create_index() self._build_vocab(qids, docids, topics)
def _prepare_bert_input(self, query_toks, psg_toks): maxseqlen, maxqlen = self.config["maxseqlen"], self.config["maxqlen"] if len(query_toks) > maxqlen: logger.warning(f"Truncating query from {len(query_toks)} to {maxqlen}") query_toks = query_toks[:maxqlen] else: # if the len(query_toks) <= maxqlen, whether to pad it if self.config["padq"]: query_toks = padlist(query_toks, padlen=maxqlen, pad_token=self.pad_tok) psg_toks = psg_toks[: maxseqlen - len(query_toks) - 3] psg_toks = " ".join(psg_toks).split() # in case that psg_toks is np.array input_line = [self.cls_tok] + query_toks + [self.sep_tok] + psg_toks + [self.sep_tok] padded_input_line = padlist(input_line, padlen=maxseqlen, pad_token=self.pad_tok) inp = self.tokenizer.convert_tokens_to_ids(padded_input_line) mask = [1 if tok != self.pad_tok else 0 for tok in input_line] + [0] * (len(padded_input_line) - len(input_line)) seg = [0] * (len(query_toks) + 2) + [1] * (len(padded_input_line) - len(query_toks) - 2) return inp, mask, seg
[docs] def id2vec(self, qid, posid, negid=None, label=None, *args, **kwargs): """ See parent class for docstring """ training = kwargs.get("training", True) # default to be training assert label is not None maxseqlen = self.config["maxseqlen"] numpassages = self.config["numpassages"] query_toks = self.qid2toks[qid] # N.B: The passages in self.docid2passages are not bert tokenized pos_passages = self._get_passages(posid) pos_bert_inputs, pos_bert_masks, pos_bert_segs, n_valid_psg = self._encode_inputs(query_toks, pos_passages) if training: pos_bert_inputs, pos_bert_masks, pos_bert_segs = self._filter_inputs( pos_bert_inputs, pos_bert_masks, pos_bert_segs, n_valid_psg ) else: assert len(pos_bert_inputs) == numpassages pos_bert_inputs, pos_bert_masks, pos_bert_segs = map( lambda lst: np.array(lst, dtype=np.long), [pos_bert_inputs, pos_bert_masks, pos_bert_segs] ) # TODO: Rename the posdoc key in the below dict to 'pos_bert_input' data = { "qid": qid, "posdocid": posid, "pos_bert_input": pos_bert_inputs, "pos_mask": pos_bert_masks, "pos_seg": pos_bert_segs, "negdocid": "", "neg_bert_input": np.zeros_like(pos_bert_inputs, dtype=np.long), "neg_mask": np.zeros_like(pos_bert_masks, dtype=np.long), "neg_seg": np.zeros_like(pos_bert_segs, dtype=np.long), "label": np.array(label, dtype=np.float32), # ^^^ not change the shape of the label as it is only needed during training } if not negid: return data neg_passages = self._get_passages(negid) neg_bert_inputs, neg_bert_masks, neg_bert_segs, n_valid_psg = self._encode_inputs(query_toks, neg_passages) if training: neg_bert_inputs, neg_bert_masks, neg_bert_segs = self._filter_inputs( neg_bert_inputs, neg_bert_masks, neg_bert_segs, n_valid_psg ) else: assert len(neg_bert_inputs) == numpassages if not neg_bert_inputs: raise MissingDocError(qid, negid) data["negdocid"] = negid data["neg_bert_input"] = np.array(neg_bert_inputs, dtype=np.long) data["neg_mask"] = np.array(neg_bert_masks, dtype=np.long) data["neg_seg"] = np.array(neg_bert_segs, dtype=np.long) return data