Source code for gedml.launcher.trainers.base_trainer

import torch
import torch.distributed as dist
import logging
from tqdm import tqdm

from ..misc import utils, Storage
from .loss_handler import LossHandler
from ...config.setting.recorder_setting import (
    TO_RECORD_LIST,
    TO_SAVE_LIST
)

[docs]class BaseTrainer: """ ``BaseTrainer`` takes charge of training. ``storage`` holds all the intermediate variables such as ``data``, ``embeddings``, etc. All other modules, such as ``collectors``, ``selectors``, etc, access thier own parameter list to get the corresponding parameters. ``loss_handler`` will compute weighted loss. Args: batch_size (int): Batch size. freeze_trunk_batchnorm (bool): Whether to freeze batch normalization layer. dataset_num_workers (int): Number of process to load data. is_distributed (bool): Whether to use distibuted mode. """ def __init__( self, batch_size, wrapper_params, freeze_trunk_batchnorm=False, dataset_num_workers=8, is_distributed=False, ): self.batch_size = batch_size self.freeze_trunk_batchnorm = freeze_trunk_batchnorm self.dataset_num_workers = dataset_num_workers self.is_distributed = is_distributed self.epochs = 0 self.storage = Storage(wrapper_params) self.initiate_property() """ Initialization """ def initiate_property(self): # to save list (for recorders) setattr( self, TO_SAVE_LIST, ["models", "collectors", "losses"] ) # recordable list self.recordable_object_list = ["models"] # trainable list self.trainable_object_list = [ "models", "collectors" ] @property def models_forward_order(self): return ["trunk", "embedder"] """ Set and Get """ def set_batch_size(self, batch_size): self.batch_size = batch_size def set_distributed(self, flag=True): self.is_distributed = flag def set_activated_optims(self, optim_list=None): self.optim_list = ( list(self.optimizers.keys()) if optim_list is None else optim_list ) logging.info("Set activated optims: {}".format(self.optim_list)) """ Train """
[docs] def prepare( self, collectors, selectors, losses, models, optimizers, datasets, schedulers=None, gradclipper=None, samplers=None, collatefns=None, device=None, recorders=None ): """ Load modules to prepare training. """ # pass parameters self.collectors = collectors self.selectors = selectors self.losses = losses self.models = models self.optimizers = optimizers self.datasets = datasets self.schedulers = schedulers self.gradclipper = gradclipper self.samplers = samplers self.collatefns = collatefns self.device = device self.loss_handler = LossHandler() self.recorders = utils.get_default(recorders, "recorders") self.optim_list = list(self.optimizers.keys())
[docs] def train(self, epochs=None): """ Start training. Args: epochs (int): Epoch to start. """ # start to train if epochs is not None: self.epochs = epochs self.initiate_dataloader() logging.info("TRAIN EPOCH: {}".format(self.epochs)) self.pipeline() self.epochs += 1
# self.release_memory() def pipeline(self): # use global collector self.use_global_collector() self.set_to_train() self.show_lr() # start training self.pbar = tqdm(range(self.iterations_per_epoch)) for self.iteration in self.pbar: self.prepare_forward() self.forward_pipeline() self.backward_and_update() self.update_record(self.recorders) self.pbar.set_description( "Loss={:.4f}".format( self.loss_handler.get_total(), ) ) self.loss_handler.average_losses() self.step_schedulers(metrics=self.loss_handler.get_total()) def show_lr(self): default_idx = 0 for k, v in self.optimizers.items(): lr = v.param_groups[default_idx]["lr"] logging.info("{} optimizer's lr: {}".format(k, lr)) def forward_pipeline(self): self.forward_models() self.forward_collectors() self.forward_selectors() self.forward_losses() def initiate_dataloader(self): logging.info( "{}: Initiating dataloader".format( self.__class__.__name__ ) ) # more initialization self.sampler, self.collate_fn = None, None self._initiate_dataloader_sampler_collatefn() # get dataloader self.dataloader = torch.utils.data.DataLoader( dataset=self.datasets["train"], batch_size=int(self.batch_size), sampler=self.sampler, drop_last=True, pin_memory=False, shuffle=self.samplers is None, num_workers=self.dataset_num_workers, collate_fn=self.collate_fn ) self.iterations_per_epoch = len(self.dataloader) self.dataloader_iter = iter(self.dataloader) def _initiate_dataloader_sampler_collatefn(self): # extract the sampler if self.is_distributed: if self.samplers is not None: logging.warn("class samplers can't be used in distribued training mode!") self.samplers = {} self.samplers["train"] = torch.utils.data.distributed.DistributedSampler( dataset=self.datasets["train"], shuffle=True, ) self.sampler = self.samplers["train"] logging.info("Get distributed sampler {}".format( self.sampler.__class__.__name__, )) else: self.sampler = ( self.samplers["train"] if self.samplers is not None else None ) # extract the collate_fn self.collate_fn = ( self.collatefns["train"] if self.collatefns is not None else None ) def set_to_train(self): for trainable_name in self.trainable_object_list: trainable_object = getattr(self, trainable_name, None) if trainable_object is None: logging.warn( "{} is not a member of trainer".format( trainable_name ) ) else: for v in trainable_object.values(): v.train() # maybe some parameters should be frozen self._maybe_freeze_batchnorm() def _maybe_freeze_batchnorm(self): if self.freeze_trunk_batchnorm: self.models["trunk"].apply( utils.set_layers_to_eval("BatchNorm") ) def use_global_collector(self): for collector in self.collectors.values(): if collector.is_global_collector: logging.info("Global collector updating...") collector.global_update(self) self.release_memory() def prepare_forward(self): # set sampler self._prepare_forward_set_sampler() # zero loss-values self.loss_handler.zero_losses() # zero grad for v in self.models.values(): v.zero_grad() for v in self.optimizers.values(): v.zero_grad() # get batch self._prepare_forward_get_batch() def _prepare_forward_set_sampler(self): if self.is_distributed: self.sampler.set_epoch(self.epochs) @property def info_dict_to_device(self): return [ "data", "labels" ] def _prepare_forward_get_batch(self): info_dict = next(self.dataloader_iter) for key in info_dict.keys(): setattr(self.storage, key, info_dict[key]) self.storage.tensors_to_device(self.info_dict_to_device, self.device) def forward_models(self): # get data and labels data = self.storage.get("data") labels = self.storage.get("labels") # forward backbone (trunk model) features = self.models["trunk"](data) # if distributed if self.is_distributed: features, labels = utils.distributed_gather_objects( features, labels ) self.storage.features = features self.storage.labels = labels self.storage.indices_dict["models"] = {"embedder": {"": {}}} self.storage.update(self.models["embedder"], cur_module="models") def forward_collectors(self): # update collector self.update_collectors() # forward collector self.storage.update(self.collectors, cur_module="collectors") def update_collectors(self): for collector in self.collectors.values(): if isinstance(collector, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): collector.module.update(self) else: collector.update(self) def forward_selectors(self): self.storage.update(self.selectors, cur_module="selectors") def forward_losses(self): value_dict = self.storage.update(self.losses, cur_module="losses") # update loss_values self.loss_handler.update_losses(self.storage.return_loss_dict()) def backward_and_update(self): self._backward_and_update_preprocess() self.loss_handler.backward() # clip gradients if self.gradclipper is not None: for v in self.gradclipper.values(): v() # step optimizers for k in self.optim_list: self.optimizers[k].step() def _backward_and_update_preprocess(self): # for distributed if self.is_distributed: # dist.barrier() dist.all_reduce( tensor=self.loss_values["total_loss"], op=dist.ReduceOp.SUM )
[docs] def step_schedulers(self, **kwargs): """ All schedulers step at the end of each epoch for the moment """ if self.schedulers is not None: for k in self.optim_list: v = self.schedulers[k] if isinstance(v, torch.optim.lr_scheduler.ReduceLROnPlateau): v.step(**kwargs) else: v.step()
def update_record(self, recorders=None): if recorders is not None: total_iterations = ( self.epochs * self.iterations_per_epoch + self.iteration ) # update loss self.loss_handler.record_losses(recorders, total_iterations) # update other statistics for recordable_name in self.recordable_object_list: recordable_object = getattr(self, recordable_name) for k, v in recordable_object.items(): if k == "trunk": # TODO: continue to_record_obj = ( v.module if isinstance(v, torch.nn.DataParallel) or isinstance(v, torch.nn.parallel.DistributedDataParallel) else v ) data, _ = recorders.get_data(to_record_obj, k) recorders.update(data, total_iterations) def release_memory(self): torch.cuda.empty_cache()