metrics

This module defines different kinds of distance (or similarity) function.

Type

Supported metrics

Distance

euclid, snr

Similarity

cosine, moco

Class

MetricFactory

class gedml.core.metrics.metric_factory.MetricFactory(is_normalize, metric_name, addition=None, **kwargs)[source]

Bases: torch.nn.modules.module.Module

Get different metric (distance or similarity)

Parameters
  • is_normalize (bool) – Whether to normalize the embeddings

  • metric_name (str) – ‘euclid’, ‘cosine’, etc.

Example

Get euclid metric:

>>> metric = MetricFactory(is_normalize=False, metric_name="euclid")
>>> data = torch.randn(100, 128)
>>> matrix = metric(data, data)
forward(*args) torch.Tensor[source]

Get metric matrix.

Parameters

*args (sequence) – Sequence which is used to compute matrix.

Returns

metric matrix.

Return type

torch.Tensor