Source code for gedml.core.losses.classifier_based_loss.cross_entropy_loss

import torch
import torch.nn.functional as F 
from ..base_loss import BaseLoss

[docs]class CrossEntropyLoss(BaseLoss): """ paper: `Momentum Contrast for Unsupervised Visual Representation Learning <http://openaccess.thecvf.com/content_CVPR_2020/html/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.html>`_ Cross-entropy loss designed for MoCo. """ def __init__( self, **kwargs ): super().__init__(**kwargs) def required_metric(self): return ["moco"]
[docs] def compute_loss( self, metric_mat, row_labels, col_labels, is_same_source=False, *args, **kwargs ) -> torch.Tensor: dtype, device = metric_mat.dtype, metric_mat.device pos_mask = (row_labels == col_labels).type(dtype).to(device) pos_index = torch.where(pos_mask)[1] loss = F.cross_entropy(metric_mat, pos_index) return loss