Source code for gedml.core.losses.pair_based_loss.margin_loss

import torch

from ...misc import loss_function as l_f 
from ..base_loss import BaseLoss

[docs]class MarginLoss(BaseLoss): """ paper: `Sampling Matters in Deep Embedding Learning <http://openaccess.thecvf.com/content_iccv_2017/html/Wu_Sampling_Matters_in_ICCV_2017_paper.html>`_ """ def __init__( self, alpha=0.2, beta=1.2, nu=0, num_classes=100, beta_constant=False, is_similarity=False, **kwargs ): super().__init__(**kwargs) self.alpha = alpha self.nu = nu self.num_classes = num_classes self.is_similarity = is_similarity self.beta_constant = beta_constant self.beta = ( beta if beta_constant else torch.nn.Parameter(torch.ones(num_classes) * beta) ) def required_metric(self): return ["euclid"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple, is_same_source=False, *args, **kwargs ) -> torch.Tensor: a1, p, a2, n = l_f.split_indices(indices_tuple) pos_pair, neg_pair = l_f.indices_to_pairs(metric_mat, indices_tuple) if self.beta_constant: beta_pos = beta_neg = self.beta else: beta_pos = self.beta[row_labels[a1].flatten()] beta_neg = self.beta[row_labels[a2].flatten()] if not self.is_similarity: pos_loss = torch.nn.functional.relu(self.alpha + pos_pair - beta_pos) neg_loss = torch.nn.functional.relu(self.alpha - neg_pair + beta_neg) else: pos_loss = torch.nn.functional.relu(beta_pos + self.alpha - pos_pair) neg_loss = torch.nn.functional.relu(neg_pair - beta_neg + self.alpha) # mean_triplet_loss = torch.mean(triplet_loss) pos_loss = pos_loss[torch.where(pos_loss)[0]] neg_loss = neg_loss[torch.where(neg_loss)[0]] numerator = torch.sum(pos_loss) + torch.sum(neg_loss) denominator = len(pos_loss) + len(neg_loss) loss = numerator / denominator return loss