capreolus.trainer
¶
Submodules¶
Package Contents¶
Classes¶
Trainer (config=None, provide=None, share_dependency_objects=False, build=True) |
Base class for profane modules. |
PytorchTrainer (config=None, provide=None, share_dependency_objects=False, build=True) |
Base class for profane modules. |
TrecCheckpointCallback (qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs) |
A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. |
TensorFlowTrainer (config=None, provide=None, share_dependency_objects=False, build=True) |
Base class for profane modules. |
-
class
capreolus.trainer.
Trainer
(config=None, provide=None, share_dependency_objects=False, build=True)[source]¶ Bases:
profane.ModuleBase
Base class for profane modules. Module construction proceeds as follows: 1) Any config options not present in config are filled in with their default values. Config options and their defaults are specified in the config_spec class attribute. 2) Any dependencies declared in the dependencies class attribute are recursively instantiated. If the dependency object is present in provide, this object will be used instead of instantiating a new object for the dependency. 3) The module object’s config variable is updated to reflect the configs of its dependencies and then frozen.
After construction is complete, the module’s dependencies are available as instance variables: self.`dependency key`.
Parameters: - config – dictionary containing a config to apply to this module and its dependencies
- provide – dictionary mapping dependency keys to module objects
- share_dependency_objects – if true, dependencies will be cached in the registry based on their configs and reused. See the share_objects argument of ModuleBase.create.
-
class
capreolus.trainer.
PytorchTrainer
(config=None, provide=None, share_dependency_objects=False, build=True)[source]¶ Bases:
capreolus.trainer.Trainer
Base class for profane modules. Module construction proceeds as follows: 1) Any config options not present in config are filled in with their default values. Config options and their defaults are specified in the config_spec class attribute. 2) Any dependencies declared in the dependencies class attribute are recursively instantiated. If the dependency object is present in provide, this object will be used instead of instantiating a new object for the dependency. 3) The module object’s config variable is updated to reflect the configs of its dependencies and then frozen.
After construction is complete, the module’s dependencies are available as instance variables: self.`dependency key`.
Parameters: - config – dictionary containing a config to apply to this module and its dependencies
- provide – dictionary mapping dependency keys to module objects
- share_dependency_objects – if true, dependencies will be cached in the registry based on their configs and reused. See the share_objects argument of ModuleBase.create.
-
single_train_iteration
(self, reranker, train_dataloader)[source]¶ Train model for one iteration using instances from train_dataloader.
Parameters: - model (Reranker) – a PyTorch Reranker
- train_dataloader (DataLoader) – a PyTorch DataLoader that iterates over training instances
Returns: average loss over the iteration
Return type: float
-
load_loss_file
(self, fn)[source]¶ Loads loss history from fn
Parameters: fn (Path) – path to a loss.txt file Returns: a list of losses ordered by iterations
-
fastforward_training
(self, reranker, weights_path, loss_fn)[source]¶ Skip to the last training iteration whose weights were saved.
If saved model and optimizer weights are available, this method will load those weights into model and optimizer, and then return the next iteration to be run. For example, if weights are available for iterations 0-10 (11 zero-indexed iterations), the weights from iteration index 10 will be loaded, and this method will return 11.
If an error or inconsistency is encountered when checking for weights, this method returns 0.
This method checks several files to determine if weights “are available”. First, loss_fn is read to determine the last recorded iteration. (If a path is missing or loss_fn is malformed, 0 is returned.) Second, the weights from the last recorded iteration in loss_fn are loaded into the model and optimizer. If this is successful, the method returns 1 + last recorded iteration. If not, it returns 0. (We consider loss_fn because it is written at the end of every training iteration.)
Parameters: - model (Reranker) – a PyTorch Reranker whose state should be loaded
- weights_path (Path) – directory containing model and optimizer weights
- loss_fn (Path) – file containing loss history
Returns: - the next training iteration after fastforwarding. If successful, this is > 0.
If no weights are available or they cannot be loaded, 0 is returned.
Return type: int
-
train
(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]¶ Train a model following the trainer’s config (specifying batch size, number of iterations, etc).
Parameters: - train_dataset (IterableDataset) – training dataset
- train_output_path (Path) – directory under which train_dataset runs and training loss will be saved
- dev_data (IterableDataset) – dev dataset
- dev_output_path (Path) – directory where dev_data runs and metrics will be saved
-
predict
(self, reranker, pred_data, pred_fn)[source]¶ Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn
Parameters: - model (Reranker) – a PyTorch Reranker
- pred_data (IterableDataset) – data to predict on
- pred_fn (Path) – path to write the prediction run file to
Returns: TREC Run
-
fill_incomplete_batch
(self, batch)[source]¶ If a batch is incomplete (i.e shorter than the desired batch size), this method fills in the batch with some data. How the data is chosen: If the values are just a simple list, use the first element of the list to pad the batch If the values are tensors/numpy arrays, use repeat() along the batch dimension
-
class
capreolus.trainer.
TrecCheckpointCallback
(qrels, dev_data, dev_records, output_path, metric, validate_freq, relevance_level, *args, **kwargs)[source]¶ Bases:
tensorflow.keras.callbacks.Callback
A callback that runs after every epoch and calculates pytrec_eval style metrics for the dev dataset. See TensorflowTrainer.train() for the invocation Also saves the best model to disk
-
class
capreolus.trainer.
TensorFlowTrainer
(config=None, provide=None, share_dependency_objects=False, build=True)[source]¶ Bases:
capreolus.trainer.Trainer
Base class for profane modules. Module construction proceeds as follows: 1) Any config options not present in config are filled in with their default values. Config options and their defaults are specified in the config_spec class attribute. 2) Any dependencies declared in the dependencies class attribute are recursively instantiated. If the dependency object is present in provide, this object will be used instead of instantiating a new object for the dependency. 3) The module object’s config variable is updated to reflect the configs of its dependencies and then frozen.
After construction is complete, the module’s dependencies are available as instance variables: self.`dependency key`.
Parameters: - config – dictionary containing a config to apply to this module and its dependencies
- provide – dictionary mapping dependency keys to module objects
- share_dependency_objects – if true, dependencies will be cached in the registry based on their configs and reused. See the share_objects argument of ModuleBase.create.
-
config_keys_not_in_path
= ['fastforward', 'boardname', 'usecache', 'tpuname', 'tpuzone', 'storage'][source]¶
-
train
(self, reranker, train_dataset, train_output_path, dev_data, dev_output_path, qrels, metric, relevance_level=1)[source]¶
-
create_tf_feature
(self, qid, query, query_idf, posdoc_id, posdoc, negdoc_id, negdoc)[source]¶ Creates a single tf.train.Feature instance (i.e, a single sample)
-
write_tf_record_to_file
(self, dir_name, tf_features)[source]¶ Actually write the tf record to file. The destination can also be a gcs bucket. TODO: Use generators to optimize memory usage
-
convert_to_tf_dev_record
(self, reranker, dataset)[source]¶ Similar to self.convert_to_tf_train_record(), but won’t result in multiple files
-
convert_to_tf_train_record
(self, reranker, dataset)[source]¶ Tensorflow works better if the input data is fed in as tfrecords Takes in a dataset, iterates through it, and creates multiple tf records from it. The exact structure of the tfrecords is defined by reranker.extractor. For example, see EmbedText.get_tf_feature()
-
get_tf_record_cache_path
(self, dataset)[source]¶ Get the path to the directory where tf records are written to. If using TPUs, this will be a gcs path.
-
get_tf_dev_records
(self, reranker, dataset)[source]¶ - Returns tf records from cache (disk) if applicable
- Else, converts the dataset into tf records, writes them to disk, and returns them
-
get_tf_train_records
(self, reranker, dataset)[source]¶ - Returns tf records from cache (disk) if applicable
- Else, converts the dataset into tf records, writes them to disk, and returns them
-
predict
(self, reranker, pred_data, pred_fn)[source]¶ Predict query-document scores on pred_data using model and write a corresponding run file to pred_fn
Parameters: - model (Reranker) – a PyTorch Reranker
- pred_data (IterableDataset) – data to predict on
- pred_fn (Path) – path to write the prediction run file to
Returns: TREC Run