Source code for capreolus.trainer

import os
import json

import numpy as np
from capreolus import Dependency, ModuleBase, get_logger

[docs]logger = get_logger(__name__) # pylint: disable=invalid-name
[docs]class Trainer(ModuleBase): """Base class for Trainer modules. The purpose of a Trainer is to train a :class:`~capreolus.reranker.Reranker` module and use it to make predictions. Capreolus provides two trainers: :class:`~capreolus.trainer.pytorch.PytorchTrainer` and :class:`~capreolus.trainer.tensorflow.TensorFlowTrainer` Modules should provide: - a ``train`` method that trains a reranker on training and dev (validation) data - a ``predict`` method that uses a reranker to make predictions on data """
[docs] module_type = "trainer"
[docs] requires_random_seed = True
[docs] dependencies = [Dependency(key="benchmark", module="benchmark", name=None)]
[docs] def load_loss_file(fn): """Loads loss history from fn Args: fn (Path): path to a loss.txt file Returns: a list of losses ordered by iterations """ loss = [] with"rt") as f: for lineidx, line in enumerate(f): line = line.strip() if not line: continue iteridx, iterloss = line.rstrip().split() if int(iteridx) != lineidx: raise IOError(f"malformed loss file {fn} ... did two processes write to it?") loss.append(float(iterloss)) return loss
[docs] def load_metric(fn): with"rt") as f: return json.load(f)
[docs] def load_best_metric(fn, metric): return Trainer.load_metric(fn).get(metric, -np.inf)
[docs] def write_to_loss_file(fn, losses): fn.write_text("\n".join(f"{idx} {loss}" for idx, loss in enumerate(losses)))
[docs] def write_to_metric_file(fn, metrics): assert isinstance(metrics, dict) json.dump(metrics, open(fn, "wt"))
[docs] def exhaust_used_train_data(train_data_generator, n_batch_to_exhaust): for i, batch in enumerate(train_data_generator): if (i + 1) == n_batch_to_exhaust: break
[docs] def n_batch_per_iter(self): return (self.config["itersize"] // self.config["batch"]) or 1
[docs] def get_paths_for_early_stopping(train_output_path, dev_output_path): dev_best_weight_fn = train_output_path / "" weights_output_path = train_output_path / "weights" info_output_path = train_output_path / "info" os.makedirs(dev_output_path, exist_ok=True) os.makedirs(weights_output_path, exist_ok=True) os.makedirs(info_output_path, exist_ok=True) loss_fn = info_output_path / "loss.txt" metrics_fn = dev_output_path / "metrics.json" return dev_best_weight_fn, weights_output_path, info_output_path, loss_fn, metrics_fn
[docs] def change_lr(self, step, lr): """ Apply warm up or decay depending on the current epoch """ return lr * self.lr_multiplier(step)
[docs] def lr_multiplier(self, step): warmup_steps = self.config["warmupiters"] * self.n_batch_per_iter if warmup_steps and step <= warmup_steps: return min((step + 1) / warmup_steps, 1) elif self.config["decaytype"] == "exponential": decay_steps = self.config["decayiters"] * self.n_batch_per_iter return self.config["decay"] ** ((step - warmup_steps) / decay_steps) elif self.config["decaytype"] == "linear": epoch = (step - warmup_steps) / self.n_batch_per_iter return 1 / (1 + self.config["decay"] * epoch) # todo: support endlr return 1
from profane import import_all_modules from .pytorch import PytorchTrainer from .tensorflow import TensorflowTrainer import_all_modules(__file__, __package__)