Source code for gedml.core.losses.proxy_based_loss.proxy_loss

import torch
from ..base_loss import BaseLoss

[docs]class ProxyLoss(BaseLoss): """ paper: `No Fuss Distance Metric Learning Using Proxies <https://openaccess.thecvf.com/content_iccv_2017/html/Movshovitz-Attias_No_Fuss_Distance_ICCV_2017_paper.html>`_ """ def __init__( self, gamma, **kwargs ): super().__init__(**kwargs) self.gamma = gamma def required_metric(self): return ["euclid"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, is_same_source=False, *args, **kwargs ) -> torch.Tensor: dtype, device = metric_mat.dtype, metric_mat.device pos_mask = (row_labels == col_labels).type(dtype).to(device) exp = torch.nn.functional.softmax( self.gamma * metric_mat, dim=-1 ) exp = torch.sum( exp * pos_mask, dim=-1 ) exp = exp[torch.where(exp)[0]] loss = torch.mean( - torch.log(exp)) return loss