Source code for gedml.core.collectors.iteration_collectors.dvml_collector

import torch
import torch.nn.functional as F 
import torch.distributions as distributions

from ..base_collector import BaseCollector

[docs]class DVMLCollector(BaseCollector): """ Paper: `Deep Variational Metric Learning <https://openaccess.thecvf.com/content_ECCV_2018/html/Xudong_Lin_Deep_Variational_Metric_ECCV_2018_paper.html>`_ Four losses: 1. loss_kl: KL divergence between learned distribution and isotropic multivariate Gaussian 2. loss_recon: reconstruction loss of original images and images generated by the decoder 3. metric learning loss of learned intra-class invariance 4. metric learning loss of the combination of sampled intra-class variance and learned intra-class invariance Default parameters recommended in the paper: 1. lr = 0.0001 2. T = 20 (for sample generation) 3. batch_size = 128 (pair-based) or 120 (triplet-based) There are two phases during training: 1. first phase: cut off the back-propagation of the gradients from the decoder network: ``lambda_1`` = 1, ``lambda_2`` = 1, ``lambda_3`` = 0.1, ``lambda_4`` = 1, 2. second phase: release the constraint: ``lambda_1`` = 0.8, ``lambda_2`` = 1, ``lambda_3`` = 0.2, ``lambda_4`` = 0.8, Args: embedder_mean (torch.nn.Module): multi-layer perceptron embedder_std (torch.nn.Module): multi-layer perceptron decoder (torch.nn.Module): multi-layer perceptron T (int): default: 20 phase (int): 1 for ``first phase`` and 2 for ``second phase`` lambda_1 (int): first phase: 1; second phase: 0.8 lambda_2 (int): first phase: 1; second phase: 1 lambda_3 (int): first phase: 0.1; second phase: 0.2 lambda_4 (int): first phase: 1; second phase: 0.8 """ def __init__( self, embedder_mean, embedder_std, decoder, T=20, phase=1, lambda_1=None, lambda_2=None, lambda_3=None, lambda_4=None, *args, **kwargs ): super(DVMLCollector, self).__init__(*args, **kwargs) self.embedder_mean = embedder_mean self.embedder_std = embedder_std self.decoder = decoder self.T = T self.phase = phase self.lambda_1 = lambda_1 self.lambda_2 = lambda_2 self.lambda_3 = lambda_3 self.lambda_4 = lambda_4 self._initiate_default_lambda() def _initiate_default_lambda(self): """ Set parameters ``lambda_x`` """ if self.phase == 1: self.lambda_1 = 1 if self.lambda_1 is None else self.lambda_1 self.lambda_2 = 1 if self.lambda_2 is None else self.lambda_2 self.lambda_3 = 0.1 if self.lambda_3 is None else self.lambda_3 self.lambda_4 = 1 if self.lambda_4 is None else self.lambda_4 elif self.phase == 2: self.lambda_1 = 1 if self.lambda_1 is None else self.lambda_1 self.lambda_2 = 1 if self.lambda_2 is None else self.lambda_2 self.lambda_3 = 0.1 if self.lambda_3 is None else self.lambda_3 self.lambda_4 = 1 if self.lambda_4 is None else self.lambda_4 else: raise KeyError("parameter 'phase' must be 1 or 2!")
[docs] def update(self, trainer): pass
[docs] def forward( self, data, embeddings, features, labels ) -> tuple: """ Four losses should be computed in function ``collect``: :math:`loss_{total} = \lambda_1 \\times loss_{kl} + \lambda_2 \\times loss_{recon} + \lambda_3 \\times loss_{syn} + \lambda_4 \\times loss_{invariant}` :math:`loss_{kl} = KL(p_{dist}, q_{dist})` :math:`loss_{recon} = mean(|f_{decode} - f_{ori}|^2_2)` :math:`loss_{syn} = loss_{metric}(matrix_{syn}, labels)` :math:`loss_{invariant} = loss_{metric}(matrix_{inv}, labels)` """ embedding_size = embeddings.size(1) feature_size = features.size(1) is_same_source = True # compute mu and std mu = self.embedder_mean(features) log_var = self.embedder_std(features) std = torch.exp(log_var * 0.5) q_distributions = distributions.Normal(mu, std) p_distributions = distributions.Normal(0, 1) # for discriminate intra-class invariant features: loss_invariant metric_mat_inv = self.metric(embeddings, embeddings) row_labels = labels.unsqueeze(1) col_labels = labels.unsqueeze(0) # loss_kl loss_kl = distributions.kl_divergence(q_distributions, p_distributions).sum(-1).mean() # two phase if self.phase == 1: embeddings_var = q_distributions.sample() embeddings_syn = embeddings + embeddings_var features_decode = self.decoder(embeddings_syn.detach()) elif self.phase == 2: embeddings_var = q_distributions.sample((self.T,)) embeddings_syn = embeddings.unsqueeze(0) + embeddings_var features_decode = self.decoder( embeddings_syn.view(-1, embedding_size) ).view(self.T, -1, feature_size) # use the first synthetic embeddings_syn = embeddings_syn[0] # loss_recon loss_recon = F.mse_loss(features_decode, features) # for discriminate synthesized features metric_mat_syn = self.metric(embeddings_syn, embeddings_syn) return ( metric_mat_inv, row_labels, col_labels, is_same_source, metric_mat_syn, loss_kl, loss_recon, self.lambda_1, self.lambda_2, self.lambda_3, self.lambda_4 )