Source code for gedml.recorder.base_recorder

import logging
import shutil
from sys import path
import pandas as pd 
import numpy as np 
import os

from gedml.launcher.misc import utils
from . import utils as utils_recorder

from .base_info_getter import BaseInfoGetter
from .base_info_writer import BaseInfoWriter

[docs]class BaseRecorder: """ Recorder can search and save two forms of information: 1. Attributions of module. If the module is inherited from ``WithRecorder`` or the module maintain a **to-record** list named 'to_record_list'. 2. Dictionary with specific format: .. code:: python { table_name: item_name: value ... } Args: root (str): Path where to save the statistics and checkpoint. hint_if_exist (bool): Whether to hint if the folder is existed (in case that user delete the folder by mistake) delete_old_folder (bool): Whether to delete the old folder if the folder is existed. """ def __init__( self, project_name="test", root="./", group_name="main", hint_if_exist=True, delete_old_folder=True, is_resume=False, use_wandb=True, initiate_create=True, ): self.project_name = project_name self.root = root self.group_name = group_name self.exp_name = os.path.basename(self.root) self.hint_if_exist = hint_if_exist self.delete_old_folder = delete_old_folder self.is_resume = is_resume self.use_wandb = use_wandb if initiate_create: self.create_exp_root() self.create_group_folders(group_name) def create_exp_root(self): self.root = utils_recorder.create_folder( path=self.root, is_resume=self.is_resume, hint_if_exist=self.hint_if_exist, delete_old_folders=self.delete_old_folder ) def create_group_folders(self, group_name): sub_root_name = os.path.join(self.root, group_name) self.getter = BaseInfoGetter() self.writer = BaseInfoWriter( project_name=self.project_name, root=sub_root_name, exp_name=self.exp_name + "_{}".format(group_name), hint_if_exist=self.hint_if_exist, delete_old_folder=self.delete_old_folder, is_resume=self.is_resume, use_wandb=self.use_wandb ) def delete_folders(self, group_name): sub_root_name = os.path.join(self.root, group_name) if os.path.exists(sub_root_name): shutil.rmtree(sub_root_name) logging.info("Delete group folder: {}".format(sub_root_name)) def log_config(self, config_dict): self.writer.log_config(config_dict)
[docs] def get_data(self, *args, **kwargs): """ Extract the information from a module. Args: data (dict or obj): Dictionary of obj with **TO_RECORD_LIST** attribute. name (str): Specific name. Returns: tuple: (dict, int) 1. output_dict (dict): information dictionary to be saved. 2. step (int): step counted by ``recorder``. """ return self.getter.get_data(*args, **kwargs)
[docs] def update(self, *args, **kwargs): """ Update record. Args: data (dict): Dictionary to be recorded. step (int): Step (optional). """ self.writer.update(*args, **kwargs)
[docs] def save_models(self, *args, **kwargs): """ Save models. Args: obj (module): Modules with **TO_SAVE_LIST**. step (int): Step (or epoch) best (bool): Whether to be the best checkpoint. delete_old (bool): Whether to delete the old checkpoint. """ self.writer.save_models(*args, **kwargs)
[docs] def load_models(self, *args, **kwargs): """ Load models. Args: obj (module): Modules with **TO_SAVE_LIST**. step (int): Step (or epoch). device (device): The computing device to load. """ self.writer.load_models(*args, **kwargs)
[docs] def search_best_epoch(self, table_name, primary_key, maximum=True): """ Search best epoch according given evaluation indicator. Args: table_name (str): The csv file. primary_key (str): The primary_key to be searched. maximum (bool): Whether to regard the maximum to be the best. Returns: tuple: (best step, best metric) """ table_path = os.path.join(self.writer.csv_path, table_name) df = pd.read_csv(table_path) # get primary metric selected_key = df[primary_key] if maximum: best_idx = selected_key.idxmax() else: best_idx = selected_key.idxmin() return df['step'][best_idx], df[primary_key][best_idx]