selectors¶
This module is aimed at How to select samples to provide more information for training.
method |
description |
---|---|
BaseSelector |
Base class. |
DefaultSelector |
Do nothing. |
DenseTripletSelector |
Select all triples. |
DensePairSelector |
Select all pairs. |
DistanceWeightedSelector |
Distance weighted selector |
SemiHardSelector |
Semi-hard selector |
RandomTripletSelector |
Randomly select triplets |
HardSelector |
Hardest hard sample |
Class¶
BaseSelector¶
- class gedml.core.selectors.base_selector.BaseSelector(**kwargs)[source]¶
Bases:
gedml.core.modules.with_recorder.WithRecorder
Base class of
selectors
.- forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple [source]¶
- Parameters
metric_mat (torch.Tensor) – Metric matrix.
row_labels (torch.Tensor) – Labels of rows.
col_labels (torch.Tensor) – Labels of columns.
is_same_source (bool) – Whether the two data streams are from the same source.
- Returns
Five type of elements:
metric_mat (torch.Tensor): Metric matrix.
labels_row (torch.Tensor): Labels of rows.
labels_col (torch.Tensor): Labels of columns.
is_same_source (bool): Whether the two tensors are from the same source.
indices_tuple (dict): Dict that has two key: “tuples” and “flags”
weights (torch.Tensor): Weights.
- Return type
tuple
DefaultSelector¶
- class gedml.core.selectors.default_selector.DefaultSelector(**kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
Do nothing selector.
DenseTripletSelector¶
- class gedml.core.selectors.dense_triplet_selector.DenseTripletSelector(**kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
Select all triplets.
DensePairSelector¶
- class gedml.core.selectors.dense_pair_selector.DensePairSelector(**kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
Select all pairs.
DistanceWeightedSelector¶
- class gedml.core.selectors.distance_weighted_selector.DistanceWeightedSelector(lower_cutoff=0.5, upper_cutoff=1.4, embedding_dim=512, **kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
Distance weighted sampling method, euclidean distance metric is required.
paper: Sampling Matters in Deep Embedding Learning
- forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple [source]¶
Randomly select a positive sample for anchor sample and select a negative sample for anchor sample according to the distance weighted probability.
The distribution of pairwise distances follows:
\(q(d) \propto d^{n-2} [1 - \frac{1}{4} d^2 ]^{\frac{n-3}{2}}\)
SemiHardSelector¶
- class gedml.core.selectors.semi_hard_selector.SemiHardSelector(margin=0.2, **kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
Semi-hard sampling method, euclidean distance metric is required.
RandomTripletSelector¶
- class gedml.core.selectors.random_triplet_selector.RandomTripletSelector(**kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
Semi-hard sampling method, euclidean distance metric is required.
HardSelector¶
- class gedml.core.selectors.hard_selector.HardSelector(hardneg_cutoff=0.5, is_similarity=False, **kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
A self-defined selector according to distance weighted sampling method
- forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple [source]¶
- Parameters
metric_mat (torch.Tensor) – Metric matrix.
row_labels (torch.Tensor) – Labels of rows.
col_labels (torch.Tensor) – Labels of columns.
is_same_source (bool) – Whether the two data streams are from the same source.
- Returns
Five type of elements:
metric_mat (torch.Tensor): Metric matrix.
labels_row (torch.Tensor): Labels of rows.
labels_col (torch.Tensor): Labels of columns.
is_same_source (bool): Whether the two tensors are from the same source.
indices_tuple (dict): Dict that has two key: “tuples” and “flags”
weights (torch.Tensor): Weights.
- Return type
tuple
HardPairSelector¶
- class gedml.core.selectors.hard_pair_selector.HardPairSelector(hardneg_cutoff=0.5, is_similarity=False, **kwargs)[source]¶
Bases:
gedml.core.selectors.base_selector.BaseSelector
A self-defined selector according to distance weighted sampling method
- forward(metric_mat, row_labels, col_labels, is_same_source=False) tuple [source]¶
- Parameters
metric_mat (torch.Tensor) – Metric matrix.
row_labels (torch.Tensor) – Labels of rows.
col_labels (torch.Tensor) – Labels of columns.
is_same_source (bool) – Whether the two data streams are from the same source.
- Returns
Five type of elements:
metric_mat (torch.Tensor): Metric matrix.
labels_row (torch.Tensor): Labels of rows.
labels_col (torch.Tensor): Labels of columns.
is_same_source (bool): Whether the two tensors are from the same source.
indices_tuple (dict): Dict that has two key: “tuples” and “flags”
weights (torch.Tensor): Weights.
- Return type
tuple