import numpy as np
import tensorflow as tf
from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import RbfKernelBankTF, similarity_matrix_tf
[docs]class TFKNRM_Class(tf.keras.layers.Layer):
def __init__(self, extractor, config, **kwargs):
super(TFKNRM_Class, self).__init__(**kwargs)
self.config = config
self.extractor = extractor
mus = [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
sigmas = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.001]
self.embedding = tf.keras.layers.Embedding(
len(self.extractor.stoi), self.extractor.embeddings.shape[1], weights=[self.extractor.embeddings], trainable=False
)
self.kernels = RbfKernelBankTF(mus, sigmas, dim=1, requires_grad=config["gradkernels"])
self.combine = tf.keras.layers.Dense(1, input_shape=(self.kernels.count(),), dtype=tf.float32)
[docs] def get_score(self, doc_tok, query_tok, query_idf):
query = self.embedding(query_tok)
doc = self.embedding(doc_tok)
batch_size, qlen, doclen = tf.shape(query)[0], tf.shape(query)[1], tf.shape(doc)[1]
simmat = similarity_matrix_tf(query, doc, query_tok, doc_tok, self.extractor.pad)
k = self.kernels(simmat)
doc_k = tf.reduce_sum(k, axis=3) # sum over document
reshaped_simmat = tf.broadcast_to(
tf.reshape(simmat, (batch_size, 1, qlen, doclen)), (batch_size, self.kernels.count(), qlen, doclen)
)
mask = tf.reduce_sum(reshaped_simmat, axis=3) != 0.0
log_k = tf.where(
mask, tf.math.log(tf.clip_by_value(doc_k, clip_value_min=1e-8, clip_value_max=np.Inf)), tf.cast(mask, doc_k.dtype)
)
query_k = tf.reduce_sum(log_k, axis=2)
scores = self.combine(query_k)
return tf.reshape(scores, [batch_size])
[docs] def call(self, x, **kwargs):
doc, query, query_idf = x[0], x[1], x[2]
score = self.get_score(doc, query, query_idf)
return score
[docs] def predict_step(self, data):
return self.score(data)
[docs] def score(self, x, **kwargs):
posdoc, negdoc, query, query_idf = x
return self.call((posdoc, query, query_idf))
[docs] def score_pair(self, x, **kwargs):
posdoc, negdoc, query, query_idf = x
pos_score = self.call((posdoc, query, query_idf))
neg_score = self.call((negdoc, query, query_idf))
return pos_score, neg_score
@Reranker.register
[docs]class TFKNRM(Reranker):
"""TensorFlow implementation of KNRM.
Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017. End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17.
"""
[docs] dependencies = [
Dependency(key="extractor", module="extractor", name="slowembedtext"),
Dependency(key="trainer", module="trainer", name="tensorflow"),
]
[docs] config_spec = [
ConfigOption("gradkernels", True, "backprop through mus and sigmas"),
ConfigOption("finetune", False, "fine tune the embedding layer"), # TODO check save when True
]
[docs] def build_model(self):
self.model = TFKNRM_Class(self.extractor, self.config)
return self.model