trainers

This module takes charge of training.

Class

BaseTrainer

class gedml.launcher.trainers.base_trainer.BaseTrainer(batch_size, wrapper_params, freeze_trunk_batchnorm=False, dataset_num_workers=8, is_distributed=False)[source]

Bases: object

BaseTrainer takes charge of training.

storage holds all the intermediate variables such as data, embeddings, etc. All other modules, such as collectors, selectors, etc, access thier own parameter list to get the corresponding parameters.

loss_handler will compute weighted loss.

Parameters
  • batch_size (int) – Batch size.

  • freeze_trunk_batchnorm (bool) – Whether to freeze batch normalization layer.

  • dataset_num_workers (int) – Number of process to load data.

  • is_distributed (bool) – Whether to use distibuted mode.

prepare(collectors, selectors, losses, models, optimizers, datasets, schedulers=None, gradclipper=None, samplers=None, collatefns=None, device=None, recorders=None)[source]

Load modules to prepare training.

step_schedulers(**kwargs)[source]

All schedulers step at the end of each epoch for the moment

train(epochs=None)[source]

Start training.

Parameters

epochs (int) – Epoch to start.