Module minder_utils.models.feature_extractors.partial_order.loss
Expand source code
import torch
class Ranking(torch.nn.Module):
def __init__(self, delta, use_cosine_similarity):
super(Ranking, self).__init__()
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
self.measure_similarity = self._get_similarity_function(use_cosine_similarity)
self.delta = delta
self.criterion = torch.nn.MSELoss(reduction='sum')
if not use_cosine_similarity:
dim = 64
self.projector = torch.nn.Linear(dim, dim, bias=False)
def _get_similarity_function(self, use_cosine_similarity):
if use_cosine_similarity:
self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
return self._cosine_simililarity
else:
return self._metrics_similarity
def _metrics_similarity(self, x, y):
return torch.sum(torch.square(self.projector(x) - self.projector(y)), dim=1)
def _cosine_simililarity(self, x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0))
return v
def forward(self, zis, zjs, z_anchor):
"""
:param zis: similar to anchor
:param zjs: dissimilar to anchor
:param z_anchor: anchor image
:return:
"""
s1 = self.measure_similarity(zis, z_anchor)
s2 = self.measure_similarity(zjs, z_anchor)
# loss = - torch.mean(torch.log(torch.mean(torch.clamp(s2 - s1 + self.delta, min=0, max=1.), dim=-1)))
margin = torch.clamp(s2 - s1 + self.delta, min=0, max=1.)
loss = self.criterion(margin, torch.zeros_like(margin))
return loss
Classes
class Ranking (delta, use_cosine_similarity)
-
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing to nest them in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will have their parameters converted too when you call :meth:
to
, etc.:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Expand source code
class Ranking(torch.nn.Module): def __init__(self, delta, use_cosine_similarity): super(Ranking, self).__init__() self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) self.measure_similarity = self._get_similarity_function(use_cosine_similarity) self.delta = delta self.criterion = torch.nn.MSELoss(reduction='sum') if not use_cosine_similarity: dim = 64 self.projector = torch.nn.Linear(dim, dim, bias=False) def _get_similarity_function(self, use_cosine_similarity): if use_cosine_similarity: self._cosine_similarity = torch.nn.CosineSimilarity(dim=-1) return self._cosine_simililarity else: return self._metrics_similarity def _metrics_similarity(self, x, y): return torch.sum(torch.square(self.projector(x) - self.projector(y)), dim=1) def _cosine_simililarity(self, x, y): # x shape: (N, 1, C) # y shape: (1, 2N, C) # v shape: (N, 2N) v = self._cosine_similarity(x.unsqueeze(1), y.unsqueeze(0)) return v def forward(self, zis, zjs, z_anchor): """ :param zis: similar to anchor :param zjs: dissimilar to anchor :param z_anchor: anchor image :return: """ s1 = self.measure_similarity(zis, z_anchor) s2 = self.measure_similarity(zjs, z_anchor) # loss = - torch.mean(torch.log(torch.mean(torch.clamp(s2 - s1 + self.delta, min=0, max=1.), dim=-1))) margin = torch.clamp(s2 - s1 + self.delta, min=0, max=1.) loss = self.criterion(margin, torch.zeros_like(margin)) return loss
Ancestors
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Methods
def forward(self, zis, zjs, z_anchor) ‑> Callable[..., Any]
-
:param zis: similar to anchor :param zjs: dissimilar to anchor :param z_anchor: anchor image :return:
Expand source code
def forward(self, zis, zjs, z_anchor): """ :param zis: similar to anchor :param zjs: dissimilar to anchor :param z_anchor: anchor image :return: """ s1 = self.measure_similarity(zis, z_anchor) s2 = self.measure_similarity(zjs, z_anchor) # loss = - torch.mean(torch.log(torch.mean(torch.clamp(s2 - s1 + self.delta, min=0, max=1.), dim=-1))) margin = torch.clamp(s2 - s1 + self.delta, min=0, max=1.) loss = self.criterion(margin, torch.zeros_like(margin)) return loss