“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

103 lines
2.9 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
# Post Activated Residual block
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, downsample=None, stride=1):
super().__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = conv3x3(out_channels, out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = nn.functional.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = nn.functional.relu(out)
return out
# Residual block
class FCResidualBlock(nn.Module):
def __init__(self, input_shape, hidden_shape):
super(FCResidualBlock, self).__init__()
self.linear1 = nn.Linear(input_shape, hidden_shape)
self.bn1 = nn.BatchNorm1d(hidden_shape)
self.linear2 = nn.Linear(hidden_shape, input_shape)
self.bn2 = nn.BatchNorm1d(input_shape)
def forward(self, x):
identity = x
out = self.linear1(x)
out = self.bn1(out)
out = nn.functional.relu(out)
out = self.linear2(out)
out = self.bn2(out)
out += identity
out = nn.functional.relu(out)
return out
def mlp(
input_size,
hidden_sizes,
output_size,
output_activation=nn.Identity,
activation=nn.ELU,
init_zero=False,
):
"""
MLP layers
:param input_size:
:param hidden_sizes:
:param output_size:
:param output_activation:
:param activation:
:param init_zero: bool, zero initialization for the last layer (including w and b).
This can provide stable zero outputs in the beginning.
:return:
"""
sizes = [input_size] + hidden_sizes + [output_size]
layers = []
for i in range(len(sizes) - 1):
if i < len(sizes) - 2:
act = activation
layers += [nn.Linear(sizes[i], sizes[i + 1]),
nn.BatchNorm1d(sizes[i + 1]),
act()]
else:
act = output_activation
layers += [nn.Linear(sizes[i], sizes[i + 1]),
act()]
if init_zero:
layers[-2].weight.data.fill_(0)
layers[-2].bias.data.fill_(0)
return nn.Sequential(*layers)
def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
)