capreolus.extractor.bertpassage

Module Contents

Classes

BertPassage Extracts passages from the document to be later consumed by a BERT based model.
capreolus.extractor.bertpassage.logger[source]
class capreolus.extractor.bertpassage.BertPassage(config=None, provide=None, share_dependency_objects=False, build=True)[source]

Bases: capreolus.extractor.Extractor

Extracts passages from the document to be later consumed by a BERT based model. Does NOT use all the passages. The first passages is always used. Use the prob config to control the probability of a passage being selected Gotcha: In Tensorflow the train tfrecords have shape (batch_size, maxseqlen) while dev tf records have the shape (batch_size, num_passages, maxseqlen). This is because during inference, we want to pool over the scores of the passages belonging to a doc

module_name = bertpassage[source]
dependencies[source]
pad = 0[source]
pad_tok = [PAD][source]
config_spec[source]
load_state(self, qids, docids)[source]
cache_state(self, qids, docids)[source]
get_tf_feature_description(self)[source]
create_tf_train_feature(self, sample)[source]

Returns a set of features from a doc. Of the num_passages passages that are present in a document, we use only a subset of it. params: sample - A dict where each entry has the shape [batch_size, num_passages, maxseqlen]

Returns a list of features. Each feature is a dict, and each value in the dict has the shape [batch_size, maxseqlen]. Yes, the output shape is different to the input shape because we sample from the passages.

create_tf_dev_feature(self, sample)[source]

Unlike the train feature, the dev set uses all passages. Both the input and the output are dicts with the shape [batch_size, num_passages, maxseqlen]

parse_tf_train_example(self, example_proto)[source]
parse_tf_dev_example(self, example_proto)[source]
get_passages_for_doc(self, doc)[source]

Extract passages from the doc. If there are too many passages, keep the first and the last one and sample from the rest. If there are not enough packages, pad.

exist(self)[source]
preprocess(self, qids, docids, topics)[source]
id2vec(self, qid, posid, negid=None, label=None)[source]

See parent class for docstring