import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import logging
import pandas as pd
import traceback
from ...core import models
from ..misc import utils
[docs]class BaseManager:
"""
Manager all modules and computation devices. Support three kinds of computation:
1. DataParallel (single machine)
2. DistributedDataParallel (single machine)
3. DistributedDataParallel (multi machines)
"""
def __init__(
self,
trainer,
tester,
recorder,
objects_dict,
device=None,
schedulers=None,
gradclipper=None,
samplers=None,
collatefns=None,
is_resume=False,
is_distributed=False,
device_wrapper_type="DP",
dist_port=23456,
world_size=None,
phase="train",
primary_metric=["test", "recall_at_1"],
to_device_list=["models", "collectors"],
to_wrap_list=["models"],
patience=10,
):
self.trainer = trainer
self.tester = tester
self.recorder = recorder
self.objects_dict = objects_dict
self.device = device
self.schedulers = schedulers
self.gradclipper = gradclipper
self.samplers = samplers
self.collatefns = collatefns
self.epochs = 0
self.is_resume = is_resume
self.is_distributed = is_distributed
self.device_wrapper_type = device_wrapper_type
self.dist_port = dist_port
self.world_size = world_size
self.phase = phase
self.primary_metric = primary_metric
self.to_device_list = to_device_list
self.to_wrap_list = to_wrap_list
self.patience = patience
self.best_metric = -1
self.patience_counts = 0
self.is_best = False
self.assert_phase()
self.assert_device()
self.assert_required_member()
self.assert_resume_folder_exist()
self.initiate_objects_dict()
self.initiate_members()
@property
def _required_member(self):
return [
"metrics",
"collectors",
"selectors",
"models",
"losses",
"evaluators",
"optimizers",
"transforms",
"datasets",
]
def assert_phase(self):
assert self.phase in ["train", "evaluate"]
def assert_device(self):
assert self.device_wrapper_type in ["DP", "DDP"]
if self.is_distributed:
assert self.device_wrapper_type == "DDP"
def assert_required_member(self):
object_dict_keys = list(self.objects_dict.keys())
assert all(
[item in object_dict_keys
for item in self._required_member]
)
def assert_resume_folder_exist(self):
if self.is_resume:
assert not self.recorder.delete_old_folder
def initiate_objects_dict(self):
for k, v in self.objects_dict.items():
setattr(self, k, v)
del self.objects_dict
def initiate_members(self):
self.initiate_device()
self.initiate_models()
self.initiate_collectors()
self.initiate_selectors()
self.initiate_losses()
self.initiate_schedulers()
# for distributed training
if self.is_distributed:
self.initiate_distributed_trainers()
self.initiate_distributed_testers()
self.initiate_addition_items()
def initiate_addition_items(self):
pass
def initiate_device(self):
if self.device_wrapper_type == "DDP" and not self.is_distributed:
torch.distributed.init_process_group(
backend='nccl',
init_method='tcp://localhost:{}'.format(self.dist_port),
rank=0,
world_size=1
)
if self.is_distributed:
self.world_size = (
dist.get_world_size()
if self.world_size is None
else self.world_size
)
self.main_device_id, self.device_ids = None, None
self.multi_gpu = False
if self.device is None:
self.main_device_id = 0
self.device_ids = [0]
elif isinstance(self.device, int):
self.main_device_id = self.device
self.device_ids = [self.device]
elif isinstance(self.device, list):
self.main_device_id = self.device[0]
self.device_ids = self.device
self.multi_gpu = (
True if len(self.device_ids) > 1
else False
)
else:
raise TypeError(
"Device type error!"
)
# initiate self.device
self.device = torch.device(
"cuda:{}".format(self.main_device_id)
if torch.cuda.is_available()
else "cpu"
)
def initiate_models(self):
# to device
is_to_device = "models" in self.to_device_list
is_to_wrap = "models" in self.to_wrap_list
if is_to_device:
self._members_to_device("models", to_warp=is_to_wrap)
def initiate_collectors(self):
# to device
is_to_device = "collectors" in self.to_device_list
is_to_wrap = "collectors" in self.to_wrap_list
if is_to_device:
self._members_to_device("collectors", to_warp=is_to_wrap)
def initiate_selectors(self):
# to device
is_to_device = "selectors" in self.to_device_list
is_to_wrap = "selectors" in self.to_wrap_list
if is_to_device:
self._members_to_device("selectors", to_warp=is_to_wrap)
def initiate_losses(self):
# to device
is_to_device = "losses" in self.to_device_list
is_to_wrap = "losses" in self.to_wrap_list
if is_to_device:
self._members_to_device("losses", to_warp=is_to_wrap)
def initiate_distributed_trainers(self):
total_batch_size = self.trainer.batch_size
assert (total_batch_size % self.world_size) == 0
sub_batch_size = int(total_batch_size // self.world_size)
self.trainer.set_distributed(True)
self.trainer.set_batch_size(sub_batch_size)
def initiate_distributed_testers(self):
self.tester.set_distributed(True)
def initiate_schedulers(self):
if self.schedulers is None:
self.schedulers = {}
def _members_to_device(self, module_name: str, to_warp=True):
members = getattr(self, module_name)
# to device
if not self.is_distributed:
# single-device
if self.multi_gpu:
for k, v in members.items():
members[k] = members[k].to(self.device)
if to_warp:
if self.device_wrapper_type == "DP":
members[k] = torch.nn.DataParallel(
v,
device_ids=self.device_ids
)
else:
try:
members[k] = DDP(
v,
device_ids=self.device_ids,
find_unused_parameters=True
)
except:
trace = traceback.format_exc()
logging.warning("{}".format(trace))
else:
for k, v in members.items():
members[k] = v.to(self.device)
else:
# multi-device
for k, v in members.items():
members[k] = members[k].to(self.device)
try:
members[k] = DDP(
members[k],
device_ids=self.device_ids,
find_unused_parameters=True
)
except:
trace = traceback.format_exc()
logging.warning("{}".format(trace))
"""
Run
"""
def run(self, phase="train", start_epoch=0, total_epochs=61, is_test=True, is_save=True, interval=1, warm_up=2, warm_up_list=None):
self.phase = phase
self.assert_phase()
self.prepare()
self.maybe_resume()
if self.phase == "train":
for i in range(start_epoch, total_epochs):
self.epochs = i
if i < warm_up:
logging.info("Warm up with {}".format(warm_up_list))
self.trainer.set_activated_optims(warm_up_list)
else:
self.trainer.set_activated_optims()
self.train(epochs=self.epochs)
self.release_memory()
if is_test:
if (i % interval) == 0:
self.test()
self.display_metrics()
self.save_metrics()
self.release_memory()
if is_save:
self.save_models()
# early stop
if self.patience_counts >= self.patience:
logging.info("Training terminated!")
break
elif self.phase == "evaluate":
self.test()
self.display_metrics()
def prepare(self):
# prepare trainer
utils.func_params_mediator(
[self],
self.trainer.prepare
)
# prepare tester
utils.func_params_mediator(
[
{"recorders": self.recorder},
self,
],
self.tester.prepare
)
def maybe_resume(self):
if self.is_resume:
logging.info("Resume objects...")
self.recorder.load_models(
obj=self.trainer,
device=self.device
)
def meta_test(self):
self.epochs = -1
self.test()
self.save_metrics()
self.display_metrics()
def save_metrics(self):
for k, v in self.metrics.items():
data, _ = self.recorder.get_data({k:v})
self.recorder.update(data, self.epochs)
def display_metrics(self):
# best metric check
cur_metric = self.metrics[self.primary_metric[0]][self.primary_metric[1]]
if cur_metric > self.best_metric:
self.best_metric = cur_metric
self.is_best = True
logging.info("NEW BEST METRIC!!!")
self.patience_counts = 0
else:
self.is_best = False
self.patience_counts += 1
self.metrics[self.primary_metric[0]]["BEST_" + self.primary_metric[1]] = self.best_metric
# display
for k, v in self.metrics.items():
logging.info("{} Metrics ---".format(k.upper()))
print(pd.DataFrame([v]))
def save_models(self):
self.recorder.save_models(self.trainer, step=self.epochs, best=self.is_best)
def train(self, epochs=None):
self.trainer.train(epochs=epochs)
def test(self):
self.metrics = self.tester.test()
def release_memory(self):
torch.cuda.empty_cache()