Source code for capreolus.reranker.KNRM

import matplotlib.pyplot as plt
import torch
from torch import nn

from capreolus import ConfigOption
from capreolus.reranker import Reranker
from capreolus.reranker.common import RbfKernelBank, SimilarityMatrix, create_emb_layer
from capreolus.utils.loginit import get_logger

[docs]logger = get_logger(__name__) # pylint: disable=invalid-name
[docs]class KNRM_class(nn.Module): # based on CedrKnrmRanker from https://github.com/Georgetown-IR-Lab/cedr/blob/master/modeling.py # which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license def __init__(self, extractor, config): super(KNRM_class, self).__init__() self.p = config mus = [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] sigmas = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.001] self.kernels = RbfKernelBank(mus, sigmas, dim=1, requires_grad=config["gradkernels"]) non_trainable = not self.p["finetune"] self.embedding = create_emb_layer(extractor.embeddings, non_trainable=non_trainable) self.simmat = SimilarityMatrix(self.embedding) channels = 1 if config["singlefc"]: combine_steps = [nn.Linear(self.kernels.count() * channels, 1)] else: combine_steps = [nn.Linear(self.kernels.count() * channels, 30), nn.Tanh(), nn.Linear(30, 1)] if config["scoretanh"]: combine_steps.append(nn.Tanh()) self.combine = nn.Sequential(*combine_steps)
[docs] def get_embedding(self, toks): return self.embedding(toks)
[docs] def forward(self, doctoks, querytoks, query_idf): simmat = self.simmat(querytoks, doctoks) kernels = self.kernels(simmat) VIEWS = 1 BATCH, KERNELS, QLEN, DLEN = kernels.shape kernels = kernels.reshape(BATCH, KERNELS * VIEWS, QLEN, DLEN) simmat = ( simmat.reshape(BATCH, 1, VIEWS, QLEN, DLEN) .expand(BATCH, KERNELS, VIEWS, QLEN, DLEN) .reshape(BATCH, KERNELS * VIEWS, QLEN, DLEN) ) result = kernels.sum(dim=3) # sum over document mask = simmat.sum(dim=3) != 0.0 # which query terms are not padding? result = torch.where(mask, (result + 1e-6).log(), mask.float()) result = result.sum(dim=2) # sum over query terms scores = self.combine(result) # linear combination over kernels return scores
@Reranker.register
[docs]class KNRM(Reranker): """Chenyan Xiong, Zhuyun Dai, Jamie Callan, Zhiyuan Liu, and Russell Power. 2017. End-to-End Neural Ad-hoc Ranking with Kernel Pooling. In SIGIR'17."""
[docs] module_name = "KNRM"
[docs] config_spec = [ ConfigOption("gradkernels", True, "backprop through mus and sigmas"), ConfigOption("scoretanh", False, "use a tanh on the prediction as in paper (True) or do not use a nonlinearity (False)"), ConfigOption("singlefc", True, "use single fully connected layer as in paper (True) or 2 fully connected layers (False)"), ConfigOption("finetune", False, "fine tune the embedding layer"), # TODO check save when True
]
[docs] def add_summary(self, summary_writer, niter): super(KNRM, self).add_summary(summary_writer, niter) if self.config["singlefc"]: fig = plt.figure() ax = fig.add_subplot(1, 1, 1) ax.matshow(self.model.combine[0].weight.data.cpu()) summary_writer.add_figure("combine_steps weight", fig, niter) else: pass
[docs] def build_model(self): if not hasattr(self, "model"): self.model = KNRM_class(self.extractor, self.config) return self.model
[docs] def score(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence, neg_sentence = d["posdoc"], d["negdoc"] return [ self.model(pos_sentence, query_sentence, query_idf).view(-1), self.model(neg_sentence, query_sentence, query_idf).view(-1),
]
[docs] def test(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence = d["posdoc"] return self.model(pos_sentence, query_sentence, query_idf).view(-1)