Module Contents



Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer

class capreolus.trainer.pytorch.PytorchTrainer(config=None, provide=None, share_dependency_objects=False, build=True)[source]

Bases: capreolus.trainer.Trainer

Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer

Modules should provide:
  • a train method that trains a reranker on training and dev (validation) data

  • a predict method that uses a reranker to make predictions on data

module_name = pytorch[source]
config_keys_not_in_path = ['boardname'][source]
single_train_iteration(self, reranker, train_dataloader)[source]

Train model for one iteration using instances from train_dataloader.

  • model (Reranker) – a PyTorch Reranker

  • train_dataloader (DataLoader) – a PyTorch DataLoader that iterates over training instances


average loss over the iteration

Return type


fastforward_training(self, reranker, weights_path, loss_fn, best_metric_fn)[source]

Skip to the last training iteration whose weights were saved.

If saved model and optimizer weights are available, this method will load those weights into model and optimizer, and then return the next iteration to be run. For example, if weights are available for iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and this method will return 11.

If an error or inconsistency is encountered when checking for weights, this method returns 0.

This method checks several files to determine if weights “are available”. First, loss_fn is read to determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. If this is successful, the method returns 1 + last recorded iteration. If not, it returns 0. (We consider loss_fn because it is written at the end of every training iteration.)

  • model (Reranker) – a PyTorch Reranker whose state should be loaded

  • weights_path (Path) – directory containing model and optimizer weights

  • loss_fn (Path) – file containing loss history


the next training iteration after fastforwarding. If successful, this is > 0.

If no weights are available or they cannot be loaded, 0 is returned.

Return type


get_validation_schedule_msg(self, initial_iter=0)[source]

Describe validation schedule considering niters and validatefreq


initial_iter (int) – starting point of iteration. defined by train method.


Assuming self.config[“niters”] = 20 and self.config[“validatefreq”] = 3, this method will return: Validation is scheduled on iterations: [3, 6, 9, 12, 15, 18] given initial_iter in [0, 1, 2] Validation is scheduled on iterations: [6, 9, 12, 15, 18] given initial_iter == 3

train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]

Train a model following the trainer’s config (specifying batch size, number of iterations, etc).

  • train_dataset (IterableDataset) – training dataset

  • train_output_path (Path) – directory under which train_dataset runs and training loss will be saved

  • dev_data (IterableDataset) – dev dataset

  • dev_output_path (Path) – directory where dev_data runs and metrics will be saved

load_best_model(self, reranker, train_output_path)[source]
predict(self, reranker, pred_data, pred_fn)[source]

Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn

  • model (Reranker) – a PyTorch Reranker

  • pred_data (IterableDataset) – data to predict on

  • pred_fn (Path) – path to write the prediction run file to



fill_incomplete_batch(self, batch, batch_size=None)[source]

If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data. How the data is chosen: If the values are just a simple list, use the first element of the list to pad the batch If the values are tensors/numpy arrays, use repeat() along the batch dimension