Module minder_utils.util.decorators
Expand source code
from .file_func import *
from .plot_func import *
from .pytorch_func import *
__all__ = ['load_save', 'formatting_plots', 'pytorch_train']Sub-modules
- minder_utils.util.decorators.file_func
- minder_utils.util.decorators.plot_func
- minder_utils.util.decorators.pytorch_func
Functions
- def formatting_plots(title=None, save_path=None, rotation=90, legend=True)
- 
Expand source codedef formatting_plots(title=None, save_path=None, rotation=90, legend=True): def plot_decorator(func): figure_title = func.__name__ if title is None else title @wraps(func) def wrapped_functions(*args, **kwargs): plt.clf() func(*args, **kwargs) plt.xticks(rotation=rotation) plt.suptitle(figure_title) if legend: plt.legend(loc='upper right') plt.tight_layout() if save_path is not None: save_mkdir(save_path) plt.savefig(os.path.join(save_path, figure_title + '.png')) plt.show() return wrapped_functions return plot_decorator
Classes
- class load_save (save_path, save_name=None, verbose=True, refresh=False)
- 
Expand source codeclass load_save: def __init__(self, save_path, save_name=None, verbose=True, refresh=False): self.save_path = reformat_path(save_path) self.file_name = save_name self.verbose = verbose self.refresh = refresh def __call__(self, func): self.file_name = func.__name__ if self.file_name is None else self.file_name @wraps(func) def wrapped_function(*args, **kwargs): if self.refresh: self.print_func(func, 'start to refresh the data') data = func(*args, **kwargs) save_file(data, self.save_path, self.file_name) else: try: data = load_file(self.save_path, self.file_name) self.print_func(func, 'loading processed data') except FileNotFoundError: save_mkdir(self.save_path) self.print_func(func, 'processing the data') data = func(*args, **kwargs) save_file(data, self.save_path, self.file_name) return data return wrapped_function def print_func(self, func, message): if self.verbose: print(str(func.__name__).ljust(20, ' '), message)Methods- def print_func(self, func, message)
- 
Expand source codedef print_func(self, func, message): if self.verbose: print(str(func.__name__).ljust(20, ' '), message)
 
- class pytorch_train
- 
Expand source codeclass pytorch_train: def __init__(self): pass def __call__(self, func): optimizer = torch.optim.Adam(self.model.parameters(), 3e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0, last_epoch=-1) @wraps(func) def wrapped_function(*args, **kwargs): pass