import torch
import torch.nn.functional as F
import torch.distributions as distributions
from ..base_collector import BaseCollector
[docs]class DVMLCollector(BaseCollector):
"""
Paper: `Deep Variational Metric Learning <https://openaccess.thecvf.com/content_ECCV_2018/html/Xudong_Lin_Deep_Variational_Metric_ECCV_2018_paper.html>`_
Four losses:
1. loss_kl: KL divergence between learned distribution and isotropic multivariate Gaussian
2. loss_recon: reconstruction loss of original images and images generated by the decoder
3. metric learning loss of learned intra-class invariance
4. metric learning loss of the combination of sampled intra-class variance and learned intra-class invariance
Default parameters recommended in the paper:
1. lr = 0.0001
2. T = 20 (for sample generation)
3. batch_size = 128 (pair-based) or 120 (triplet-based)
There are two phases during training:
1. first phase: cut off the back-propagation of the gradients from the decoder network:
``lambda_1`` = 1,
``lambda_2`` = 1,
``lambda_3`` = 0.1,
``lambda_4`` = 1,
2. second phase: release the constraint:
``lambda_1`` = 0.8,
``lambda_2`` = 1,
``lambda_3`` = 0.2,
``lambda_4`` = 0.8,
Args:
embedder_mean (torch.nn.Module): multi-layer perceptron
embedder_std (torch.nn.Module): multi-layer perceptron
decoder (torch.nn.Module): multi-layer perceptron
T (int): default: 20
phase (int): 1 for ``first phase`` and 2 for ``second phase``
lambda_1 (int): first phase: 1; second phase: 0.8
lambda_2 (int): first phase: 1; second phase: 1
lambda_3 (int): first phase: 0.1; second phase: 0.2
lambda_4 (int): first phase: 1; second phase: 0.8
"""
def __init__(
self,
embedder_mean,
embedder_std,
decoder,
T=20,
phase=1,
lambda_1=None,
lambda_2=None,
lambda_3=None,
lambda_4=None,
*args,
**kwargs
):
super(DVMLCollector, self).__init__(*args, **kwargs)
self.embedder_mean = embedder_mean
self.embedder_std = embedder_std
self.decoder = decoder
self.T = T
self.phase = phase
self.lambda_1 = lambda_1
self.lambda_2 = lambda_2
self.lambda_3 = lambda_3
self.lambda_4 = lambda_4
self._initiate_default_lambda()
def _initiate_default_lambda(self):
"""
Set parameters ``lambda_x``
"""
if self.phase == 1:
self.lambda_1 = 1 if self.lambda_1 is None else self.lambda_1
self.lambda_2 = 1 if self.lambda_2 is None else self.lambda_2
self.lambda_3 = 0.1 if self.lambda_3 is None else self.lambda_3
self.lambda_4 = 1 if self.lambda_4 is None else self.lambda_4
elif self.phase == 2:
self.lambda_1 = 1 if self.lambda_1 is None else self.lambda_1
self.lambda_2 = 1 if self.lambda_2 is None else self.lambda_2
self.lambda_3 = 0.1 if self.lambda_3 is None else self.lambda_3
self.lambda_4 = 1 if self.lambda_4 is None else self.lambda_4
else:
raise KeyError("parameter 'phase' must be 1 or 2!")
[docs] def update(self, trainer):
pass
[docs] def forward(
self,
data,
embeddings,
features,
labels
) -> tuple:
"""
Four losses should be computed in function ``collect``:
:math:`loss_{total} = \lambda_1 \\times loss_{kl} + \lambda_2 \\times loss_{recon} + \lambda_3 \\times loss_{syn} + \lambda_4 \\times loss_{invariant}`
:math:`loss_{kl} = KL(p_{dist}, q_{dist})`
:math:`loss_{recon} = mean(|f_{decode} - f_{ori}|^2_2)`
:math:`loss_{syn} = loss_{metric}(matrix_{syn}, labels)`
:math:`loss_{invariant} = loss_{metric}(matrix_{inv}, labels)`
"""
embedding_size = embeddings.size(1)
feature_size = features.size(1)
is_same_source = True
# compute mu and std
mu = self.embedder_mean(features)
log_var = self.embedder_std(features)
std = torch.exp(log_var * 0.5)
q_distributions = distributions.Normal(mu, std)
p_distributions = distributions.Normal(0, 1)
# for discriminate intra-class invariant features: loss_invariant
metric_mat_inv = self.metric(embeddings, embeddings)
row_labels = labels.unsqueeze(1)
col_labels = labels.unsqueeze(0)
# loss_kl
loss_kl = distributions.kl_divergence(q_distributions, p_distributions).sum(-1).mean()
# two phase
if self.phase == 1:
embeddings_var = q_distributions.sample()
embeddings_syn = embeddings + embeddings_var
features_decode = self.decoder(embeddings_syn.detach())
elif self.phase == 2:
embeddings_var = q_distributions.sample((self.T,))
embeddings_syn = embeddings.unsqueeze(0) + embeddings_var
features_decode = self.decoder(
embeddings_syn.view(-1, embedding_size)
).view(self.T, -1, feature_size)
# use the first synthetic
embeddings_syn = embeddings_syn[0]
# loss_recon
loss_recon = F.mse_loss(features_decode, features)
# for discriminate synthesized features
metric_mat_syn = self.metric(embeddings_syn, embeddings_syn)
return (
metric_mat_inv,
row_labels,
col_labels,
is_same_source,
metric_mat_syn,
loss_kl,
loss_recon,
self.lambda_1,
self.lambda_2,
self.lambda_3,
self.lambda_4
)