Source code for capreolus.reranker.TK

import math

import torch
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import StackedSimilarityMatrix, create_emb_layer
from capreolus.utils.loginit import get_logger

[docs]logger = get_logger(__name__) # pylint: disable=invalid-name
[docs]class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer("pe", pe)
[docs] def forward(self, x): x = x + self.pe[: x.size(0), :] return self.dropout(x)
[docs]class TK_class(nn.Module): """ Adapted from https://github.com/sebastian-hofstaetter/transformer-kernel-ranking/blob/master/matchmaker/models/tk.py TK is a neural IR model - a fusion between transformer contextualization & kernel-based scoring -> uses 1 transformer block to contextualize embeddings -> soft-histogram kernels to score interactions """ def __init__(self, extractor, config): super(TK_class, self).__init__() self.embeddim = extractor.embeddings.shape[1] self.p = config self.mus = torch.tensor([-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0], dtype=torch.float) self.mu_matrix = self.get_mu_matrix(extractor) self.sigma = torch.tensor(0.1, requires_grad=False) dropout = 0 non_trainable = not self.p["finetune"] self.embedding = create_emb_layer(extractor.embeddings, non_trainable=non_trainable) self.cosine_module = StackedSimilarityMatrix(padding=extractor.pad) self.position_encoder = PositionalEncoding(self.embeddim) self.mixer = nn.Parameter(torch.full([1, 1, 1], 0.9, dtype=torch.float32, requires_grad=True)) encoder_layers = TransformerEncoderLayer(self.embeddim, config["numattheads"], config["ffdim"], dropout) self.transformer_encoder = TransformerEncoder(encoder_layers, config["numlayers"]) self.s_log_fcc = nn.Linear(len(self.mus), 1, bias=False) self.s_len_fcc = nn.Linear(len(self.mus), 1, bias=False) self.comb_fcc = nn.Linear(2, 1, bias=False) # init with small weights, otherwise the dense output is way to high for the tanh -> resulting in loss == 1 all the time torch.nn.init.uniform_(self.s_log_fcc.weight, -0.014, 0.014) # inits taken from matchzoo torch.nn.init.uniform_(self.s_len_fcc.weight, -0.014, 0.014) # inits taken from matchzoo # init with small weights, otherwise the dense output is way to high for the tanh -> resulting in loss == 1 all the time torch.nn.init.uniform_(self.comb_fcc.weight, -0.014, 0.014) # inits taken from matchzoo
[docs] def get_mu_matrix(self, extractor): """ Returns a matrix of mu values that can be directly subtracted from the cosine matrix. This is the matrix mu in equation 5 in the paper (https://arxiv.org/pdf/2002.01854.pdf) """ qlen = extractor.config["maxqlen"] doclen = extractor.config["maxdoclen"] mu_matrix = torch.zeros(len(self.mus), qlen, doclen, requires_grad=False) for i, mu in enumerate(self.mus): mu_matrix[i] = torch.full((qlen, doclen), mu) return mu_matrix
[docs] def get_mask(self, embedding): """ Gets a mask of shape (seq_len, seq_len). This is an additive mask, hence masked elements should be -inf """ batch_size = embedding.shape[0] seq_len = embedding.shape[1] # Get a normal mask of shape (batch_size, seq_len). Entry would be 0 if a seq element should be masked mask = ((embedding != torch.zeros(self.embeddim).to(embedding.device)).to(dtype=embedding.dtype).sum(-1) != 0).to( dtype=embedding.dtype ) # The square attention mask encoder_mask = torch.zeros(batch_size, seq_len, seq_len).to(embedding.device) # Set -inf on all rows corresponding to a pad token encoder_mask[mask == 0] = float("-inf") # Set -inf on all columns corresponding to a pad token (the tricky bit) col_mask = mask.reshape(batch_size, 1, seq_len).expand(batch_size, seq_len, seq_len) encoder_mask[col_mask == 0] = float("-inf") return torch.cat([encoder_mask] * self.p["numattheads"])
[docs] def get_embedding(self, toks): """ Overrides KNRM_Class's get_embedding to return contextualized word embeddings """ embedding = self.embedding(toks) # Transformer layers expect input in shape (L, N, E), where L is sequence len, N is batch, E is embed dims reshaped_embedding = embedding.permute(1, 0, 2) position_encoded_embedding = self.position_encoder(reshaped_embedding) # TODO: Mask should be additive mask = self.get_mask(embedding) if self.p["usemask"] else None contextual_embedding = self.transformer_encoder(position_encoded_embedding, mask).permute(1, 0, 2) if self.p["usemixer"]: return self.mixer * embedding + (1 - self.mixer) * contextual_embedding else: return self.p["alpha"] * embedding + (1 - self.p["alpha"]) * contextual_embedding
[docs] def forward(self, doctoks, querytoks, query_idf): batches = doctoks.shape[0] qlen = querytoks.shape[1] doclen = doctoks.shape[1] doc = self.get_embedding(doctoks) device = doc.device query = self.get_embedding(querytoks) # cosine_matrix = self.cosine_module.forward(query, doc) cosine_matrix = self.cosine_module.forward(query, doc, querytoks, doctoks) # cosine_matrix = cosine_matrix.reshape(batches, 1, qlen, doclen) cosine_matrix = cosine_matrix.expand(batches, len(self.mus), qlen, doclen) kernel_matrix = torch.exp(-torch.pow(cosine_matrix - self.mu_matrix.to(device), 2)) / (2 * torch.pow(self.sigma, 2)) condensed_kernel_matrix = kernel_matrix.sum(3) s_log_k = torch.log2(condensed_kernel_matrix).sum(2) s_len_k = condensed_kernel_matrix.sum(2) / doclen s_log = self.s_log_fcc(s_log_k) s_len = self.s_len_fcc(s_len_k) score = self.comb_fcc(torch.cat([s_log, s_len], dim=1)) return score
@Reranker.register
[docs]class TK(Reranker): """Sebastian Hofstätter, Markus Zlabinger, and Allan Hanbury. 2019. TU Wien @ TREC Deep Learning '19 -- Simple Contextualization for Re-ranking. In TREC '19."""
[docs] module_name = "TK"
[docs] dependencies = [ Dependency(key="extractor", module="extractor", name="slowembedtext"), Dependency(key="trainer", module="trainer", name="pytorch"),
]
[docs] config_spec = [ ConfigOption("gradkernels", True, "backprop through mus and sigmas"), ConfigOption("scoretanh", False, "use a tanh on the prediction as in paper (True) or do not use a nonlinearity (False)"), ConfigOption("singlefc", True, "use single fully connected layer as in paper (True) or 2 fully connected layers (False)"), ConfigOption("projdim", 32), ConfigOption("ffdim", 100), ConfigOption("numlayers", 2), ConfigOption("numattheads", 10), ConfigOption("alpha", 0.5), ConfigOption("usemask", False), ConfigOption("usemixer", False), ConfigOption("finetune", False, "fine tune the embedding layer"), # TODO check save when True
]
[docs] def build_model(self): if not hasattr(self, "model"): self.model = TK_class(self.extractor, self.config) return self.model
[docs] def score(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence, neg_sentence = d["posdoc"], d["negdoc"] return [ self.model(pos_sentence, query_sentence, query_idf).view(-1), self.model(neg_sentence, query_sentence, query_idf).view(-1),
]
[docs] def test(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence = d["posdoc"] return self.model(pos_sentence, query_sentence, query_idf).view(-1)