Source code for gedml.core.losses.classifier_based_loss.large_margin_softmax_loss

import torch
import scipy.special
import math

from ..base_loss import BaseLoss

[docs]class LargeMarginSoftmaxLoss(BaseLoss): """ modified from: https://github.com/KevinMusgrave/pytorch-metric-learning paper: `Large-Margin Softmax Loss for Convolutional Neural Networks <https://www.jmlr.org/proceedings/papers/v48/liud16.pdf>`_ """ def __init__( self, margin=4, scale=1, **kwargs ): super(LargeMarginSoftmaxLoss, self).__init__(**kwargs) self.margin = margin self.scale = scale self.initiate_margin() def required_metric(self): return ["cosine"] def initiate_margin(self): self.margin = int(self.margin) self.max_n = self.margin // 2 ## For the trigonometric multiple-angle formula ## self.n_range = torch.Tensor([ n for n in range(0, self.max_n + 1) ]) self.margin_choose_n = torch.Tensor([ scipy.special.binom(self.margin, 2 * n) for n in self.n_range ]) self.cos_powers = torch.Tensor([ self.margin - (2 * n) for n in self.n_range ]) self.alternating = torch.Tensor([ (-1) ** n for n in self.n_range ]) def get_cos_with_margin(self, cosine): cosine = cosine.unsqueeze(1) for attr in ["n_range", "margin_choose_n", "cos_powers", "alternating"]: setattr(self, attr, getattr(self, attr).to(cosine.device)) cos_powered = cosine ** self.cos_powers sin_powered = (1 - cosine ** 2) ** self.n_range terms = ( self.alternating * self.margin_choose_n * cos_powered * sin_powered ) return torch.sum(terms, dim=1) def get_target_mask(self, metric_mat, labels): batch_size = labels.size(0) mask = torch.zeros_like(metric_mat, device=metric_mat.device) mask[torch.arange(batch_size), labels] = 1 return mask def get_angles(self, cosine_of_target_classes): angles = torch.acos(torch.clamp(cosine_of_target_classes, -1, 1)) return angles def modify_cosine_of_target_classes(self, cosine_of_target_classes): cos_with_margin = self.get_cos_with_margin(cosine_of_target_classes) angles = self.get_angles(cosine_of_target_classes) with torch.no_grad(): k = ( angles / (math.pi / self.margin) ).floor() return ((-1) ** k) * cos_with_margin - (2 * k)
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple=None, is_same_source=False, *args, **kwargs, ) -> torch.Tensor: dtype, device = metric_mat.dtype, metric_mat.device mask = self.get_target_mask(metric_mat, row_labels.squeeze()) cosine_of_target_classes = metric_mat[mask == 1] modified_cosine_of_target_classes = self.modify_cosine_of_target_classes( cosine_of_target_classes ) diff = (modified_cosine_of_target_classes - cosine_of_target_classes).unsqueeze(1) logits = metric_mat + (mask * diff) # TODO: lack scaling function loss = torch.nn.functional.cross_entropy(logits * self.scale, row_labels.squeeze()) return loss