139 lines
5.7 KiB
Python
139 lines
5.7 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
import torch
|
|
from torch.cuda.amp import autocast as autocast
|
|
from ez.utils.format import DiscreteSupport
|
|
import numpy as np
|
|
|
|
|
|
class MCTS:
|
|
def __init__(self, num_actions, **kwargs):
|
|
"""
|
|
|
|
:param num_actions:
|
|
:param num_top_actions:
|
|
:param kwargs:
|
|
"""
|
|
self.num_actions = num_actions
|
|
|
|
self.num_simulations = kwargs.get('num_simulations')
|
|
self.num_top_actions = kwargs.get('num_top_actions')
|
|
self.c_visit = kwargs.get('c_visit')
|
|
self.c_scale = kwargs.get('c_scale')
|
|
self.c_base = kwargs.get('c_base')
|
|
self.c_init = kwargs.get('c_init')
|
|
self.dirichlet_alpha = kwargs.get('dirichlet_alpha')
|
|
self.explore_frac = kwargs.get('explore_frac')
|
|
self.discount = kwargs.get('discount')
|
|
self.value_minmax_delta = kwargs.get('value_minmax_delta')
|
|
self.value_support = kwargs.get('value_support')
|
|
self.reward_support = kwargs.get('reward_support')
|
|
self.value_prefix = kwargs.get('value_prefix')
|
|
self.lstm_hidden_size = kwargs.get('lstm_hidden_size')
|
|
self.lstm_horizon_len = kwargs.get('lstm_horizon_len')
|
|
self.mpc_horizon = kwargs.get('mpc_horizon')
|
|
self.env = kwargs.get('env')
|
|
self.vis = kwargs.get('vis') # vis: [log, text, graph]
|
|
self.std_magnification = kwargs.get('std_magnification')
|
|
|
|
self.current_num_top_actions = self.num_top_actions # /2 every phase
|
|
self.current_phase = 0 # current phase index
|
|
self.visit_num_for_next_phase = max(
|
|
np.floor(self.num_simulations / (np.log2(self.num_top_actions) * self.current_num_top_actions)), 1) \
|
|
* self.current_num_top_actions # how many visit counts for next phase
|
|
self.used_visit_num = 0
|
|
self.verbose = 0
|
|
assert self.num_top_actions <= self.num_actions
|
|
|
|
def search(self, model, batch_size, root_states, root_values, root_policy_logits, **kwargs):
|
|
raise NotImplementedError()
|
|
|
|
def sample_mpc_actions(self, policy):
|
|
is_continuous = (self.env in ['DMC', 'Gym'])
|
|
if is_continuous:
|
|
action_dim = policy.shape[-1] // 2
|
|
mean = policy[:, :action_dim]
|
|
return mean
|
|
else:
|
|
return policy.argmax(dim=-1).unsqueeze(1)
|
|
|
|
def update_statistics(self, **kwargs):
|
|
if kwargs.get('prediction'):
|
|
# prediction for next states, rewards, values, logits
|
|
model = kwargs.get('model')
|
|
states = kwargs.get('states')
|
|
last_actions = kwargs.get('actions')
|
|
reward_hidden = kwargs.get('reward_hidden')
|
|
|
|
next_value_prefixes = 0
|
|
for _ in range(self.mpc_horizon):
|
|
with torch.no_grad():
|
|
with autocast():
|
|
states, pred_value_prefixes, next_values, next_logits, reward_hidden = \
|
|
model.recurrent_inference(states, last_actions, reward_hidden)
|
|
# last_actions = self.sample_mpc_actions(next_logits)
|
|
next_value_prefixes += pred_value_prefixes
|
|
|
|
# process outputs
|
|
next_value_prefixes = next_value_prefixes.detach().cpu().numpy()
|
|
next_values = next_values.detach().cpu().numpy()
|
|
|
|
self.log('simulate action {}, r = {:.3f}, v = {:.3f}, logits = {}'
|
|
''.format(last_actions[0].tolist(), next_value_prefixes[0].item(), next_values[0].item(), next_logits[0].tolist()),
|
|
verbose=3)
|
|
return states, next_value_prefixes, next_values, next_logits, reward_hidden
|
|
else:
|
|
# env simulation for next states
|
|
env = kwargs.get('env')
|
|
current_states = kwargs.get('states')
|
|
last_actions = kwargs.get('actions')
|
|
states = env.step(last_actions)
|
|
raise NotImplementedError()
|
|
|
|
def estimate_value(self, **kwargs):
|
|
# prediction for value in planning
|
|
model = kwargs.get('model')
|
|
current_states = kwargs.get('states')
|
|
actions = kwargs.get('actions')
|
|
reward_hidden = kwargs.get('reward_hidden')
|
|
|
|
Value = 0
|
|
discount = 1
|
|
for i in range(actions.shape[0]):
|
|
current_states_hidden = None
|
|
with torch.no_grad():
|
|
with autocast():
|
|
next_states, next_value_prefixes, next_values, next_logits, reward_hidden = model.recurrent_inference(current_states, actions[i], reward_hidden)
|
|
|
|
next_value_prefixes = next_value_prefixes.detach()
|
|
next_values = next_values.detach()
|
|
current_states = next_states
|
|
Value += next_value_prefixes * discount
|
|
discount *= self.discount
|
|
|
|
Value += discount * next_values
|
|
|
|
return Value
|
|
|
|
def log(self, string, verbose, iteration_begin=False, iteration_end=False):
|
|
if verbose <= self.verbose:
|
|
if iteration_begin:
|
|
print('>' * 50)
|
|
print(string)
|
|
print('-' * 20)
|
|
if iteration_end:
|
|
print('<' * 50)
|
|
|
|
def reset(self):
|
|
self.current_num_top_actions = self.num_top_actions
|
|
self.current_phase = 0
|
|
self.visit_num_for_next_phase = max(
|
|
np.floor(self.num_simulations / (np.log2(self.num_top_actions) * self.current_num_top_actions)), 1) \
|
|
* self.current_num_top_actions # how many visit counts for next phase
|
|
self.used_visit_num = 0
|
|
self.verbose = 0
|