Module minder_utils.models.feature_extractors.simclr.simclr

Expand source code
from minder_utils.models.feature_extractors.simclr.basic import ResNetSimCLR
from minder_utils.models.feature_extractors.simclr.loss import NTXentLoss
import torch.nn.functional as F
from minder_utils.models.utils import Feature_extractor


class SimCLR(Feature_extractor):

    def __init__(self):
        super(SimCLR, self).__init__()
        self.nt_xent_criterion = NTXentLoss(self.device, self.config['loss']['temperature'],
                                            self.config['loss']['use_cosine_similarity'])
        self.model = ResNetSimCLR(**self.config["model"]).to(self.device)

    def step(self, data):
        (xis, xjs), _ = data
        ris, zis = self.model(xis)  # [N,C]

        # get the representations and the projections
        rjs, zjs = self.model(xjs)  # [N,C]

        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        loss = self.nt_xent_criterion(zis, zjs)
        return loss

    def _step(self, model, xis, xjs, n_iter):
        # get the representations and the projections
        ris, zis = model(xis)  # [N,C]

        # get the representations and the projections
        rjs, zjs = model(xjs)  # [N,C]

        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        loss = self.nt_xent_criterion(zis, zjs)
        return loss

    @staticmethod
    def which_data(data):
        return data[0]

Classes

class SimCLR

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class SimCLR(Feature_extractor):

    def __init__(self):
        super(SimCLR, self).__init__()
        self.nt_xent_criterion = NTXentLoss(self.device, self.config['loss']['temperature'],
                                            self.config['loss']['use_cosine_similarity'])
        self.model = ResNetSimCLR(**self.config["model"]).to(self.device)

    def step(self, data):
        (xis, xjs), _ = data
        ris, zis = self.model(xis)  # [N,C]

        # get the representations and the projections
        rjs, zjs = self.model(xjs)  # [N,C]

        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        loss = self.nt_xent_criterion(zis, zjs)
        return loss

    def _step(self, model, xis, xjs, n_iter):
        # get the representations and the projections
        ris, zis = model(xis)  # [N,C]

        # get the representations and the projections
        rjs, zjs = model(xjs)  # [N,C]

        # normalize projection feature vectors
        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        loss = self.nt_xent_criterion(zis, zjs)
        return loss

    @staticmethod
    def which_data(data):
        return data[0]

Ancestors

Class variables

var dump_patches : bool
var training : bool

Static methods

def which_data(data)
Expand source code
@staticmethod
def which_data(data):
    return data[0]

Methods

def step(self, data)
Expand source code
def step(self, data):
    (xis, xjs), _ = data
    ris, zis = self.model(xis)  # [N,C]

    # get the representations and the projections
    rjs, zjs = self.model(xjs)  # [N,C]

    # normalize projection feature vectors
    zis = F.normalize(zis, dim=1)
    zjs = F.normalize(zjs, dim=1)

    loss = self.nt_xent_criterion(zis, zjs)
    return loss

Inherited members