capreolus.trainer.tensorflow

Module Contents

Classes

TrecCheckpointCallback(qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs) A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset.
TensorFlowTrainer(config=None, provide=None, share_dependency_objects=False, build=True) Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer
capreolus.trainer.tensorflow.logger[source]
capreolus.trainer.tensorflow.RESULTS_BASE_PATH[source]
class capreolus.trainer.tensorflow.TrecCheckpointCallback(qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs)[source]

Bases: tensorflow.keras.callbacks.Callback

A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. See TensorflowTrainer.train() for the invocation Also saves the best model to disk

save_model(self)[source]
on_epoch_begin(self, epoch, logs=None)[source]
on_epoch_end(self, epoch, logs=None)[source]
static get_preds_in_trec_format(predictions, dev_data)[source]

Takes in a list of predictions and returns a dict that can be fed into pytrec_eval As a side effect, also writes the predictions into a file in the trec format

class capreolus.trainer.tensorflow.TensorFlowTrainer(config=None, provide=None, share_dependency_objects=False, build=True)[source]

Bases: capreolus.trainer.Trainer

Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer

Modules should provide:
  • a train method that trains a reranker on training and dev (validation) data
  • a predict method that uses a reranker to make predictions on data
module_name = tensorflow[source]
config_spec[source]
config_keys_not_in_path = ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]
build(self)[source]
validate(self)[source]
get_optimizer(self)[source]
fastforward_training(self, reranker, weights_path, loss_fn)[source]
load_best_model(self, reranker, train_output_path)[source]
apply_gradients(self, weights, grads)[source]
train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]
create_tf_feature(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc)[source]

Creates a single tf.train.Feature instance (i.e, a single sample)

write_tf_record_to_file(self, dir_name, tf_features)[source]

Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage

convert_to_tf_dev_record(self, reranker, dataset)[source]

Similar to self.convert_to_tf_train_record(), but won’t result in multiple files

convert_to_tf_train_record(self, reranker, dataset)[source]

Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature()

get_tf_record_cache_path(self, dataset)[source]

Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.

cache_exists(self, dataset)[source]
load_tf_records_from_file(self, reranker, filenames, batch_size)[source]
load_cached_tf_records(self, reranker, dataset, batch_size)[source]
get_tf_dev_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable
  2. Else, converts the dataset into tf records, writes them to disk, and returns them
get_tf_train_records(self, reranker, dataset)[source]
  1. Returns tf records from cache (disk) if applicable
  2. Else, converts the dataset into tf records, writes them to disk, and returns them
predict(self, reranker, pred_data, pred_fn)[source]

Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn

Parameters:
  • model (Reranker) – a PyTorch Reranker
  • pred_data (IterableDataset) – data to predict on
  • pred_fn (Path) – path to write the prediction run file to
Returns:

TREC Run