Source code for capreolus.reranker.common

import tensorflow as tf
import torch
from tensorflow.keras.layers import Layer
from tensorflow.python.keras.losses import CategoricalCrossentropy
from tensorflow_ranking.python.keras.losses import PairwiseHingeLoss

_hinge_loss = torch.nn.MarginRankingLoss(margin=1, reduction="mean")


[docs]class KerasPairModel(tf.keras.Model): def __init__(self, model, *args, **kwargs): super(KerasPairModel, self).__init__(*args, **kwargs) self.model = model
[docs] def call(self, x, **kwargs): score = self.model.score(x, **kwargs) return score
[docs] def predict_step(self, data): return self.model.predict_step(data)
[docs]class KerasTripletModel(tf.keras.Model): def __init__(self, model, *args, **kwargs): super(KerasTripletModel, self).__init__(*args, **kwargs) self.model = model
[docs] def call(self, x, **kwargs): pos_score, neg_score = self.model.score_pair(x, **kwargs) stacked_score = tf.stack([pos_score, neg_score], axis=1) return stacked_score
[docs] def predict_step(self, data): return self.model.predict_step(data)
[docs]class TFPairwiseHingeLoss(PairwiseHingeLoss):
[docs] def call(self, y_true, y_pred): y_true = tf.reshape(y_true, [-1, 2]) y_pred = tf.reshape(y_pred, [-1, 2]) return super(TFPairwiseHingeLoss, self).call(y_true, y_pred)
[docs]class TFCategoricalCrossEntropyLoss(CategoricalCrossentropy):
[docs] def call(self, ytrue, ypred): tf.debugging.assert_equal(tf.shape(ytrue), tf.shape(ypred)) return super(TFCategoricalCrossEntropyLoss, self).call(ytrue, ypred)
[docs]def pair_softmax_loss(pos_neg_scores, *args, **kwargs): scores = torch.stack(pos_neg_scores, dim=1) return torch.mean(1.0 - scores.softmax(dim=1)[:, 0])
[docs]def pair_hinge_loss(pos_neg_scores, *args, **kwargs): label = torch.ones_like(pos_neg_scores[0]) # , dtype=torch.int) return _hinge_loss(pos_neg_scores[0], pos_neg_scores[1], label)
[docs]def similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding): batch_size, qlen, doclen = tf.shape(query_embed)[0], tf.shape(query_embed)[1], tf.shape(doc_embed)[1] q_denom = tf.broadcast_to(tf.reshape(tf.norm(query_embed, axis=2), (batch_size, qlen, 1)), (batch_size, qlen, doclen)) + 1e-9 doc_denom = ( tf.broadcast_to(tf.reshape(tf.norm(doc_embed, axis=2), (batch_size, 1, doclen)), (batch_size, qlen, doclen)) + 1e-9 ) # Why perm? # let query have shape (32, 8, 300) # let doc have shape (32, 800, 300) # Our similarity matrix should have the shape (32, 8, 800) # The perm is required so that the result of matmul will have this shape perm = tf.transpose(doc_embed, perm=[0, 2, 1]) sim = tf.matmul(query_embed, perm) / (q_denom * doc_denom) nul = tf.zeros_like(sim) sim = tf.where(tf.broadcast_to(tf.reshape(query_tok, (batch_size, qlen, 1)), (batch_size, qlen, doclen)) == padding, nul, sim) sim = tf.where(tf.broadcast_to(tf.reshape(doc_tok, (batch_size, 1, doclen)), (batch_size, qlen, doclen)) == padding, nul, sim) # TODO: Add support for handling list inputs (eg: for CEDR). See the pytorch implementation of simmat return sim
[docs]class SimilarityMatrix(torch.nn.Module): def __init__(self, embedding): super().__init__() self.embedding = embedding self.padding = 0
[docs] def remove_padding(self, sim, query_tok, doc_tok, BAT, A, B): nul = torch.zeros_like(sim) sim = torch.where(query_tok.reshape(BAT, A, 1).expand(BAT, A, B) == self.padding, nul, sim) sim = torch.where(doc_tok.reshape(BAT, 1, B).expand(BAT, A, B) == self.padding, nul, sim) return sim
[docs] def exact_match_matrix(self, query_tok, doc_tok, BAT, A, B): sim = (query_tok.reshape(BAT, A, 1).expand(BAT, A, B) == doc_tok.reshape(BAT, 1, B).expand(BAT, A, B)).float() sim = self.remove_padding(sim, query_tok, doc_tok, BAT, A, B) return sim
[docs] def cosine_similarity_matrix(self, query_tok, doc_tok, BAT, A, B): a_emb, b_emb = self.embedding(query_tok), self.embedding(doc_tok) a_denom = a_emb.norm(p=2, dim=2).reshape(BAT, A, 1).expand(BAT, A, B) + 1e-9 # avoid 0div b_denom = b_emb.norm(p=2, dim=2).reshape(BAT, 1, B).expand(BAT, A, B) + 1e-9 # avoid 0div perm = b_emb.permute(0, 2, 1) sim = a_emb.bmm(perm) / (a_denom * b_denom) sim = self.remove_padding(sim, query_tok, doc_tok, BAT, A, B) return sim
# query_tok and doc_tok should contain integers
[docs] def forward(self, query_tok, doc_tok): BAT, A, B = query_tok.shape[0], query_tok.shape[1], doc_tok.shape[1] assert doc_tok.shape[0] == BAT # note: all OOV terms are given negative indices and 0 is padding # approach: # 1. we calculate an exact match matrix on OOV terms only, and set padding to 0 # 2. we calculate a cosine sim matrix on in-vocab terms only, and set padding to 0 # 3. we sum the two matrices exact_match = self.exact_match_matrix(query_tok.clamp(max=0), doc_tok.clamp(max=0), BAT, A, B) cos_matrix = self.cosine_similarity_matrix(query_tok.clamp(min=0), doc_tok.clamp(min=0), BAT, A, B) simmat = exact_match + cos_matrix return simmat
[docs]class StackedSimilarityMatrix(torch.nn.Module): # based on SimmatModule from https://github.com/Georgetown-IR-Lab/cedr/blob/master/modeling_util.py # which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license def __init__(self, padding=0): super().__init__() self.padding = padding # query_embed and doc_embed can be a list (eg for CEDR)
[docs] def forward(self, query_embed, doc_embed, query_tok, doc_tok): simmat = [] assert type(query_embed) == type(doc_embed) if not isinstance(query_embed, list): query_embed, doc_embed = [query_embed], [doc_embed] for a_emb, b_emb in zip(query_embed, doc_embed): BAT, A, B = a_emb.shape[0], a_emb.shape[1], b_emb.shape[1] if a_emb is None and b_emb is None: # exact match matrix sim = query_tok.reshape(BAT, A, 1).expand(BAT, A, B) == doc_tok.reshape(BAT, 1, B).expand(BAT, A, B).float() else: # cosine similarity matrix a_denom = a_emb.norm(p=2, dim=2).reshape(BAT, A, 1).expand(BAT, A, B) + 1e-9 # avoid 0div b_denom = b_emb.norm(p=2, dim=2).reshape(BAT, 1, B).expand(BAT, A, B) + 1e-9 # avoid 0div perm = b_emb.permute(0, 2, 1) sim = a_emb.bmm(perm) / (a_denom * b_denom) # set similarity values to 0 for <pad> tokens in query and doc (indicated by self.padding) nul = torch.zeros_like(sim) sim = torch.where(query_tok.reshape(BAT, A, 1).expand(BAT, A, B) == self.padding, nul, sim) sim = torch.where(doc_tok.reshape(BAT, 1, B).expand(BAT, A, B) == self.padding, nul, sim) simmat.append(sim) return torch.stack(simmat, dim=1)
[docs]class RbfKernel(torch.nn.Module): # based on KNRMRbfKernel from https://github.com/Georgetown-IR-Lab/cedr/blob/master/modeling_util.py # which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license def __init__(self, initial_mu, initial_sigma, requires_grad=True): super().__init__() self.mu = torch.nn.Parameter(torch.tensor(initial_mu), requires_grad=requires_grad) self.sigma = torch.nn.Parameter(torch.tensor(initial_sigma), requires_grad=requires_grad)
[docs] def forward(self, data): adj = data - self.mu return torch.exp(-0.5 * adj * adj / self.sigma / self.sigma)
[docs]class RbfKernelBank(torch.nn.Module): # based on KNRMRbfKernelBank from https://github.com/Georgetown-IR-Lab/cedr/blob/master/modeling_util.py # which is copyright (c) 2019 Georgetown Information Retrieval Lab, MIT license def __init__(self, mus=None, sigmas=None, dim=1, requires_grad=True): super().__init__() self.dim = dim kernels = [RbfKernel(m, s, requires_grad=requires_grad) for m, s in zip(mus, sigmas)] self.kernels = torch.nn.ModuleList(kernels)
[docs] def count(self): return len(self.kernels)
[docs] def forward(self, data): return torch.stack([k(data) for k in self.kernels], dim=self.dim)
[docs]class RbfKernelBankTF(Layer): def __init__(self, mus, sigmas, dim=1, requires_grad=True, **kwargs): super(RbfKernelBankTF, self).__init__(**kwargs) self.dim = dim self.kernel_list = [RbfKernelTF(m, s, requires_grad=requires_grad) for m, s in zip(mus, sigmas)]
[docs] def count(self): return len(self.kernel_list)
[docs] def call(self, data, **kwargs): return tf.stack([self.kernel_list[i](data) for i in range(len(self.kernel_list))], axis=self.dim)
[docs]class RbfKernelTF(Layer): def __init__(self, initial_mu, initial_sigma, requires_grad=True, **kwargs): super(RbfKernelTF, self).__init__(**kwargs) self.mu = tf.Variable(initial_mu, trainable=requires_grad, name="mus", dtype=tf.float32) self.sigma = tf.Variable(initial_sigma, trainable=requires_grad, name="sigmas", dtype=tf.float32)
[docs] def call(self, data, *kwargs): adj = data - self.mu return tf.exp(-0.5 * adj * adj / self.sigma / self.sigma)
[docs]def create_emb_layer(weights, non_trainable=True): layer = torch.nn.Embedding(*weights.shape) layer.load_state_dict({"weight": torch.tensor(weights)}) if non_trainable: layer.weight.requires_grad = False else: layer.weight.requires_grad = True return layer