capreolus.trainer.tensorflow
¶
Module Contents¶
Classes¶
Trains (optionally) on the TPU. |
Functions¶
Attributes¶
- class capreolus.trainer.tensorflow.TensorflowTrainer(config=None, provide=None, share_dependency_objects=False, build=True)[source]¶
Bases:
capreolus.trainer.Trainer
Trains (optionally) on the TPU. Uses two optimizers with different learning rates - one for the BERT layers and another for the classifier layers. Configurable warmup and decay for bertlr. WARNING: The optimizers depend on specific layer names (see train()) - if your reranker does not have layers with ‘bert’ in the name, the normal learning rate will be applied to it instead of the value supplied through the bertlr ConfigOption
- config_keys_not_in_path = ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]¶
- train(reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]¶
- form_tf_record_cache_path(dataset)[source]¶
Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.
- find_cached_tf_records(dataset, required_sample_count)[source]¶
Looks for a tf record for the passed dataset that has at least the specified number of samples
- get_tf_train_records(reranker, dataset)[source]¶
Returns tf records from cache (disk) if applicable
Else, converts the dataset into tf records, writes them to disk, and returns them
- convert_to_tf_train_record(reranker, dataset)[source]¶
Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. Creates exactly niters * itersize samples. The exact structure of the tfrecords is defined by reranker.extractor. For example, see BertPassage.get_tf_train_feature() params: reranker - A capreolus.reranker.Reranker instance dataset - A capreolus.sampler.Sampler instance
- get_tf_dev_records(reranker, dataset)[source]¶
Returns tf records from cache (disk) if applicable
Else, converts the dataset into tf records, writes them to disk, and returns them
- write_tf_record_to_file(dir_name, tf_features, file_name=None)[source]¶
Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage
- static get_preds_in_trec_format(predictions, dev_data)[source]¶
Takes in a list of predictions and returns a dict that can be fed into pytrec_eval As a side effect, also writes the predictions into a file in the trec format
- get_wrapped_model(model)[source]¶
We need a wrapped model because the logic changes slightly depending on whether the input is pointwise or pairwise: 1. In case of pointwise input, there’s no “negative document” - so in this case we just have to execute the model’s call() method 2. In case of pairwise input, we need to execute the model’s call() method twice (one for positive doc and then again for negative doc), and then stack the results together before passing to a pairwise loss function.
The alternative was to let the user manually configure everything, for example: loss=crossentropy reranker.trainer.input=pairwise … - we already have too many ConfigOptions :shrug:
- fastforward_training(model, weights_path, loss_fn, best_metric_fn)[source]¶
Skip to the last training iteration whose weights were saved.
If saved model and optimizer weights are available, this method will load those weights into model and optimizer, and then return the next iteration to be run. For example, if weights are available for iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and this method will return 11.
If an error or inconsistency is encountered when checking for weights, this method returns 0.
This method checks several files to determine if weights “are available”. First, loss_fn is read to determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. If this is successful, the method returns 1 + last recorded iteration. If not, it returns 0. (We consider loss_fn because it is written at the end of every training iteration.)
- Parameters
model (Reranker) – a PyTorch Reranker whose state should be loaded
weights_path (Path) – directory containing model and optimizer weights
loss_fn (Path) – file containing loss history
- Returns
- the next training iteration after fastforwarding. If successful, this is > 0.
If no weights are available or they cannot be loaded, 0 is returned.
- Return type
int