Module minder_utils.models.classifiers.torch_classifiers

Expand source code
from minder_utils.models.utils import EarlyStopping
import torch.nn as nn
import torch


class Classifiers:
    '''
    This class contains multiple classifiers based on pytorch
    '''
    def __init__(self, model_type, num_features, initial_manually=False, num_outputs=2):
        '''
        Initialise the classifier
        Parameters
        ----------
        model_type: str, 'lr' or 'nn'
        num_features: int, input dim
        initial_manually: bool, initial the weights to ones
        num_outputs: int, output dim, default = 2
        '''
        self.model_type = model_type
        self.early_stop = EarlyStopping()
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.model = getattr(self, model_type)()
        self.initial_manually = initial_manually
        if self.initial_manually:
            for param in self.model.parameters():
                param.data = nn.parameter.Parameter(torch.ones_like(param))

    def reset(self):
        self.model = getattr(self, self.model_type)()
        if self.initial_manually:
            for param in self.model.parameters():
                param.data = nn.parameter.Parameter(torch.ones_like(param))

    def nn(self):
        return nn.Sequential(
            nn.Linear(self.num_features, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_outputs)
        )

    def lr(self):
        return nn.Linear(self.num_features, self.num_outputs)

    def parameters(self):
        return self.model.parameters()

    def __call__(self, X):
        return self.model(X)

Classes

class Classifiers (model_type, num_features, initial_manually=False, num_outputs=2)

This class contains multiple classifiers based on pytorch

Initialise the classifier Parameters


model_type : str, 'lr' or 'nn'
 
num_features : int, input dim
 
initial_manually : bool, initial the weights to ones
 
num_outputs : int, output dim, default = 2
 
Expand source code
class Classifiers:
    '''
    This class contains multiple classifiers based on pytorch
    '''
    def __init__(self, model_type, num_features, initial_manually=False, num_outputs=2):
        '''
        Initialise the classifier
        Parameters
        ----------
        model_type: str, 'lr' or 'nn'
        num_features: int, input dim
        initial_manually: bool, initial the weights to ones
        num_outputs: int, output dim, default = 2
        '''
        self.model_type = model_type
        self.early_stop = EarlyStopping()
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.model = getattr(self, model_type)()
        self.initial_manually = initial_manually
        if self.initial_manually:
            for param in self.model.parameters():
                param.data = nn.parameter.Parameter(torch.ones_like(param))

    def reset(self):
        self.model = getattr(self, self.model_type)()
        if self.initial_manually:
            for param in self.model.parameters():
                param.data = nn.parameter.Parameter(torch.ones_like(param))

    def nn(self):
        return nn.Sequential(
            nn.Linear(self.num_features, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_outputs)
        )

    def lr(self):
        return nn.Linear(self.num_features, self.num_outputs)

    def parameters(self):
        return self.model.parameters()

    def __call__(self, X):
        return self.model(X)

Methods

def lr(self)
Expand source code
def lr(self):
    return nn.Linear(self.num_features, self.num_outputs)
def nn(self)
Expand source code
def nn(self):
    return nn.Sequential(
        nn.Linear(self.num_features, 256),
        nn.ReLU(),
        nn.Linear(256, self.num_outputs)
    )
def parameters(self)
Expand source code
def parameters(self):
    return self.model.parameters()
def reset(self)
Expand source code
def reset(self):
    self.model = getattr(self, self.model_type)()
    if self.initial_manually:
        for param in self.model.parameters():
            param.data = nn.parameter.Parameter(torch.ones_like(param))