Source code for capreolus.reranker.HINT

import math

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer

[docs]device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
[docs]class GRUCell2d(nn.Module): def __init__(self, input_size, hidden_size, bias=True): super(GRUCell2d, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias self.Wrz = nn.Linear(3 * hidden_size + input_size, 7 * hidden_size, bias=bias) self.W = nn.Linear(input_size, hidden_size, bias=bias) self.U = nn.Linear(3 * hidden_size, hidden_size, bias=bias)
[docs] def reset_parameters(self): std = 1.0 / math.sqrt(self.hidden_size) for w in self.parameters(): w.data.uniform_(-std, std)
[docs] def softmax_by_row(self, zi, zl, zt, zd): zi, zl, zt, zd = zi.view(-1, 1, 2), zl.view(-1, 1, 2), zt.view(-1, 1, 2), zd.view(-1, 1, 2) # each: (B, 1, hidden=2) ppp = torch.cat((zi, zl, zt, zd), dim=1) # (B, 4, hidden) pt = F.softmax(ppp, dim=1) # (B, 4, hidden) zi, zl, zt, zd = pt.unbind(dim=1) zi, zl, zt, zd = zi.view(-1, 2), zl.view(-1, 2), zt.view(-1, 2), zd.view(-1, 2) return zi, zl, zt, zd
[docs] def forward(self, x, hidden_i1_j1, hidden_i1_j, hidden_i_j1): q = torch.cat([hidden_i1_j, hidden_i_j1, hidden_i1_j1, x], dim=-1) # (B, 3*nhidden+input) r_z = self.Wrz(q) rl, rt, rd, zi, zl, zt, zd = r_z.chunk(7, 1) # each: (B, hidden) rl, rt, rd = torch.sigmoid(rl), torch.sigmoid(rt), torch.sigmoid(rd) zi, zl, zt, zd = self.softmax_by_row(zi, zl, zt, zd) r = torch.cat([rl, rt, rd], dim=-1) t11 = torch.cat([hidden_i1_j, hidden_i_j1, hidden_i1_j1], dim=-1) h1 = torch.tanh(self.W(x) + self.U(r * t11)) h = (zl * hidden_i_j1) + (zt * hidden_i1_j) + (zd * hidden_i1_j1) + (zi * h1) return h
[docs]class GRUModel2d(nn.Module): def __init__(self, input_dim, hidden_dim, bias=True): super(GRUModel2d, self).__init__() self.hidden_dim = hidden_dim self.gru_cell = GRUCell2d(input_dim, hidden_dim).to(device)
[docs] def forward(self, x): B, T1, T2, H = x.size() last_outs = [(torch.zeros(x.size(0), self.hidden_dim).to(device)) for _ in range(T2 + 1)] for seq in range(T1): outs_row = [(torch.zeros(x.size(0), self.hidden_dim).to(device))] for seq1 in range(1, T2 + 1): hn = last_outs[seq1 - 1] hn_top = last_outs[seq1] hn_left = outs_row[seq1 - 1] hn1 = self.gru_cell(x[:, seq, seq1 - 1, :], hn, hn_top, hn_left) outs_row.append(hn1) last_outs = outs_row out = last_outs[-1] return out
[docs]class HiNT(nn.Module): def __init__(self, weights_matrix, p): super(HiNT, self).__init__() self.p = p self.passagelen = int(p["maxdoclen"] / 100) # 100: windows size Ws_dim = 1 # fix to 1, since we assume 1 when creating GRUModel (2*Ws_dim + 1) embedding_dim = weights_matrix.shape[1] self.batch_size, self.lstm_hidden_dim = p["batch"], self.p["LSTMdim"] self.embedding = create_emb_layer(weights_matrix, non_trainable=True) self.Ws = nn.Linear(embedding_dim, Ws_dim) self.GRU2d1 = GRUModel2d(3, self.p["spatialGRU"]).to(device) self.GRU2d3 = GRUModel2d(3, self.p["spatialGRU"]).to(device) self.lstm = nn.LSTM(input_size=(4 * self.p["spatialGRU"]), hidden_size=self.lstm_hidden_dim, bidirectional=True) self.Wv = nn.Linear((4 * self.p["spatialGRU"]), self.lstm_hidden_dim, bias=True) self.fc = nn.Linear(self.lstm_hidden_dim * self.p["kmax"], 1) self.hidden = self.init_hidden()
[docs] def init_hidden(self): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") return ( Variable(torch.zeros(2, self.batch_size, self.lstm_hidden_dim).to(device)), Variable(torch.zeros(2, self.batch_size, self.lstm_hidden_dim).to(device)),
)
[docs] def matrix_inv(self, A): A1 = torch.randn(self.passagelen * self.batch_size, self.p["maxqlen"], 100, 3).type(torch.FloatTensor).to(device) for i in range(self.p["maxqlen"]): for j in range(100): A1[:, i, j, :] = A[:, self.p["maxqlen"] - i - 1, 99 - j, :] return A1
[docs] def forward(self, sentence, query_sentence, M_XOR, M_cos, masks): """ M_XOR or M_cos: (B, Q, D) masks: (B, Q, D) have 0 on non-pad positions and 1 on pad positions """ sentence, query_sentence = sentence.to(device), query_sentence.to(device) x, query_x = self.embedding(sentence), self.embedding(query_sentence) X_i = self.Ws(query_x).view(self.batch_size, -1) Y_j = self.Ws(x).view(self.batch_size, -1) total_passage_level = torch.randn(self.passagelen, self.batch_size, 8).type(torch.FloatTensor).to(device) M_cos_passage = torch.randn(self.passagelen, self.batch_size, self.p["maxqlen"], 100).type(torch.FloatTensor).to(device) M_XOR_passage = torch.randn(self.passagelen, self.batch_size, self.p["maxqlen"], 100).type(torch.FloatTensor).to(device) Y_j_passage = torch.randn(self.passagelen, self.batch_size, 100).type(torch.FloatTensor).to(device) X_i_passage = torch.randn(self.passagelen, self.batch_size, self.p["maxqlen"]).type(torch.FloatTensor).to(device) mask_passage = torch.randn(self.passagelen, self.batch_size, self.p["maxqlen"], 100).type(torch.FloatTensor).to(device) for number_window in range(self.passagelen): mask_passage[number_window] = masks[:, :, (number_window * 100) : ((number_window + 1) * 100)] # (P, BAT, Q, 100) M_cos_passage[number_window] = M_cos[:, :, (number_window * 100) : ((number_window + 1) * 100)] # (P, BAT, Q, 100) M_XOR_passage[number_window] = M_XOR[:, :, (number_window * 100) : ((number_window + 1) * 100)] # (P, BAT, Q, 100) Y_j_passage[number_window] = Y_j[:, (number_window * 100) : ((number_window + 1) * 100)] # (P, BAT, 100) X_i_passage[number_window] = X_i # (P, BAT, 100) S_cos = ( torch.randn(self.passagelen, self.batch_size, self.p["maxqlen"], 100, 3).type(torch.FloatTensor).to(device) ) # (P, BAT, Q, 100, 3) S_xor = torch.randn(self.passagelen, self.batch_size, self.p["maxqlen"], 100, 3).type(torch.FloatTensor).to(device) S_cos[:, :, :, :, 0] = X_i_passage.reshape(self.passagelen, self.batch_size, self.p["maxqlen"], 1).expand( self.passagelen, self.batch_size, self.p["maxqlen"], 100 ) S_cos[:, :, :, :, 1] = Y_j_passage.reshape(self.passagelen, self.batch_size, 1, 100).expand( self.passagelen, self.batch_size, self.p["maxqlen"], 100 ) S_cos[:, :, :, :, 2] = M_cos_passage S_xor[:, :, :, :, 0] = X_i_passage.reshape(self.passagelen, self.batch_size, self.p["maxqlen"], 1).expand( self.passagelen, self.batch_size, self.p["maxqlen"], 100 ) S_xor[:, :, :, :, 1] = Y_j_passage.reshape(self.passagelen, self.batch_size, 1, 100).expand( self.passagelen, self.batch_size, self.p["maxqlen"], 100 ) S_xor[:, :, :, :, 2] = M_XOR_passage # add mask on X and Y S_cos = S_cos * (1 - mask_passage).unsqueeze(-1) S_xor = S_xor * (1 - mask_passage).unsqueeze(-1) # S_xor, S_cos: (P, B, Q, 100, 3) -> (P*B, Q, 100, 3) S_xor1 = S_xor.view(self.passagelen * self.batch_size, self.p["maxqlen"], 100, 3) S_cos1 = S_cos.view(self.passagelen * self.batch_size, self.p["maxqlen"], 100, 3) S_xor1_cos1 = torch.cat([S_xor1, S_cos1], dim=0) # (2*P*B, Q, 100, 3) H_xor_cos = self.GRU2d1(S_xor1_cos1) # (2*P*B, 2) H_xor = H_xor_cos[: (self.passagelen * self.batch_size)] # (P*B, 2) H_cos = H_xor_cos[(self.passagelen * self.batch_size) :] # (P*B, 2) e = torch.cat([H_xor, H_cos], dim=-1) # (P*B, 4) S_xor_inv = self.matrix_inv(S_xor1) # (P*B, Q, 100, 3) S_cos_inv = self.matrix_inv(S_cos1) S_xor_cos_inv = torch.cat([S_xor_inv, S_cos_inv], dim=0) H_xor_cos_inv = self.GRU2d3(S_xor_cos_inv) H_xor_inv = H_xor_cos_inv[: (self.passagelen * self.batch_size)] H_cos_inv = H_xor_cos_inv[(self.passagelen * self.batch_size) :] e_inv = torch.cat([H_xor_inv, H_cos_inv], dim=-1) # (P*B, 4) passage_level_e = torch.cat([e, e_inv], dim=-1) # (P*B, 8) for number_window in range(self.passagelen): # (P, B, 8) total_passage_level[number_window] = passage_level_e[ (self.batch_size * number_window) : (self.batch_size * (number_window + 1)) ] lstm_out, self.hidden = self.lstm(total_passage_level, self.hidden) # lstm_out: (P, B, 2 * self.hidden), where P is the timestep dimension lstm_out_forward = lstm_out[:, :, :6] # (P, B, self.hidden) lstm_out_backward = lstm_out[:, :, 6:] # (P, B, self.hidden) # added lstm_out = lstm_out_forward + lstm_out_backward vt = torch.tanh(self.Wv(total_passage_level)) # (P, B, 8) -> (P, B, 6) # evidence = torch.cat((vt, lstm_out_forward, lstm_out_backward), 0) # (3P, B, 6) evidence = torch.cat((vt, lstm_out), 0) # (2P, B, 6) evidence1 = torch.transpose(evidence, 0, 2) # (6, B, 3P) / (6, B, 2P) anss, _ = torch.topk(evidence1, self.p["kmax"], largest=True, sorted=True, dim=2) # (6, B, 10) ans = torch.transpose(anss, 0, 1) # (B, 6, 10) ans1 = ans.contiguous().view(self.batch_size, -1) # (B, 60) score = self.fc(ans1) return score.view(-1)
[docs]class HiNT_main(nn.Module): def __init__(self, extractor, config): super(HiNT_main, self).__init__() self.HiNT1 = HiNT(extractor.embeddings, config).to(device) self.batch_size = config["batch"]
[docs] def init_hidden(self): return self.HiNT1.init_hidden()
[docs] def forward(self, query_sentence, query_idf, pos_sentence, neg_sentence): self.HiNT1.hidden = self.HiNT1.init_hidden() query_sentence = query_sentence query_idf = query_idf pos_sentence = pos_sentence neg_sentence = neg_sentence x = self.HiNT1.embedding(pos_sentence) query_x = self.HiNT1.embedding(query_sentence) BAT = query_sentence.shape[0] A = query_sentence.shape[1] B = pos_sentence.shape[1] x1 = x.norm(dim=2)[:, :, None] + 1e-7 query_x1 = query_x.norm(dim=2)[:, :, None] + 1e-7 x_norm = x / x1 # (BAT, B, H) query_x_norm = query_x / query_x1 # (BAT, A, H) M_cos_pos = torch.matmul(query_x_norm, torch.transpose(x_norm, 1, 2)) # (BAT, A, B) nul = torch.zeros_like(M_cos_pos) one = torch.ones_like(M_cos_pos) XOR_matrix_pos = torch.where( query_sentence.reshape(BAT, A, 1).expand(BAT, A, B) == pos_sentence.reshape(BAT, 1, B).expand(BAT, A, B), one, nul ) # add padding for both matrix query_masks, sentence_masks = (query_sentence == 0), (pos_sentence == 0) # (B, Q), (B, D) pos_masks = query_masks[:, :, None] * sentence_masks[:, None, :] # (B, Q, D) # add mask # XOR_matrix_pos = torch.where(pos_masks, nul, XOR_matrix_pos) # M_cos_pos = torch.where(pos_masks, nul, M_cos_pos) # mask would be applied in HiNT1 pos_scores = self.HiNT1(pos_sentence, query_sentence, XOR_matrix_pos, M_cos_pos, pos_masks) self.HiNT1.hidden = self.HiNT1.init_hidden() x = self.HiNT1.embedding(neg_sentence) query_x = self.HiNT1.embedding(query_sentence) BAT = query_sentence.shape[0] A = query_sentence.shape[1] B = neg_sentence.shape[1] x1 = x.norm(dim=2)[:, :, None] + 1e-7 query_x1 = query_x.norm(dim=2)[:, :, None] + 1e-7 x_norm = x / x1 query_x_norm = query_x / query_x1 M_cos_neg = torch.matmul(query_x_norm, torch.transpose(x_norm, 1, 2)) nul = torch.zeros_like(M_cos_pos) one = torch.ones_like(M_cos_pos) XOR_matrix_neg = torch.where( query_sentence.reshape(BAT, A, 1).expand(BAT, A, B) == neg_sentence.reshape(BAT, 1, B).expand(BAT, A, B), one, nul ) # add mask for both matrix query_masks, sentence_masks = (query_sentence == 0), (neg_sentence == 0) # (B, Q), (B, D) neg_masks = query_masks[:, :, None] * sentence_masks[:, None, :] # (B, Q, D) # add mask # XOR_matrix_neg = torch.where(neg_masks, nul, XOR_matrix_neg) # M_cos_neg = torch.where(neg_masks, nul, M_cos_neg) # mask would be applied in HiNT1 neg_scores = self.HiNT1(neg_sentence, query_sentence, XOR_matrix_neg, M_cos_neg, neg_masks) return [pos_scores, neg_scores]
[docs] def test_forward(self, query_sentence, query_idf, pos_sentence): self.HiNT1.hidden = self.HiNT1.init_hidden() query_sentence = query_sentence query_idf = query_idf pos_sentence = pos_sentence x = self.HiNT1.embedding(pos_sentence) query_x = self.HiNT1.embedding(query_sentence) BAT = query_sentence.shape[0] A = query_sentence.shape[1] B = pos_sentence.shape[1] x1 = x.norm(dim=2)[:, :, None] + 1e-7 query_x1 = query_x.norm(dim=2)[:, :, None] + 1e-7 x_norm = x / x1 query_x_norm = query_x / query_x1 M_cos_pos = torch.matmul(query_x_norm, torch.transpose(x_norm, 1, 2)) nul = torch.zeros_like(M_cos_pos) one = torch.ones_like(M_cos_pos) XOR_matrix_pos = torch.where( query_sentence.reshape(BAT, A, 1).expand(BAT, A, B) == pos_sentence.reshape(BAT, 1, B).expand(BAT, A, B), one, nul ) query_masks, sentence_masks = (query_sentence == 0), (pos_sentence == 0) # (B, Q), (B, D) pos_masks = query_masks[:, :, None] * sentence_masks[:, None, :] # (B, Q, D) # add mask # XOR_matrix_pos = torch.where(pos_masks, nul, XOR_matrix_pos) # M_cos_pos = torch.where(pos_masks, nul, M_cos_pos) # mask would be applied in HiNT1 pos_scores = self.HiNT1(pos_sentence, query_sentence, XOR_matrix_pos, M_cos_pos, pos_masks) return pos_scores
[docs]@Reranker.register class HINT(Reranker): """Yixing Fan, Jiafeng Guo, Yanyan Lan, Jun Xu, Chengxiang Zhai, and Xueqi Cheng. 2018. Modeling Diverse Relevance Patterns in Ad-hoc Retrieval. In SIGIR'18."""
[docs] module_name = "HINT"
[docs] dependencies = [ Dependency(key="extractor", module="extractor", name="slowembedtext"), Dependency(key="trainer", module="trainer", name="pytorch"),
]
[docs] config_spec = [ConfigOption("spatialGRU", 2), ConfigOption("LSTMdim", 6), ConfigOption("kmax", 10)]
[docs] def test(self, query_sentence, query_idf, pos_sentence, *args, **kwargs): return self.model.test_forward(query_sentence, query_idf, pos_sentence)
[docs] def score(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence, neg_sentence = d["posdoc"], d["negdoc"] return self.model(query_sentence, query_idf, pos_sentence, neg_sentence)
def test(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence = d["posdoc"] return self.model.test_forward(query_sentence, query_idf, pos_sentence)
[docs] def zero_grad(self, *args, **kwargs): self.model.zero_grad(*args, **kwargs)
[docs] def build_model(self): if not hasattr(self, "model"): config = dict(self.config) config.update(self.extractor.config) config["batch"] = self.trainer.config["batch"] self.model = HiNT_main(self.extractor, config) return self.model