Module minder_utils.models.classifiers.torch_classifiers
Expand source code
from minder_utils.models.utils import EarlyStopping
import torch.nn as nn
import torch
class Classifiers:
    '''
    This class contains multiple classifiers based on pytorch
    '''
    def __init__(self, model_type, num_features, initial_manually=False, num_outputs=2):
        '''
        Initialise the classifier
        Parameters
        ----------
        model_type: str, 'lr' or 'nn'
        num_features: int, input dim
        initial_manually: bool, initial the weights to ones
        num_outputs: int, output dim, default = 2
        '''
        self.model_type = model_type
        self.early_stop = EarlyStopping()
        self.num_features = num_features
        self.num_outputs = num_outputs
        self.model = getattr(self, model_type)()
        self.initial_manually = initial_manually
        if self.initial_manually:
            for param in self.model.parameters():
                param.data = nn.parameter.Parameter(torch.ones_like(param))
    def reset(self):
        self.model = getattr(self, self.model_type)()
        if self.initial_manually:
            for param in self.model.parameters():
                param.data = nn.parameter.Parameter(torch.ones_like(param))
    def nn(self):
        return nn.Sequential(
            nn.Linear(self.num_features, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_outputs)
        )
    def lr(self):
        return nn.Linear(self.num_features, self.num_outputs)
    def parameters(self):
        return self.model.parameters()
    def __call__(self, X):
        return self.model(X)Classes
- class Classifiers (model_type, num_features, initial_manually=False, num_outputs=2)
- 
This class contains multiple classifiers based on pytorch Initialise the classifier Parameters 
 - model_type:- str, 'lr'or- 'nn'
- num_features:- int, input dim
- initial_manually:- bool, initial the weights to ones
- num_outputs:- int, output dim, default- = 2
 Expand source codeclass Classifiers: ''' This class contains multiple classifiers based on pytorch ''' def __init__(self, model_type, num_features, initial_manually=False, num_outputs=2): ''' Initialise the classifier Parameters ---------- model_type: str, 'lr' or 'nn' num_features: int, input dim initial_manually: bool, initial the weights to ones num_outputs: int, output dim, default = 2 ''' self.model_type = model_type self.early_stop = EarlyStopping() self.num_features = num_features self.num_outputs = num_outputs self.model = getattr(self, model_type)() self.initial_manually = initial_manually if self.initial_manually: for param in self.model.parameters(): param.data = nn.parameter.Parameter(torch.ones_like(param)) def reset(self): self.model = getattr(self, self.model_type)() if self.initial_manually: for param in self.model.parameters(): param.data = nn.parameter.Parameter(torch.ones_like(param)) def nn(self): return nn.Sequential( nn.Linear(self.num_features, 256), nn.ReLU(), nn.Linear(256, self.num_outputs) ) def lr(self): return nn.Linear(self.num_features, self.num_outputs) def parameters(self): return self.model.parameters() def __call__(self, X): return self.model(X)Methods- def lr(self)
- 
Expand source codedef lr(self): return nn.Linear(self.num_features, self.num_outputs)
- def nn(self)
- 
Expand source codedef nn(self): return nn.Sequential( nn.Linear(self.num_features, 256), nn.ReLU(), nn.Linear(256, self.num_outputs) )
- def parameters(self)
- 
Expand source codedef parameters(self): return self.model.parameters()
- def reset(self)
- 
Expand source codedef reset(self): self.model = getattr(self, self.model_type)() if self.initial_manually: for param in self.model.parameters(): param.data = nn.parameter.Parameter(torch.ones_like(param))