Module minder_utils.models.feature_extractors.simclr.simclr
Expand source code
from minder_utils.models.feature_extractors.simclr.basic import ResNetSimCLR
from minder_utils.models.feature_extractors.simclr.loss import NTXentLoss
import torch.nn.functional as F
from minder_utils.models.utils import Feature_extractor
class SimCLR(Feature_extractor):
def __init__(self):
super(SimCLR, self).__init__()
self.nt_xent_criterion = NTXentLoss(self.device, self.config['loss']['temperature'],
self.config['loss']['use_cosine_similarity'])
self.model = ResNetSimCLR(**self.config["model"]).to(self.device)
def step(self, data):
(xis, xjs), _ = data
ris, zis = self.model(xis) # [N,C]
# get the representations and the projections
rjs, zjs = self.model(xjs) # [N,C]
# normalize projection feature vectors
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
loss = self.nt_xent_criterion(zis, zjs)
return loss
def _step(self, model, xis, xjs, n_iter):
# get the representations and the projections
ris, zis = model(xis) # [N,C]
# get the representations and the projections
rjs, zjs = model(xjs) # [N,C]
# normalize projection feature vectors
zis = F.normalize(zis, dim=1)
zjs = F.normalize(zjs, dim=1)
loss = self.nt_xent_criterion(zis, zjs)
return loss
@staticmethod
def which_data(data):
return data[0]
Classes
class SimCLR
-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code
class SimCLR(Feature_extractor): def __init__(self): super(SimCLR, self).__init__() self.nt_xent_criterion = NTXentLoss(self.device, self.config['loss']['temperature'], self.config['loss']['use_cosine_similarity']) self.model = ResNetSimCLR(**self.config["model"]).to(self.device) def step(self, data): (xis, xjs), _ = data ris, zis = self.model(xis) # [N,C] # get the representations and the projections rjs, zjs = self.model(xjs) # [N,C] # normalize projection feature vectors zis = F.normalize(zis, dim=1) zjs = F.normalize(zjs, dim=1) loss = self.nt_xent_criterion(zis, zjs) return loss def _step(self, model, xis, xjs, n_iter): # get the representations and the projections ris, zis = model(xis) # [N,C] # get the representations and the projections rjs, zjs = model(xjs) # [N,C] # normalize projection feature vectors zis = F.normalize(zis, dim=1) zjs = F.normalize(zjs, dim=1) loss = self.nt_xent_criterion(zis, zjs) return loss @staticmethod def which_data(data): return data[0]
Ancestors
- Feature_extractor
- abc.ABC
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Static methods
def which_data(data)
-
Expand source code
@staticmethod def which_data(data): return data[0]
Methods
def step(self, data)
-
Expand source code
def step(self, data): (xis, xjs), _ = data ris, zis = self.model(xis) # [N,C] # get the representations and the projections rjs, zjs = self.model(xjs) # [N,C] # normalize projection feature vectors zis = F.normalize(zis, dim=1) zjs = F.normalize(zjs, dim=1) loss = self.nt_xent_criterion(zis, zjs) return loss
Inherited members