Source code for gedml.core.losses.proxy_based_loss.proxy_anchor_loss

import torch

from ...misc import loss_function as l_f 
from ..base_loss import BaseLoss

[docs]class ProxyAnchorLoss(BaseLoss): """ paper: `Proxy Anchor Loss for Deep Metric Learning <http://openaccess.thecvf.com/content_CVPR_2020/html/Kim_Proxy_Anchor_Loss_for_Deep_Metric_Learning_CVPR_2020_paper.html>`_ """ def __init__( self, margin=0.1, alpha=32, **kwargs ): super(ProxyAnchorLoss, self).__init__(**kwargs) self.alpha = alpha self.margin = margin def required_metric(self): return ["cosine"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple=None, weights=None, is_same_source=False, ) -> torch.Tensor: pos_mask = row_labels == col_labels neg_mask = ~ pos_mask with_pos_proxies = torch.where(torch.sum(pos_mask, dim=0) != 0)[0] pos_term = l_f.logsumexp( - self.alpha * (metric_mat - self.margin), keep_mask=pos_mask, add_one=True, dim=0 ).squeeze() neg_term = l_f.logsumexp( self.alpha * (metric_mat + self.margin), keep_mask=neg_mask, add_one=True, dim=0 ).squeeze() if len(with_pos_proxies) == 0: pos_loss = torch.sum(metric_mat * 0) else: pos_loss = torch.mean(pos_term[with_pos_proxies]) neg_loss = torch.mean(neg_term) return pos_loss + neg_loss