Source code for gedml.core.models.resnet50

import pretrainedmodels as ptm
import torch
import torch.nn as nn 
import logging

# Code from: Pytorch Metric BaseLine
[docs]class resnet50(nn.Module): """ Container for ResNet50 s.t. it can be used for metric learning. The Network has been broken down to allow for higher modularity, if one wishes to target specific layers/blocks directly. """ def __init__(self, pretrained=True, list_style=False, no_norm=False): super(resnet50, self).__init__() if pretrained: logging.info('Getting pretrained weights...') self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained='imagenet') else: logging.info('Not utilizing pretrained weights!') self.model = ptm.__dict__['resnet50'](num_classes=1000, pretrained=None) for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()): module.eval() module.train = lambda _: None # self.last_linear = self.model.last_linear self.last_linear = self.model.last_linear self.model.last_linear = torch.nn.Linear(self.model.last_linear.in_features, 128) self.layer_blocks = nn.ModuleList([self.model.layer1, self.model.layer2, self.model.layer3, self.model.layer4])
[docs] def forward(self, x, is_init_cluster_generation=False): x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x)))) for layerblock in self.layer_blocks: x = layerblock(x) x = self.model.avgpool(x) x = x.view(x.size(0),-1) return x