Source code for gedml.core.collectors.epoch_collectors.global_proxy_collector

import torch
import logging
from tqdm import tqdm
import torch.nn.functional as F 

from .. import ProxyCollector
from ._default_global_collector import _DefaultGlobalCollector

[docs]class GlobalProxyCollector(ProxyCollector, _DefaultGlobalCollector): """ Compute the global proxies before updating other parameters. """ def __init__( self, optimizer_name="Adam", optimizer_param={"lr": 0.001}, dataloader_param={"batch_size": 120, "drop_last": False, "shuffle": True, "num_workers": 8}, max_iter=50000, error_bound=1e-3, total_patience=10, auth_weight=1.0, repre_weight=1.0, disc_weight=1.0, *args, **kwargs ): super().__init__(*args, **kwargs) self.optimizer_name = optimizer_name self.optimizer_param = optimizer_param self.dataloader_param = dataloader_param self.max_iter = max_iter self.error_bound = error_bound self.total_patience = total_patience self.auth_weight = auth_weight self.repre_weight = repre_weight self.disc_weight = disc_weight self.datasets, self.models = None, None def global_update(self, trainer): logging.info("Start to compute global embeddings") self.prepare(trainer) self.compute_global_embedding() self.compute_global_proxy() def prepare(self, trainer): _DefaultGlobalCollector.prepare(self, trainer) # get optimizer self.optimizer_param["params"] = [self.proxies] self.optimizer = getattr( torch.optim, self.optimizer_name, )(**self.optimizer_param) def compute_global_proxy(self): pre_loss_value = -100 patience_count = 0 pbar = tqdm() index = 0 while(True): # update proxies index += 1 self.optimizer.zero_grad() cur_loss = self.loss_func(self.embeddings, self.labels) cur_loss.backward() self.optimizer.step() # raise error when get max_iter if index > self.max_iter: logging.warn("Reach MAX_ITER!") raise RuntimeError() # termination condition cur_error = abs(cur_loss.item() - pre_loss_value) pbar.set_description("CurrError={}, index={}".format( cur_error, index )) pre_loss_value = cur_loss.item() if cur_error < self.error_bound: patience_count += 1 if patience_count > self.total_patience: pbar.close() logging.info("Have reached the solution of proxies!") break else: patience_count = 0 def loss_func(self, embeddings, labels): dtype = embeddings.dtype # normalize embeddings = F.normalize(embeddings, dim=-1, p=2) proxies = F.normalize(self.proxies, dim=-1, p=2) # compute cosine matrix metric_mat = torch.matmul(embeddings, proxies.t()) pos_mask = (labels.unsqueeze(1) == self.proxy_labels.unsqueeze(0)).byte() # compute loss-auth masked_metric_mat = metric_mat * pos_mask masked_metric_mat[pos_mask == 0] = torch.finfo(dtype).min loss_auth = torch.mean(torch.max(masked_metric_mat, dim=0)[0]) # compute loss-repre loss_repre = torch.mean(torch.max(masked_metric_mat, dim=1)[0]) if self.centers_per_class > 1 and self.disc_weight > 0: # compute loss-disc proxy_metric_mat = torch.matmul(proxies, proxies.t()) proxy_pos_mask = (self.proxy_labels.unsqueeze(1) == self.proxy_labels.unsqueeze(0)).byte() proxy_neg_mask = proxy_pos_mask ^ 1 proxy_pos_mask.fill_diagonal_(0) ## compute pos metric pos_sum_exp = torch.sum( torch.exp( - proxy_metric_mat ) * proxy_pos_mask, dim=-1 ) neg_sum_exp = torch.sum( torch.exp( proxy_metric_mat ) * proxy_neg_mask, dim=-1 ) loss_disc = torch.mean( torch.log( 1 + pos_sum_exp * neg_sum_exp ) ) else: loss_disc = 0 return ( - self.auth_weight * loss_auth - self.repre_weight * loss_repre + self.disc_weight * loss_disc )