Source code for gedml.core.losses.pair_based_loss.contrastive_loss

import torch

from ...misc import loss_function as l_f 
from ..base_loss import BaseLoss

[docs]class ContrastiveLoss(BaseLoss): """ paper: `Learning a Similarity Metric Discriminatively, with Application to Face Verification <https://ieeexplore.ieee.org/abstract/document/1467314/>`_ """ def __init__( self, pos_margin=0, neg_margin=1, **kwargs ): super().__init__(**kwargs) self.pos_margin = pos_margin self.neg_margin = neg_margin to_record_list = [ "mean_pos_pair_dist", "mean_neg_pair_dist", "mean_pos_loss", "mean_neg_loss", "nonzero_mean_pos_loss", "nonzero_mean_neg_loss" ] for item in to_record_list: self.add_recordable_attr(name=item) def required_metric(self): return ["euclid"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple, is_same_source=False, *args, **kwargs, ) -> torch.Tensor: pos_pair, neg_pair = l_f.indices_to_pairs(metric_mat, indices_tuple) self.mean_pos_pair_dist = torch.mean(pos_pair) self.mean_neg_pair_dist = torch.mean(neg_pair) pos_loss = torch.nn.functional.relu(pos_pair - self.pos_margin) neg_loss = torch.nn.functional.relu(self.neg_margin - neg_pair) self.mean_pos_loss = torch.mean(pos_loss) self.mean_neg_loss = torch.mean(neg_loss) pos_loss = pos_loss[torch.where(pos_loss)[0]] neg_loss = neg_loss[torch.where(neg_loss)[0]] self.nonzero_mean_pos_loss = torch.mean(pos_loss) self.nonzero_mean_neg_loss = torch.mean(neg_loss) numerator = torch.sum(pos_loss) + torch.sum(neg_loss) denominator = len(pos_loss) + len(neg_loss) loss = numerator / denominator return loss