selectors

This module is aimed at How to select samples to provide more information for training.

method

description

BaseSelector

Base class.

DefaultSelector

Do nothing.

DenseTripletSelector

Select all triples.

DensePairSelector

Select all pairs.

DistanceWeightedSelector

Distance weighted selector

SemiHardSelector

Semi-hard selector

RandomTripletSelector

Randomly select triplets

HardSelector

Hardest hard sample

Class

BaseSelector

class gedml.core.selectors.base_selector.BaseSelector(**kwargs)[source]

Bases: gedml.core.modules.with_recorder.WithRecorder

Base class of selectors.

forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple[source]
Parameters
  • metric_mat (torch.Tensor) – Metric matrix.

  • row_labels (torch.Tensor) – Labels of rows.

  • col_labels (torch.Tensor) – Labels of columns.

  • is_same_source (bool) – Whether the two data streams are from the same source.

Returns

Five type of elements:

  1. metric_mat (torch.Tensor): Metric matrix.

  2. labels_row (torch.Tensor): Labels of rows.

  3. labels_col (torch.Tensor): Labels of columns.

  4. is_same_source (bool): Whether the two tensors are from the same source.

  5. indices_tuple (dict): Dict that has two key: “tuples” and “flags”

  6. weights (torch.Tensor): Weights.

Return type

tuple

DefaultSelector

class gedml.core.selectors.default_selector.DefaultSelector(**kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

Do nothing selector.

forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple[source]

Do nothing.

DenseTripletSelector

class gedml.core.selectors.dense_triplet_selector.DenseTripletSelector(**kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

Select all triplets.

forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple[source]

Select all triplets.

DensePairSelector

class gedml.core.selectors.dense_pair_selector.DensePairSelector(**kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

Select all pairs.

forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple[source]

Select all pairs.

DistanceWeightedSelector

class gedml.core.selectors.distance_weighted_selector.DistanceWeightedSelector(lower_cutoff=0.5, upper_cutoff=1.4, embedding_dim=512, **kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

Distance weighted sampling method, euclidean distance metric is required.

paper: Sampling Matters in Deep Embedding Learning

forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple[source]

Randomly select a positive sample for anchor sample and select a negative sample for anchor sample according to the distance weighted probability.

The distribution of pairwise distances follows:

\(q(d) \propto d^{n-2} [1 - \frac{1}{4} d^2 ]^{\frac{n-3}{2}}\)

SemiHardSelector

class gedml.core.selectors.semi_hard_selector.SemiHardSelector(margin=0.2, **kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

Semi-hard sampling method, euclidean distance metric is required.

forward(metric_mat, row_labels, col_labels, is_same_source=False)[source]

Randomly select a positive sample and select a negative sample holds:

\(d_p < d_n < d_p + margin\)

RandomTripletSelector

class gedml.core.selectors.random_triplet_selector.RandomTripletSelector(**kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

Semi-hard sampling method, euclidean distance metric is required.

forward(metric_mat, row_labels, col_labels, is_same_source=False)[source]

Randomly select a positive sample and a negative sample

HardSelector

class gedml.core.selectors.hard_selector.HardSelector(hardneg_cutoff=0.5, is_similarity=False, **kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

A self-defined selector according to distance weighted sampling method

forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple[source]
Parameters
  • metric_mat (torch.Tensor) – Metric matrix.

  • row_labels (torch.Tensor) – Labels of rows.

  • col_labels (torch.Tensor) – Labels of columns.

  • is_same_source (bool) – Whether the two data streams are from the same source.

Returns

Five type of elements:

  1. metric_mat (torch.Tensor): Metric matrix.

  2. labels_row (torch.Tensor): Labels of rows.

  3. labels_col (torch.Tensor): Labels of columns.

  4. is_same_source (bool): Whether the two tensors are from the same source.

  5. indices_tuple (dict): Dict that has two key: “tuples” and “flags”

  6. weights (torch.Tensor): Weights.

Return type

tuple

HardPairSelector

class gedml.core.selectors.hard_pair_selector.HardPairSelector(hardneg_cutoff=0.5, is_similarity=False, **kwargs)[source]

Bases: gedml.core.selectors.base_selector.BaseSelector

A self-defined selector according to distance weighted sampling method

forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple[source]
Parameters
  • metric_mat (torch.Tensor) – Metric matrix.

  • row_labels (torch.Tensor) – Labels of rows.

  • col_labels (torch.Tensor) – Labels of columns.

  • is_same_source (bool) – Whether the two data streams are from the same source.

Returns

Five type of elements:

  1. metric_mat (torch.Tensor): Metric matrix.

  2. labels_row (torch.Tensor): Labels of rows.

  3. labels_col (torch.Tensor): Labels of columns.

  4. is_same_source (bool): Whether the two tensors are from the same source.

  5. indices_tuple (dict): Dict that has two key: “tuples” and “flags”

  6. weights (torch.Tensor): Weights.

Return type

tuple