Module minder_utils.models.feature_extractors.simclr.simclr
Expand source code
from minder_utils.models.feature_extractors.simclr.basic import ResNetSimCLR
from minder_utils.models.feature_extractors.simclr.loss import NTXentLoss
import torch.nn.functional as F
from minder_utils.models.utils import Feature_extractor
class SimCLR(Feature_extractor):
    def __init__(self):
        super(SimCLR, self).__init__()
        self.nt_xent_criterion = NTXentLoss(self.device, self.config['loss']['temperature'],
                                            self.config['loss']['use_cosine_similarity'])
        self.model = ResNetSimCLR(**self.config["model"]).to(self.device)
    def step(self, data):
        (xis, xjs), _ = data
        ris, zis = self.model(xis)  # [N,C]
        # get the representations and the projections
        rjs, zjs = self.model(xjs)  # [N,C]
        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)
        loss = self.nt_xent_criterion(zis, zjs)
        return loss
    def _step(self, model, xis, xjs, n_iter):
        # get the representations and the projections
        ris, zis = model(xis)  # [N,C]
        # get the representations and the projections
        rjs, zjs = model(xjs)  # [N,C]
        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)
        loss = self.nt_xent_criterion(zis, zjs)
        return loss
    @staticmethod
    def which_data(data):
        return data[0]Classes
- class SimCLR
- 
Helper class that provides a standard way to create an ABC using inheritance. Expand source codeclass SimCLR(Feature_extractor): def __init__(self): super(SimCLR, self).__init__() self.nt_xent_criterion = NTXentLoss(self.device, self.config['loss']['temperature'], self.config['loss']['use_cosine_similarity']) self.model = ResNetSimCLR(**self.config["model"]).to(self.device) def step(self, data): (xis, xjs), _ = data ris, zis = self.model(xis) # [N,C] # get the representations and the projections rjs, zjs = self.model(xjs) # [N,C] # normalize projection feature vectors zis = F.normalize(zis, dim=1) zjs = F.normalize(zjs, dim=1) loss = self.nt_xent_criterion(zis, zjs) return loss def _step(self, model, xis, xjs, n_iter): # get the representations and the projections ris, zis = model(xis) # [N,C] # get the representations and the projections rjs, zjs = model(xjs) # [N,C] # normalize projection feature vectors zis = F.normalize(zis, dim=1) zjs = F.normalize(zjs, dim=1) loss = self.nt_xent_criterion(zis, zjs) return loss @staticmethod def which_data(data): return data[0]Ancestors- Feature_extractor
- abc.ABC
- torch.nn.modules.module.Module
 Class variables- var dump_patches : bool
- var training : bool
 Static methods- def which_data(data)
- 
Expand source code@staticmethod def which_data(data): return data[0]
 Methods- def step(self, data)
- 
Expand source codedef step(self, data): (xis, xjs), _ = data ris, zis = self.model(xis) # [N,C] # get the representations and the projections rjs, zjs = self.model(xjs) # [N,C] # normalize projection feature vectors zis = F.normalize(zis, dim=1) zjs = F.normalize(zjs, dim=1) loss = self.nt_xent_criterion(zis, zjs) return loss
 Inherited members