Source code for capreolus.reranker.DRMMTKS

import torch
import torch.nn as nn
import torch.nn.functional as F

from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer
from capreolus.utils.loginit import get_logger

[docs]logger = get_logger(__name__) # pylint: disable=invalid-name
[docs]class DRMMTKS_class(nn.Module): def __init__(self, extractor, config): super(DRMMTKS_class, self).__init__() self.topk = config["topk"] self.gate_type = config["gateType"] self.embedding = create_emb_layer(extractor.embeddings, non_trainable=config["freezeemb"]) self.ffw = nn.Sequential(nn.Linear(self.topk, 1), nn.Tanh()) gate_inp_dim = 1 if self.gate_type == "IDF" else self.embedding.weight.size(-1) self.gates = nn.Linear(gate_inp_dim, 1, bias=False) # for idf scalar self.output_layer = nn.Linear(1, 1) # initialize FC and gate weight in the same way as MatchZoo nn.init.uniform_(self.ffw[0].weight, -0.1, 0.1) nn.init.uniform_(self.gates.weight, -0.01, 0.01) def _sim(self, queries, documents): """ Args: queries: (B, Tq, H) documents: (B, Td, H) Returns: (B, Tq, 6) """ q_norm = torch.sqrt((queries * queries).sum(dim=-1) + 1e-7)[:, :, None] + 1e-7 d_norm = torch.sqrt((documents * documents).sum(dim=-1) + 1e-7)[:, None, :] + 1e-7 sim_matrix = torch.bmm(queries, documents.transpose(2, 1)) # (B, Q, D) sim_matrix = sim_matrix / q_norm sim_matrix = sim_matrix / d_norm return sim_matrix def _term_gate(self, queries, query_idf, q_masks): """ : param queries: (B, Tq, H) : param query_idf: (B, Tq) : param q_masks: (B, Tq) """ atten_mask = (1 - q_masks) * -1e7 if self.gate_type == "IDF": gate_prob = self.gates(query_idf[:, :, None]).squeeze() + atten_mask # (B, 1) elif self.gate_type == "TV": gate_prob = self.gates(queries).squeeze() + atten_mask # (B, Tq) else: raise ValueError("Invalid value for histType: histType should be 'CH', 'NH', or 'LCH'") gate_prob = F.softmax(gate_prob, dim=1) # (B, Tq) return gate_prob
[docs] def forward(self, doc, query, query_idf): query_sent_mask = (query != 0).float() # (B, Q) sent_pad_pos = (doc == 0).float() # (B, D) query_idf = query_idf.float() query, doc = self.embedding(query), self.embedding(doc) cos_mat = self._sim(query, doc) # (B, Q, D) cos_mat = cos_mat - 1e5 * (sent_pad_pos[:, None, :]) topk, _ = torch.topk(cos_mat, k=self.topk, dim=-1) # (B, Q, k) ffw_vec = self.ffw(topk).squeeze() # (B, Q) w = self._term_gate(query, query_idf, query_sent_mask) # (B, Q) doc = (w * ffw_vec).sum(dim=-1, keepdim=True) # (B, 1) score = self.output_layer(doc) return score
[docs]dtype = torch.FloatTensor
[docs]@Reranker.register class DRMMTKS(Reranker): """Jiafeng Guo, Yixing Fan, Qingyao Ai, and W. Bruce Croft. 2016. A Deep Relevance Matching Model for Ad-hoc Retrieval. In CIKM'16.""" # reference: https://github.com/NTMC-Community/MatchZoo-py/blob/master/matchzoo/models/drmmtks.py
[docs] module_name = "DRMMTKS"
[docs] config_spec = [ ConfigOption("topk", 10, "number of bins in matching histogram"), ConfigOption("gateType", "IDF", "term gate type: TV or IDF"), ConfigOption("freezeemb", True, "term gate type: TV or IDF"),
]
[docs] def build_model(self): if not hasattr(self, "model"): self.model = DRMMTKS_class(self.extractor, self.config) return self.model
[docs] def score(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence, neg_sentence = d["posdoc"], d["negdoc"] return [ self.model(pos_sentence, query_sentence, query_idf).view(-1), self.model(neg_sentence, query_sentence, query_idf).view(-1),
]
[docs] def test(self, d): query_idf = d["query_idf"] query_sentence = d["query"] pos_sentence = d["posdoc"] return self.model(pos_sentence, query_sentence, query_idf).view(-1)