Module Contents



Trains (optionally) on the TPU.



class capreolus.trainer.tensorflow.TensorflowTrainer(config=None, provide=None, share_dependency_objects=False, build=True)[source]

Bases: capreolus.trainer.Trainer

Trains (optionally) on the TPU. Uses two optimizers with different learning rates - one for the BERT layers and another for the classifier layers. Configurable warmup and decay for bertlr. WARNING: The optimizers depend on specific layer names (see train()) - if your reranker does not have layers with ‘bert’ in the name, the normal learning rate will be applied to it instead of the value supplied through the bertlr ConfigOption

module_name = tensorflow[source]
config_keys_not_in_path = ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]
train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]
predict(self, reranker, pred_data, pred_fn)[source]
form_tf_record_cache_path(self, dataset)[source]

Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.

find_cached_tf_records(self, dataset, required_sample_count)[source]

Looks for a tf record for the passed dataset that has at least the specified number of samples

get_tf_train_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable

  2. Else, converts the dataset into tf records, writes them to disk, and returns them

load_tf_train_records_from_file(self, reranker, filenames, batch_size)[source]
convert_to_tf_train_record(self, reranker, dataset)[source]

Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. Creates exactly niters * itersize samples. The exact structure of the tfrecords is defined by reranker.extractor. For example, see BertPassage.get_tf_train_feature() params: reranker - A capreolus.reranker.Reranker instance dataset - A capreolus.sampler.Sampler instance

get_tf_dev_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable

  2. Else, converts the dataset into tf records, writes them to disk, and returns them

load_tf_dev_records_from_file(self, reranker, filenames, batch_size)[source]
convert_to_tf_dev_record(self, reranker, dataset)[source]
write_tf_record_to_file(self, dir_name, tf_features, file_name=None)[source]

Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage

static get_preds_in_trec_format(predictions, dev_data)[source]

Takes in a list of predictions and returns a dict that can be fed into pytrec_eval As a side effect, also writes the predictions into a file in the trec format

get_loss(self, loss_name)[source]
get_wrapped_model(self, model)[source]

We need a wrapped model because the logic changes slightly depending on whether the input is pointwise or pairwise: 1. In case of pointwise input, there’s no “negative document” - so in this case we just have to execute the model’s call() method 2. In case of pairwise input, we need to execute the model’s call() method twice (one for positive doc and then again for negative doc), and then stack the results together before passing to a pairwise loss function.

The alternative was to let the user manually configure everything, for example: loss=crossentropy reranker.trainer.input=pairwise … - we already have too many ConfigOptions :shrug:

fastforward_training(self, model, weights_path, loss_fn, best_metric_fn)[source]

Skip to the last training iteration whose weights were saved.

If saved model and optimizer weights are available, this method will load those weights into model and optimizer, and then return the next iteration to be run. For example, if weights are available for iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and this method will return 11.

If an error or inconsistency is encountered when checking for weights, this method returns 0.

This method checks several files to determine if weights “are available”. First, loss_fn is read to determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. If this is successful, the method returns 1 + last recorded iteration. If not, it returns 0. (We consider loss_fn because it is written at the end of every training iteration.)

  • model (Reranker) – a PyTorch Reranker whose state should be loaded

  • weights_path (Path) – directory containing model and optimizer weights

  • loss_fn (Path) – file containing loss history


the next training iteration after fastforwarding. If successful, this is > 0.

If no weights are available or they cannot be loaded, 0 is returned.

Return type


load_best_model(self, reranker, train_output_path)[source]