103 lines
2.9 KiB
Python
103 lines
2.9 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
# Post Activated Residual block
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, downsample=None, stride=1):
|
|
super().__init__()
|
|
self.conv1 = conv3x3(in_channels, out_channels, stride)
|
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
self.conv2 = conv3x3(out_channels, out_channels)
|
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
self.downsample = downsample
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = nn.functional.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample is not None:
|
|
identity = self.downsample(x)
|
|
|
|
out += identity
|
|
out = nn.functional.relu(out)
|
|
return out
|
|
|
|
# Residual block
|
|
class FCResidualBlock(nn.Module):
|
|
def __init__(self, input_shape, hidden_shape):
|
|
super(FCResidualBlock, self).__init__()
|
|
self.linear1 = nn.Linear(input_shape, hidden_shape)
|
|
self.bn1 = nn.BatchNorm1d(hidden_shape)
|
|
self.linear2 = nn.Linear(hidden_shape, input_shape)
|
|
self.bn2 = nn.BatchNorm1d(input_shape)
|
|
|
|
def forward(self, x):
|
|
identity = x
|
|
out = self.linear1(x)
|
|
out = self.bn1(out)
|
|
out = nn.functional.relu(out)
|
|
|
|
out = self.linear2(out)
|
|
out = self.bn2(out)
|
|
|
|
out += identity
|
|
out = nn.functional.relu(out)
|
|
return out
|
|
|
|
|
|
def mlp(
|
|
input_size,
|
|
hidden_sizes,
|
|
output_size,
|
|
output_activation=nn.Identity,
|
|
activation=nn.ELU,
|
|
init_zero=False,
|
|
):
|
|
"""
|
|
MLP layers
|
|
:param input_size:
|
|
:param hidden_sizes:
|
|
:param output_size:
|
|
:param output_activation:
|
|
:param activation:
|
|
:param init_zero: bool, zero initialization for the last layer (including w and b).
|
|
This can provide stable zero outputs in the beginning.
|
|
:return:
|
|
"""
|
|
sizes = [input_size] + hidden_sizes + [output_size]
|
|
layers = []
|
|
for i in range(len(sizes) - 1):
|
|
if i < len(sizes) - 2:
|
|
act = activation
|
|
layers += [nn.Linear(sizes[i], sizes[i + 1]),
|
|
nn.BatchNorm1d(sizes[i + 1]),
|
|
act()]
|
|
else:
|
|
act = output_activation
|
|
layers += [nn.Linear(sizes[i], sizes[i + 1]),
|
|
act()]
|
|
|
|
if init_zero:
|
|
layers[-2].weight.data.fill_(0)
|
|
layers[-2].bias.data.fill_(0)
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
def conv3x3(in_channels, out_channels, stride=1):
|
|
return nn.Conv2d(
|
|
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
|
|
)
|