Source code for gedml.core.losses.base_loss

import torch
import logging
import torch.nn.functional as F 
from abc import abstractmethod, ABCMeta

from ..modules import WithRecorder

[docs]class BaseLoss(WithRecorder, metaclass=ABCMeta): """ Base loss module. The output of this module will be wrapped with "FINISH" flag which indicates the output doesn't need to be further processed. """ def __init__(self, **kwargs): super().__init__(**kwargs)
[docs] def forward( self, metric_mat, row_labels, col_labels, indices_tuple=None, weights=None, is_same_source=False, ) -> torch.Tensor: loss = self.compute_loss( metric_mat, row_labels, col_labels, indices_tuple=indices_tuple, weights=weights, is_same_source=is_same_source ) return loss
[docs] @abstractmethod def compute_loss( self, metric_mat, row_labels, col_labels, indices_tuple=None, weights=None, is_same_source=False, ) -> torch.Tensor: """ Compute loss value. Args: metric_mat (torch.Tensor): Metric matrix. row_labels (torch.Tensor): Labels of matrix rows. col_labels (torch.Tensor): Labels of matrix columns. indices_tuple (dict): Dict that has two keys: "tuples" and "flags" weights (torch.Tensor): Can be element-wised, tuple-wised etc. is_same_source (bool): Returns: torch.Tensor: Final loss value (a tensor value). """ return 0
@abstractmethod def required_metric(self) -> list: return []