Source code for gedml.core.collectors.iteration_collectors.simsiam_collector

import torch
import torch.nn as nn
import math
from copy import deepcopy
import logging
import torch.distributed as dist 

from ..base_collector import BaseCollector
from ...misc import utils
from ...models import BatchNormMLP

[docs]class SimSiamCollector(BaseCollector): """ Paper: `Exploring Simple Siamese Representation Learning <https://arxiv.org/abs/2011.10566>`_ This method use none of the following to learn meaningful representations: 1. negative sample pairs; 2. large batches; 3. momentum encoders. And a stop-gradient operation plays an essential role in preventing collapsing. """ def __init__( self, *args, **kwargs ): super(SimSiamCollector, self).__init__(*args, **kwargs) self.predictor = BatchNormMLP( layer_size_list=[2048, 512, 2048], relu_list=[True, False], bn_list=[True, False] )
[docs] def forward(self, data, embeddings, labels) -> tuple: """ For simplicity, two data streams will be combined together and be passed through ``embeddings`` parameter. In function ``collect``, two data streams will be split (first half for first stream; second half for second stream). Args: data (torch.Tensor): A batch of key images (**not used**). size: :math:`B \\times C \\times H \\times W` embeddings (torch.Tensor): A batch of query embeddings. size: :math:`2B \\times dim` labels (torch.Tensor): Labels of the input. size: :math:`2B \\times 1` """ # split two streams N = embeddings.size(0) assert N % 2 == 0 z1, z2 = embeddings[:N//2], embeddings[N//2:] labels = labels[:N//2] # compute p p1 = self.predictor(z1) p2 = self.predictor(z2) metric_mat = 0.5 * ( self.metric(p1, z2.detach()) + self.metric(p2, z1.detach()) ) return ( metric_mat, labels.unsqueeze(1), labels.unsqueeze(0), False )