Source code for gedml.core.losses.pair_based_loss.fast_ap_loss

import torch

from ...misc import loss_function as l_f 
from ..base_loss import BaseLoss

[docs]class FastAPLoss(BaseLoss): """ modified from: https://github.com/KevinMusgrave/pytorch-metric-learning paper: `Deep Metric Learning to Rank <https://openaccess.thecvf.com/content_CVPR_2019/html/Cakir_Deep_Metric_Learning_to_Rank_CVPR_2019_paper.html>`_ """ def __init__( self, num_bins=10, **kwargs ): super(FastAPLoss, self).__init__(**kwargs) self.num_bins = int(num_bins) self.num_edges = self.num_bins + 1 self.histogram_max = 4 self.histogram_delta = self.histogram_max / self.num_bins def required_metric(self): return ["euclid_normalized"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple=None, weights=None, is_same_source=False, ) -> torch.Tensor: dtype, device = metric_mat.dtype, metric_mat.device pos_mask = row_labels == col_labels neg_mask = ~ pos_mask if is_same_source: pos_mask.fill_diagonal_(False) pos_mask = pos_mask.type(dtype) neg_mask = neg_mask.type(dtype) N_pos = torch.sum(pos_mask, dim=1) safe_N = N_pos > 0 if torch.sum(safe_N) == 0: return torch.sum(metric_mat * 0) # construct the histogram mid_points = ( torch.linspace(0.0, self.histogram_max, steps=self.num_edges) .view(-1, 1, 1) .to(device) .type(dtype) ) pulse = torch.relu( 1 - torch.abs(metric_mat - mid_points) / self.histogram_delta ) pos_hist = torch.sum(pulse * pos_mask, dim=2).t() neg_hist = torch.sum(pulse * neg_mask, dim=2).t() total_pos_hist = torch.cumsum(pos_hist, dim=1) total_hist = torch.cumsum(pos_hist + neg_hist, dim=1) h_pos_product = pos_hist * total_pos_hist safe_H = (h_pos_product > 0) & (total_hist > 0) if torch.sum(safe_H) > 0: FastAP = torch.zeros_like(pos_hist).to(device) FastAP[safe_H] = h_pos_product[safe_H] / total_hist[safe_H] FastAP = torch.sum(FastAP, dim=1) FastAP = FastAP[safe_N] / N_pos[safe_N] FastAP = 1 - FastAP return torch.mean(FastAP) else: return torch.sum(metric_mat * 0)