Module minder_utils.models.feature_extractors.partial_order.partial_order
Expand source code
from .basic import Partial_Order_Models
from .loss import Ranking
from minder_utils.models.utils import Feature_extractor
from minder_utils.dataloader import Partial_Order_Loader
class Partial_Order(Feature_extractor):
    def __init__(self):
        super(Partial_Order, self).__init__()
        self.model = Partial_Order_Models(**self.config["model"])
        self.criterion = Ranking(**self.config["loss"])
    def _custom_loader(self, data):
        X, y = data
        return Partial_Order_Loader(X, y, **self.config['loader'])
    def step(self, data):
        pre_anchor, anchor, post_anchor = data
        loss = 0
        for idx_day in range(len(post_anchor) - 1):
            loss += self._step(post_anchor[idx_day], post_anchor[idx_day + 1], anchor)
            loss += self._step(pre_anchor[idx_day], pre_anchor[idx_day + 1], anchor)
        return loss
    def _step(self, xi, xj, anchor):
        ris, zis = self.model(xi)
        rjs, zjs = self.model(xj)
        ras, zas = self.model(anchor)
        return self.criterion(zis, zjs, zas)
    @staticmethod
    def which_data(data):
        return data[0]Classes
- class Partial_Order
- 
Helper class that provides a standard way to create an ABC using inheritance. Expand source codeclass Partial_Order(Feature_extractor): def __init__(self): super(Partial_Order, self).__init__() self.model = Partial_Order_Models(**self.config["model"]) self.criterion = Ranking(**self.config["loss"]) def _custom_loader(self, data): X, y = data return Partial_Order_Loader(X, y, **self.config['loader']) def step(self, data): pre_anchor, anchor, post_anchor = data loss = 0 for idx_day in range(len(post_anchor) - 1): loss += self._step(post_anchor[idx_day], post_anchor[idx_day + 1], anchor) loss += self._step(pre_anchor[idx_day], pre_anchor[idx_day + 1], anchor) return loss def _step(self, xi, xj, anchor): ris, zis = self.model(xi) rjs, zjs = self.model(xj) ras, zas = self.model(anchor) return self.criterion(zis, zjs, zas) @staticmethod def which_data(data): return data[0]Ancestors- Feature_extractor
- abc.ABC
- torch.nn.modules.module.Module
 Class variables- var dump_patches : bool
- var training : bool
 Static methods- def which_data(data)
- 
Expand source code@staticmethod def which_data(data): return data[0]
 Methods- def step(self, data)
- 
Expand source codedef step(self, data): pre_anchor, anchor, post_anchor = data loss = 0 for idx_day in range(len(post_anchor) - 1): loss += self._step(post_anchor[idx_day], post_anchor[idx_day + 1], anchor) loss += self._step(pre_anchor[idx_day], pre_anchor[idx_day + 1], anchor) return loss
 Inherited members