capreolus.trainer

Submodules

Package Contents

Classes

Trainer

Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer

Attributes

logger

capreolus.trainer.logger[source]
class capreolus.trainer.Trainer(config=None, provide=None, share_dependency_objects=False, build=True)[source]

Bases: capreolus.ModuleBase

Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer

Modules should provide:
  • a train method that trains a reranker on training and dev (validation) data

  • a predict method that uses a reranker to make predictions on data

module_type = trainer[source]
requires_random_seed = True[source]
dependencies[source]
static load_loss_file(fn)[source]

Loads loss history from fn

Parameters

fn (Path) – path to a loss.txt file

Returns

a list of losses ordered by iterations

static load_metric(fn)[source]
static load_best_metric(fn, metric)[source]
static write_to_loss_file(fn, losses)[source]
static write_to_metric_file(fn, metrics)[source]
static exhaust_used_train_data(train_data_generator, n_batch_to_exhaust)[source]
property n_batch_per_iter(self)[source]
static get_paths_for_early_stopping(train_output_path, dev_output_path)[source]
change_lr(self, step, lr)[source]

Apply warm up or decay depending on the current epoch

lr_multiplier(self, step)[source]