EfficientZeroV2/ez/agents/ez_dmc_state.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

540 lines
19 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import time
import copy
import math
import torch
import torchrl
import torch.nn as nn
from ez.agents.base import Agent
from omegaconf import open_dict
import numpy as np
from ez.envs import make_dmc
from ez.utils.format import DiscreteSupport
from ez.agents.models import EfficientZero
def mlp(
input_size,
layer_sizes,
output_size,
output_activation=nn.Identity,
activation=nn.ReLU,
init_zero=False,
use_bn=True,
p_norm=False,
noisy=False
):
"""MLP layers
Parameters
----------
input_size: int
dim of inputs
layer_sizes: list
dim of hidden layers
output_size: int
dim of outputs
init_zero: bool
zero initialization for the last layer (including w and b).
This can provide stable zero outputs in the beginning.
"""
sizes = [input_size] + layer_sizes + [output_size]
layers = []
for i in range(len(sizes) - 1):
if i < len(sizes) - 2:
act = activation
if use_bn:
layers += [
torchrl.modules.NoisyLinear(sizes[i], sizes[i + 1], std_init=0.5) if noisy else nn.Linear(sizes[i], sizes[i + 1]),
nn.BatchNorm1d(sizes[i + 1]),
act()
]
else:
layers += [torchrl.modules.NoisyLinear(sizes[i], sizes[i + 1], std_init=0.5) if noisy else nn.Linear(sizes[i], sizes[i + 1]),
act()]
else:
if p_norm == True:
layers += [PNorm()]
act = output_activation
layers += [torchrl.modules.NoisyLinear(sizes[i], sizes[i + 1], std_init=0.5) if noisy else nn.Linear(sizes[i], sizes[i + 1]),
act()]
if init_zero:
if noisy:
layers[-2].reset_parameters()
else:
layers[-2].weight.data.fill_(0)
layers[-2].bias.data.fill_(0)
return nn.Sequential(*layers)
# improved residual block from Pre-LN Transformer
# ref: http://proceedings.mlr.press/v119/xiong20b/xiong20b.pdf
class ImproveResidualBlock(nn.Module):
def __init__(self, input_shape, hidden_shape):
super(ImproveResidualBlock, self).__init__()
self.ln1 = nn.LayerNorm(input_shape)
self.linear1 = nn.Linear(input_shape, hidden_shape)
self.linear2 = nn.Linear(hidden_shape, input_shape)
def forward(self, x):
identity = x
out = self.ln1(x)
out = self.linear1(out)
out = nn.functional.relu(out)
out = self.linear2(out)
out += identity
return out
# L2 norm layer on the next to last layer
class PNorm(nn.Module):
def __init__(self, eps=1e-10):
super().__init__()
self.eps = eps
def forward(self, x):
assert len(x.shape) == 2
return nn.functional.normalize(x, dim=1, eps=self.eps)
class RunningMeanStd(nn.Module):
def __init__(self, shape, epsilon=1e-5, momentum=0.1):
super(RunningMeanStd, self).__init__()
self.epsilon = epsilon
self.momentum = momentum
self.count = 1e3
self.register_buffer('running_mean', torch.zeros(shape))
self.register_buffer('running_var', torch.ones(shape))
def forward(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
batch_count = x.shape[0]
self.running_mean, self.running_var, self.count = self.update_mean_var_count_from_moments(self.running_mean, self.running_var, self.count, mean, var, batch_count)
global_mean = self.running_mean
global_var = self.running_var
else:
global_mean = self.running_mean
global_var = self.running_var
x = (x - global_mean) / torch.sqrt(global_var + self.epsilon)
return x
def update_mean_var_count_from_moments(self, mean, var, count, batch_mean, batch_var, batch_count):
"""Updates the mean, var and count using the previous mean, var, count and batch values."""
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + torch.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
# Encode the observations into hidden states
class RepresentationNetwork(nn.Module):
def __init__(
self,
observation_shape,
n_stacked_obs,
num_blocks,
rep_net_shape,
hidden_shape,
use_bn=True,
):
"""Representation network
Parameters
----------
observation_shape: tuple or list
shape of observations: [C, W, H]
n_stacked_obs: int
number of stacked observation
num_blocks: int
number of res blocks
rep_net_shape: int
shape of hidden layers
hidden_shape:
dim of output hidden state
use_bn: bool
True -> Batch normalization
"""
super().__init__()
self.running_mean_std = RunningMeanStd(observation_shape * n_stacked_obs)
self.mlp = nn.Linear(observation_shape * n_stacked_obs, hidden_shape)
self.ln = nn.LayerNorm(hidden_shape)
self.Rep_resblocks = nn.ModuleList(
[ImproveResidualBlock(hidden_shape, rep_net_shape) for _ in range(num_blocks)]
)
def forward(self, x):
x = self.running_mean_std(x)
x = self.mlp(x)
x = self.ln(x)
x = torch.tanh(x)
# res block
for block in self.Rep_resblocks:
x = block(x)
return x
def get_param_mean(self):
mean = []
for name, param in self.named_parameters():
mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist()
mean = sum(mean) / len(mean)
return mean
# Predict next hidden states given current states and actions
class DynamicsNetwork(nn.Module):
def __init__(
self,
hidden_shape,
action_shape,
num_blocks,
dyn_shape,
act_embed_shape,
rew_net_shape,
reward_support_size,
init_zero=False,
use_bn=True,
):
"""Dynamics network
Parameters
----------
hidden_shape: int
dim of input hidden state
action_shape: int
dim of action
num_blocks: int
number of res blocks
dyn_shape: int
number of nodes of hidden layer
act_embed_shape: int
dim of action embedding
rew_net_shape: list
hidden layers of the reward prediction head (MLP head)
reward_support_size: int
dim of reward output
init_zero: bool
True -> zero initialization for the last layer of reward mlp
use_bn: bool
True -> Batch normalization
"""
super().__init__()
self.hidden_shape = hidden_shape
self.act_linear1 = nn.Linear(action_shape, act_embed_shape)
self.act_ln1 = nn.LayerNorm(act_embed_shape)
self.dyn_ln_1 = nn.LayerNorm(hidden_shape + act_embed_shape)
self.dyn_net_1 = nn.Linear(hidden_shape + act_embed_shape, dyn_shape)
self.dyn_ln_2 = nn.LayerNorm(dyn_shape)
self.dyn_net_2 = nn.Linear(dyn_shape, hidden_shape)
if num_blocks > 0:
self.dyn_resblocks = nn.ModuleList(
[ImproveResidualBlock(hidden_shape, dyn_shape) for _ in range(num_blocks)]
)
else:
self.dyn_resblocks = nn.ModuleList([])
def forward(self, hidden, action, reward_hidden=None):
# action embedding
act_emb = self.act_linear1(action)
act_emb = self.act_ln1(act_emb)
act_emb = nn.functional.relu(act_emb)
# act_emb = nn.functional.tanh(act_emb)
# imporved res block 1st
x = self.dyn_ln_1(torch.cat((hidden, act_emb), dim=-1))
x = self.dyn_net_1(x)
x = nn.functional.relu(x)
x = self.dyn_net_2(x)
state = hidden + x
# residual tower for dynamic model (2nd -> num blocks)
for block in self.dyn_resblocks:
state = block(state)
# return state
return state
def get_dynamic_mean(self):
mean = []
for name, param in self.dyn_net_1.named_parameters():
mean += np.abs(param.detach().cpu().numpy().reshape(-1)).tolist()
mean = sum(mean) / len(mean)
return mean
class RewardNetwork(nn.Module):
def __init__(
self,
hidden_shape,
rew_net_shape,
reward_support_size,
init_zero=False,
use_bn=True,
):
super().__init__()
self.hidden_shape = hidden_shape
self.rew_net_shape = rew_net_shape
self.reward_support_size = reward_support_size
self.rew_resblock = ImproveResidualBlock(self.hidden_shape, self.hidden_shape)
self.ln = nn.LayerNorm(self.hidden_shape)
self.rew_net = mlp(self.hidden_shape, self.rew_net_shape, self.reward_support_size,
init_zero=init_zero,
use_bn=use_bn)
def forward(self, next_state):
next_state = self.rew_resblock(next_state)
next_state = self.ln(next_state)
reward = self.rew_net(next_state)
return reward
class RewardNetworkLSTM(nn.Module):
def __init__(
self,
hidden_shape,
rew_net_shape,
reward_support_size,
lstm_hidden_size,
init_zero=False,
use_bn=True,
):
super().__init__()
self.hidden_shape = hidden_shape
self.rew_net_shape = rew_net_shape
self.reward_support_size = reward_support_size
self.rew_resblock = ImproveResidualBlock(self.hidden_shape, self.hidden_shape)
self.ln = nn.LayerNorm(self.hidden_shape)
self.lstm = nn.LSTM(input_size=self.hidden_shape, hidden_size=lstm_hidden_size)
self.rew_net = mlp(lstm_hidden_size, self.rew_net_shape, self.reward_support_size,
init_zero=init_zero,
use_bn=use_bn)
def forward(self, next_state, hidden):
next_state = self.rew_resblock(next_state)
next_state = self.ln(next_state)
next_state = next_state.unsqueeze(0)
reward, hidden = self.lstm(next_state, hidden)
reward = reward.squeeze(0)
reward = self.rew_net(reward)
return reward, hidden
# predict the value and policy given hidden states
class ValuePolicyNetwork(nn.Module):
def __init__(
self,
hidden_shape,
val_net_shape,
pi_net_shape,
action_shape,
full_support_size,
init_zero=False,
use_bn=True,
p_norm=False,
policy_distr='squashed_gaussian',
noisy=False,
value_support=None,
**kwargs
):
super().__init__()
self.v_num = kwargs.get('v_num')
self.hidden_shape = hidden_shape
self.val_net_shape = val_net_shape
self.action_shape = action_shape
self.pi_net_shape = pi_net_shape
self.action_space_size = action_shape
self.init_std = 1.0
self.min_std = 0.1
self.policy_distr = policy_distr
self.val_resblock = ImproveResidualBlock(hidden_shape, hidden_shape)
self.pi_resblock = ImproveResidualBlock(hidden_shape, hidden_shape)
self.val_ln = nn.LayerNorm(hidden_shape)
self.pi_ln = nn.LayerNorm(hidden_shape)
self.val_nets = nn.ModuleList([
mlp(self.hidden_shape, self.val_net_shape, full_support_size,
# init_zero=init_zero,
use_bn=use_bn,
)
for _ in range(self.v_num)])
self.pi_net = mlp(self.hidden_shape, self.pi_net_shape, self.action_shape * 2,
init_zero=init_zero,
use_bn=use_bn,
p_norm=p_norm,
noisy=noisy)
self.noisy = noisy
self.value_support = value_support
def reset_noise(self):
if self.noisy:
for layer in self.pi_net:
try:
layer.reset_noise()
except:
pass
def forward(self, x):
value = self.val_resblock(x)
value = self.val_ln(value)
values = []
for val_net in self.val_nets:
values.append(val_net(value))
values = torch.stack(values)
policy = self.pi_resblock(x)
policy = self.pi_ln(policy)
policy = self.pi_net(policy)
action_space_size = policy.shape[-1] // 2
if self.policy_distr == 'squashed_gaussian':
policy[:, :action_space_size] = 5 * torch.tanh(policy[:, :action_space_size] / 5) # soft clamp mu
policy[:, action_space_size:] = torch.nn.functional.softplus(
policy[:, action_space_size:] + self.init_std) + self.min_std # same as Dreamer-v3
return values, policy
def log_std(self, x, low, dif):
return low + 0.5 * dif * (torch.tanh(x) + 1)
class EZDMCStateAgent(Agent):
def __init__(self, config):
super().__init__(config)
self.update_config()
self.num_blocks = config.model.num_blocks
self.fc_layers = config.model.fc_layers
self.down_sample = config.model.down_sample
self.state_norm = config.model.state_norm
self.value_prefix = config.model.value_prefix
self.init_zero = config.model.init_zero
self.value_ensumble = config.model.value_ensumble
self.value_policy_detach = config.train.value_policy_detach
self.v_num = config.train.v_num
def update_config(self):
assert not self._update
env = make_dmc(self.config.env.game, seed=0, save_path=None, **self.config.env)
obs = env.reset()
obs_shape = obs.shape[0]
action_space_size = env.action_space.shape[0]
reward_support = DiscreteSupport(self.config)
reward_size = reward_support.size
self.reward_support = reward_support
value_support = DiscreteSupport(self.config)
value_size = value_support.size
self.value_support = value_support
localtime = time.strftime('%Y-%m-%d %H:%M:%S')
tag = '{}-seed={}-{}/'.format(self.config.tag, self.config.env.base_seed, localtime)
with open_dict(self.config):
self.config.env.action_space_size = action_space_size
self.config.env.obs_shape = obs_shape
self.config.rl.discount **= self.config.env.n_skip
self.config.model.reward_support.size = reward_size
self.config.model.value_support.size = value_size
self.config.save_path += tag
self.action_space_size = self.config.env.action_space_size
self.obs_shape = self.config.env.obs_shape
self.n_stack = self.config.env.n_stack
self.rep_net_shape = self.config.model.rep_net_shape
self.hidden_shape = self.config.model.hidden_shape
self.dyn_shape = self.config.model.dyn_shape
self.act_embed_shape = self.config.model.act_embed_shape
self.rew_net_shape = self.config.model.rew_net_shape
self.val_net_shape = self.config.model.val_net_shape
self.pi_net_shape = self.config.model.pi_net_shape
self.proj_hid_shape = self.config.model.proj_hid_shape
self.pred_hid_shape = self.config.model.pred_hid_shape
self.proj_shape = self.config.model.proj_shape
self.pred_shape = self.config.model.pred_shape
self._update = True
self.use_bn = self.config.model.use_bn
self.use_p_norm = self.config.model.use_p_norm
self.noisy_net = self.config.model.noisy_net
def build_model(self):
representation_model = RepresentationNetwork(self.obs_shape, self.n_stack, self.num_blocks,
self.rep_net_shape, self.hidden_shape,
use_bn=self.use_bn)
value_output_size = self.config.model.value_support.size if self.config.model.value_support.type != 'symlog' else 1
reward_output_size = self.config.model.reward_support.size if self.config.model.reward_support.type != 'symlog' else 1
dynamics_model = DynamicsNetwork(self.hidden_shape, self.action_space_size, self.num_blocks, self.dyn_shape,
self.act_embed_shape, self.rew_net_shape, reward_output_size,
use_bn=self.use_bn)
value_policy_model = ValuePolicyNetwork(self.hidden_shape, self.val_net_shape, self.pi_net_shape,
self.action_space_size, value_output_size,
init_zero=self.init_zero, use_bn=self.use_bn, p_norm=self.use_p_norm,
policy_distr=self.config.model.policy_distribution, noisy=self.noisy_net,
value_support=self.config.model.value_support,
v_num=self.v_num)
if self.config.model.value_prefix:
reward_prediction_model = RewardNetworkLSTM(self.hidden_shape, self.rew_net_shape, reward_output_size, self.config.model.lstm_hidden_size,
init_zero=self.init_zero, use_bn=self.use_bn)
else:
reward_prediction_model = RewardNetwork(self.hidden_shape, self.rew_net_shape, reward_output_size,
init_zero=self.init_zero, use_bn=self.use_bn)
projection_model = nn.Sequential(
nn.Linear(self.hidden_shape, self.proj_hid_shape),
nn.LayerNorm(self.proj_hid_shape),
nn.ReLU(),
nn.Linear(self.proj_hid_shape, self.proj_hid_shape),
nn.LayerNorm(self.proj_hid_shape),
nn.ReLU(),
nn.Linear(self.proj_hid_shape, self.proj_shape),
nn.LayerNorm(self.proj_shape)
)
projection_head_model = nn.Sequential(
nn.Linear(self.proj_shape, self.pred_hid_shape),
nn.LayerNorm(self.pred_hid_shape),
nn.ReLU(),
nn.Linear(self.pred_hid_shape, self.pred_shape),
)
ez_model = EfficientZero(representation_model, dynamics_model, reward_prediction_model, value_policy_model,
projection_model, projection_head_model, self.config,
state_norm=self.state_norm, value_prefix=self.value_prefix)
return ez_model