:mod:`capreolus.trainer` ======================== .. py:module:: capreolus.trainer Submodules ---------- .. toctree:: :titlesonly: :maxdepth: 1 test_trainer/index.rst Package Contents ---------------- Classes ~~~~~~~ .. autoapisummary:: capreolus.trainer.Trainer capreolus.trainer.PytorchTrainer capreolus.trainer.TrecCheckpointCallback capreolus.trainer.TensorFlowTrainer .. data:: logger .. data:: RESULTS_BASE_PATH .. py:class:: Trainer Bases: :class:`profane.ModuleBase` .. attribute:: module_type :annotation: = trainer .. attribute:: requires_random_seed :annotation: = True .. method:: get_paths_for_early_stopping(self, train_output_path, dev_output_path) .. py:class:: PytorchTrainer Bases: :class:`capreolus.trainer.Trainer` .. attribute:: module_name :annotation: = pytorch .. attribute:: config_spec .. attribute:: config_keys_not_in_path :annotation: = ['fastforward', 'boardname'] .. method:: build(self) .. method:: single_train_iteration(self, reranker, train_dataloader) Train model for one iteration using instances from train_dataloader. :param model: a PyTorch Reranker :type model: Reranker :param train_dataloader: a PyTorch DataLoader that iterates over training instances :type train_dataloader: DataLoader :returns: average loss over the iteration :rtype: float .. method:: load_loss_file(self, fn) Loads loss history from fn :param fn: path to a loss.txt file :type fn: Path :returns: a list of losses ordered by iterations .. method:: fastforward_training(self, reranker, weights_path, loss_fn) Skip to the last training iteration whose weights were saved. If saved model and optimizer weights are available, this method will load those weights into model and optimizer, and then return the next iteration to be run. For example, if weights are available for iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and this method will return 11. If an error or inconsistency is encountered when checking for weights, this method returns 0. This method checks several files to determine if weights "are available". First, loss_fn is read to determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. If this is successful, the method returns `1 + last recorded iteration`. If not, it returns 0. (We consider loss_fn because it is written at the end of every training iteration.) :param model: a PyTorch Reranker whose state should be loaded :type model: Reranker :param weights_path: directory containing model and optimizer weights :type weights_path: Path :param loss_fn: file containing loss history :type loss_fn: Path :returns: the next training iteration after fastforwarding. If successful, this is > 0. If no weights are available or they cannot be loaded, 0 is returned. :rtype: int .. method:: train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric) Train a model following the trainer's config (specifying batch size, number of iterations, etc). :param train_dataset: training dataset :type train_dataset: IterableDataset :param train_output_path: directory under which train_dataset runs and training loss will be saved :type train_output_path: Path :param dev_data: dev dataset :type dev_data: IterableDataset :param dev_output_path: directory where dev_data runs and metrics will be saved :type dev_output_path: Path .. method:: load_best_model(self, reranker, train_output_path) .. method:: predict(self, reranker, pred_data, pred_fn) Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn` :param model: a PyTorch Reranker :type model: Reranker :param pred_data: data to predict on :type pred_data: IterableDataset :param pred_fn: path to write the prediction run file to :type pred_fn: Path :returns: TREC Run .. method:: fill_incomplete_batch(self, batch) If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data. How the data is chosen: If the values are just a simple list, use the first element of the list to pad the batch If the values are tensors/numpy arrays, use repeat() along the batch dimension .. py:class:: TrecCheckpointCallback(qrels, dev_data, dev_records, output_path, validate_freq=1, *args, **kwargs) Bases: :class:`tensorflow.keras.callbacks.Callback` A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. See TensorflowTrainer.train() for the invocation Also saves the best model to disk .. method:: save_model(self) .. method:: on_epoch_begin(self, epoch, logs=None) .. method:: on_epoch_end(self, epoch, logs=None) .. staticmethod:: get_preds_in_trec_format(predictions, dev_data) Takes in a list of predictions and returns a dict that can be fed into pytrec_eval As a side effect, also writes the predictions into a file in the trec format .. py:class:: TensorFlowTrainer Bases: :class:`capreolus.trainer.Trainer` .. attribute:: module_name :annotation: = tensorflow .. attribute:: config_spec .. attribute:: config_keys_not_in_path :annotation: = ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'] .. method:: build(self) .. method:: validate(self) .. method:: get_optimizer(self) .. method:: fastforward_training(self, reranker, weights_path, loss_fn) .. method:: load_best_model(self, reranker, train_output_path) .. method:: apply_gradients(self, weights, grads) .. method:: train(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric) .. method:: create_tf_feature(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc) Creates a single tf.train.Feature instance (i.e, a single sample) .. method:: write_tf_record_to_file(self, dir_name, tf_features) Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage .. method:: convert_to_tf_dev_record(self, reranker, dataset) Similar to self.convert_to_tf_train_record(), but won't result in multiple files .. method:: convert_to_tf_train_record(self, reranker, dataset) Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature() .. method:: get_tf_record_cache_path(self, dataset) Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path. .. method:: cache_exists(self, dataset) .. method:: load_tf_records_from_file(self, reranker, filenames, batch_size) .. method:: load_cached_tf_records(self, reranker, dataset, batch_size) .. method:: get_tf_dev_records(self, reranker, dataset) 1. Returns tf records from cache (disk) if applicable 2. Else, converts the dataset into tf records, writes them to disk, and returns them .. method:: get_tf_train_records(self, reranker, dataset) 1. Returns tf records from cache (disk) if applicable 2. Else, converts the dataset into tf records, writes them to disk, and returns them .. method:: predict(self, reranker, pred_data, pred_fn) Predict query-document scores on `pred_data` using `model` and write a corresponding run file to `pred_fn` :param model: a PyTorch Reranker :type model: Reranker :param pred_data: data to predict on :type pred_data: IterableDataset :param pred_fn: path to write the prediction run file to :type pred_fn: Path :returns: TREC Run