capreolus.trainer.tensorflow
¶
Module Contents¶
Classes¶
TrecCheckpointCallback (qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs) |
A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. |
TensorFlowTrainer (config=None, provide=None, share_dependency_objects=False, build=True) |
Base class for Trainer modules. The purpose of a Trainer is to train a Reranker module and use it to make predictions. Capreolus provides two trainers: PytorchTrainer and TensorFlowTrainer |
-
class
capreolus.trainer.tensorflow.
TrecCheckpointCallback
(qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs)[source]¶ Bases:
tensorflow.keras.callbacks.Callback
A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. See TensorflowTrainer.train() for the invocation Also saves the best model to disk
-
class
capreolus.trainer.tensorflow.
TensorFlowTrainer
(config=None, provide=None, share_dependency_objects=False, build=True)[source]¶ Bases:
capreolus.trainer.Trainer
Base class for Trainer modules. The purpose of a Trainer is to train a
Reranker
module and use it to make predictions. Capreolus provides two trainers:PytorchTrainer
andTensorFlowTrainer
- Modules should provide:
- a
train
method that trains a reranker on training and dev (validation) data - a
predict
method that uses a reranker to make predictions on data
- a
-
config_keys_not_in_path
= ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]¶
-
train
(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]¶
-
create_tf_feature
(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc)[source]¶ Creates a single tf.train.Feature instance (i.e, a single sample)
-
write_tf_record_to_file
(self, dir_name, tf_features)[source]¶ Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage
-
convert_to_tf_dev_record
(self, reranker, dataset)[source]¶ Similar to self.convert_to_tf_train_record(), but won’t result in multiple files
-
convert_to_tf_train_record
(self, reranker, dataset)[source]¶ Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature()
-
get_tf_record_cache_path
(self, dataset)[source]¶ Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.
-
get_tf_dev_records
(self, reranker, dataset)[source]¶ - Returns tf records from cache (disk) if applicable
- Else, converts the dataset into tf records, writes them to disk, and returns them
-
get_tf_train_records
(self, reranker, dataset)[source]¶ - Returns tf records from cache (disk) if applicable
- Else, converts the dataset into tf records, writes them to disk, and returns them
-
predict
(self, reranker, pred_data, pred_fn)[source]¶ Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn
Parameters: - model (Reranker) – a PyTorch Reranker
- pred_data (IterableDataset) – data to predict on
- pred_fn (Path) – path to write the prediction run file to
Returns: TREC Run