Source code for gedml.core.selectors.dense_triplet_selector

import numpy as np
import torch

from .base_selector import BaseSelector
from ...config.setting.core_setting import (
    INDICES_TUPLE,
    INDICES_FLAG
)

def get_all_triplets_indices(row_labels, col_labels=None, is_same_source=False):
    if col_labels is None:
        col_labels = row_labels.t()
    matches = (row_labels == col_labels).byte()
    diffs = matches ^ 1
    if is_same_source:
        matches.fill_diagonal_(0)
    triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
    return torch.stack(torch.where(triplets), dim=1)


[docs]class DenseTripletSelector(BaseSelector): """ Select all triplets. """
[docs] def forward( self, metric_mat, row_labels, col_labels, is_same_source=False ) -> tuple: """ Select all triplets. """ weight = None tuples = get_all_triplets_indices(row_labels, col_labels, is_same_source) indices_tuple = {INDICES_TUPLE:tuples, INDICES_FLAG:None} return metric_mat, row_labels, col_labels, is_same_source, indices_tuple, weight