capreolus.reranker.common

Module Contents

Functions

pair_softmax_loss(pos_neg_scores, *args, **kwargs)

pair_hinge_loss(pos_neg_scores, *args, **kwargs)

new_similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding)

similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding)

Original TF similarity matrix. May have issues with mixed precision. Use new_similarity_matrix_tf instead

create_emb_layer(weights, non_trainable=True)

class capreolus.reranker.common.KerasPairModel(model, *args, **kwargs)[source]

Bases: tensorflow.keras.Model

call(self, x, **kwargs)[source]
predict_step(self, data)[source]
class capreolus.reranker.common.KerasTripletModel(model, *args, **kwargs)[source]

Bases: tensorflow.keras.Model

call(self, x, **kwargs)[source]
predict_step(self, data)[source]
class capreolus.reranker.common.TFPairwiseHingeLoss[source]

Bases: tensorflow_ranking.python.keras.losses.PairwiseHingeLoss

call(self, y_true, y_pred)[source]
class capreolus.reranker.common.TFCategoricalCrossEntropyLoss[source]

Bases: tensorflow.python.keras.losses.CategoricalCrossentropy

call(self, ytrue, ypred)[source]
capreolus.reranker.common.pair_softmax_loss(pos_neg_scores, *args, **kwargs)[source]
capreolus.reranker.common.pair_hinge_loss(pos_neg_scores, *args, **kwargs)[source]
capreolus.reranker.common.new_similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding)[source]
capreolus.reranker.common.similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding)[source]

Original TF similarity matrix. May have issues with mixed precision. Use new_similarity_matrix_tf instead

class capreolus.reranker.common.SimilarityMatrix(embedding)[source]

Bases: torch.nn.Module

remove_padding(self, sim, query_tok, doc_tok, BAT, A, B)[source]
exact_match_matrix(self, query_tok, doc_tok, BAT, A, B)[source]
cosine_similarity_matrix(self, query_tok, doc_tok, BAT, A, B)[source]
forward(self, query_tok, doc_tok)[source]
class capreolus.reranker.common.StackedSimilarityMatrix(padding=0)[source]

Bases: torch.nn.Module

forward(self, query_embed, doc_embed, query_tok, doc_tok)[source]
class capreolus.reranker.common.RbfKernel(initial_mu, initial_sigma, requires_grad=True)[source]

Bases: torch.nn.Module

forward(self, data)[source]
class capreolus.reranker.common.RbfKernelBank(mus=None, sigmas=None, dim=1, requires_grad=True)[source]

Bases: torch.nn.Module

count(self)[source]
forward(self, data)[source]
class capreolus.reranker.common.RbfKernelBankTF(mus, sigmas, dim=1, requires_grad=True, **kwargs)[source]

Bases: tensorflow.keras.layers.Layer

count(self)[source]
call(self, data, **kwargs)[source]
class capreolus.reranker.common.RbfKernelTF(initial_mu, initial_sigma, requires_grad=True, **kwargs)[source]

Bases: tensorflow.keras.layers.Layer

call(self, data, *kwargs)[source]
capreolus.reranker.common.create_emb_layer(weights, non_trainable=True)[source]
class capreolus.reranker.common.NewRbfKernelBankTF(mus, sigmas, dim=1, requires_grad=True, **kwargs)[source]

Bases: tensorflow.keras.layers.Layer

count(self)[source]
call(self, data, **kwargs)[source]