Source code for capreolus.reranker.base

import os
import pickle

from profane import ConfigOption, Dependency, ModuleBase

[docs]class Reranker(ModuleBase):
[docs] module_type = "reranker"
[docs] dependencies = [ Dependency(key="extractor", module="extractor", name="embedtext"), Dependency(key="trainer", module="trainer", name="pytorch"),
[docs] def add_summary(self, summary_writer, niter): """ Write to the summay_writer custom visualizations/data specific to this reranker """ for name, weight in self.model.named_parameters(): summary_writer.add_histogram(name,, niter)
# summary_writer.add_histogram(f'{name}.grad', weight.grad, niter)
[docs] def save_weights(self, weights_fn, optimizer): if not os.path.exists(os.path.dirname(weights_fn)): os.makedirs(os.path.dirname(weights_fn)) d = {k: v for k, v in self.model.state_dict().items() if ("embedding.weight" not in k and "_nosave_" not in k)} with open(weights_fn, "wb") as outf: pickle.dump(d, outf, protocol=-1) optimizer_fn = weights_fn.as_posix() + ".optimizer" with open(optimizer_fn, "wb") as outf: pickle.dump(optimizer.state_dict(), outf, protocol=-1)
[docs] def load_weights(self, weights_fn, optimizer): with open(weights_fn, "rb") as f: d = pickle.load(f) cur_keys = set(k for k in self.model.state_dict().keys() if not ("embedding.weight" in k or "_nosave_" in k)) missing = cur_keys - set(d.keys()) if len(missing) > 0: raise RuntimeError("loading state_dict with keys that do not match current model: %s" % missing) self.model.load_state_dict(d, strict=False) optimizer_fn = weights_fn.as_posix() + ".optimizer" with open(optimizer_fn, "rb") as f: optimizer.load_state_dict(pickle.load(f))