Source code for gedml.core.losses.pair_based_loss.pos_pair_loss

import torch

from ...misc import loss_function as l_f 
from ..base_loss import BaseLoss

[docs]class PosPairLoss(BaseLoss): """ Designed for SimSiam. paper: `Exploring Simple Siamese Representation Learning <https://arxiv.org/abs/2011.10566>`_ """ def __init__( self, **kwargs ): super(PosPairLoss, self).__init__(**kwargs) def required_metric(self): return ["cosine"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple, weights=None, is_same_source=False, ) -> torch.Tensor: pos_mask = (row_labels == col_labels) pos_pair = metric_mat[pos_mask] loss = torch.mean( - pos_pair) return loss