Module minder_utils.models.feature_extractors.partial_order.partial_order
Expand source code
from .basic import Partial_Order_Models
from .loss import Ranking
from minder_utils.models.utils import Feature_extractor
from minder_utils.dataloader import Partial_Order_Loader
class Partial_Order(Feature_extractor):
def __init__(self):
super(Partial_Order, self).__init__()
self.model = Partial_Order_Models(**self.config["model"])
self.criterion = Ranking(**self.config["loss"])
def _custom_loader(self, data):
X, y = data
return Partial_Order_Loader(X, y, **self.config['loader'])
def step(self, data):
pre_anchor, anchor, post_anchor = data
loss = 0
for idx_day in range(len(post_anchor) - 1):
loss += self._step(post_anchor[idx_day], post_anchor[idx_day + 1], anchor)
loss += self._step(pre_anchor[idx_day], pre_anchor[idx_day + 1], anchor)
return loss
def _step(self, xi, xj, anchor):
ris, zis = self.model(xi)
rjs, zjs = self.model(xj)
ras, zas = self.model(anchor)
return self.criterion(zis, zjs, zas)
@staticmethod
def which_data(data):
return data[0]
Classes
class Partial_Order
-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code
class Partial_Order(Feature_extractor): def __init__(self): super(Partial_Order, self).__init__() self.model = Partial_Order_Models(**self.config["model"]) self.criterion = Ranking(**self.config["loss"]) def _custom_loader(self, data): X, y = data return Partial_Order_Loader(X, y, **self.config['loader']) def step(self, data): pre_anchor, anchor, post_anchor = data loss = 0 for idx_day in range(len(post_anchor) - 1): loss += self._step(post_anchor[idx_day], post_anchor[idx_day + 1], anchor) loss += self._step(pre_anchor[idx_day], pre_anchor[idx_day + 1], anchor) return loss def _step(self, xi, xj, anchor): ris, zis = self.model(xi) rjs, zjs = self.model(xj) ras, zas = self.model(anchor) return self.criterion(zis, zjs, zas) @staticmethod def which_data(data): return data[0]
Ancestors
- Feature_extractor
- abc.ABC
- torch.nn.modules.module.Module
Class variables
var dump_patches : bool
var training : bool
Static methods
def which_data(data)
-
Expand source code
@staticmethod def which_data(data): return data[0]
Methods
def step(self, data)
-
Expand source code
def step(self, data): pre_anchor, anchor, post_anchor = data loss = 0 for idx_day in range(len(post_anchor) - 1): loss += self._step(post_anchor[idx_day], post_anchor[idx_day + 1], anchor) loss += self._step(pre_anchor[idx_day], pre_anchor[idx_day + 1], anchor) return loss
Inherited members