Source code for gedml.core.collectors.base_collector

import torch
from abc import ABCMeta, abstractmethod

from ..modules import WithRecorder

[docs]class BaseCollector(WithRecorder, metaclass=ABCMeta): """ Base class of collector module, which defines main collector method in function ``collect`` and ``update``, and defines default parameters in function ``output_list``, ``input_list`` and ``_default_next_module``. Args: metric (metric instance): metric to compute matrix (e.g. euclidean or cosine) Example: >>> metric = MetricFactory(is_normalize=True, metric_name="cosine") >>> data = torch.randn(10, 3, 227, 227) >>> embeddings = torch.randn(10, 128) >>> labels = torch.randint(0, 3, size=(10,)) >>> collector = DefaultCollector(metric=metric) >>> # collector forward >>> output_dict = collector(data, embeddings, labels) """ def __init__(self, metric, **kwargs): super().__init__(**kwargs) self.metric = metric @property def is_global_collector(self): return False
[docs] def update(self, *args, **kwargs): """ Define the interface that collector can update itself by giving specific information (default do nothing) """ pass
[docs] def forward( self, data, embeddings, labels ) -> tuple: """ In ``collect`` function, three kinds of operation may be done: 1. maintain sets of parameters about collecting (or synthesizing) samples 2. compute metric matrix and pass to next module 3. compute some regularization term using embeddings Args: data (torch.Tensor): Images with RGB channels. size: :math:`B \\times C \\times H \\times W` embeddings (torch.Tensor): Embedding. size: :math:`B \\times dim` lables (torch.Tensor): Ground truth of dataset. size: :math:`B \\times 1` Returns: tuple: include metric matrix, labels etc according to function ``output_list``. Let :math:`B_{row}` be the length of rows and :math:`B_{col}` be the length of columns, typical output type is listed below: 1. metric matrix (torch.Tensor): size: :math:`B_{row} \\times B_{col}` 2. labels of rows (torch.Tensor): size: :math:`B_{row} \\times 1` or :math:`B_{row} \\times B_{col}` 3. labels of columns (torch.Tensor): size: :math:`1 \\times B_{col}` or :math:`B_{row} \\times B_{col}` 4. is_from_same_source (bool): indicate whether row vectors and column vectors are from the same data """ metric_mat, labels, proxies_labels, is_same_source = None, None, None, False return ( metric_mat, labels, proxies_labels, is_same_source )