158 lines
5.8 KiB
Python
158 lines
5.8 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
from ez.utils.format import normalize_state
|
|
from ez.utils.format import formalize_obs_lst, DiscreteSupport, allocate_gpu, prepare_obs_lst, symexp
|
|
|
|
|
|
class EfficientZero(nn.Module):
|
|
def __init__(self,
|
|
representation_model,
|
|
dynamics_model,
|
|
reward_prediction_model,
|
|
value_policy_model,
|
|
projection_model,
|
|
projection_head_model,
|
|
config,
|
|
**kwargs,
|
|
):
|
|
"""The basic models in EfficientZero
|
|
Parameters
|
|
----------
|
|
representation_model: nn.Module
|
|
represent the observations'
|
|
dynamics_model: nn.Module
|
|
dynamics model predicts the next state given the current state and action
|
|
reward_prediction_model: nn.Module
|
|
predict the reward given the next state (Namely, current state and action)
|
|
value_prediction_model: nn.Module
|
|
predict the value given the state
|
|
policy_prediction_model: nn.Module
|
|
predict the policy given the state
|
|
kwargs: dict
|
|
state_norm: bool.
|
|
use state normalization for encoded state
|
|
value_prefix: bool
|
|
predict value prefix instead of reward
|
|
"""
|
|
super().__init__()
|
|
|
|
self.representation_model = representation_model
|
|
self.dynamics_model = dynamics_model
|
|
self.reward_prediction_model = reward_prediction_model
|
|
self.value_policy_model = value_policy_model
|
|
self.projection_model = projection_model
|
|
self.projection_head_model = projection_head_model
|
|
self.config = config
|
|
self.state_norm = kwargs.get('state_norm')
|
|
self.value_prefix = kwargs.get('value_prefix')
|
|
self.v_num = config.train.v_num
|
|
|
|
def do_representation(self, obs):
|
|
state = self.representation_model(obs)
|
|
if self.state_norm:
|
|
state = normalize_state(state)
|
|
|
|
return state
|
|
|
|
def do_dynamics(self, state, action):
|
|
next_state = self.dynamics_model(state, action)
|
|
if self.state_norm:
|
|
next_state = normalize_state(next_state)
|
|
|
|
return next_state
|
|
|
|
def do_reward_prediction(self, next_state, reward_hidden=None):
|
|
# use the predicted state (Namely, current state + action) for reward prediction
|
|
if self.value_prefix:
|
|
value_prefix, reward_hidden = self.reward_prediction_model(next_state, reward_hidden)
|
|
return value_prefix, reward_hidden
|
|
else:
|
|
reward = self.reward_prediction_model(next_state)
|
|
return reward, None
|
|
|
|
def do_value_policy_prediction(self, state):
|
|
value, policy = self.value_policy_model(state)
|
|
return value, policy
|
|
|
|
def do_projection(self, state, with_grad=True):
|
|
# only the branch of proj + pred can share the gradients
|
|
proj = self.projection_model(state)
|
|
|
|
# with grad, use proj_head
|
|
if with_grad:
|
|
proj = self.projection_head_model(proj)
|
|
return proj
|
|
else:
|
|
return proj.detach()
|
|
|
|
def initial_inference(self, obs, training=False):
|
|
state = self.do_representation(obs)
|
|
values, policy = self.do_value_policy_prediction(state)
|
|
|
|
if training:
|
|
return state, values, policy
|
|
|
|
if self.v_num > 2:
|
|
values = values[np.random.choice(self.v_num, 2, replace=False)]
|
|
if self.config.model.value_support.type == 'symlog':
|
|
output_values = symexp(values).min(0)[0]
|
|
else:
|
|
output_values = DiscreteSupport.vector_to_scalar(values, **self.config.model.value_support).min(0)[0]
|
|
|
|
if self.config.env.env in ['DMC', 'Gym']:
|
|
output_values = output_values.clip(0, 1e5)
|
|
|
|
return state, output_values, policy
|
|
|
|
|
|
def recurrent_inference(self, state, action, reward_hidden, training=False):
|
|
next_state = self.do_dynamics(state, action)
|
|
value_prefix, reward_hidden = self.do_reward_prediction(next_state, reward_hidden)
|
|
values, policy = self.do_value_policy_prediction(next_state)
|
|
if training:
|
|
return next_state, value_prefix, values, policy, reward_hidden
|
|
|
|
if self.v_num > 2:
|
|
values = values[np.random.choice(self.v_num, 2, replace=False)]
|
|
if self.config.model.value_support.type == 'symlog':
|
|
output_values = symexp(values).min(0)[0]
|
|
else:
|
|
output_values = DiscreteSupport.vector_to_scalar(values, **self.config.model.value_support).min(0)[0]
|
|
|
|
if self.config.env.env in ['DMC', 'Gym']:
|
|
output_values = output_values.clip(0, 1e5)
|
|
|
|
if self.config.model.reward_support.type == 'symlog':
|
|
value_prefix = symexp(value_prefix)
|
|
else:
|
|
value_prefix = DiscreteSupport.vector_to_scalar(value_prefix, **self.config.model.reward_support)
|
|
return next_state, value_prefix, output_values, policy, reward_hidden
|
|
|
|
def get_weights(self, part='none'):
|
|
if part == 'reward':
|
|
weights = self.reward_prediction_model.state_dict()
|
|
else:
|
|
weights = self.state_dict()
|
|
|
|
return {k: v.cpu() for k, v in weights.items()}
|
|
|
|
def set_weights(self, weights):
|
|
self.load_state_dict(weights)
|
|
|
|
def get_gradients(self):
|
|
grads = []
|
|
for p in self.parameters():
|
|
grad = None if p.grad is None else p.grad.data.cpu().numpy()
|
|
grads.append(grad)
|
|
return grads
|
|
|
|
def set_gradients(self, gradients):
|
|
for g, p in zip(gradients, self.parameters()):
|
|
if g is not None:
|
|
p.grad = torch.from_numpy(g) |