“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

295 lines
10 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import torch
import math
import torch.nn as nn
import numpy as np
from .layer import ResidualBlock, conv3x3, mlp
# Down_sample observations before representation network (See paper appendix Network Architecture)
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(
in_channels,
out_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(out_channels // 2)
self.resblocks1 = nn.ModuleList(
[ResidualBlock(out_channels // 2, out_channels // 2) for _ in range(1)]
)
self.conv2 = nn.Conv2d(
out_channels // 2,
out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
)
self.downsample_block = ResidualBlock(out_channels // 2, out_channels, downsample=self.conv2, stride=2)
self.resblocks2 = nn.ModuleList(
[ResidualBlock(out_channels, out_channels) for _ in range(1)]
)
self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.resblocks3 = nn.ModuleList(
[ResidualBlock(out_channels, out_channels) for _ in range(1)]
)
self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = nn.functional.relu(x)
for block in self.resblocks1:
x = block(x)
x = self.downsample_block(x)
for block in self.resblocks2:
x = block(x)
x = self.pooling1(x)
for block in self.resblocks3:
x = block(x)
x = self.pooling2(x)
return x
# Encode the observations into hidden states
class RepresentationNetwork(nn.Module):
def __init__(self, observation_shape, num_blocks, num_channels, downsample):
"""
Representation network
:param observation_shape: tuple or list, shape of observations: [C, W, H]
:param num_blocks: int, number of res blocks
:param num_channels: int, channels of hidden states
:param downsample: bool, True -> do downsampling for observations. (For board games, do not need)
"""
super().__init__()
self.downsample = downsample
if self.downsample:
self.downsample_net = DownSample(
observation_shape[0],
num_channels,
)
else:
self.conv = conv3x3(
observation_shape[0],
num_channels,
)
self.bn = nn.BatchNorm2d(num_channels)
self.resblocks = nn.ModuleList(
[ResidualBlock(num_channels, num_channels) for _ in range(num_blocks)]
)
def forward(self, x):
if self.downsample:
x = self.downsample_net(x)
else:
x = self.conv(x)
x = self.bn(x)
x = nn.functional.relu(x)
for block in self.resblocks:
x = block(x)
return x
# Predict next hidden states given current states and actions
class DynamicsNetwork(nn.Module):
def __init__(self, num_blocks, num_channels, action_space_size, is_continuous=False, action_embedding=False, action_embedding_dim=32):
"""
Dynamics network
:param num_blocks: int, number of res blocks
:param num_channels: int, channels of hidden states
:param action_space_size: int, action space size
"""
super().__init__()
self.is_continuous = is_continuous
self.action_embedding = action_embedding
self.action_embedding_dim = action_embedding_dim
self.num_channels = num_channels
self.action_space_size = action_space_size
if action_embedding:
self.conv1x1 = nn.Conv2d(action_space_size if is_continuous else 1, self.action_embedding_dim, 1)
self.ln = nn.LayerNorm([action_embedding_dim, 6, 6])
self.conv = conv3x3(num_channels + self.action_embedding_dim, num_channels)
else:
self.conv = conv3x3(num_channels + action_space_size if is_continuous else num_channels + 1, num_channels)
self.bn = nn.BatchNorm2d(num_channels)
self.resblocks = nn.ModuleList(
[ResidualBlock(num_channels, num_channels) for _ in range(num_blocks)]
)
def forward(self, state, action):
# encode action
if not self.is_continuous:
action_place = torch.ones((
state.shape[0],
1,
state.shape[2],
state.shape[3],
)).cuda().float()
action_place = (
action[:, :, None, None] * action_place / self.action_space_size
)
else:
action_place = action.reshape(*action.shape, 1, 1).repeat(1, 1, state.shape[-2], state.shape[-1])
if self.action_embedding:
action_place = self.conv1x1(action_place)
action_place = self.ln(action_place)
action_place = nn.functional.relu(action_place)
x = torch.cat((state, action_place), dim=1)
x = self.conv(x)
x = self.bn(x)
x += state
x = nn.functional.relu(x)
for block in self.resblocks:
x = block(x)
state = x
return state
class ValuePolicyNetwork(nn.Module):
def __init__(self, num_blocks, num_channels, reduced_channels, flatten_size, fc_layers, value_output_size,
policy_output_size, init_zero, is_continuous=False, policy_distribution='beta', **kwargs):
super().__init__()
self.v_num = kwargs.get('v_num')
self.resblocks = nn.ModuleList(
[ResidualBlock(num_channels, num_channels) for _ in range(num_blocks)]
)
self.conv1x1_values = nn.ModuleList([nn.Conv2d(num_channels, reduced_channels, 1) for _ in range(self.v_num)])
self.conv1x1_policy = nn.Conv2d(num_channels, reduced_channels, 1)
self.bn_values = nn.ModuleList([nn.BatchNorm2d(reduced_channels) for _ in range(self.v_num)])
self.bn_policy = nn.BatchNorm2d(reduced_channels)
self.block_output_size_value = flatten_size
self.block_output_size_policy = flatten_size
self.fc_values = nn.ModuleList([mlp(self.block_output_size_value, fc_layers, value_output_size,
init_zero=False if is_continuous else init_zero) for _ in range(self.v_num)])
self.fc_policy = mlp(self.block_output_size_policy, fc_layers if not is_continuous else [64],
policy_output_size, init_zero=init_zero)
self.is_continuous = is_continuous
self.init_std = 1.0
self.min_std = 0.1
def forward(self, x):
for block in self.resblocks:
x = block(x)
values = []
for i in range(self.v_num):
value = self.conv1x1_values[i](x)
value = self.bn_values[i](value)
value = nn.functional.relu(value)
value = value.reshape(-1, self.block_output_size_value)
value = self.fc_values[i](value)
values.append(value)
policy = self.conv1x1_policy(x)
policy = self.bn_policy(policy)
policy = nn.functional.relu(policy)
policy = policy.reshape(-1, self.block_output_size_policy)
policy = self.fc_policy(policy)
if self.is_continuous:
action_space_size = policy.shape[-1] // 2
policy[:, :action_space_size] = 5 * torch.tanh(policy[:, :action_space_size] / 5) # soft clamp mu
policy[:, action_space_size:] = (torch.nn.functional.softplus(policy[:, action_space_size:] + self.init_std) + self.min_std)#.clip(0, 5) # same as Dreamer-v3
return torch.stack(values), policy
class SupportNetwork(nn.Module):
def __init__(self, num_blocks, num_channels, reduced_channels, flatten_size, fc_layers, output_support_size, init_zero):
super().__init__()
self.flatten_size = flatten_size
self.conv1x1 = nn.Conv2d(num_channels, reduced_channels, 1)
self.bn = nn.BatchNorm2d(reduced_channels)
self.fc = mlp(flatten_size, fc_layers, output_support_size, init_zero=init_zero)
def forward(self, x):
x = self.conv1x1(x)
x = self.bn(x)
x = nn.functional.relu(x)
x = x.reshape(-1, self.flatten_size)
x = self.fc(x)
return x
class SupportLSTMNetwork(nn.Module):
def __init__(self, num_blocks, num_channels, reduced_channels, flatten_size, fc_layers, output_support_size, lstm_hidden_size, init_zero):
super().__init__()
self.flatten_size = flatten_size
self.conv1x1_reward = nn.Conv2d(num_channels, reduced_channels, 1)
self.bn_reward = nn.BatchNorm2d(reduced_channels)
self.lstm = nn.LSTM(input_size=flatten_size, hidden_size=lstm_hidden_size)
self.bn_reward_sum = nn.BatchNorm1d(lstm_hidden_size)
self.fc = mlp(lstm_hidden_size, fc_layers, output_support_size, init_zero=init_zero)
def forward(self, x, hidden):
x = self.conv1x1_reward(x)
x = self.bn_reward(x)
x = nn.functional.relu(x)
x = x.reshape(-1, self.flatten_size).unsqueeze(0)
x, hidden = self.lstm(x, hidden)
x = x.squeeze(0)
x = self.bn_reward_sum(x)
x = nn.functional.relu(x)
x = self.fc(x)
return x, hidden
class ProjectionNetwork(nn.Module):
def __init__(self, input_dim, hid_dim, out_dim):
super().__init__()
self.input_dim = input_dim
self.layer = nn.Sequential(
nn.Linear(input_dim, hid_dim),
nn.BatchNorm1d(hid_dim),
nn.ReLU(),
nn.Linear(hid_dim, hid_dim),
nn.BatchNorm1d(hid_dim),
nn.ReLU(),
nn.Linear(hid_dim, out_dim),
nn.BatchNorm1d(out_dim)
)
def forward(self, x):
x = x.reshape(-1, self.input_dim)
return self.layer(x)
class ProjectionHeadNetwork(nn.Module):
def __init__(self, input_dim, hid_dim, out_dim):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(input_dim, hid_dim),
nn.BatchNorm1d(hid_dim),
nn.ReLU(),
nn.Linear(hid_dim, out_dim),
)
def forward(self, x):
return self.layer(x)