capreolus.trainer.tensorflow
¶
Module Contents¶
Classes¶
TensorflowTrainer |
Trains (optionally) on the TPU. |
-
class
capreolus.trainer.tensorflow.
TensorflowTrainer
(config=None, provide=None, share_dependency_objects=False, build=True)[source]¶ Bases:
capreolus.trainer.Trainer
Trains (optionally) on the TPU. Uses two optimizers with different learning rates - one for the BERT layers and another for the classifier layers. Configurable warmup and decay for bertlr. WARNING: The optimizers depend on specific layer names (see train()) - if your reranker does not have layers with ‘bert’ in the name, the normal learning rate will be applied to it instead of the value supplied through the bertlr ConfigOption
-
config_keys_not_in_path
= ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]¶
-
train
(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]¶
-
form_tf_record_cache_path
(self, dataset)[source]¶ Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.
-
find_cached_tf_records
(self, dataset, required_sample_count)[source]¶ Looks for a tf record for the passed dataset that has at least the specified number of samples
-
get_tf_train_records
(self, reranker, dataset)[source]¶ - Returns tf records from cache (disk) if applicable
- Else, converts the dataset into tf records, writes them to disk, and returns them
-
convert_to_tf_train_record
(self, reranker, dataset)[source]¶ Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. Creates exactly niters * itersize * batch_size samples. The exact structure of the tfrecords is defined by reranker.extractor. For example, see BertPassage.get_tf_train_feature() params: reranker - A capreolus.reranker.Reranker instance dataset - A capreolus.sampler.Sampler instance
-
get_tf_dev_records
(self, reranker, dataset)[source]¶ - Returns tf records from cache (disk) if applicable
- Else, converts the dataset into tf records, writes them to disk, and returns them
-
write_tf_record_to_file
(self, dir_name, tf_features)[source]¶ Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage
-
static
get_preds_in_trec_format
(predictions, dev_data)[source]¶ Takes in a list of predictions and returns a dict that can be fed into pytrec_eval As a side effect, also writes the predictions into a file in the trec format
-
get_wrapped_model
(self, model)[source]¶ We need a wrapped model because the logic changes slightly depending on whether the input is pointwise or pairwise: 1. In case of pointwise input, there’s no “negative document” - so in this case we just have to execute the model’s call() method 2. In case of pairwise input, we need to execute the model’s call() method twice (one for positive doc and then again for negative doc), and then stack the results together before passing to a pairwise loss function.
The alternative was to let the user manually configure everything, for example: loss=crossentropy reranker.trainer.input=pairwise … - we already have too many ConfigOptions :shrug:
-