Source code for gedml.core.losses.pair_based_loss.circle_loss

import torch

from ...misc import loss_function as l_f 
from ..base_loss import BaseLoss

[docs]class CircleLoss(BaseLoss): """ modified from: https://github.com/KevinMusgrave/pytorch-metric-learning paper: `Circle Loss: A Unified Perspective of Pair Similarity Optimization <http://openaccess.thecvf.com/content_CVPR_2020/html/Sun_Circle_Loss_A_Unified_Perspective_of_Pair_Similarity_Optimization_CVPR_2020_paper.html>`_ """ def __init__( self, m=0.4, gamma=80, **kwargs ): super(CircleLoss, self).__init__(**kwargs) self.m = m self.gamma = gamma self.op = 1 + m self.on = - m self.delta_p = 1 - m self.delta_n = m def required_metric(self): return ["cosine"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple, weights=None, is_same_source=False, ) -> torch.Tensor: pos_mask = (row_labels == col_labels) neg_mask = ~ pos_mask if is_same_source: pos_mask.fill_diagonal_(False) tmp_mat = torch.zeros_like(metric_mat) pos_pair = metric_mat[pos_mask] neg_pair = metric_mat[neg_mask] # construct matrix tmp_mat[pos_mask] = ( - self.gamma * torch.relu(self.op - pos_pair.detach()) * (pos_pair - self.delta_p) ) tmp_mat[neg_mask] = ( self.gamma * torch.relu(neg_pair.detach() - self.on) * (neg_pair - self.delta_n) ) # compute logsumexp se_pos = l_f.sumexp( tmp_mat, keep_mask=pos_mask, dim=1 ) se_neg = l_f.sumexp( tmp_mat, keep_mask=neg_mask, dim=1 ) loss = torch.log(1 + se_pos * se_neg) zero_rows = torch.where( (torch.sum(pos_mask, dim=1) != 0) & (torch.sum(neg_mask, dim=1) != 0) )[0] loss = loss[zero_rows] if len(loss) == 0: loss = torch.sum(metric_mat * 0) else: loss = torch.mean(loss) return loss