Source code for gedml.core.losses.proxy_based_loss.soft_triple_loss

import torch

from ...misc import loss_function as l_f 
from ..base_loss import BaseLoss

[docs]class SoftTripleLoss(BaseLoss): """ paper: `SoftTriple Loss: Deep Metric Learning Without Triplet Sampling <https://openaccess.thecvf.com/content_ICCV_2019/html/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.html>`_ """ def __init__( self, centers_per_class=10, num_classes=100, la=20, gamma=0.1, margin=0.01, **kwargs ): super(SoftTripleLoss, self).__init__(**kwargs) self.centers_per_class = centers_per_class self.num_classes = num_classes self.la = la self.gamma = gamma self.margin = margin def required_metric(self): return ["cosine"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple=None, weights=None, is_same_source=False, ) -> torch.Tensor: metric_to_centers = metric_mat.view( -1, self.num_classes, self.centers_per_class ) prob = torch.nn.functional.softmax( metric_to_centers * self.gamma, dim=2 ) metric_to_classes = torch.sum( prob * metric_to_centers, dim=2 ) margin = torch.zeros_like(metric_to_classes) margin[torch.arange(margin.shape[0]), row_labels.squeeze()] = self.margin loss = torch.nn.functional.cross_entropy( self.la * (metric_to_classes - margin), row_labels.squeeze() ) return loss