Module Contents


TensorflowTrainer Trains (optionally) on the TPU.
class capreolus.trainer.tensorflow.TensorflowTrainer(config=None, provide=None, share_dependency_objects=False, build=True)[source]

Bases: capreolus.trainer.Trainer

Trains (optionally) on the TPU. Uses two optimizers with different learning rates - one for the BERT layers and another for the classifier layers. Configurable warmup and decay for bertlr. WARNING: The optimizers depend on specific layer names (see train()) - if your reranker does not have layers with ‘bert’ in the name, the normal learning rate will be applied to it instead of the value supplied through the bertlr ConfigOption

module_name = tensorflow[source]
config_keys_not_in_path = ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]
train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]
predict(self, reranker, pred_data, pred_fn)[source]
form_tf_record_cache_path(self, dataset)[source]

Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.

find_cached_tf_records(self, dataset, required_sample_count)[source]

Looks for a tf record for the passed dataset that has at least the specified number of samples

get_tf_train_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable
  2. Else, converts the dataset into tf records, writes them to disk, and returns them
load_tf_train_records_from_file(self, reranker, filenames, batch_size)[source]
convert_to_tf_train_record(self, reranker, dataset)[source]

Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. Creates exactly niters * itersize * batch_size samples. The exact structure of the tfrecords is defined by reranker.extractor. For example, see BertPassage.get_tf_train_feature() params: reranker - A capreolus.reranker.Reranker instance dataset - A capreolus.sampler.Sampler instance

get_tf_dev_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable
  2. Else, converts the dataset into tf records, writes them to disk, and returns them
load_tf_dev_records_from_file(self, reranker, filenames, batch_size)[source]
convert_to_tf_dev_record(self, reranker, dataset)[source]
write_tf_record_to_file(self, dir_name, tf_features)[source]

Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage

static get_preds_in_trec_format(predictions, dev_data)[source]

Takes in a list of predictions and returns a dict that can be fed into pytrec_eval As a side effect, also writes the predictions into a file in the trec format

change_lr(self, epoch, lr)[source]

Apply warm up or decay depending on the current epoch

get_loss(self, loss_name)[source]
get_wrapped_model(self, model)[source]

We need a wrapped model because the logic changes slightly depending on whether the input is pointwise or pairwise: 1. In case of pointwise input, there’s no “negative document” - so in this case we just have to execute the model’s call() method 2. In case of pairwise input, we need to execute the model’s call() method twice (one for positive doc and then again for negative doc), and then stack the results together before passing to a pairwise loss function.

The alternative was to let the user manually configure everything, for example: loss=crossentropy reranker.trainer.input=pairwise … - we already have too many ConfigOptions :shrug:

fastforward_training(self, reranker, weights_path, loss_fn)[source]
load_best_model(self, reranker, train_output_path)[source]