Source code for gedml.core.models.mlp

import torch
import torch.nn as nn 

from ..modules import WithRecorder

"""
Normal multi-layer perceptron
"""

[docs]class MLP(WithRecorder): ''' Normal multi-layer perceptron. Args: layer_size_list (list): The numbers of neurals in each layer. first_relu (bool): Whether to set ReLU at the beginning of the MLP. last_relu (bool): Whether to set ReLU at the end of the MLP. Example: >>> model = MLP(layer_size_list=[512, 100], first_relu=False) ''' def __init__(self, layer_size_list, first_relu=True, last_relu=False, output_dim=None): super(MLP, self).__init__() # check data type if output_dim is not None: layer_size_list[-1] = int(output_dim) self.layer_size_list = [int(item) for item in layer_size_list] self.first_relu = first_relu self.last_relu = last_relu # construct MLP layer_list = [nn.ReLU(inplace=True)] if first_relu else [] for idx in range(len(layer_size_list)-1): layer_list.append(nn.Linear(layer_size_list[idx], layer_size_list[idx+1])) if idx != len(layer_size_list) - 2: layer_list.append(nn.ReLU(inplace=True)) if last_relu: layer_list.append(nn.ReLU(inplace=True)) # compose self.net = nn.Sequential(*layer_list) self.last_linear = self.net[-1]
[docs] def forward(self, features): return self.net(features)
""" multi-layer perceptron with batch normalization """ class BatchNormLayer(WithRecorder): def __init__(self, in_dim, out_dim, is_bn=True, is_relu=True, **kwargs): super(BatchNormLayer, self).__init__(**kwargs) layers_list = [] # add linear layers_list.append(nn.Linear(in_dim, out_dim)) # add batch norm if is_bn: layers_list.append(nn.BatchNorm1d(out_dim)) # add ReLU if is_relu: layers_list.append(nn.ReLU(inplace=True)) # compose self.net = nn.Sequential(*layers_list) def forward(self, x): return self.net(x)
[docs]class BatchNormMLP(WithRecorder): """ Multi-layer perceptron with batch normalization. Args: layer_size_list (list): The numbers of neurals in each layer. (N + 1) relu_list (list): Whether relu is added. bn_list (list): Whether bn is added. first_bn (bool): Whether to set BN at the beginning of the MLP. Example: >>> model = BatchNormMLP( layer_size_list=[512, 512, 1024], relu_list=[True, False], bn_list=[True, False], first_bn=False ) """ def __init__(self, layer_size_list, relu_list, bn_list, first_bn=False, **kwargs): super(BatchNormMLP, self).__init__(**kwargs) layer_size_list = [int(item) for item in layer_size_list] self.layer_size_list = layer_size_list self.relu_list = relu_list self.bn_list = bn_list self.layers_num = len(self.relu_list) self.first_bn = first_bn layers_list = [] if self.first_bn: layers_list.append( nn.BatchNorm1d(self.layer_size_list[0]) ) for i in range(self.layers_num): layers_list.append( BatchNormLayer( in_dim=self.layer_size_list[i], out_dim=self.layer_size_list[i+1], is_bn=self.bn_list[i], is_relu=self.relu_list[i] ) ) self.net = nn.Sequential(*layers_list)
[docs] def forward(self, features): return self.net(features)
# if __name__ == "__main__": # model = BatchNormMLP( # layer_size_list=[512, 512, 1024], # relu_list=[True, False], # bn_list=[True, False], # first_bn=False # ) # print(model) if __name__ == '__main__': model = MLP(layer_size_list=[512, 100], first_relu=False) print(model)