capreolus.reranker.common

Module Contents

Classes

KerasPairModel

KerasTripletModel

KerasLCEModel

TFPairwiseHingeLoss

TFCategoricalCrossEntropyLoss

TFLCELoss

SimilarityMatrix

StackedSimilarityMatrix

RbfKernel

RbfKernelBank

RbfKernelBankTF

RbfKernelTF

NewRbfKernelBankTF

Functions

pair_softmax_loss(pos_neg_scores, *args, **kwargs)

pair_hinge_loss(pos_neg_scores, *args, **kwargs)

new_similarity_matrix_tf(query_embed, doc_embed, ...)

similarity_matrix_tf(query_embed, doc_embed, ...)

Original TF similarity matrix. May have issues with mixed precision. Use new_similarity_matrix_tf instead

create_emb_layer(weights[, non_trainable])

class capreolus.reranker.common.KerasPairModel(model, *args, **kwargs)[source]

Bases: tensorflow.keras.Model

call(x, **kwargs)[source]
predict_step(data)[source]
class capreolus.reranker.common.KerasTripletModel(model, *args, **kwargs)[source]

Bases: tensorflow.keras.Model

call(x, **kwargs)[source]
predict_step(data)[source]
class capreolus.reranker.common.KerasLCEModel(model, *args, **kwargs)[source]

Bases: tensorflow.keras.Model

call(x, **kwargs)[source]
predict_step(data)[source]
class capreolus.reranker.common.TFPairwiseHingeLoss[source]

Bases: tensorflow_ranking.python.keras.losses.PairwiseHingeLoss

call(y_true, y_pred)[source]
class capreolus.reranker.common.TFCategoricalCrossEntropyLoss[source]

Bases: tensorflow.python.keras.losses.CategoricalCrossentropy

call(ytrue, ypred)[source]

Shape: (batch_size, 2)

class capreolus.reranker.common.TFLCELoss[source]

Bases: tensorflow.python.keras.losses.CategoricalCrossentropy

call(ytrue, ypred)[source]
capreolus.reranker.common.pair_softmax_loss(pos_neg_scores, *args, **kwargs)[source]
capreolus.reranker.common.pair_hinge_loss(pos_neg_scores, *args, **kwargs)[source]
capreolus.reranker.common.new_similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding)[source]
capreolus.reranker.common.similarity_matrix_tf(query_embed, doc_embed, query_tok, doc_tok, padding)[source]

Original TF similarity matrix. May have issues with mixed precision. Use new_similarity_matrix_tf instead

class capreolus.reranker.common.SimilarityMatrix(embedding)[source]

Bases: torch.nn.Module

remove_padding(sim, query_tok, doc_tok, BAT, A, B)[source]
exact_match_matrix(query_tok, doc_tok, BAT, A, B)[source]
cosine_similarity_matrix(query_tok, doc_tok, BAT, A, B)[source]
forward(query_tok, doc_tok)[source]
class capreolus.reranker.common.StackedSimilarityMatrix(padding=0)[source]

Bases: torch.nn.Module

forward(query_embed, doc_embed, query_tok, doc_tok)[source]
class capreolus.reranker.common.RbfKernel(initial_mu, initial_sigma, requires_grad=True)[source]

Bases: torch.nn.Module

forward(data)[source]
class capreolus.reranker.common.RbfKernelBank(mus=None, sigmas=None, dim=1, requires_grad=True)[source]

Bases: torch.nn.Module

count()[source]
forward(data)[source]
class capreolus.reranker.common.RbfKernelBankTF(mus, sigmas, dim=1, requires_grad=True, **kwargs)[source]

Bases: tensorflow.keras.layers.Layer

count()[source]
call(data, **kwargs)[source]
class capreolus.reranker.common.RbfKernelTF(initial_mu, initial_sigma, requires_grad=True, **kwargs)[source]

Bases: tensorflow.keras.layers.Layer

call(data, *kwargs)[source]
capreolus.reranker.common.create_emb_layer(weights, non_trainable=True)[source]
class capreolus.reranker.common.NewRbfKernelBankTF(mus, sigmas, dim=1, requires_grad=True, **kwargs)[source]

Bases: tensorflow.keras.layers.Layer

count()[source]
call(data, **kwargs)[source]