Source code for capreolus.trainer.test_trainer

import collections
import numpy as np
import os
import tensorflow as tf
from capreolus.benchmark import DummyBenchmark
from capreolus.extractor import EmbedText
from capreolus.sampler import TrainDataset
from capreolus.trainer import TensorFlowTrainer


[docs]def test_tf_get_tf_dataset(monkeypatch): benchmark = DummyBenchmark() extractor = EmbedText( {"maxdoclen": 4, "maxqlen": 4, "tokenizer": {"keepstops": True}}, provide={"collection": benchmark.collection} ) training_judgments = benchmark.qrels.copy() train_dataset = TrainDataset(training_judgments, training_judgments, extractor) reranker = collections.namedtuple("reranker", "extractor")(extractor=extractor) def mock_id2vec(*args, **kwargs): return { "query": np.array([1, 2, 3, 4], dtype=np.long), "posdoc": np.array([1, 1, 1, 1], dtype=np.long), "negdoc": np.array([2, 2, 2, 2], dtype=np.long), "qid": "1", "posdocid": "posdoc1", "negdocid": "negdoc1", "query_idf": np.array([0.1, 0.1, 0.2, 0.1], dtype=np.float), } monkeypatch.setattr(EmbedText, "id2vec", mock_id2vec) trainer = TensorFlowTrainer( { "name": "tensorflow", "batch": 2, "niters": 2, "itersize": 16, "lr": 0.001, "validatefreq": 1, "usecache": False, "tpuname": None, "tpuzone": None, "storage": None, } ) tf_record_filenames = trainer.convert_to_tf_train_record(reranker, train_dataset) for filename in tf_record_filenames: assert os.path.isfile(filename) tf_record_dataset = trainer.load_tf_records_from_file(reranker, tf_record_filenames, 2) dataset = tf_record_dataset for idx, data_and_label in enumerate(dataset): batch, _ = data_and_label tf.debugging.assert_equal(batch[0], tf.convert_to_tensor(np.array([[1, 1, 1, 1], [1, 1, 1, 1]]), dtype=tf.int64)) tf.debugging.assert_equal(batch[1], tf.convert_to_tensor(np.array([[2, 2, 2, 2], [2, 2, 2, 2]]), dtype=tf.int64)) tf.debugging.assert_equal(batch[2], tf.convert_to_tensor(np.array([[1, 2, 3, 4], [1, 2, 3, 4]]), dtype=tf.int64)) tf.debugging.assert_equal( batch[3], tf.convert_to_tensor(np.array([[0.1, 0.1, 0.2, 0.1], [0.1, 0.1, 0.2, 0.1]]), dtype=tf.float32)
)