Source code for capreolus.reranker.CDSSM

import torch
from torch import nn

from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.reranker.common import create_emb_layer


[docs]class CDSSM_class(nn.Module): def __init__(self, extractor, config): super(CDSSM_class, self).__init__() self.p = config self.embedding = create_emb_layer(extractor.embeddings, non_trainable=True) self.conv = nn.Sequential( nn.Conv1d(1, config["nfilter"], config["nkernel"]), nn.ReLU(), nn.Dropout(config["dropoutrate"]) ) self.ffw = nn.Linear(config["nkernel"], config["nhiddens"]) self.output_layer = nn.Sequential(nn.Sigmoid())
[docs] def forward(self, sentence, query): # query = query.to(device) # (B, Q) # sentence = sentence.to(device) # (B, D) device = sentence.device query = self.embedding(query) # (B, Q) sentence = self.embedding(sentence) # (B, D) # pad sentence so its timestep is divisible by W B, Q, H = query.size() W = self.p["windowsize"] # pad query npad = W - query.size(1) % W pads = torch.zeros(B, npad, H).to(device) # assert query.size(1) % W == 0 # assert sentence.size(1) % W == 0 query = torch.cat([query, pads], dim=1) # (B, K_q*W, H) # pad document npad = W - sentence.size(1) % W pads = torch.zeros(B, npad, H).to(device) sentence = torch.cat([sentence, pads], dim=1) # (B, K_d*W, H) query = query.reshape(B, -1, W * H) # (B, K_q, Q*H) sentence = sentence.reshape(B, -1, W * H) # (B, K_d, Q*H) # major part of CDSSM query = torch.cat([self.conv(query[:, i : i + 1, :]) for i in range(query.size(1))], dim=1) # (B, K_q, Q*H/Kernel) sentence = torch.cat( # (B, K_d, Q*H/Kernel) [self.conv(sentence[:, i : i + 1, :]) for i in range(sentence.size(1))], dim=1 ) # 'max pooling' query, _ = torch.max(query, dim=1) # (B, Q*H/Kernel) sentence, _ = torch.max(sentence, dim=1) # (B, Q*H/Kernel) query_norm = query.norm(dim=-1)[:, None] + 1e-7 sentence_norm = sentence.norm(dim=-1)[:, None] + 1e-7 query = query / query_norm sentence = sentence / sentence_norm cos_x = (query * sentence).sum(dim=-1, keepdim=True) score = self.output_layer(cos_x) return score
[docs]dtype = torch.FloatTensor
[docs]@Reranker.register class CDSSM(Reranker): """Yelong Shen, Xiaodong He, Jianfeng Gao, Li Deng, and Grégoire Mesnil. 2014. A Latent Semantic Model with Convolutional-Pooling Structure for Information Retrieval. In CIKM'14."""
[docs] module_name = "CDSSM"
[docs] dependencies = [ Dependency(key="extractor", module="extractor", name="slowembedtext"), Dependency(key="trainer", module="trainer", name="pytorch"),
]
[docs] config_spec = [ ConfigOption("nkernel", 3, "kernel dimension in conv"), ConfigOption("nfilter", 1, "number of filters in conv"), ConfigOption("nhiddens", 30, "hidden layer dimension for ffw layer"), ConfigOption("windowsize", 3, "number of query/document words to concatenate before conv"), ConfigOption("dropoutrate", 0, "dropout rate for conv"),
]
[docs] def build_model(self): if not hasattr(self, "model"): self.model = CDSSM_class(self.extractor, self.config) return self.model
[docs] def score(self, d): query_sentence = d["query"] pos_sentence, neg_sentence = d["posdoc"], d["negdoc"] return [self.model(pos_sentence, query_sentence).view(-1), self.model(neg_sentence, query_sentence).view(-1)]
[docs] def test(self, data): query_sentence, pos_sentence = data["query"], data["posdoc"] return self.model(pos_sentence, query_sentence).view(-1)
[docs] def zero_grad(self, *args, **kwargs): self.model.zero_grad(*args, **kwargs)