Module minder_utils.util.decorators

Expand source code
from .file_func import *
from .plot_func import *
from .pytorch_func import *

__all__ = ['load_save', 'formatting_plots', 'pytorch_train']

Sub-modules

minder_utils.util.decorators.file_func
minder_utils.util.decorators.plot_func
minder_utils.util.decorators.pytorch_func

Functions

def formatting_plots(title=None, save_path=None, rotation=90, legend=True)
Expand source code
def formatting_plots(title=None, save_path=None, rotation=90, legend=True):
    def plot_decorator(func):
        figure_title = func.__name__ if title is None else title

        @wraps(func)
        def wrapped_functions(*args, **kwargs):
            plt.clf()
            func(*args, **kwargs)
            plt.xticks(rotation=rotation)
            plt.suptitle(figure_title)
            if legend:
                plt.legend(loc='upper right')
            plt.tight_layout()
            if save_path is not None:
                save_mkdir(save_path)
                plt.savefig(os.path.join(save_path, figure_title + '.png'))
            plt.show()

        return wrapped_functions

    return plot_decorator

Classes

class load_save (save_path, save_name=None, verbose=True, refresh=False)
Expand source code
class load_save:
    def __init__(self, save_path, save_name=None, verbose=True, refresh=False):
        self.save_path = reformat_path(save_path)
        self.file_name = save_name
        self.verbose = verbose
        self.refresh = refresh

    def __call__(self, func):
        self.file_name = func.__name__ if self.file_name is None else self.file_name

        @wraps(func)
        def wrapped_function(*args, **kwargs):
            if self.refresh:
                self.print_func(func, 'start to refresh the data')
                data = func(*args, **kwargs)
                save_file(data, self.save_path, self.file_name)
            else:
                try:
                    data = load_file(self.save_path, self.file_name)
                    self.print_func(func, 'loading processed data')
                except FileNotFoundError:
                    save_mkdir(self.save_path)
                    self.print_func(func, 'processing the data')
                    data = func(*args, **kwargs)
                    save_file(data, self.save_path, self.file_name)
            return data

        return wrapped_function

    def print_func(self, func, message):
        if self.verbose:
            print(str(func.__name__).ljust(20, ' '), message)

Methods

def print_func(self, func, message)
Expand source code
def print_func(self, func, message):
    if self.verbose:
        print(str(func.__name__).ljust(20, ' '), message)
class pytorch_train
Expand source code
class pytorch_train:
    def __init__(self):
        pass

    def __call__(self, func):
        optimizer = torch.optim.Adam(self.model.parameters(), 3e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader), eta_min=0,
                                                               last_epoch=-1)
        @wraps(func)
        def wrapped_function(*args, **kwargs):
            pass