999 lines
46 KiB
Python
999 lines
46 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
import torch
|
|
import torchrl
|
|
import numpy as np
|
|
import math
|
|
|
|
from .base import MCTS
|
|
# import sys
|
|
# sys.path.append('/workspace/EZ-Codebase')
|
|
# # import ez.mcts.ctree.cytree as tree
|
|
from ez.mcts.ctree import cytree as tree
|
|
from ez.mcts.ori_ctree import cytree as ori_tree
|
|
from ez.mcts.ctree_v2 import cytree as tree2
|
|
from torch.cuda.amp import autocast as autocast
|
|
from ez.utils.format import DiscreteSupport, symexp, pad_and_mask
|
|
from ez.utils.distribution import SquashedNormal, TruncatedNormal, ContDist
|
|
import colorednoise as cn
|
|
|
|
class CyMCTS(MCTS):
|
|
def __init__(self, num_actions, **kwargs):
|
|
super().__init__(num_actions, **kwargs)
|
|
self.policy_action_num = kwargs.get('policy_action_num')
|
|
self.random_action_num = kwargs.get('random_action_num')
|
|
self.policy_distribution = kwargs.get('policy_distribution')
|
|
|
|
def sample_actions(self, policy, add_noise=True, temperature=1.0, input_noises=None, input_dist=None, input_actions=None, sample_nums=None, states=None):
|
|
n_policy = self.policy_action_num
|
|
n_random = self.random_action_num
|
|
if sample_nums:
|
|
n_policy = math.ceil(sample_nums / 2)
|
|
n_random = sample_nums - n_policy
|
|
std_magnification = self.std_magnification
|
|
action_dim = policy.shape[-1] // 2
|
|
|
|
if input_dist is not None:
|
|
n_policy //= 2
|
|
n_random //= 2
|
|
|
|
Dist = SquashedNormal
|
|
mean, std = policy[:, :action_dim], policy[:, action_dim:]
|
|
distr = Dist(mean, std)
|
|
sampled_actions = distr.sample(torch.Size([n_policy + n_random]))
|
|
|
|
policy_actions = sampled_actions[:n_policy]
|
|
random_actions = sampled_actions[-n_random:]
|
|
|
|
random_distr = distr
|
|
if add_noise:
|
|
if input_noises is None:
|
|
random_distr = Dist(mean, std_magnification * std) # more flatten gaussian policy
|
|
random_actions = random_distr.sample(torch.Size([n_random]))
|
|
else:
|
|
noises = torch.from_numpy(input_noises).float().cuda()
|
|
random_actions += noises
|
|
|
|
if input_dist is not None:
|
|
refined_mean, refined_std = input_dist[:, :action_dim], input_dist[:, action_dim:]
|
|
refined_distr = Dist(refined_mean, refined_std)
|
|
refined_actions = refined_distr.sample(torch.Size([n_policy + n_random]))
|
|
|
|
refined_policy_actions = refined_actions[:n_policy]
|
|
refined_random_actions = refined_actions[-n_random:]
|
|
|
|
if add_noise:
|
|
if input_noises is None:
|
|
refined_random_distr = Dist(refined_mean, std_magnification * refined_std)
|
|
refined_random_actions = refined_random_distr.sample(torch.Size([n_random]))
|
|
else:
|
|
noises = torch.from_numpy(input_noises).float().cuda()
|
|
refined_random_actions += noises
|
|
|
|
all_actions = torch.cat((policy_actions, random_actions), dim=0)
|
|
if input_actions is not None:
|
|
all_actions = torch.from_numpy(input_actions).float().cuda()
|
|
if input_dist is not None:
|
|
all_actions = torch.cat((all_actions, refined_policy_actions, refined_random_actions), dim=0)
|
|
all_actions = all_actions.clip(-0.999, 0.999)
|
|
|
|
assert (n_policy + n_random) == sample_nums if sample_nums is not None else self.num_actions
|
|
ratio = n_policy / (sample_nums if sample_nums is not None else self.num_actions)
|
|
probs = distr.log_prob(all_actions) - (ratio * distr.log_prob(all_actions) + (1 - ratio) * random_distr.log_prob(all_actions))
|
|
probs = probs.sum(-1).permute(1, 0)
|
|
all_actions = all_actions.permute(1, 0, 2)
|
|
return all_actions, probs
|
|
|
|
|
|
def inv_softmax(self, dist):
|
|
constant = 100
|
|
return np.log(dist) + constant
|
|
|
|
def atanh(self, x):
|
|
return 0.5 * (np.log1p(x) - np.log1p(-x))
|
|
|
|
def softmax_temperature(self, dist, temperature=1.0):
|
|
soft_dist = temperature * dist
|
|
dist_max = soft_dist.max(-1, keepdims=True)
|
|
scores = np.exp(soft_dist - dist_max)
|
|
return scores / scores.sum(-1, keepdims=True)
|
|
|
|
def q_init(self, states, sampled_actions, model):
|
|
action_num = sampled_actions.shape[1]
|
|
q_inits = []
|
|
for i in range(action_num):
|
|
_, rewards, next_values, _, _ = self.update_statistics(
|
|
prediction=True, # use model prediction instead of env simulation
|
|
model=model, # model
|
|
states=states, # current states
|
|
actions=sampled_actions[:, i], # last actions
|
|
reward_hidden=None, # reward hidden
|
|
)
|
|
q_inits.append(rewards + self.discount * next_values)
|
|
return np.asarray(q_inits).swapaxes(0, 1).tolist()
|
|
|
|
|
|
def search_continuous(self, model, batch_size, root_states, root_values, root_policy_logits,
|
|
use_gumble_noise=False, temperature=1.0, verbose=0, add_noise=True,
|
|
input_noises=None, input_dist=None, input_actions=None, prev_mean=None, **kwargs):
|
|
# preparation
|
|
# Node.set_static_attributes(self.discount, self.num_actions) # set static parameters of MCTS
|
|
# set root nodes for the batch
|
|
# root_sampled_actions, policy_priors = self.sample_actions(root_policy_logits, std_mag=3 if add_noise else 1)
|
|
root_sampled_actions, policy_priors = self.sample_actions(root_policy_logits, add_noise, temperature, input_noises, input_dist=input_dist, input_actions=input_actions)
|
|
|
|
sampled_action_num = root_sampled_actions.shape[1]
|
|
uniform_policy = [
|
|
[0.0 for _ in range(sampled_action_num)]
|
|
for _ in range(batch_size)
|
|
]
|
|
leaf_num = 2
|
|
uniform_policy_non_root = [
|
|
[0.0 for _ in range(leaf_num)] for _ in range(batch_size)
|
|
]
|
|
|
|
# set gumble noise (during training)
|
|
if use_gumble_noise:
|
|
gumble_noises = np.random.gumbel(0, 1, (batch_size, self.num_actions)) #* temperature
|
|
else:
|
|
gumble_noises = np.zeros((batch_size, self.num_actions))
|
|
|
|
gumble_noises = gumble_noises.tolist()
|
|
|
|
roots = tree.Roots(batch_size, self.num_actions, self.num_simulations, self.discount)
|
|
roots.prepare(root_values.tolist(), uniform_policy, leaf_num)
|
|
# save the min and max value of the tree nodes
|
|
value_min_max_lst = tree.MinMaxStatsList(batch_size)
|
|
value_min_max_lst.set_static_val(self.value_minmax_delta, self.c_visit, self.c_scale)
|
|
|
|
reward_hidden = (torch.zeros(1, batch_size, self.lstm_hidden_size).cuda().float(),
|
|
torch.zeros(1, batch_size, self.lstm_hidden_size).cuda().float())
|
|
|
|
# index of states
|
|
state_pool = [root_states]
|
|
hidden_state_index_x = 0
|
|
# 1 x batch x 64
|
|
reward_hidden_c_pool = [reward_hidden[0]]
|
|
reward_hidden_h_pool = [reward_hidden[1]]
|
|
|
|
assert batch_size == len(root_states) == len(root_values)
|
|
# expand the roots and update the statistics
|
|
|
|
self.verbose = verbose
|
|
if self.verbose:
|
|
np.set_printoptions(precision=3)
|
|
assert batch_size == 1
|
|
self.log('Gumble Noise: {}'.format(gumble_noises), verbose=1)
|
|
|
|
# search for N iterations
|
|
mcts_info = {}
|
|
actions_pool = [root_sampled_actions]
|
|
|
|
for simulation_idx in range(self.num_simulations):
|
|
current_states = []
|
|
hidden_states_c_reward = []
|
|
hidden_states_h_reward = []
|
|
results = tree.ResultsWrapper(batch_size)
|
|
|
|
self.log('Iteration {} \t'.format(simulation_idx), verbose=2, iteration_begin=True)
|
|
if self.verbose > 1:
|
|
self.log('Tree:', verbose=2)
|
|
roots.print_tree()
|
|
|
|
# select action for the roots
|
|
hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions = \
|
|
tree.batch_traverse(roots, value_min_max_lst, results, self.num_simulations, simulation_idx,
|
|
gumble_noises, self.current_num_top_actions)
|
|
|
|
search_lens = results.get_search_len()
|
|
selected_actions = []
|
|
ptr = 0
|
|
for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
|
|
current_states.append(state_pool[ix][iy])
|
|
if self.value_prefix:
|
|
hidden_states_c_reward.append(reward_hidden_c_pool[ix][0][iy])
|
|
hidden_states_h_reward.append(reward_hidden_h_pool[ix][0][iy])
|
|
|
|
selected_actions.append(actions_pool[ix][iy][last_actions[ptr]])
|
|
ptr += 1
|
|
|
|
current_states = torch.stack(current_states)
|
|
if self.value_prefix:
|
|
hidden_states_c_reward = torch.stack(hidden_states_c_reward).unsqueeze(0)
|
|
hidden_states_h_reward = torch.stack(hidden_states_h_reward).unsqueeze(0)
|
|
selected_actions = torch.stack(selected_actions)
|
|
|
|
# inference state, reward, value, policy given the current state
|
|
reward_hidden = (hidden_states_c_reward, hidden_states_h_reward)
|
|
mcts_info[simulation_idx] = {
|
|
'states': current_states,
|
|
'actions': last_actions,
|
|
'reward_hidden': reward_hidden,
|
|
}
|
|
|
|
next_states, next_value_prefixes, next_values, next_logits, reward_hidden = self.update_statistics(
|
|
prediction=True, # use model prediction instead of env simulation
|
|
model=model, # model
|
|
states=current_states, # current states
|
|
actions=selected_actions, # last actions
|
|
reward_hidden=reward_hidden, # reward hidden
|
|
)
|
|
mcts_info[simulation_idx] = {
|
|
'next_states': next_states,
|
|
'next_value_prefixes': next_value_prefixes,
|
|
'next_values': next_values,
|
|
'next_logits': next_logits,
|
|
'next_reward_hidden': reward_hidden
|
|
}
|
|
leaf_sampled_actions, leaf_policy_priors = self.sample_actions(next_logits, sample_nums=leaf_num, add_noise=False)
|
|
# leaf_sampled_actions, leaf_policy_priors = self.sample_actions(next_logits, sample_nums=leaf_num)
|
|
|
|
actions_pool.append(leaf_sampled_actions)
|
|
# save to database
|
|
state_pool.append(next_states)
|
|
# change value prefix to reward
|
|
if self.value_prefix:
|
|
reset_idx = (np.array(search_lens) % self.lstm_horizon_len == 0)
|
|
reward_hidden[0][:, reset_idx, :] = 0
|
|
reward_hidden[1][:, reset_idx, :] = 0
|
|
reward_hidden_c_pool.append(reward_hidden[0])
|
|
reward_hidden_h_pool.append(reward_hidden[1])
|
|
else:
|
|
reset_idx = np.asarray([1. for _ in range(batch_size)])
|
|
|
|
to_reset_lst = reset_idx.astype(np.int32).tolist()
|
|
hidden_state_index_x += 1
|
|
|
|
# expand the leaf node and backward for statistics update
|
|
tree.batch_back_propagate(hidden_state_index_x, next_value_prefixes.squeeze(-1).tolist(),
|
|
next_values.squeeze(-1).tolist(), uniform_policy_non_root, value_min_max_lst,
|
|
results, to_reset_lst, leaf_num)
|
|
|
|
# sequential halving
|
|
if self.ready_for_next_gumble_phase(simulation_idx):
|
|
|
|
tree.batch_sequential_halving(roots, gumble_noises, value_min_max_lst, self.current_phase,
|
|
self.current_num_top_actions)
|
|
self.log('change to phase: {}, top m action -> {}'
|
|
''.format(self.current_phase, self.current_num_top_actions), verbose=3)
|
|
|
|
# obtain the final results and infos
|
|
search_root_values = np.asarray(roots.get_values())
|
|
search_root_policies = np.asarray(roots.get_root_policies(value_min_max_lst))
|
|
|
|
search_best_actions = np.asarray(roots.get_best_actions())
|
|
root_sampled_actions = root_sampled_actions.detach().cpu().numpy()
|
|
final_selected_actions = np.asarray(
|
|
[root_sampled_actions[i, best_a] for i, best_a in enumerate(search_best_actions)]
|
|
)
|
|
|
|
if self.verbose:
|
|
self.log('Final Tree:', verbose=1)
|
|
roots.print_tree()
|
|
self.log('search root value -> \t\t {} \n'
|
|
'search root policy -> \t\t {} \n'
|
|
'search best action -> \t\t {}'
|
|
''.format(search_root_values[0], search_root_policies[0], search_best_actions[0]),
|
|
verbose=1, iteration_end=True)
|
|
|
|
return search_root_values, search_root_policies, final_selected_actions, root_sampled_actions, search_best_actions, mcts_info
|
|
|
|
def select_action(self, visit_counts, temperature=1, deterministic=False):
|
|
action_probs = visit_counts ** (1 / temperature)
|
|
total_count = action_probs.sum(-1, keepdims=True)
|
|
action_probs = action_probs / total_count
|
|
if deterministic:
|
|
action_pos = action_probs.argmax(-1)
|
|
else:
|
|
action_pos = []
|
|
for i in range(action_probs.shape[0]):
|
|
action_pos.append(np.random.choice(action_probs.shape[1], p=action_probs[i]))
|
|
action_pos = np.asarray(action_pos)
|
|
# action_pos = torch.nn.functional.gumbel_softmax(torch.from_numpy(action_probs), hard=True, dim=1).argmax(-1)
|
|
|
|
return action_pos
|
|
|
|
def search_ori_mcts(self, model, batch_size, root_states, root_values, root_policy_logits,
|
|
use_noise=True, temperature=1.0, verbose=0, is_reanalyze=False, **kwargs):
|
|
# preparation
|
|
# set dirichley noise (during training)
|
|
if use_noise:
|
|
noises = np.asarray([np.random.dirichlet([self.dirichlet_alpha] * self.num_actions).astype(np.float32).tolist() for _
|
|
in range(batch_size)])
|
|
else:
|
|
noises = np.zeros((batch_size, self.num_actions))
|
|
noises = noises.tolist()
|
|
# Node.set_static_attributes(self.discount, self.num_actions) # set static parameters of MCTS
|
|
# set root nodes for the batch
|
|
roots = ori_tree.Roots(batch_size, self.num_actions, self.num_simulations)
|
|
roots.prepare(self.explore_frac, noises, [0. for _ in range(batch_size)], root_policy_logits.tolist())
|
|
# save the min and max value of the tree nodes
|
|
value_min_max_lst = ori_tree.MinMaxStatsList(batch_size)
|
|
value_min_max_lst.set_delta(self.value_minmax_delta)
|
|
|
|
if self.value_prefix:
|
|
reward_hidden = (torch.zeros(1, batch_size, self.lstm_hidden_size).cuda().float(),
|
|
torch.zeros(1, batch_size, self.lstm_hidden_size).cuda().float())
|
|
else:
|
|
reward_hidden = None
|
|
|
|
# index of states
|
|
state_pool = [root_states]
|
|
hidden_state_index_x = 0
|
|
# 1 x batch x 64
|
|
reward_hidden_c_pool = [reward_hidden[0]]
|
|
reward_hidden_h_pool = [reward_hidden[1]]
|
|
|
|
assert batch_size == len(root_states) == len(root_values)
|
|
# expand the roots and update the statistics
|
|
|
|
self.verbose = verbose
|
|
if self.verbose:
|
|
np.set_printoptions(precision=3)
|
|
assert batch_size == 1
|
|
self.log('Dirichlet Noise: {}'.format(noises), verbose=1)
|
|
|
|
# search for N iterations
|
|
mcts_info = {}
|
|
for simulation_idx in range(self.num_simulations):
|
|
current_states = []
|
|
hidden_states_c_reward = []
|
|
hidden_states_h_reward = []
|
|
results = ori_tree.ResultsWrapper(batch_size)
|
|
|
|
self.log('Iteration {} \t'.format(simulation_idx), verbose=2, iteration_begin=True)
|
|
if self.verbose > 1:
|
|
self.log('Tree:', verbose=2)
|
|
roots.print_tree()
|
|
|
|
# select action for the roots
|
|
hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions = ori_tree.batch_traverse(roots, self.c_base, self.c_init, self.discount, value_min_max_lst, results)
|
|
search_lens = results.get_search_len()
|
|
|
|
for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
|
|
current_states.append(state_pool[ix][iy])
|
|
hidden_states_c_reward.append(reward_hidden_c_pool[ix][0][iy])
|
|
hidden_states_h_reward.append(reward_hidden_h_pool[ix][0][iy])
|
|
|
|
current_states = torch.stack(current_states)
|
|
hidden_states_c_reward = torch.stack(hidden_states_c_reward).unsqueeze(0)
|
|
hidden_states_h_reward = torch.stack(hidden_states_h_reward).unsqueeze(0)
|
|
last_actions = torch.from_numpy(np.asarray(last_actions)).cuda().long().unsqueeze(1)
|
|
|
|
# inference state, reward, value, policy given the current state
|
|
reward_hidden = (hidden_states_c_reward, hidden_states_h_reward)
|
|
|
|
next_states, next_value_prefixes, next_values, next_logits, reward_hidden = self.update_statistics(
|
|
prediction=True, # use model prediction instead of env simulation
|
|
model=model, # model
|
|
states=current_states, # current states
|
|
actions=last_actions, # last actions
|
|
reward_hidden=reward_hidden, # reward hidden
|
|
)
|
|
|
|
# save to database
|
|
state_pool.append(next_states)
|
|
# change value prefix to reward
|
|
if self.value_prefix:
|
|
reset_idx = (np.array(search_lens) % self.lstm_horizon_len == 0)
|
|
reward_hidden[0][:, reset_idx, :] = 0
|
|
reward_hidden[1][:, reset_idx, :] = 0
|
|
reward_hidden_c_pool.append(reward_hidden[0])
|
|
reward_hidden_h_pool.append(reward_hidden[1])
|
|
else:
|
|
reset_idx = np.asarray([1. for _ in range(batch_size)])
|
|
to_reset_lst = reset_idx.astype(np.int32).tolist()
|
|
|
|
hidden_state_index_x += 1
|
|
|
|
# expand the leaf node and backward for statistics update
|
|
ori_tree.batch_back_propagate(hidden_state_index_x, self.discount, next_value_prefixes.squeeze(-1).tolist(), next_values.squeeze(-1).tolist(), next_logits.tolist(), value_min_max_lst, results, to_reset_lst)
|
|
|
|
# obtain the final results and infos
|
|
search_root_values = np.asarray(roots.get_values())
|
|
search_root_policies = np.asarray(roots.get_distributions())
|
|
if not is_reanalyze:
|
|
search_best_actions = self.select_action(search_root_policies, temperature=temperature, deterministic=not use_noise)
|
|
else:
|
|
search_best_actions = np.zeros(batch_size)
|
|
|
|
if self.verbose:
|
|
self.log('Final Tree:', verbose=1)
|
|
roots.print_tree()
|
|
self.log('search root value -> \t\t {} \n'
|
|
'search root policy -> \t\t {} \n'
|
|
'search best action -> \t\t {}'
|
|
''.format(search_root_values[0], search_root_policies[0], search_best_actions[0]),
|
|
verbose=1, iteration_end=True)
|
|
|
|
search_root_policies = search_root_policies / search_root_policies.sum(-1, keepdims=True)
|
|
return search_root_values, search_root_policies, search_best_actions, mcts_info
|
|
|
|
|
|
def search(self, model, batch_size, root_states, root_values, root_policy_logits,
|
|
use_gumble_noise=True, temperature=1.0, verbose=0, **kwargs):
|
|
# preparation
|
|
# Node.set_static_attributes(self.discount, self.num_actions) # set static parameters of MCTS
|
|
# set root nodes for the batch
|
|
roots = tree.Roots(batch_size, self.num_actions, self.num_simulations, self.discount)
|
|
roots.prepare(root_values.tolist(), root_policy_logits.tolist(), self.num_actions)
|
|
# save the min and max value of the tree nodes
|
|
value_min_max_lst = tree.MinMaxStatsList(batch_size)
|
|
value_min_max_lst.set_static_val(self.value_minmax_delta, self.c_visit, self.c_scale)
|
|
|
|
if self.value_prefix:
|
|
reward_hidden = (torch.zeros(1, batch_size, self.lstm_hidden_size).cuda().float(),
|
|
torch.zeros(1, batch_size, self.lstm_hidden_size).cuda().float())
|
|
else:
|
|
reward_hidden = None
|
|
|
|
# index of states
|
|
state_pool = [root_states]
|
|
hidden_state_index_x = 0
|
|
# 1 x batch x 64
|
|
reward_hidden_c_pool = [reward_hidden[0]]
|
|
reward_hidden_h_pool = [reward_hidden[1]]
|
|
|
|
# set gumble noise (during training)
|
|
if use_gumble_noise:
|
|
gumble_noises = np.random.gumbel(0, 1, (batch_size, self.num_actions)) #* temperature
|
|
else:
|
|
gumble_noises = np.zeros((batch_size, self.num_actions))
|
|
gumble_noises = gumble_noises.tolist()
|
|
|
|
assert batch_size == len(root_states) == len(root_values)
|
|
# expand the roots and update the statistics
|
|
|
|
self.verbose = verbose
|
|
if self.verbose:
|
|
np.set_printoptions(precision=3)
|
|
assert batch_size == 1
|
|
self.log('Gumble Noise: {}'.format(gumble_noises), verbose=1)
|
|
|
|
|
|
# search for N iterations
|
|
mcts_info = {}
|
|
for simulation_idx in range(self.num_simulations):
|
|
current_states = []
|
|
hidden_states_c_reward = []
|
|
hidden_states_h_reward = []
|
|
results = tree.ResultsWrapper(batch_size)
|
|
# results1 = tree2.ResultsWrapper(roots1.num)
|
|
|
|
self.log('Iteration {} \t'.format(simulation_idx), verbose=2, iteration_begin=True)
|
|
if self.verbose > 1:
|
|
self.log('Tree:', verbose=2)
|
|
roots.print_tree()
|
|
|
|
# select action for the roots
|
|
hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions = \
|
|
tree.batch_traverse(roots, value_min_max_lst, results, self.num_simulations, simulation_idx,
|
|
gumble_noises, self.current_num_top_actions)
|
|
|
|
search_lens = results.get_search_len()
|
|
|
|
for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
|
|
current_states.append(state_pool[ix][iy])
|
|
hidden_states_c_reward.append(reward_hidden_c_pool[ix][0][iy])
|
|
hidden_states_h_reward.append(reward_hidden_h_pool[ix][0][iy])
|
|
|
|
current_states = torch.stack(current_states)
|
|
hidden_states_c_reward = torch.stack(hidden_states_c_reward).unsqueeze(0)
|
|
hidden_states_h_reward = torch.stack(hidden_states_h_reward).unsqueeze(0)
|
|
last_actions = torch.from_numpy(np.asarray(last_actions)).cuda().long().unsqueeze(1)
|
|
|
|
# inference state, reward, value, policy given the current state
|
|
reward_hidden = (hidden_states_c_reward, hidden_states_h_reward)
|
|
mcts_info[simulation_idx] = {
|
|
'states': current_states,
|
|
'actions': last_actions,
|
|
'reward_hidden': reward_hidden,
|
|
}
|
|
|
|
next_states, next_value_prefixes, next_values, next_logits, reward_hidden = self.update_statistics(
|
|
prediction=True, # use model prediction instead of env simulation
|
|
model=model, # model
|
|
states=current_states, # current states
|
|
actions=last_actions, # last actions
|
|
reward_hidden=reward_hidden, # reward hidden
|
|
)
|
|
mcts_info[simulation_idx] = {
|
|
'next_states': next_states,
|
|
'next_value_prefixes': next_value_prefixes,
|
|
'next_values': next_values,
|
|
'next_logits': next_logits,
|
|
'next_reward_hidden': reward_hidden
|
|
}
|
|
|
|
# save to database
|
|
state_pool.append(next_states)
|
|
# change value prefix to reward
|
|
reset_idx = (np.array(search_lens) % self.lstm_horizon_len == 0)
|
|
if self.value_prefix:
|
|
reward_hidden[0][:, reset_idx, :] = 0
|
|
reward_hidden[1][:, reset_idx, :] = 0
|
|
to_reset_lst = reset_idx.astype(np.int32).tolist()
|
|
if not self.value_prefix:
|
|
to_reset_lst = [1 for _ in range(batch_size)]
|
|
|
|
reward_hidden_c_pool.append(reward_hidden[0])
|
|
reward_hidden_h_pool.append(reward_hidden[1])
|
|
hidden_state_index_x += 1
|
|
|
|
# expand the leaf node and backward for statistics update
|
|
tree.batch_back_propagate(hidden_state_index_x, next_value_prefixes.squeeze(-1).tolist(), next_values.squeeze(-1).tolist(), next_logits.tolist(), value_min_max_lst, results, to_reset_lst, self.num_actions)
|
|
|
|
|
|
# sequential halving
|
|
if self.ready_for_next_gumble_phase(simulation_idx):
|
|
tree.batch_sequential_halving(roots, gumble_noises, value_min_max_lst, self.current_phase,
|
|
self.current_num_top_actions)
|
|
# if self.current_phase == 0:
|
|
# search_root_values = np.asarray(roots.get_values())
|
|
|
|
self.log('change to phase: {}, top m action -> {}'
|
|
''.format(self.current_phase, self.current_num_top_actions), verbose=3)
|
|
|
|
# assert self.ready_for_next_gumble_phase(self.num_simulations)
|
|
# final selection
|
|
# tree.batch_sequential_halving(roots, gumble_noises, value_min_max_lst, self.current_phase, self.current_num_top_actions)
|
|
|
|
# obtain the final results and infos
|
|
search_root_values = np.asarray(roots.get_values())
|
|
search_root_policies = np.asarray(roots.get_root_policies(value_min_max_lst))
|
|
search_best_actions = np.asarray(roots.get_best_actions())
|
|
|
|
if self.verbose:
|
|
self.log('Final Tree:', verbose=1)
|
|
roots.print_tree()
|
|
self.log('search root value -> \t\t {} \n'
|
|
'search root policy -> \t\t {} \n'
|
|
'search best action -> \t\t {}'
|
|
''.format(search_root_values[0], search_root_policies[0], search_best_actions[0]),
|
|
verbose=1, iteration_end=True)
|
|
|
|
|
|
return search_root_values, search_root_policies, search_best_actions, mcts_info
|
|
|
|
def ready_for_next_gumble_phase(self, simulation_idx):
|
|
ready = (simulation_idx + 1) >= self.visit_num_for_next_phase
|
|
if ready:
|
|
# change the current top action num from m -> m / 2
|
|
self.current_phase += 1
|
|
self.current_num_top_actions = self.current_num_top_actions // 2
|
|
assert self.current_num_top_actions == self.num_top_actions // (2 ** self.current_phase)
|
|
|
|
# update the total visit num for the next phase
|
|
n = self.num_simulations
|
|
m = self.num_top_actions
|
|
current_m = self.current_num_top_actions
|
|
# visit n / log2(m) * current_m at current phase
|
|
if current_m > 2:
|
|
extra_visit = max(np.floor(n / (np.log2(m) * current_m)), 1) * current_m
|
|
else:
|
|
extra_visit = n - self.used_visit_num
|
|
self.used_visit_num += extra_visit
|
|
self.visit_num_for_next_phase += extra_visit
|
|
self.visit_num_for_next_phase = min(self.visit_num_for_next_phase, self.num_simulations)
|
|
|
|
self.log('Be ready for the next gumble phase at iteration {}: \n'
|
|
'current top action num is {}, visit {} times for next phase'
|
|
''.format(simulation_idx, current_m, self.visit_num_for_next_phase), verbose=3)
|
|
return ready
|
|
|
|
|
|
"""
|
|
legacy code of Gumbel search
|
|
"""
|
|
class Gumbel_MCTS(object):
|
|
def __init__(self, config):
|
|
self.config = config
|
|
self.value_prefix = self.config.model.value_prefix
|
|
self.num_simulations = self.config.mcts.num_simulations
|
|
self.num_top_actions = self.config.mcts.num_top_actions
|
|
self.c_visit = self.config.mcts.c_visit
|
|
self.c_scale = self.config.mcts.c_scale
|
|
self.discount = self.config.rl.discount
|
|
self.value_minmax_delta = self.config.mcts.value_minmax_delta
|
|
self.lstm_hidden_size = self.config.model.lstm_hidden_size
|
|
self.action_space_size = self.config.env.action_space_size
|
|
try:
|
|
self.policy_distribution = self.config.model.policy_distribution
|
|
except:
|
|
pass
|
|
|
|
def update_statistics(self, **kwargs):
|
|
if kwargs.get('prediction'):
|
|
# prediction for next states, rewards, values, logits
|
|
model = kwargs.get('model')
|
|
current_states = kwargs.get('states')
|
|
last_actions = kwargs.get('actions')
|
|
reward_hidden = kwargs.get('reward_hidden')
|
|
|
|
with torch.no_grad():
|
|
with autocast():
|
|
next_states, next_value_prefixes, next_values, next_logits, reward_hidden = \
|
|
model.recurrent_inference(current_states, last_actions, reward_hidden)
|
|
|
|
# process outputs
|
|
next_values = next_values.detach().cpu().numpy().flatten()
|
|
next_value_prefixes = next_value_prefixes.detach().cpu().numpy().flatten()
|
|
# if masks is not None:
|
|
# next_states = next_states[:, -1]
|
|
return next_states, next_value_prefixes, next_values, next_logits, reward_hidden
|
|
else:
|
|
# env simulation for next states
|
|
env = kwargs.get('env')
|
|
current_states = kwargs.get('states')
|
|
last_actions = kwargs.get('actions')
|
|
states = env.step(last_actions)
|
|
raise NotImplementedError()
|
|
|
|
def sample_actions(self, policy, add_noise=True, temperature=1.0, input_noises=None, input_dist=None, input_actions=None):
|
|
batch_size = policy.shape[0]
|
|
n_policy = self.config.model.policy_action_num
|
|
n_random = self.config.model.random_action_num
|
|
std_magnification = self.config.mcts.std_magnification
|
|
action_dim = policy.shape[-1] // 2
|
|
|
|
if input_dist is not None:
|
|
n_policy //= 2
|
|
n_random //= 2
|
|
|
|
Dist = SquashedNormal
|
|
mean, std = policy[:, :action_dim], policy[:, action_dim:]
|
|
distr = Dist(mean, std)
|
|
sampled_actions = distr.sample(torch.Size([n_policy + n_random]))
|
|
sampled_actions = sampled_actions.permute(1, 0, 2)
|
|
|
|
policy_actions = sampled_actions[:, :n_policy]
|
|
random_actions = sampled_actions[:, -n_random:]
|
|
|
|
if add_noise:
|
|
if input_noises is None:
|
|
# random_distr = Dist(mean, self.std_magnification * std * temperature) # more flatten gaussian policy
|
|
random_distr = Dist(mean, std_magnification * std) # more flatten gaussian policy
|
|
random_actions = random_distr.sample(torch.Size([n_random]))
|
|
random_actions = random_actions.permute(1, 0, 2)
|
|
|
|
# random_actions = torch.rand(batch_size, n_random, action_dim).float().cuda()
|
|
# random_actions = 2 * random_actions - 1
|
|
|
|
# Gaussian noise
|
|
# random_actions += torch.randn_like(random_actions)
|
|
else:
|
|
noises = torch.from_numpy(input_noises).float().cuda()
|
|
random_actions += noises
|
|
|
|
if input_dist is not None:
|
|
refined_mean, refined_std = input_dist[:, :action_dim], input_dist[:, action_dim:]
|
|
refined_distr = Dist(refined_mean, refined_std)
|
|
refined_actions = refined_distr.sample(torch.Size([n_policy + n_random]))
|
|
refined_actions = refined_actions.permute(1, 0, 2)
|
|
|
|
refined_policy_actions = refined_actions[:, :n_policy]
|
|
refined_random_actions = refined_actions[:, -n_random:]
|
|
|
|
if add_noise:
|
|
if input_noises is None:
|
|
refined_random_distr = Dist(refined_mean, std_magnification * refined_std)
|
|
refined_random_actions = refined_random_distr.sample(torch.Size([n_random]))
|
|
refined_random_actions = refined_random_actions.permute(1, 0, 2)
|
|
else:
|
|
noises = torch.from_numpy(input_noises).float().cuda()
|
|
refined_random_actions += noises
|
|
|
|
all_actions = torch.cat((policy_actions, random_actions), dim=1)
|
|
if input_actions is not None:
|
|
all_actions = torch.from_numpy(input_actions).float().cuda()
|
|
if input_dist is not None:
|
|
all_actions = torch.cat((all_actions, refined_policy_actions, refined_random_actions), dim=1)
|
|
# all_actions[:, 0, :] = mean # add mean as one of candidate
|
|
all_actions = all_actions.clip(-0.999, 0.999)
|
|
|
|
return all_actions
|
|
|
|
@torch.no_grad()
|
|
def run_multi_discrete(
|
|
self, model, batch_size,
|
|
hidden_state_roots, root_values,
|
|
root_policy_logits, temperature=1.0,
|
|
use_gumbel_noise=True
|
|
):
|
|
|
|
model.eval()
|
|
|
|
reward_sum_pool = [0. for _ in range(batch_size)]
|
|
|
|
roots = tree2.Roots(
|
|
batch_size, self.action_space_size,
|
|
self.num_simulations
|
|
)
|
|
root_policy_logits = root_policy_logits.detach().cpu().numpy()
|
|
|
|
roots.prepare(
|
|
reward_sum_pool, root_policy_logits.tolist(),
|
|
self.num_top_actions, self.num_simulations,
|
|
root_values.tolist()
|
|
)
|
|
|
|
reward_hidden_roots = (
|
|
torch.from_numpy(np.zeros((1, batch_size, self.lstm_hidden_size))).float().cuda(),
|
|
torch.from_numpy(np.zeros((1, batch_size, self.lstm_hidden_size))).float().cuda()
|
|
)
|
|
|
|
gumbels = np.random.gumbel(
|
|
0, 1, (batch_size, self.action_space_size)
|
|
)# * temperature
|
|
if not use_gumbel_noise:
|
|
gumbels = np.zeros_like(gumbels)
|
|
gumbels = gumbels.tolist()
|
|
|
|
num = roots.num
|
|
c_visit, c_scale, discount = self.c_visit, self.c_scale, self.discount
|
|
hidden_state_pool = [hidden_state_roots]
|
|
# 1 x batch x 64
|
|
reward_hidden_c_pool = [reward_hidden_roots[0]]
|
|
reward_hidden_h_pool = [reward_hidden_roots[1]]
|
|
hidden_state_index_x = 0
|
|
min_max_stats_lst = tree2.MinMaxStatsList(num)
|
|
min_max_stats_lst.set_delta(self.value_minmax_delta)
|
|
horizons = self.config.model.lstm_horizon_len
|
|
|
|
for index_simulation in range(self.num_simulations):
|
|
hidden_states = []
|
|
hidden_states_c_reward = []
|
|
hidden_states_h_reward = []
|
|
|
|
results = tree2.ResultsWrapper(num)
|
|
hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, _, _, _ = \
|
|
tree2.multi_traverse(
|
|
roots, c_visit, c_scale, discount,
|
|
min_max_stats_lst, results,
|
|
index_simulation, gumbels,
|
|
# int(self.config.model.dynamic_type == 'Transformer')
|
|
int(False)
|
|
)
|
|
search_lens = results.get_search_len()
|
|
|
|
for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
|
|
hidden_states.append(hidden_state_pool[ix][iy].unsqueeze(0))
|
|
if self.value_prefix:
|
|
hidden_states_c_reward.append(reward_hidden_c_pool[ix][0][iy].unsqueeze(0))
|
|
hidden_states_h_reward.append(reward_hidden_h_pool[ix][0][iy].unsqueeze(0))
|
|
|
|
hidden_states = torch.cat(hidden_states, dim=0)
|
|
if self.value_prefix:
|
|
hidden_states_c_reward = torch.cat(hidden_states_c_reward).unsqueeze(0)
|
|
hidden_states_h_reward = torch.cat(hidden_states_h_reward).unsqueeze(0)
|
|
|
|
last_actions = torch.from_numpy(
|
|
np.asarray(last_actions)
|
|
).to('cuda').unsqueeze(1).long()
|
|
|
|
hidden_state_nodes, reward_sum_pool, value_pool, policy_logits_pool, reward_hidden_nodes = \
|
|
self.update_statistics(
|
|
prediction=True, # use model prediction instead of env simulation
|
|
model=model, # model
|
|
states=hidden_states, # current states
|
|
actions=last_actions, # last actions
|
|
reward_hidden=(hidden_states_c_reward, hidden_states_h_reward), # reward hidden
|
|
)
|
|
|
|
reward_sum_pool = reward_sum_pool.tolist()
|
|
value_pool = value_pool.tolist()
|
|
policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist()
|
|
|
|
hidden_state_pool.append(hidden_state_nodes)
|
|
# reset 0
|
|
if self.value_prefix:
|
|
if horizons > 0:
|
|
reset_idx = (np.array(search_lens) % horizons == 0)
|
|
assert len(reset_idx) == num
|
|
reward_hidden_nodes[0][:, reset_idx, :] = 0
|
|
reward_hidden_nodes[1][:, reset_idx, :] = 0
|
|
is_reset_lst = reset_idx.astype(np.int32).tolist()
|
|
else:
|
|
is_reset_lst = [0 for _ in range(num)]
|
|
else:
|
|
is_reset_lst = [1 for _ in range(num)]
|
|
|
|
if self.value_prefix:
|
|
reward_hidden_c_pool.append(reward_hidden_nodes[0])
|
|
reward_hidden_h_pool.append(reward_hidden_nodes[1])
|
|
hidden_state_index_x += 1
|
|
|
|
tree2.multi_back_propagate(
|
|
hidden_state_index_x, discount,
|
|
reward_sum_pool, value_pool, policy_logits_pool,
|
|
min_max_stats_lst, results, is_reset_lst,
|
|
index_simulation, gumbels, c_visit, c_scale, self.num_simulations
|
|
)
|
|
|
|
root_values = np.asarray(roots.get_values())
|
|
pi_primes = np.asarray(roots.get_pi_primes(
|
|
min_max_stats_lst, c_visit, c_scale, discount
|
|
))
|
|
best_actions = np.asarray(roots.get_actions(
|
|
min_max_stats_lst, c_visit, c_scale, gumbels, discount
|
|
))
|
|
root_sampled_actions = np.expand_dims(
|
|
np.arange(self.action_space_size), axis=0
|
|
).repeat(batch_size, axis=0)
|
|
|
|
advantages = np.asarray(roots.get_advantages(discount))
|
|
|
|
worst_actions = np.asarray(pi_primes).argmin(-1)
|
|
|
|
# import ipdb
|
|
# ipdb.set_trace()
|
|
# if best_actions[0] != np.asarray(pi_primes)[0].argmax():
|
|
# import ipdb
|
|
# ipdb.set_trace()
|
|
# print(f'best_actions={best_actions[0]}, largest_i={np.asarray(pi_primes)[0].argmax()}, pi={pi_primes[0]}')
|
|
|
|
return root_values, pi_primes, best_actions, \
|
|
min_max_stats_lst.get_min_max(), root_sampled_actions
|
|
|
|
|
|
def run_multi_continuous(
|
|
self, model, batch_size,
|
|
hidden_state_roots, root_values,
|
|
root_policy_logits, is_reanalyze=False, cnt=-1, temperature=1.0, add_noise=True, use_gumbel_noise=False,
|
|
input_noises=None, input_actions=None
|
|
):
|
|
with torch.no_grad():
|
|
model.eval()
|
|
|
|
reward_sum_pool = [0. for _ in range(batch_size)]
|
|
action_pool = []
|
|
reward_hidden_roots = (
|
|
torch.from_numpy(np.zeros((1, batch_size, self.lstm_hidden_size))).float().cuda(),
|
|
torch.from_numpy(np.zeros((1, batch_size, self.lstm_hidden_size))).float().cuda()
|
|
)
|
|
root_sampled_actions = self.sample_actions(root_policy_logits, add_noise, temperature, input_noises, input_actions=input_actions)
|
|
sampled_action_num = root_sampled_actions.shape[1]
|
|
|
|
roots = tree2.Roots(
|
|
batch_size, sampled_action_num, self.num_simulations
|
|
)
|
|
action_pool.append(root_sampled_actions)
|
|
|
|
uniform_policy = [
|
|
# [1 / sampled_action_num for _ in range(sampled_action_num)]
|
|
[0.0 for _ in range(sampled_action_num)]
|
|
for _ in range(batch_size)
|
|
]
|
|
q_inits = uniform_policy
|
|
assert self.num_top_actions == self.config.model.policy_action_num + self.config.model.random_action_num
|
|
# assert np.array(uniform_policy).shape == np.array(eval_policy).shape
|
|
|
|
# roots.prepare_q_init(
|
|
# reward_sum_pool,
|
|
# uniform_policy,
|
|
# # eval_policy,
|
|
# self.num_top_actions,
|
|
# self.num_simulations,
|
|
# root_values.tolist(),
|
|
# # q_inits.tolist()
|
|
# q_inits
|
|
# )
|
|
roots.prepare(
|
|
reward_sum_pool,
|
|
uniform_policy,
|
|
self.num_top_actions,
|
|
self.num_simulations,
|
|
root_values.tolist()
|
|
)
|
|
|
|
gumbels = np.random.gumbel(
|
|
0, 1, (batch_size, sampled_action_num)
|
|
) * temperature
|
|
if not use_gumbel_noise:
|
|
gumbels = np.zeros_like(gumbels)
|
|
gumbels = gumbels.tolist()
|
|
|
|
num = roots.num
|
|
c_visit, c_scale, discount = self.c_visit, self.c_scale, self.discount
|
|
hidden_state_pool = [hidden_state_roots]
|
|
# 1 x batch x 64
|
|
reward_hidden_c_pool = [reward_hidden_roots[0]]
|
|
reward_hidden_h_pool = [reward_hidden_roots[1]]
|
|
hidden_state_index_x = 0
|
|
min_max_stats_lst = tree2.MinMaxStatsList(num)
|
|
min_max_stats_lst.set_delta(self.value_minmax_delta)
|
|
horizons = self.config.model.lstm_horizon_len
|
|
|
|
actions_pool = [root_sampled_actions]
|
|
for index_simulation in range(self.num_simulations):
|
|
hidden_states = []
|
|
states_hidden_c = []
|
|
states_hidden_h = []
|
|
hidden_states_c_reward = []
|
|
hidden_states_h_reward = []
|
|
|
|
results = tree2.ResultsWrapper(num)
|
|
hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions, _, _, _ = \
|
|
tree2.multi_traverse(roots, c_visit, c_scale, discount, min_max_stats_lst,
|
|
results, index_simulation, gumbels, int(self.config.model.dynamic_type == 'Transformer'))
|
|
search_lens = results.get_search_len()
|
|
|
|
ptr = 0
|
|
selected_actions = []
|
|
for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
|
|
hidden_states.append(hidden_state_pool[ix][iy].unsqueeze(0))
|
|
if self.value_prefix:
|
|
hidden_states_c_reward.append(reward_hidden_c_pool[ix][0][iy].unsqueeze(0))
|
|
hidden_states_h_reward.append(reward_hidden_h_pool[ix][0][iy].unsqueeze(0))
|
|
selected_actions.append(
|
|
actions_pool[ix][iy][last_actions[ptr]].unsqueeze(0)
|
|
)
|
|
ptr += 1
|
|
|
|
hidden_states = torch.cat(hidden_states, dim=0).float()
|
|
if self.value_prefix:
|
|
hidden_states_c_reward = torch.cat(hidden_states_c_reward, dim=0).unsqueeze(0)
|
|
hidden_states_h_reward = torch.cat(hidden_states_h_reward, dim=0).unsqueeze(0)
|
|
|
|
selected_actions = torch.cat(selected_actions, dim=0).float()
|
|
hidden_state_nodes, reward_sum_pool, value_pool, policy_logits_pool, reward_hidden_nodes = self.update_statistics(
|
|
prediction=True, # use model prediction instead of env simulation
|
|
model=model, # model
|
|
states=hidden_states, # current states
|
|
actions=selected_actions, # last actions
|
|
reward_hidden=(hidden_states_c_reward, hidden_states_h_reward), # reward hidden
|
|
)
|
|
|
|
leaf_sampled_actions = self.sample_actions(policy_logits_pool, False, input_actions=input_actions)
|
|
actions_pool.append(leaf_sampled_actions)
|
|
reward_sum_pool = reward_sum_pool.tolist()
|
|
value_pool = value_pool.tolist()
|
|
|
|
hidden_state_pool.append(hidden_state_nodes)
|
|
# reset 0
|
|
if self.value_prefix:
|
|
if horizons > 0:
|
|
reset_idx = (np.array(search_lens) % horizons == 0)
|
|
assert len(reset_idx) == num
|
|
reward_hidden_nodes[0][:, reset_idx, :] = 0
|
|
reward_hidden_nodes[1][:, reset_idx, :] = 0
|
|
is_reset_lst = reset_idx.astype(np.int32).tolist()
|
|
else:
|
|
is_reset_lst = [0 for _ in range(num)]
|
|
else:
|
|
is_reset_lst = [1 for _ in range(num)] # TODO: this is a huge bug, previous 0.
|
|
|
|
if self.value_prefix:
|
|
reward_hidden_c_pool.append(reward_hidden_nodes[0])
|
|
reward_hidden_h_pool.append(reward_hidden_nodes[1])
|
|
hidden_state_index_x += 1
|
|
|
|
tree2.multi_back_propagate(
|
|
hidden_state_index_x, discount,
|
|
reward_sum_pool, value_pool,
|
|
uniform_policy,
|
|
min_max_stats_lst, results, is_reset_lst,
|
|
index_simulation, gumbels, c_visit, c_scale, self.num_simulations
|
|
)
|
|
|
|
best_actions = roots.get_actions(min_max_stats_lst, c_visit, c_scale, gumbels, discount)
|
|
root_sampled_actions = root_sampled_actions.detach().cpu().numpy()
|
|
final_selected_actions = np.asarray(
|
|
[root_sampled_actions[i, best_a] for i, best_a in enumerate(best_actions)]
|
|
)
|
|
advantages = np.asarray(roots.get_advantages(discount))
|
|
|
|
# pi_prime = roots.get_pi_primes(min_max_stats_lst, c_visit, c_scale, discount)
|
|
# if best_actions[0] != np.asarray(pi_prime)[0].argmax():
|
|
# import ipdb
|
|
# ipdb.set_trace()
|
|
# print(f'best_actions={best_actions[0]}, largest_i={np.asarray(pi_prime)[0].argmax()}, pi={pi_prime[0]}')
|
|
|
|
return np.asarray(roots.get_values()), \
|
|
np.asarray(roots.get_pi_primes(min_max_stats_lst, c_visit, c_scale, discount)), \
|
|
np.asarray(final_selected_actions), min_max_stats_lst.get_min_max(), \
|
|
np.asarray(root_sampled_actions), np.asarray(best_actions) |