EfficientZeroV2/ez/agents/ez_dmc_image.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

113 lines
5.4 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import time
import copy
import math
from ez.agents.base import Agent
from omegaconf import open_dict
from ez.envs import make_dmc
from ez.utils.format import DiscreteSupport
from ez.agents.models import EfficientZero
from ez.agents.models.base_model import *
class EZDMCImageAgent(Agent):
def __init__(self, config):
super().__init__(config)
self.update_config()
self.num_blocks = config.model.num_blocks
self.num_channels = config.model.num_channels
self.reduced_channels = config.model.reduced_channels
self.fc_layers = config.model.fc_layers
self.down_sample = config.model.down_sample
self.state_norm = config.model.state_norm
self.value_prefix = config.model.value_prefix
self.init_zero = config.model.init_zero
self.action_embedding = config.model.action_embedding
self.action_embedding_dim = config.model.action_embedding_dim
def update_config(self):
assert not self._update
env = make_dmc(self.config.env.game, seed=0, save_path=None, **self.config.env)
action_space_size = env.action_space.shape[0]
obs_channel = 1 if self.config.env.gray_scale else 3
reward_support = DiscreteSupport(self.config)
reward_size = reward_support.size
self.reward_support = reward_support
value_support = DiscreteSupport(self.config)
value_size = value_support.size
self.value_support = value_support
localtime = time.strftime('%Y-%m-%d %H:%M:%S')
tag = '{}-seed={}-{}/'.format(self.config.tag, self.config.env.base_seed, localtime)
with open_dict(self.config):
self.config.env.action_space_size = action_space_size
self.config.env.obs_shape[0] = obs_channel
self.config.rl.discount **= self.config.env.n_skip
self.config.model.reward_support.size = reward_size
self.config.model.value_support.size = value_size
self.config.save_path += tag
self.obs_shape = copy.deepcopy(self.config.env.obs_shape)
self.input_shape = copy.deepcopy(self.config.env.obs_shape)
self.input_shape[0] *= self.config.env.n_stack
self.action_space_size = self.config.env.action_space_size
self._update = True
def build_model(self):
if self.down_sample:
state_shape = (self.num_channels, math.ceil(self.obs_shape[1] / 16), math.ceil(self.obs_shape[2] / 16))
else:
state_shape = (self.num_channels, self.obs_shape[1], self.obs_shape[2])
state_dim = state_shape[0] * state_shape[1] * state_shape[2]
flatten_size = self.reduced_channels * state_shape[1] * state_shape[2]
representation_model = RepresentationNetwork(self.input_shape, self.num_blocks, self.num_channels, self.down_sample)
is_continuous = (self.config.env.env == "DMC")
value_output_size = self.config.model.value_support.size if self.config.model.value_support.type != 'symlog' else 1
dynamics_model = DynamicsNetwork(self.num_blocks, self.num_channels, self.action_space_size, is_continuous,
action_embedding=self.config.model.action_embedding, action_embedding_dim=self.action_embedding_dim)
value_policy_model = ValuePolicyNetwork(self.num_blocks, self.num_channels, self.reduced_channels, flatten_size,
self.fc_layers, value_output_size,
self.action_space_size * 2, self.init_zero, is_continuous,
policy_distribution=self.config.model.policy_distribution,
v_num=self.config.train.v_num)
reward_output_size = self.config.model.reward_support.size if self.config.model.reward_support.type != 'symlog' else 1
if self.value_prefix:
reward_prediction_model = SupportLSTMNetwork(0, self.num_channels, self.reduced_channels,
flatten_size, self.fc_layers, reward_output_size,
self.config.model.lstm_hidden_size, self.init_zero)
else:
reward_prediction_model = SupportNetwork(self.num_blocks, self.num_channels, self.reduced_channels,
flatten_size, self.fc_layers, reward_output_size,
self.init_zero)
projection_layers = self.config.model.projection_layers
head_layers = self.config.model.prjection_head_layers
assert projection_layers[1] == head_layers[1]
projection_model = ProjectionNetwork(state_dim, projection_layers[0], projection_layers[1])
projection_head_model = ProjectionHeadNetwork(projection_layers[1], head_layers[0], head_layers[1])
ez_model = EfficientZero(representation_model, dynamics_model, reward_prediction_model, value_policy_model,
projection_model, projection_head_model, self.config,
state_norm=self.state_norm, value_prefix=self.value_prefix)
return ez_model