capreolus.trainer.pytorch

Module Contents

Classes

PytorchTrainer Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer
capreolus.trainer.pytorch.logger[source]
capreolus.trainer.pytorch.RESULTS_BASE_PATH[source]
class capreolus.trainer.pytorch.PytorchTrainer(config=None, provide=None, share_dependency_objects=False, build=True)[source]

Bases: capreolus.trainer.Trainer

Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer

Modules should provide:
  • a train method that trains a reranker on training and dev (validation) data
  • a predict method that uses a reranker to make predictions on data
module_name = pytorch[source]
config_spec[source]
config_keys_not_in_path = ['fastforward', 'boardname'][source]
build(self)[source]
single_train_iteration(self, reranker, train_dataloader)[source]

Train model for one iteration using instances from train_dataloader.

Parameters:
  • model (Reranker) – a PyTorch Reranker
  • train_dataloader (DataLoader) – a PyTorch DataLoader that iterates over training instances
Returns:

average loss over the iteration

Return type:

float

load_loss_file(self, fn)[source]

Loads loss history from fn

Parameters:fn (Path) – path to a loss.txt file
Returns:a list of losses ordered by iterations
fastforward_training(self, reranker, weights_path, loss_fn)[source]

Skip to the last training iteration whose weights were saved.

If saved model and optimizer weights are available, this method will load those weights into model and optimizer, and then return the next iteration to be run. For example, if weights are available for iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and this method will return 11.

If an error or inconsistency is encountered when checking for weights, this method returns 0.

This method checks several files to determine if weights “are available”. First, loss_fn is read to determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. If this is successful, the method returns 1 + last recorded iteration. If not, it returns 0. (We consider loss_fn because it is written at the end of every training iteration.)

Parameters:
  • model (Reranker) – a PyTorch Reranker whose state should be loaded
  • weights_path (Path) – directory containing model and optimizer weights
  • loss_fn (Path) – file containing loss history
Returns:

the next training iteration after fastforwarding. If successful, this is > 0.

If no weights are available or they cannot be loaded, 0 is returned.

Return type:

int

train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]

Train a model following the trainer’s config (specifying batch size, number of iterations, etc).

Parameters:
  • train_dataset (IterableDataset) – training dataset
  • train_output_path (Path) – directory under which train_dataset runs and training loss will be saved
  • dev_data (IterableDataset) – dev dataset
  • dev_output_path (Path) – directory where dev_data runs and metrics will be saved
load_best_model(self, reranker, train_output_path)[source]
predict(self, reranker, pred_data, pred_fn)[source]

Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn

Parameters:
  • model (Reranker) – a PyTorch Reranker
  • pred_data (IterableDataset) – data to predict on
  • pred_fn (Path) – path to write the prediction run file to
Returns:

TREC Run

fill_incomplete_batch(self, batch)[source]

If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data. How the data is chosen: If the values are just a simple list, use the first element of the list to pad the batch If the values are tensors/numpy arrays, use repeat() along the batch dimension