capreolus.trainer

Package Contents

Classes

Trainer()
PytorchTrainer()
TrecCheckpointCallback(qrels, dev_data, dev_records, output_path, validate_freq=1, *args, **kwargs) A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset.
TensorFlowTrainer()
capreolus.trainer.logger[source]
capreolus.trainer.RESULTS_BASE_PATH[source]
class capreolus.trainer.Trainer[source]

Bases: profane.ModuleBase

module_type = trainer[source]
requires_random_seed = True[source]
get_paths_for_early_stopping(self, train_output_path, dev_output_path)[source]
class capreolus.trainer.PytorchTrainer[source]

Bases: capreolus.trainer.Trainer

module_name = pytorch[source]
config_spec[source]
config_keys_not_in_path = ['fastforward', 'boardname'][source]
build(self)[source]
single_train_iteration(self, reranker, train_dataloader)[source]

Train model for one iteration using instances from train_dataloader.

Parameters:
  • model (Reranker) – a PyTorch Reranker
  • train_dataloader (DataLoader) – a PyTorch DataLoader that iterates over training instances
Returns:

average loss over the iteration

Return type:

float

load_loss_file(self, fn)[source]

Loads loss history from fn

Parameters:fn (Path) – path to a loss.txt file
Returns:a list of losses ordered by iterations
fastforward_training(self, reranker, weights_path, loss_fn)[source]

Skip to the last training iteration whose weights were saved.

If saved model and optimizer weights are available, this method will load those weights into model and optimizer, and then return the next iteration to be run. For example, if weights are available for iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and this method will return 11.

If an error or inconsistency is encountered when checking for weights, this method returns 0.

This method checks several files to determine if weights “are available”. First, loss_fn is read to determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. If this is successful, the method returns 1 + last recorded iteration. If not, it returns 0. (We consider loss_fn because it is written at the end of every training iteration.)

Parameters:
  • model (Reranker) – a PyTorch Reranker whose state should be loaded
  • weights_path (Path) – directory containing model and optimizer weights
  • loss_fn (Path) – file containing loss history
Returns:

the next training iteration after fastforwarding. If successful, this is > 0.

If no weights are available or they cannot be loaded, 0 is returned.

Return type:

int

train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric)[source]

Train a model following the trainer’s config (specifying batch size, number of iterations, etc).

Parameters:
  • train_dataset (IterableDataset) – training dataset
  • train_output_path (Path) – directory under which train_dataset runs and training loss will be saved
  • dev_data (IterableDataset) – dev dataset
  • dev_output_path (Path) – directory where dev_data runs and metrics will be saved
load_best_model(self, reranker, train_output_path)[source]
predict(self, reranker, pred_data, pred_fn)[source]

Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn

Parameters:
  • model (Reranker) – a PyTorch Reranker
  • pred_data (IterableDataset) – data to predict on
  • pred_fn (Path) – path to write the prediction run file to
Returns:

TREC Run

fill_incomplete_batch(self, batch)[source]

If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data. How the data is chosen: If the values are just a simple list, use the first element of the list to pad the batch If the values are tensors/numpy arrays, use repeat() along the batch dimension

class capreolus.trainer.TrecCheckpointCallback(qrels, dev_data, dev_records, output_path, validate_freq=1, *args, **kwargs)[source]

Bases: tensorflow.keras.callbacks.Callback

A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. See TensorflowTrainer.train() for the invocation Also saves the best model to disk

save_model(self)[source]
on_epoch_begin(self, epoch, logs=None)[source]
on_epoch_end(self, epoch, logs=None)[source]
static get_preds_in_trec_format(predictions, dev_data)[source]

Takes in a list of predictions and returns a dict that can be fed into pytrec_eval As a side effect, also writes the predictions into a file in the trec format

class capreolus.trainer.TensorFlowTrainer[source]

Bases: capreolus.trainer.Trainer

module_name = tensorflow[source]
config_spec[source]
config_keys_not_in_path = ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]
build(self)[source]
validate(self)[source]
get_optimizer(self)[source]
fastforward_training(self, reranker, weights_path, loss_fn)[source]
load_best_model(self, reranker, train_output_path)[source]
apply_gradients(self, weights, grads)[source]
train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric)[source]
create_tf_feature(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc)[source]

Creates a single tf.train.Feature instance (i.e, a single sample)

write_tf_record_to_file(self, dir_name, tf_features)[source]

Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage

convert_to_tf_dev_record(self, reranker, dataset)[source]

Similar to self.convert_to_tf_train_record(), but won’t result in multiple files

convert_to_tf_train_record(self, reranker, dataset)[source]

Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature()

get_tf_record_cache_path(self, dataset)[source]

Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.

cache_exists(self, dataset)[source]
load_tf_records_from_file(self, reranker, filenames, batch_size)[source]
load_cached_tf_records(self, reranker, dataset, batch_size)[source]
get_tf_dev_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable
  2. Else, converts the dataset into tf records, writes them to disk, and returns them
get_tf_train_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable
  2. Else, converts the dataset into tf records, writes them to disk, and returns them
predict(self, reranker, pred_data, pred_fn)[source]

Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn

Parameters:
  • model (Reranker) – a PyTorch Reranker
  • pred_data (IterableDataset) – data to predict on
  • pred_fn (Path) – path to write the prediction run file to
Returns:

TREC Run