Source code for capreolus.reranker.DeepTileBar

import copy

import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.utils.loginit import get_logger

[docs]logger = get_logger(__name__) # pylint: disable=invalid-name
[docs]class DeepTileBar_nn(nn.Module): def __init__(self, p, batch_size, number_filter, lstm_hidden_dim, linear_hidden_dim1, linear_hidden_dim2): super(DeepTileBar_nn, self).__init__() self.p = p self.tilechannels = 3 if not self.p["tfchannel"]: self.tilechannels -= 1 self.batch_size = batch_size self.number_filter = number_filter self.lstm_hidden_dim = lstm_hidden_dim self.conv1 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 1), stride=1) self.conv2 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 2), stride=1) self.conv3 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 3), stride=1) self.conv4 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 4), stride=1) self.conv5 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 5), stride=1) self.conv6 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 6), stride=1) self.conv7 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 7), stride=1) self.conv8 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 8), stride=1) self.conv9 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 9), stride=1) self.conv10 = nn.Conv2d(self.tilechannels, number_filter, (p["maxqlen"], 10), stride=1) self.lstm1 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm2 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm3 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm4 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm5 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm6 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm7 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm8 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm9 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.lstm10 = nn.LSTM(input_size=3, hidden_size=lstm_hidden_dim) self.W1 = nn.Linear(10 * lstm_hidden_dim, linear_hidden_dim1, bias=True) self.W2 = nn.Linear(linear_hidden_dim1, linear_hidden_dim2, bias=True) self.W3 = nn.Linear(linear_hidden_dim2, 1, bias=True) [ self.hidden1, self.hidden2, self.hidden3, self.hidden4, self.hidden5, self.hidden6, self.hidden7, self.hidden8, self.hidden9, self.hidden10, ] = self.init_hidden()
[docs] def init_hidden(self): # first is the hidden h # second is the cell c # if self.use_gpu: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") l = [] for j in range(10): l.append( ( Variable(torch.zeros(1, self.batch_size, self.lstm_hidden_dim).to(device)), Variable(torch.zeros(1, self.batch_size, self.lstm_hidden_dim).to(device)), ) ) return l
[docs] def reset_hidden(self): [ self.hidden1, self.hidden2, self.hidden3, self.hidden4, self.hidden5, self.hidden6, self.hidden7, self.hidden8, self.hidden9, self.hidden10, ] = self.init_hidden()
[docs] def forward(self, tile_matrix1): tile_matrix2 = torch.transpose( torch.transpose(tile_matrix1.view(self.batch_size, self.p["maxqlen"], self.p["passagelen"], -1), 1, 3), 2, 3 ) x1 = torch.transpose(torch.transpose(self.conv1(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x2 = torch.transpose(torch.transpose(self.conv2(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x3 = torch.transpose(torch.transpose(self.conv3(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x4 = torch.transpose(torch.transpose(self.conv4(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x5 = torch.transpose(torch.transpose(self.conv5(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x6 = torch.transpose(torch.transpose(self.conv6(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x7 = torch.transpose(torch.transpose(self.conv7(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x8 = torch.transpose(torch.transpose(self.conv8(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x9 = torch.transpose(torch.transpose(self.conv9(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2) x10 = torch.transpose( torch.transpose(self.conv10(tile_matrix2).view(self.batch_size, self.number_filter, -1), 0, 2), 1, 2 ) lstm_out1, self.hidden1 = self.lstm1(x1, self.hidden1) lstm_out2, self.hidden2 = self.lstm2(x2, self.hidden2) lstm_out3, self.hidden3 = self.lstm3(x3, self.hidden3) lstm_out4, self.hidden4 = self.lstm4(x4, self.hidden4) lstm_out5, self.hidden5 = self.lstm5(x5, self.hidden5) lstm_out6, self.hidden6 = self.lstm6(x6, self.hidden6) lstm_out7, self.hidden7 = self.lstm7(x7, self.hidden7) lstm_out8, self.hidden8 = self.lstm8(x8, self.hidden8) lstm_out9, self.hidden9 = self.lstm9(x9, self.hidden9) lstm_out10, self.hidden10 = self.lstm10(x10, self.hidden10) input_x = torch.cat( [ lstm_out1[-1], lstm_out2[-1], lstm_out3[-1], lstm_out4[-1], lstm_out5[-1], lstm_out6[-1], lstm_out7[-1], lstm_out8[-1], lstm_out9[-1], lstm_out10[-1], ], 1, ) input_x1 = F.relu(self.W1(input_x)) input_x2 = F.relu(self.W2(input_x1)) input_x3 = self.W3(input_x2) return input_x3.view(-1)
[docs]class DeepTileBar_class(nn.Module): def __init__(self, extractor, config): super(DeepTileBar_class, self).__init__() batch_size = config["batch"] number_filter = config["numberfilter"] lstm_hidden_dim = config["lstmhiddendim"] linear_hidden_dim1 = config["linearhiddendim1"] linear_hidden_dim2 = config["linearhiddendim2"] config = dict(config) config.update(dict(extractor.config)) self.DeepTileBar1 = DeepTileBar_nn( config, batch_size, number_filter, lstm_hidden_dim, linear_hidden_dim1, linear_hidden_dim2 )
[docs] def init_hidden(self): return self.DeepTileBar1.init_hidden()
[docs] def reset_hidden(self): self.DeepTileBar1.reset_hidden()
[docs] def forward(self, pos_tile_matrix, neg_tile_matrix): self.reset_hidden() pos_tag_scores = self.DeepTileBar1(pos_tile_matrix) self.reset_hidden() neg_tag_scores = self.DeepTileBar1(neg_tile_matrix) return [pos_tag_scores, neg_tag_scores]
[docs] def test_forward(self, pos_tile_matrix): self.reset_hidden() pos_tag_scores = self.DeepTileBar1(pos_tile_matrix) return pos_tag_scores
[docs]@Reranker.register class DeepTileBar(Reranker): """Zhiwen Tang and Grace Hui Yang. 2019. DeepTileBars: Visualizing Term Distribution for Neural Information Retrieval. In AAAI'19."""
[docs] module_name = "DeepTileBar"
[docs] dependencies = [ Dependency(key="extractor", module="extractor", name="deeptiles"), Dependency(key="trainer", module="trainer", name="pytorch"),
]
[docs] config_spec = [ ConfigOption("passagelen", 30), ConfigOption("numberfilter", 3), ConfigOption("lstmhiddendim", 3), ConfigOption("linearhiddendim1", 32), ConfigOption("linearhiddendim2", 16),
]
[docs] def build_model(self): if not hasattr(self, "model"): config = copy.copy(dict(self.config)) config["batch"] = self.trainer.config["batch"] self.model = DeepTileBar_class(self.extractor, config) return self.model
[docs] def score(self, d): pos_tile_matrix = torch.cat([d["posdoc"][i] for i in range(len(d["qid"]))]) # 32 x neg_tile_matrix = torch.cat([d["negdoc"][i] for i in range(len(d["qid"]))]) return self.model(pos_tile_matrix, neg_tile_matrix)
[docs] def test(self, d): qids = d["qid"] pos_sentence = d["posdoc"] pos_tile_matrix = torch.cat([pos_sentence[i] for i in range(len(qids))]) return self.model.test_forward(pos_tile_matrix)
[docs] def zero_grad(self, *args, **kwargs): self.model.zero_grad(*args, **kwargs)