Source code for gedml.core.selectors.semi_hard_selector

import numpy as np 
import torch

from .base_selector import BaseSelector
from ...config.setting.core_setting import (
    INDICES_TUPLE,
    INDICES_FLAG
)

[docs]class SemiHardSelector(BaseSelector): """ Semi-hard sampling method, euclidean distance metric is required. """ def __init__(self, margin=0.2, **kwargs): super(BaseSelector, self).__init__(**kwargs) self.margin = margin
[docs] def forward( self, metric_mat, row_labels, col_labels, is_same_source=False ): """ Randomly select a positive sample and select a negative sample holds: :math:`d_p < d_n < d_p + margin` """ device = metric_mat.device bs = metric_mat.size(0) # pos and neg mask matches = (row_labels == col_labels).byte() diffs = matches ^ 1 if is_same_source: matches.fill_diagonal_(0) has_pos_mask = torch.where( torch.sum(matches, dim=-1) > 0 )[0] a_ids = torch.arange(bs)[has_pos_mask].to(device) # select positive samples p_ids = torch.multinomial( input=matches.float()[has_pos_mask, :], num_samples=1, replacement=True ).flatten() ap_dist = metric_mat[a_ids, p_ids].unsqueeze(1) # select negative samples has_pos_metric_mat = metric_mat[has_pos_mask] semi_hard_mask = ( (has_pos_metric_mat > ap_dist) & (has_pos_metric_mat < self.margin + ap_dist) ).byte() * diffs[has_pos_mask] nonzero_semi_hard = torch.where( torch.sum(semi_hard_mask, dim=-1) > 0 )[0] n_ids = torch.multinomial( input=semi_hard_mask[nonzero_semi_hard, :].float(), num_samples=1, replacement=True ).flatten() a_ids = a_ids[nonzero_semi_hard] p_ids = p_ids[nonzero_semi_hard] indices_tuple = { INDICES_TUPLE: torch.stack([a_ids, p_ids, n_ids], dim=1), INDICES_FLAG: None } weight = None return ( metric_mat, row_labels, col_labels, is_same_source, indices_tuple, weight, )