Source code for capreolus.reranker.PACRR

import torch
from torch import nn
from torch.nn import functional as F

from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker

# TODO add shuffle, cascade, disambig?
from capreolus.reranker.common import SimilarityMatrix, create_emb_layer

[docs]class PACRR_class(nn.Module): # based on CedrPacrrRanker from # which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license def __init__(self, extractor, config): super(PACRR_class, self).__init__() p = config self.p = p self.extractor = extractor self.embedding_dim = extractor.embeddings.shape[1] self.embedding = create_emb_layer(extractor.embeddings, non_trainable=True) self.simmat = SimilarityMatrix(padding=extractor.pad) self.ngrams = nn.ModuleList() for ng in range(p["mingram"], p["maxgram"] + 1): self.ngrams.append(PACRRConvMax2dModule(ng, p["nfilters"], k=p["kmax"], channels=1)) qterm_size = len(self.ngrams) * p["kmax"] + (1 if p["idf"] else 0) self.linear1 = torch.nn.Linear(extractor.config["maxqlen"] * qterm_size, p["combine"]) self.linear2 = torch.nn.Linear(p["combine"], p["combine"]) self.linear3 = torch.nn.Linear(p["combine"], 1) if p["nonlinearity"] == "none": nonlinearity = torch.nn.Identity elif p["nonlinearity"] == "relu": nonlinearity = torch.nn.ReLU elif p["nonlinearity"] == "tanh": nonlinearity = torch.nn.Tanh self.combine = torch.nn.Sequential(self.linear1, nonlinearity(), self.linear2, nonlinearity(), self.linear3)
[docs] def forward(self, sentence, query_sentence, query_idf): doc = self.embedding(sentence) query = self.embedding(query_sentence) simmat = self.simmat(query, doc, query_sentence, sentence) scores = [ng(simmat) for ng in self.ngrams] if self.p["idf"]: scores.append( F.softmax(query_idf.reshape(query_idf.shape, 1).float(), dim=1).view(-1, self.extractor.config["maxqlen"], 1) ) scores =, dim=2) scores = scores.reshape(scores.shape[0], scores.shape[1] * scores.shape[2]) rel = self.combine(scores) return rel
[docs]class PACRRConvMax2dModule(torch.nn.Module): # based on PACRRConvMax2dModule from # which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license def __init__(self, shape, n_filters, k, channels): super().__init__() self.shape = shape if shape != 1: self.pad = torch.nn.ConstantPad2d((0, shape - 1, 0, shape - 1), 0) else: self.pad = None self.conv = torch.nn.Conv2d(channels, n_filters, shape) self.activation = torch.nn.ReLU() self.k = k self.shape = shape self.channels = channels
[docs] def forward(self, simmat): BATCH, CHANNELS, QLEN, DLEN = simmat.shape if self.pad: simmat = self.pad(simmat) conv = self.activation(self.conv(simmat)) top_filters, _ = conv.max(dim=1) top_toks, _ = top_filters.topk(self.k, dim=2) result = top_toks.reshape(BATCH, QLEN, self.k) return result
[docs]@Reranker.register class PACRR(Reranker): """Kai Hui, Andrew Yates, Klaus Berberich, and Gerard de Melo. 2017. PACRR: A Position-Aware Neural IR Model for Relevance Matching. EMNLP 2017. """
[docs] module_name = "PACRR"
[docs] config_spec = [ ConfigOption("mingram", 1, "minimum length of ngram used"), ConfigOption("maxgram", 3, "maximum length of ngram used"), ConfigOption("nfilters", 32, "number of filters in convolution layer"), ConfigOption("idf", True, "concatenate idf signals to combine relevance score from individual query terms"), ConfigOption("kmax", 2, "value of kmax pooling used"), ConfigOption("combine", 32, "size of combination layers"), ConfigOption("nonlinearity", "relu", "nonlinearity in combination layer: none, relu, or tanh"),
[docs] def build_model(self): if not hasattr(self, "model"): self.model = PACRR_class(self.extractor, self.config) return self.model
[docs] def score(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence, neg_sentence = d["posdoc"], d["negdoc"] return [ self.model(pos_sentence, query_sentence, query_idf).view(-1), self.model(neg_sentence, query_sentence, query_idf).view(-1),
[docs] def test(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence = d["posdoc"] return self.model(pos_sentence, query_sentence, query_idf).view(-1)