Source code for capreolus.reranker.TFKNRM

import tensorflow as tf
from profane import ConfigOption, Dependency

from capreolus.reranker import Reranker
from capreolus.reranker.common import RbfKernelBankTF, similarity_matrix_tf


[docs]class TFKNRM_Class(tf.keras.Model): def __init__(self, extractor, config, **kwargs): super(TFKNRM_Class, self).__init__(**kwargs) self.config = config self.extractor = extractor mus = [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] sigmas = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.001] self.embedding = tf.keras.layers.Embedding( len(self.extractor.stoi), self.extractor.embeddings.shape[1], weights=[self.extractor.embeddings], trainable=False ) self.kernels = RbfKernelBankTF(mus, sigmas, dim=1, requires_grad=config["gradkernels"]) self.combine = tf.keras.layers.Dense(1, input_shape=(self.kernels.count(),))
[docs] def get_score(self, doc_tok, query_tok, query_idf): query = self.embedding(query_tok) doc = self.embedding(doc_tok) batch_size, qlen, doclen = tf.shape(query)[0], tf.shape(query)[1], tf.shape(doc)[1] simmat = similarity_matrix_tf(query, doc, query_tok, doc_tok, self.extractor.pad) k = self.kernels(simmat) doc_k = tf.reduce_sum(k, axis=3) # sum over document reshaped_simmat = tf.broadcast_to( tf.reshape(simmat, (batch_size, 1, qlen, doclen)), (batch_size, self.kernels.count(), qlen, doclen) ) mask = tf.reduce_sum(reshaped_simmat, axis=3) != 0.0 log_k = tf.where(mask, tf.math.log(doc_k + 1e-6), tf.cast(mask, tf.float32)) query_k = tf.reduce_sum(log_k, axis=2) scores = self.combine(query_k) return tf.reshape(scores, [batch_size])
[docs] def call(self, x, **kwargs): """ During training, both posdoc and negdoc are passed During eval, both posdoc and negdoc are passed but negdoc would be a zero tensor Whether negdoc is a legit doc tensor or a dummy zero tensor is determined by which sampler is used (eg: sampler.TrainDataset) as well as the extractor (eg: EmbedText) Unlike the pytorch KNRM model, KNRMTF accepts both the positive and negative document in its forward pass. It scores them separately and returns the score difference (i.e posdoc_score - negdoc_score). """ posdoc, negdoc, query, query_idf = x[0], x[1], x[2], x[3] posdoc_score, negdoc_score = self.get_score(posdoc, query, query_idf), self.get_score(negdoc, query, query_idf) # During eval, the negdoc_score would be a zero tensor # TODO: Verify that negdoc_score is indeed always zero whenever a zero negdoc tensor is passed into it return tf.stack([posdoc_score, negdoc_score], axis=1)
[docs]@Reranker.register class TFKNRM(Reranker):
[docs] module_name = "TFKNRM"
[docs] dependencies = [ Dependency(key="extractor", module="extractor", name="embedtext"), Dependency(key="trainer", module="trainer", name="tensorflow"),
]
[docs] config_spec = [ ConfigOption("gradkernels", True, "backprop through mus and sigmas"), ConfigOption("finetune", False, "fine tune the embedding layer"), # TODO check save when True
]
[docs] def build_model(self): self.model = TFKNRM_Class(self.extractor, self.config) return self.model