Source code for capreolus.reranker.birch

import contextlib
import numpy as np
import torch
from torch import nn
from transformers import BertForNextSentencePrediction

from capreolus import ConfigOption, Dependency
from capreolus.reranker import Reranker
from capreolus.utils.loginit import get_logger

[docs]logger = get_logger(__name__)
# official weights converted with: # def convert(name): # from transformers import BertTokenizer, BertForNextSentencePrediction, TFBertForNextSentencePrediction # tokenizer = BertTokenizer.from_pretrained("bert-large-uncased") # state = torch.load(f"/GW/NeuralIR/nobackup/birch-emnlp_bert4ir_v2/models/saved.{name}_1", map_location="cpu") # model = BertForNextSentencePrediction.from_pretrained("bert-large-uncased") # model.load_state_dict(state["model"].state_dict()) # output = f"/GW/NeuralIR/nobackup/birch-emnlp_bert4ir_v2/models/export/birch-bert-large-{name}" # os.makedirs(output, exist_ok=True) # model.save_pretrained(output) # tokenizer.save_pretrained(output) # # tf2 support # tf_model = TFBertForNextSentencePrediction.from_pretrained(output, from_pt=True) # tf_model.save_pretrained(output)
[docs]class Birch_Class(nn.Module): def __init__(self, extractor, config): super().__init__() self.config = config if config["hidden"] == 0: self.combine = nn.Linear(config["topk"], 1, bias=False) with torch.no_grad(): self.combine.weight = nn.Parameter(torch.ones_like(self.combine.weight) / config["topk"]) else: assert config["hidden"] > 0 self.combine = nn.Sequential(nn.Linear(config["topk"], config["hidden"]), nn.ReLU(), nn.Linear(config["hidden"], 1)) # original model file (requires apex): # state = torch.load("/GW/NeuralIR/nobackup/birch-emnlp_bert4ir_v2/models/saved.msmarco_mb_1", map_location="cpu") # self.bert = state["model"] # saved.msmarco_mb_1 weights exported from the official apex model: # self.bert = BertForNextSentencePrediction.from_pretrained("bert-large-uncased") # self.bert.load_state_dict(torch.load("/GW/NeuralIR/nobackup/birch-emnlp_bert4ir_v2/models/converted")) # converted_weights.msmarco_mb # kevin's base model: # self.bert = BertForNextSentencePrediction.from_pretrained("bert-base-uncased") # saved_bert = torch.load("/GW/NeuralIR/nobackup/birch/models/saved.tmp_1")["model"] # self.bert.load_state_dict(saved_bert.state_dict()) # also /GW/NeuralIR/nobackup/birch-emnlp_bert4ir_v2/models/export/birch-bert-base-kevin self.bert = BertForNextSentencePrediction.from_pretrained(f"Capreolus/birch-bert-large-{config['pretrained']}") if not config["finetune"]: self.bert.requires_grad = False self.bert_context = torch.no_grad else: self.bert_context = contextlib.nullcontext
[docs] def forward(self, doc, seg, mask): batch = doc.shape[0] with self.bert_context(): bi_scores = [self.score_passages(doc[bi], seg[bi], mask[bi], batch) for bi in range(batch)] scores = torch.stack(bi_scores) assert scores.shape == (batch, self.config["extractor"]["numpassages"], 2) scores = scores[:, :, 1] # take second output topk, _ = torch.topk(scores, dim=1, k=self.config["topk"]) doc_score = self.combine(topk) return doc_score
[docs] def score_passages(self, doc, seg, mask, batch): needed_passages = doc.shape[0] maxlen = doc.shape[-1] # find instances that contain a document (segment B) # for unmasked tokens in seg B, seg+mask=2 # there are always two SEPs in segment B, so the document is not empty if there are >= 3 tokens where seg+mask=2 valid = ((seg + mask) == 2).sum(dim=1) > 2 if not any(valid): valid[0] = True doc, seg, mask = doc[valid], seg[valid], mask[valid] out = [] batches = np.ceil(doc.shape[0] / batch).astype(int) for bi in range(batches): start = bi * batch stop = (bi + 1) * batch sb_doc, sb_seg, sb_mask = doc[start:stop], seg[start:stop], mask[start:stop] # find first non-padding token and shorten batch to this length idx = (sb_seg + sb_mask).argmax(dim=1).max() sb_doc = sb_doc[:, : idx + 1] sb_seg = sb_seg[:, : idx + 1] sb_mask = sb_mask[:, : idx + 1] # for idx in reversed(range(maxlen)): # if any(sb_mask[:, idx]): # sb_doc = sb_doc[:, : idx + 1] # sb_seg = sb_seg[:, : idx + 1] # sb_mask = sb_mask[:, : idx + 1] # break sb_scores = self.bert(input_ids=sb_doc, token_type_ids=sb_seg, attention_mask=sb_mask) sb_scores = sb_scores[0] # for new bert output out.append(sb_scores) real_out = torch.cat(out, dim=0) found_passages = real_out.shape[0] if found_passages < needed_passages: pad_out = torch.min(real_out, dim=0)[0].repeat(needed_passages - found_passages, 1) return torch.cat((real_out, pad_out), dim=0) else: return real_out
[docs]@Reranker.register class Birch(Reranker):
[docs] module_name = "birch"
[docs] config_spec = [ ConfigOption("topk", 3, "top k scores to use"), ConfigOption("hidden", 0, "size of hidden layer or 0 to take the weighted sum of the topk"), ConfigOption("finetune", False, "fine-tune the BERT model"), ConfigOption("pretrained", "msmarco_mb", "pretrained Birch model to load: mb, msmarco_mb, or car_mb"),
]
[docs] dependencies = [ Dependency( key="extractor", module="extractor", name="bertpassage", default_config_overrides={"tokenizer": {"pretrained": "bert-large-uncased"}}, ), Dependency(key="trainer", module="trainer", name="pytorch"),
]
[docs] def build_model(self): self.model = Birch_Class(self.extractor, self.config) return self.model
[docs] def score(self, d): return [ self.model(d["pos_bert_input"], d["pos_seg"], d["pos_mask"]).view(-1), self.model(d["neg_bert_input"], d["neg_seg"], d["neg_mask"]).view(-1),
]
[docs] def test(self, d): return self.model(d["pos_bert_input"], d["pos_seg"], d["pos_mask"]).view(-1)