Source code for gedml.core.collectors.iteration_collectors.default_collector

from ..base_collector import BaseCollector

[docs]class DefaultCollector(BaseCollector): """ This is the default collector which directly computes metric matrix using embeddings. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def forward(self, data, embeddings, labels) -> tuple: """ Do nothing. Copy embeddings as proxies and copy labels as proxies labels. """ proxies = embeddings proxies_labels = labels metric_mat = self.metric(embeddings, proxies) is_same_source = True return ( metric_mat, labels.unsqueeze(-1), proxies_labels.unsqueeze(0), is_same_source )