metrics¶
This module defines different kinds of distance (or similarity) function.
Type |
Supported metrics |
---|---|
Distance |
euclid, snr |
Similarity |
cosine, moco |
Class¶
MetricFactory¶
- class gedml.core.metrics.metric_factory.MetricFactory(is_normalize, metric_name, addition=None, **kwargs)[source]¶
Bases:
torch.nn.modules.module.Module
Get different metric (distance or similarity)
- Parameters
is_normalize (bool) – Whether to normalize the embeddings
metric_name (str) – ‘euclid’, ‘cosine’, etc.
Example
Get
euclid
metric:>>> metric = MetricFactory(is_normalize=False, metric_name="euclid") >>> data = torch.randn(100, 128) >>> matrix = metric(data, data)