trainers¶
This module takes charge of training.
Class¶
BaseTrainer¶
- class gedml.launcher.trainers.base_trainer.BaseTrainer(batch_size, wrapper_params, freeze_trunk_batchnorm=False, dataset_num_workers=8, is_distributed=False)[source]¶
Bases:
objectBaseTrainertakes charge of training.storageholds all the intermediate variables such asdata,embeddings, etc. All other modules, such ascollectors,selectors, etc, access thier own parameter list to get the corresponding parameters.loss_handlerwill compute weighted loss.- Parameters
batch_size (int) – Batch size.
freeze_trunk_batchnorm (bool) – Whether to freeze batch normalization layer.
dataset_num_workers (int) – Number of process to load data.
is_distributed (bool) – Whether to use distibuted mode.