113 lines
5.4 KiB
Python
113 lines
5.4 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import time
|
|
import copy
|
|
import math
|
|
from ez.agents.base import Agent
|
|
from omegaconf import open_dict
|
|
|
|
from ez.envs import make_dmc
|
|
from ez.utils.format import DiscreteSupport
|
|
from ez.agents.models import EfficientZero
|
|
from ez.agents.models.base_model import *
|
|
|
|
|
|
class EZDMCImageAgent(Agent):
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
|
|
self.update_config()
|
|
|
|
self.num_blocks = config.model.num_blocks
|
|
self.num_channels = config.model.num_channels
|
|
self.reduced_channels = config.model.reduced_channels
|
|
self.fc_layers = config.model.fc_layers
|
|
self.down_sample = config.model.down_sample
|
|
self.state_norm = config.model.state_norm
|
|
self.value_prefix = config.model.value_prefix
|
|
self.init_zero = config.model.init_zero
|
|
self.action_embedding = config.model.action_embedding
|
|
self.action_embedding_dim = config.model.action_embedding_dim
|
|
|
|
def update_config(self):
|
|
assert not self._update
|
|
|
|
env = make_dmc(self.config.env.game, seed=0, save_path=None, **self.config.env)
|
|
action_space_size = env.action_space.shape[0]
|
|
|
|
obs_channel = 1 if self.config.env.gray_scale else 3
|
|
|
|
reward_support = DiscreteSupport(self.config)
|
|
reward_size = reward_support.size
|
|
self.reward_support = reward_support
|
|
|
|
value_support = DiscreteSupport(self.config)
|
|
value_size = value_support.size
|
|
self.value_support = value_support
|
|
|
|
localtime = time.strftime('%Y-%m-%d %H:%M:%S')
|
|
tag = '{}-seed={}-{}/'.format(self.config.tag, self.config.env.base_seed, localtime)
|
|
|
|
with open_dict(self.config):
|
|
self.config.env.action_space_size = action_space_size
|
|
self.config.env.obs_shape[0] = obs_channel
|
|
self.config.rl.discount **= self.config.env.n_skip
|
|
self.config.model.reward_support.size = reward_size
|
|
self.config.model.value_support.size = value_size
|
|
|
|
self.config.save_path += tag
|
|
|
|
self.obs_shape = copy.deepcopy(self.config.env.obs_shape)
|
|
self.input_shape = copy.deepcopy(self.config.env.obs_shape)
|
|
self.input_shape[0] *= self.config.env.n_stack
|
|
self.action_space_size = self.config.env.action_space_size
|
|
|
|
self._update = True
|
|
|
|
def build_model(self):
|
|
if self.down_sample:
|
|
state_shape = (self.num_channels, math.ceil(self.obs_shape[1] / 16), math.ceil(self.obs_shape[2] / 16))
|
|
else:
|
|
state_shape = (self.num_channels, self.obs_shape[1], self.obs_shape[2])
|
|
|
|
state_dim = state_shape[0] * state_shape[1] * state_shape[2]
|
|
flatten_size = self.reduced_channels * state_shape[1] * state_shape[2]
|
|
|
|
representation_model = RepresentationNetwork(self.input_shape, self.num_blocks, self.num_channels, self.down_sample)
|
|
is_continuous = (self.config.env.env == "DMC")
|
|
value_output_size = self.config.model.value_support.size if self.config.model.value_support.type != 'symlog' else 1
|
|
dynamics_model = DynamicsNetwork(self.num_blocks, self.num_channels, self.action_space_size, is_continuous,
|
|
action_embedding=self.config.model.action_embedding, action_embedding_dim=self.action_embedding_dim)
|
|
value_policy_model = ValuePolicyNetwork(self.num_blocks, self.num_channels, self.reduced_channels, flatten_size,
|
|
self.fc_layers, value_output_size,
|
|
self.action_space_size * 2, self.init_zero, is_continuous,
|
|
policy_distribution=self.config.model.policy_distribution,
|
|
v_num=self.config.train.v_num)
|
|
|
|
reward_output_size = self.config.model.reward_support.size if self.config.model.reward_support.type != 'symlog' else 1
|
|
if self.value_prefix:
|
|
reward_prediction_model = SupportLSTMNetwork(0, self.num_channels, self.reduced_channels,
|
|
flatten_size, self.fc_layers, reward_output_size,
|
|
self.config.model.lstm_hidden_size, self.init_zero)
|
|
else:
|
|
reward_prediction_model = SupportNetwork(self.num_blocks, self.num_channels, self.reduced_channels,
|
|
flatten_size, self.fc_layers, reward_output_size,
|
|
self.init_zero)
|
|
|
|
projection_layers = self.config.model.projection_layers
|
|
head_layers = self.config.model.prjection_head_layers
|
|
assert projection_layers[1] == head_layers[1]
|
|
|
|
projection_model = ProjectionNetwork(state_dim, projection_layers[0], projection_layers[1])
|
|
projection_head_model = ProjectionHeadNetwork(projection_layers[1], head_layers[0], head_layers[1])
|
|
|
|
ez_model = EfficientZero(representation_model, dynamics_model, reward_prediction_model, value_policy_model,
|
|
projection_model, projection_head_model, self.config,
|
|
state_norm=self.state_norm, value_prefix=self.value_prefix)
|
|
|
|
return ez_model
|
|
|