EfficientZeroV2/ez/worker/batch_worker.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

1023 lines
48 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import os
import time
import ray
import torch
import copy
import gym
import imageio
from PIL import Image, ImageDraw
import numpy as np
from torch.cuda.amp import autocast as autocast
import torch.nn.functional as F
from .base import Worker
from ez import mcts
from ez.envs import make_envs
from ez.utils.distribution import SquashedNormal, TruncatedNormal, ContDist
from ez.utils.format import formalize_obs_lst, DiscreteSupport, LinearSchedule, prepare_obs_lst, allocate_gpu, profile, symexp
from ez.data.trajectory import GameTrajectory
from ez.mcts.cy_mcts import Gumbel_MCTS
@ray.remote(num_gpus=0.03)
# @ray.remote(num_gpus=0.14)
class BatchWorker(Worker):
def __init__(self, rank, agent, replay_buffer, storage, batch_storage, config):
super().__init__(rank, agent, replay_buffer, storage, config)
self.model_update_interval = config.train.reanalyze_update_interval
self.batch_storage = batch_storage
self.beta_schedule = LinearSchedule(self.total_steps, initial_p=config.priority.priority_prob_beta, final_p=1.0)
self.total_transitions = self.config.data.total_transitions
self.auto_td_steps = self.config.rl.auto_td_steps
self.td_steps = self.config.rl.td_steps
self.unroll_steps = self.config.rl.unroll_steps
self.n_stack = self.config.env.n_stack
self.discount = self.config.rl.discount
self.value_support = self.config.model.value_support
self.action_space_size = self.config.env.action_space_size
self.batch_size = self.config.train.batch_size
self.PER_alpha = self.config.priority.priority_prob_alpha
self.env = self.config.env.env
self.image_based = self.config.env.image_based
self.reanalyze_ratio = self.config.train.reanalyze_ratio
self.value_target = self.config.train.value_target
self.value_target_type = self.config.model.value_target
self.GAE_max_steps = self.config.model.GAE_max_steps
self.episodic = self.config.env.episodic
self.value_prefix = self.config.model.value_prefix
self.lstm_horizon_len = self.config.model.lstm_horizon_len
self.training_steps = self.config.train.training_steps
self.td_lambda = self.config.rl.td_lambda
self.gray_scale = self.config.env.gray_scale
self.obs_shape = self.config.env.obs_shape
self.trajectory_size = self.config.data.trajectory_size
self.mixed_value_threshold = self.config.train.mixed_value_threshold
self.lstm_hidden_size = self.config.model.lstm_hidden_size
self.cnt = 0
def concat_trajs(self, items):
obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, \
bootstrapped_value_lsts = items
traj_lst = []
for obs_lst, reward_lst, policy_lst, action_lst, pred_value_lst, search_value_lst, bootstrapped_value_lst in \
zip(obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, bootstrapped_value_lsts):
# traj = GameTrajectory(**self.config.env, **self.config.rl, **self.config.model)
traj = GameTrajectory(
n_stack=self.n_stack, discount=self.discount, gray_scale=self.gray_scale, unroll_steps=self.unroll_steps,
td_steps=self.td_steps, td_lambda=self.td_lambda, obs_shape=self.obs_shape, max_size=self.trajectory_size,
image_based=self.image_based, episodic=self.episodic, GAE_max_steps=self.GAE_max_steps
)
traj.obs_lst = obs_lst
traj.reward_lst = reward_lst
traj.policy_lst = policy_lst
traj.action_lst = action_lst
traj.pred_value_lst = pred_value_lst
traj.search_value_lst = search_value_lst
traj.bootstrapped_value_lst = bootstrapped_value_lst
traj_lst.append(traj)
return traj_lst
def run(self):
trained_steps = 0
# create the model for self-play data collection
self.model = self.agent.build_model()
self.latest_model = self.agent.build_model()
if self.config.eval.analysis_value:
weights = torch.load(self.config.eval.model_path)
self.model.load_state_dict(weights)
print('analysis begin')
self.model.cuda()
self.latest_model.cuda()
if int(torch.__version__[0]) == 2:
self.model = torch.compile(self.model)
self.latest_model = torch.compile(self.latest_model)
self.model.eval()
self.latest_model.eval()
self.resume_model()
# wait for starting to train
while not ray.get(self.storage.get_start_signal.remote()):
time.sleep(0.5)
# begin to make batch
prev_trained_steps = -10
while not self.is_finished(trained_steps):
trained_steps = ray.get(self.storage.get_counter.remote())
if self.config.ray.single_process:
if trained_steps <= prev_trained_steps:
time.sleep(0.1)
continue
prev_trained_steps = trained_steps
print(f'reanalyze[{self.rank}] makes batch at step {trained_steps}')
# get the fresh model weights
self.get_recent_model(trained_steps, 'reanalyze')
self.get_latest_model(trained_steps, 'latest')
# if self.config.model.noisy_net:
# self.model.value_policy_model.reset_noise()
start_time = time.time()
ray_time = self.make_batch(trained_steps, self.cnt)
self.cnt += 1
end_time = time.time()
# if self.cnt % 100 == 0:
# print(f'make batch time={end_time-start_time:.3f}s, ray get time={ray_time:.3f}s')
# @torch.no_grad()
# @profile
def make_batch(self, trained_steps, cnt, real_time=False):
beta = self.beta_schedule.value(trained_steps)
batch_size = self.batch_size
# obtain the batch context from replay buffer
x = time.time()
batch_context = ray.get(
self.replay_buffer.prepare_batch_context.remote(batch_size=batch_size,
alpha=self.PER_alpha,
beta=beta,
rank=self.rank,
cnt=cnt)
)
batch_context, validation_flag = batch_context
ray_time = time.time() - x
traj_lst, transition_pos_lst, indices_lst, weights_lst, make_time_lst, transition_num, prior_lst = batch_context
traj_lst = self.concat_trajs(traj_lst)
# part of policy will be reanalyzed
reanalyze_batch_size = batch_size if self.env in ['DMC', 'Gym'] \
else int(batch_size * self.config.train.reanalyze_ratio)
assert 0 <= reanalyze_batch_size <= batch_size
# ==============================================================================================================
# make inputs
# ==============================================================================================================
collected_transitions = ray.get(self.replay_buffer.get_transition_num.remote())
# make observations, actions and masks (if unrolled steps are out of trajectory)
obs_lst, action_lst, mask_lst = [], [], []
top_new_masks = []
# prepare the inputs of a batch
for i in range(batch_size):
traj = traj_lst[i]
state_index = transition_pos_lst[i]
sample_idx = indices_lst[i]
top_new_masks.append(int(sample_idx > collected_transitions - self.mixed_value_threshold))
if self.env in ['DMC', 'Gym']:
_actions = traj.action_lst[state_index:state_index + self.unroll_steps]
_unroll_actions = traj.action_lst[state_index + 1:state_index + 1 + self.unroll_steps]
# _unroll_actions = traj.action_lst[state_index:state_index + self.unroll_steps]
_mask = [1. for _ in range(_unroll_actions.shape[0])]
_mask += [0. for _ in range(self.unroll_steps - len(_mask))]
_rand_actions = np.zeros((self.unroll_steps - _actions.shape[0], self.action_space_size))
_actions = np.concatenate((_actions, _rand_actions), axis=0)
else:
_actions = traj.action_lst[state_index:state_index + self.unroll_steps].tolist()
_mask = [1. for _ in range(len(_actions))]
_mask += [0. for _ in range(self.unroll_steps - len(_mask))]
_actions += [np.random.randint(0, self.action_space_size) for _ in range(self.unroll_steps - len(_actions))]
# obtain the input observations
obs_lst.append(traj.get_index_stacked_obs(state_index, padding=True))
action_lst.append(_actions)
mask_lst.append(_mask)
obs_lst = prepare_obs_lst(obs_lst, self.image_based)
inputs_batch = [obs_lst, action_lst, mask_lst, indices_lst, weights_lst, make_time_lst, prior_lst]
for i in range(len(inputs_batch)):
inputs_batch[i] = np.asarray(inputs_batch[i])
# ==============================================================================================================
# make targets
# ==============================================================================================================
if self.value_target in ['sarsa', 'mixed', 'max']:
if self.value_target_type == 'GAE':
prepare_func = self.prepare_reward_value_gae
elif self.value_target_type == 'bootstrapped':
prepare_func = self.prepare_reward_value
else:
raise NotImplementedError
elif self.value_target == 'search':
prepare_func = self.prepare_reward
else:
raise NotImplementedError
# obtain the value prefix (reward), and the value
batch_value_prefixes, batch_values, td_steps, pre_calc, value_masks = \
prepare_func(traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps)
# obtain the re policy
if reanalyze_batch_size > 0:
batch_policies_re, sampled_actions, best_actions, reanalyzed_values, pre_lst, policy_masks = \
self.prepare_policy_reanalyze(
trained_steps, traj_lst[:reanalyze_batch_size], transition_pos_lst[:reanalyze_batch_size],
indices_lst[:reanalyze_batch_size],
state_lst=pre_calc[0], value_lst=pre_calc[1], policy_lst=pre_calc[2], policy_mask=pre_calc[3]
)
else:
batch_policies_re = []
# obtain the non-re policy
if batch_size - reanalyze_batch_size > 0:
batch_policies_non_re = self.prepare_policy_non_reanalyze(traj_lst[reanalyze_batch_size:],
transition_pos_lst[reanalyze_batch_size:])
else:
batch_policies_non_re = []
# concat target policy
batch_policies = batch_policies_re
if self.env in ['DMC', 'Gym']:
batch_best_actions = best_actions.reshape(batch_size, self.unroll_steps + 1,
self.action_space_size)
else:
batch_best_actions = np.asarray(best_actions).reshape(batch_size,
self.unroll_steps + 1)
# target value prefix (reward), value, policy
if self.env not in ['DMC', 'Gym']:
batch_actions = np.ones_like(batch_policies)
else:
batch_actions = sampled_actions.reshape(
batch_size, self.unroll_steps + 1, -1, self.action_space_size
)
targets_batch = [batch_value_prefixes, batch_values, batch_actions, batch_policies, batch_best_actions, top_new_masks, policy_masks, reanalyzed_values]
for i in range(len(targets_batch)):
targets_batch[i] = np.asarray(targets_batch[i])
# ==============================================================================================================
# push batch into batch queue
# ==============================================================================================================
# full batch data: [obs_lst, other stuffs, target stuffs]
# batch = [inputs_batch[0], inputs_batch[1:], targets_batch]
batch = [inputs_batch, targets_batch]
# log
self.storage.add_log_scalar.remote({
'batch_worker/td_step': np.mean(td_steps)
})
if real_time:
return batch
else:
# push into batch storage
self.batch_storage.push(batch)
return ray_time
# @profile
def prepare_reward_value_gae_faster(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps):
# value prefix (or reward), value
batch_value_prefixes, batch_values = [], []
extra = max(0, min(int(1 / (1 - self.td_lambda)), self.GAE_max_steps) - self.unroll_steps - 1)
# init
value_obs_lst, td_steps_lst, value_mask, policy_mask = [], [], [], [] # mask: 0 -> out of traj
# policy_obs_lst, policy_mask = [], []
zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False)
# get obs_{t+k}
td_steps = 1
for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst):
traj_len = len(traj)
# prepare the corresponding observations for bootstrapped values o_{t+k}
traj_obs = traj.get_index_stacked_obs(state_index, extra=extra + td_steps)
for current_index in range(state_index, state_index + self.unroll_steps + 1 + extra + td_steps):
bootstrap_index = current_index
if not self.episodic:
if bootstrap_index <= traj_len:
beg_index = bootstrap_index - state_index
end_index = beg_index + self.n_stack
obs = traj_obs[beg_index:end_index]
value_mask.append(1)
if bootstrap_index < traj_len:
policy_mask.append(1)
else:
policy_mask.append(0)
else:
value_mask.append(0)
policy_mask.append(0)
obs = np.asarray(zero_obs)
else:
if bootstrap_index < traj_len:
beg_index = bootstrap_index - state_index
end_index = beg_index + self.n_stack
obs = traj_obs[beg_index:end_index]
value_mask.append(1)
policy_mask.append(1)
else:
value_mask.append(0)
policy_mask.append(0)
obs = np.asarray(zero_obs)
value_obs_lst.append(obs)
td_steps_lst.append(td_steps)
# reanalyze the bootstrapped value v_{t+k}
state_lst, value_lst, policy_lst = self.efficient_inference(value_obs_lst, only_value=False)
# v_{t+k}
batch_size = len(value_lst)
value_lst = value_lst.reshape(-1) * (np.array([self.discount for _ in range(batch_size)]) ** td_steps_lst)
value_lst = value_lst * np.array(value_mask)
# value_lst = np.zeros_like(value_lst) # for unit test, remove if training
td_value_lst = copy.deepcopy(value_lst)
value_lst = value_lst.tolist()
td_value_lst = td_value_lst.tolist()
re_state_lst, re_value_lst, re_policy_lst, re_policy_mask = [], [], [], []
for i in range(len(state_lst)):
if i % (self.unroll_steps + extra + 1 + td_steps) < self.unroll_steps + 1:
re_state_lst.append(state_lst[i].unsqueeze(0))
re_value_lst.append(value_lst[i])
re_policy_lst.append(policy_lst[i].unsqueeze(0))
re_policy_mask.append(policy_mask[i])
re_state_lst = torch.cat(re_state_lst, dim=0)
re_value_lst = np.asarray(re_value_lst)
re_policy_lst = torch.cat(re_policy_lst, dim=0)
# v_{t} = r + ... + gamma ^ k * v_{t+k}
value_index = 0
td_lambdas = []
for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst):
traj_len = len(traj)
target_values = []
target_value_prefixs = []
delta_lambda = 0.1 * (collected_transitions - idx) / self.auto_td_steps
if self.value_target in ['mixed', 'max']:
delta_lambda = 0.0
td_lambda = self.td_lambda - delta_lambda
td_lambda = np.clip(td_lambda, 0.65, self.td_lambda)
td_lambdas.append(td_lambda)
delta = np.zeros(self.unroll_steps + 1 + extra)
advantage = np.zeros(self.unroll_steps + 1 + extra + 1)
index = self.unroll_steps + extra
for current_index in reversed(range(state_index, state_index + self.unroll_steps + 1 + extra)):
bootstrap_index = current_index + td_steps_lst[value_index + index]
for i, reward in enumerate(traj.reward_lst[current_index:bootstrap_index]):
td_value_lst[value_index + index + td_steps] += reward * self.discount ** i
delta[index] = td_value_lst[value_index + index + td_steps] - value_lst[value_index + index]
advantage[index] = delta[index] + self.discount * td_lambda * advantage[index + 1]
index -= 1
target_values_tmp = advantage[:self.unroll_steps + 1] + np.asarray(value_lst)[value_index:value_index + self.unroll_steps + 1]
horizon_id = 0
value_prefix = 0.0
for i, current_index in enumerate(range(state_index, state_index + self.unroll_steps + 1)):
# reset every lstm_horizon_len
if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix:
value_prefix = 0.0
horizon_id += 1
if current_index < traj_len:
# Since the horizon is small and the discount is close to 1.
# Compute the reward sum to approximate the value prefix for simplification
if self.value_prefix:
value_prefix += traj.reward_lst[current_index]
else:
value_prefix = traj.reward_lst[current_index]
target_value_prefixs.append(value_prefix)
else:
target_value_prefixs.append(value_prefix)
if self.episodic:
if current_index < traj_len:
target_values.append(target_values_tmp[i])
else:
target_values.append(0)
else:
if current_index <= traj_len:
target_values.append(target_values_tmp[i])
else:
target_values.append(0)
value_index += (self.unroll_steps + 1 + extra + td_steps)
batch_value_prefixes.append(target_value_prefixs)
batch_values.append(target_values)
if self.rank == 0 and self.cnt % 20 == 0:
print(f'--------------- lambda={np.asarray(td_lambdas).mean():.3f} -------------------')
self.storage.add_log_scalar.remote({
'batch_worker/td_lambda': np.asarray(td_lambdas).mean()
})
value_index = 0
value_masks, policy_masks = [], []
for i, idx in enumerate(indices_lst):
value_masks.append(int(idx > collected_transitions - self.mixed_value_threshold))
value_index += (self.unroll_steps + 1 + extra)
value_masks = np.asarray(value_masks)
return np.asarray(batch_value_prefixes), np.asarray(batch_values), np.asarray(td_steps_lst).flatten(), \
(re_state_lst, re_value_lst, re_policy_lst, re_policy_mask), value_masks
# @profile
def prepare_reward_value_gae(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps):
# value prefix (or reward), value
batch_value_prefixes, batch_values = [], []
extra = max(0, min(int(1 / (1 - self.td_lambda)), self.GAE_max_steps) - self.unroll_steps - 1)
# init
value_obs_lst, td_steps_lst, value_mask = [], [], [] # mask: 0 -> out of traj
policy_obs_lst, policy_mask = [], []
zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False)
# get obs_{t+k}
for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst):
traj_len = len(traj)
td_steps = 1
# prepare the corresponding observations for bootstrapped values o_{t+k}
traj_obs = traj.get_index_stacked_obs(state_index + td_steps, extra=extra)
game_obs = traj.get_index_stacked_obs(state_index, extra=extra)
for current_index in range(state_index, state_index + self.unroll_steps + 1 + extra):
bootstrap_index = current_index + td_steps
if not self.episodic:
if bootstrap_index <= traj_len:
value_mask.append(1)
beg_index = bootstrap_index - (state_index + td_steps)
end_index = beg_index + self.n_stack
obs = traj_obs[beg_index:end_index]
else:
value_mask.append(0)
obs = np.asarray(zero_obs)
else:
if bootstrap_index < traj_len:
value_mask.append(1)
beg_index = bootstrap_index - (state_index + td_steps)
end_index = beg_index + self.n_stack
obs = traj_obs[beg_index:end_index]
else:
value_mask.append(0)
obs = np.asarray(zero_obs)
value_obs_lst.append(obs)
td_steps_lst.append(td_steps)
if current_index < traj_len:
policy_mask.append(1)
beg_index = current_index - state_index
end_index = beg_index + self.n_stack
obs = game_obs[beg_index:end_index]
else:
policy_mask.append(0)
obs = np.asarray(zero_obs)
policy_obs_lst.append(obs)
# reanalyze the bootstrapped value v_{t+k}
_, value_lst, _ = self.efficient_inference(value_obs_lst, only_value=True)
state_lst, ori_cur_value_lst, policy_lst = self.efficient_inference(policy_obs_lst, only_value=False)
# v_{t+k}
batch_size = len(value_lst)
value_lst = value_lst.reshape(-1) * (np.array([self.discount for _ in range(batch_size)]) ** td_steps_lst)
value_lst = value_lst * np.array(value_mask)
# value_lst = np.zeros_like(value_lst) # for unit test, remove if training
value_lst = value_lst.tolist()
cur_value_lst = ori_cur_value_lst.reshape(-1) * np.array(policy_mask)
# cur_value_lst = np.zeros_like(cur_value_lst) # for unit test, remove if training
cur_value_lst = cur_value_lst.tolist()
state_lst_cut, ori_cur_value_lst_cut, policy_lst_cut, policy_mask_cut = [], [], [], []
for i in range(len(state_lst)):
if i % (self.unroll_steps + extra + 1) < self.unroll_steps + 1:
state_lst_cut.append(state_lst[i].unsqueeze(0))
ori_cur_value_lst_cut.append(ori_cur_value_lst[i])
policy_lst_cut.append(policy_lst[i].unsqueeze(0))
policy_mask_cut.append(policy_mask[i])
state_lst_cut = torch.cat(state_lst_cut, dim=0)
ori_cur_value_lst_cut = np.asarray(ori_cur_value_lst_cut)
policy_lst_cut = torch.cat(policy_lst_cut, dim=0)
# v_{t} = r + ... + gamma ^ k * v_{t+k}
value_index = 0
td_lambdas = []
for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst):
traj_len = len(traj)
target_values = []
target_value_prefixs = []
delta_lambda = 0.1 * (collected_transitions - idx) / self.auto_td_steps
if self.value_target in ['mixed', 'max']:
delta_lambda = 0.0
td_lambda = self.td_lambda - delta_lambda
td_lambda = np.clip(td_lambda, 0.65, self.td_lambda)
td_lambdas.append(td_lambda)
delta = np.zeros(self.unroll_steps + 1 + extra)
advantage = np.zeros(self.unroll_steps + 1 + extra + 1)
index = self.unroll_steps + extra
for current_index in reversed(range(state_index, state_index + self.unroll_steps + 1 + extra)):
bootstrap_index = current_index + td_steps_lst[value_index + index]
for i, reward in enumerate(traj.reward_lst[current_index:bootstrap_index]):
value_lst[value_index + index] += reward * self.discount ** i
delta[index] = value_lst[value_index + index] - cur_value_lst[value_index + index]
advantage[index] = delta[index] + self.discount * td_lambda * advantage[index + 1]
index -= 1
target_values_tmp = advantage[:self.unroll_steps + 1] + np.asarray(cur_value_lst)[value_index:value_index + self.unroll_steps + 1]
horizon_id = 0
value_prefix = 0.0
for i, current_index in enumerate(range(state_index, state_index + self.unroll_steps + 1)):
# reset every lstm_horizon_len
if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix:
value_prefix = 0.0
horizon_id += 1
if current_index < traj_len:
# Since the horizon is small and the discount is close to 1.
# Compute the reward sum to approximate the value prefix for simplification
if self.value_prefix:
value_prefix += traj.reward_lst[current_index]
else:
value_prefix = traj.reward_lst[current_index]
target_value_prefixs.append(value_prefix)
else:
target_value_prefixs.append(value_prefix)
if self.episodic:
if current_index < traj_len:
target_values.append(target_values_tmp[i])
else:
target_values.append(0)
else:
if current_index <= traj_len:
target_values.append(target_values_tmp[i])
else:
target_values.append(0)
value_index += (self.unroll_steps + 1 + extra)
batch_value_prefixes.append(target_value_prefixs)
batch_values.append(target_values)
if self.rank == 0 and self.cnt % 20 == 0:
print(f'--------------- lambda={np.asarray(td_lambdas).mean():.3f} -------------------')
self.storage.add_log_scalar.remote({
'batch_worker/td_lambda': np.asarray(td_lambdas).mean()
})
value_index = 0
value_masks, policy_masks = [], []
for i, idx in enumerate(indices_lst):
value_masks.append(int(idx > collected_transitions - self.mixed_value_threshold))
value_index += (self.unroll_steps + 1 + extra)
value_masks = np.asarray(value_masks)
return np.asarray(batch_value_prefixes), np.asarray(batch_values), np.asarray(td_steps_lst).flatten(), \
(state_lst_cut, ori_cur_value_lst_cut, policy_lst_cut, policy_mask_cut), value_masks
def prepare_reward(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps):
# value prefix (or reward), value
batch_value_prefixes = []
# v_{t} = r + ... + gamma ^ k * v_{t+k}
value_index = 0
top_value_masks = []
for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst):
traj_len = len(traj)
target_value_prefixs = []
horizon_id = 0
value_prefix = 0.0
top_value_masks.append(int(idx > collected_transitions - self.config.train.start_use_mix_training_steps))
for current_index in range(state_index, state_index + self.unroll_steps + 1):
# reset every lstm_horizon_len
if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix:
value_prefix = 0.0
horizon_id += 1
if current_index < traj_len:
# Since the horizon is small and the discount is close to 1.
# Compute the reward sum to approximate the value prefix for simplification
if self.value_prefix:
value_prefix += traj.reward_lst[current_index]
else:
value_prefix = traj.reward_lst[current_index]
target_value_prefixs.append(value_prefix)
else:
target_value_prefixs.append(value_prefix)
value_index += 1
batch_value_prefixes.append(target_value_prefixs)
value_masks = np.asarray(top_value_masks)
batch_value_prefixes = np.asarray(batch_value_prefixes)
batch_values = np.zeros_like(batch_value_prefixes)
td_steps_lst = np.ones_like(batch_value_prefixes)
return batch_value_prefixes, np.asarray(batch_values), td_steps_lst.flatten(), \
(None, None, None, None), value_masks
def prepare_reward_value(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps):
# value prefix (or reward), value
batch_value_prefixes, batch_values = [], []
# search_values = []
# init
value_obs_lst, td_steps_lst, value_mask = [], [], [] # mask: 0 -> out of traj
zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False)
# get obs_{t+k}
for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst):
traj_len = len(traj)
# off-policy correction: shorter horizon of td steps
delta_td = (collected_transitions - idx) // self.auto_td_steps
if self.value_target in ['mixed', 'max']:
delta_td = 0
td_steps = self.td_steps - delta_td
# td_steps = self.td_steps # for test off-policy issue
if not self.episodic:
td_steps = min(traj_len - state_index, td_steps)
td_steps = np.clip(td_steps, 1, self.td_steps).astype(np.int32)
obs_idx = state_index + td_steps
# prepare the corresponding observations for bootstrapped values o_{t+k}
traj_obs = traj.get_index_stacked_obs(state_index + td_steps)
for current_index in range(state_index, state_index + self.unroll_steps + 1):
if not self.episodic:
td_steps = min(traj_len - current_index, td_steps)
td_steps = max(td_steps, 1)
bootstrap_index = current_index + td_steps
if not self.episodic:
if bootstrap_index <= traj_len:
value_mask.append(1)
beg_index = bootstrap_index - obs_idx
end_index = beg_index + self.n_stack
obs = traj_obs[beg_index:end_index]
else:
value_mask.append(0)
obs = zero_obs
else:
if bootstrap_index < traj_len:
value_mask.append(1)
beg_index = bootstrap_index - (state_index + td_steps)
end_index = beg_index + self.n_stack
obs = traj_obs[beg_index:end_index]
else:
value_mask.append(0)
obs = zero_obs
value_obs_lst.append(obs)
td_steps_lst.append(td_steps)
# reanalyze the bootstrapped value v_{t+k}
state_lst, value_lst, policy_lst = self.efficient_inference(value_obs_lst, only_value=True)
batch_size = len(value_lst)
value_lst = value_lst.reshape(-1) * (np.array([self.discount for _ in range(batch_size)]) ** td_steps_lst)
value_lst = value_lst * np.array(value_mask)
# value_lst = np.zeros_like(value_lst) # for unit test, remove if training
value_lst = value_lst.tolist()
# v_{t} = r + ... + gamma ^ k * v_{t+k}
value_index = 0
top_value_masks = []
for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst):
traj_len = len(traj)
target_values = []
target_value_prefixs = []
horizon_id = 0
value_prefix = 0.0
top_value_masks.append(int(idx > collected_transitions - self.mixed_value_threshold))
for current_index in range(state_index, state_index + self.unroll_steps + 1):
bootstrap_index = current_index + td_steps_lst[value_index]
for i, reward in enumerate(traj.reward_lst[current_index:bootstrap_index]):
value_lst[value_index] += reward * self.discount ** i
# reset every lstm_horizon_len
if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix:
value_prefix = 0.0
horizon_id += 1
if current_index < traj_len:
# Since the horizon is small and the discount is close to 1.
# Compute the reward sum to approximate the value prefix for simplification
if self.value_prefix:
value_prefix += traj.reward_lst[current_index]
else:
value_prefix = traj.reward_lst[current_index]
target_value_prefixs.append(value_prefix)
else:
target_value_prefixs.append(value_prefix)
if self.episodic:
if current_index < traj_len:
target_values.append(value_lst[value_index])
else:
target_values.append(0)
else:
if current_index <= traj_len:
target_values.append(value_lst[value_index])
else:
target_values.append(0)
value_index += 1
batch_value_prefixes.append(target_value_prefixs)
batch_values.append(target_values)
value_masks = np.asarray(top_value_masks)
return np.asarray(batch_value_prefixes), np.asarray(batch_values), np.asarray(td_steps_lst).flatten(), \
(None, None, None, None), value_masks
def prepare_policy_non_reanalyze(self, traj_lst, transition_pos_lst):
# policy
batch_policies = []
# load searched policy in self-play
for traj, state_index in zip(traj_lst, transition_pos_lst):
traj_len = len(traj)
target_policies = []
for current_index in range(state_index, state_index + self.unroll_steps + 1):
if current_index < traj_len:
target_policies.append(traj.policy_lst[current_index])
else:
target_policies.append([0 for _ in range(self.action_space_size)])
batch_policies.append(target_policies)
return batch_policies
def prepare_policy_reanalyze(self, trained_steps, traj_lst, transition_pos_lst, indices_lst, state_lst=None, value_lst=None, policy_lst=None, policy_mask=None):
# policy
reanalyzed_values = []
batch_policies = []
# init
if value_lst is None:
policy_obs_lst, policy_mask = [], [] # mask: 0 -> out of traj
zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False)
# get obs_{t} instead of obs_{t+k}
for traj, state_index in zip(traj_lst, transition_pos_lst):
traj_len = len(traj)
game_obs = traj.get_index_stacked_obs(state_index)
for current_index in range(state_index, state_index + self.unroll_steps + 1):
if current_index < traj_len:
policy_mask.append(1)
beg_index = current_index - state_index
end_index = beg_index + self.n_stack
obs = game_obs[beg_index:end_index]
else:
policy_mask.append(0)
obs = np.asarray(zero_obs)
policy_obs_lst.append(obs)
# reanalyze the search policy pi_{t}
state_lst, value_lst, policy_lst = self.efficient_inference(policy_obs_lst, only_value=False)
# tree search for policies
batch_size = len(state_lst)
# temperature
temperature = self.agent.get_temperature(trained_steps=trained_steps) #* np.ones((batch_size, 1))
tree = mcts.names[self.config.mcts.language](
num_actions=self.action_space_size if self.env == 'Atari' else self.config.mcts.num_sampled_actions,
discount=self.config.rl.discount,
env=self.env,
**self.config.mcts, # pass mcts related params
**self.config.model, # pass the value and reward support params
)
if self.env == 'Atari':
if self.config.mcts.use_gumbel:
r_values, r_policies, best_actions, _ = tree.search(
self.model,
# self.latest_model,
batch_size, state_lst, value_lst, policy_lst, use_gumble_noise=True, temperature=temperature
)
else:
r_values, r_policies, best_actions, _ = tree.search_ori_mcts(
self.model, batch_size, state_lst, value_lst, policy_lst, use_noise=True, temperature=temperature, is_reanalyze=True
)
sampled_actions = best_actions
search_best_indexes = best_actions
else:
r_values, r_policies, best_actions, sampled_actions, search_best_indexes, _ = tree.search_continuous(
self.model,
batch_size, state_lst, value_lst, policy_lst, temperature=temperature,
)
if self.config.train.optimal_Q:
r_values = self.efficient_recurrent(state_lst, policy_lst)
r_values = r_values.reshape(-1) * np.array(policy_mask)
r_values = r_values.tolist()
# concat policy
policy_index = 0
policy_masks = []
mismatch_index = []
for traj, state_index, ind in zip(traj_lst, transition_pos_lst, indices_lst):
target_policies = []
search_values = []
policy_masks.append([])
for current_index in range(state_index, state_index + self.unroll_steps + 1):
traj_len = len(traj)
assert (current_index < traj_len) == (policy_mask[policy_index])
if policy_mask[policy_index]:
target_policies.append(r_policies[policy_index])
search_values.append(r_values[policy_index])
# mask best-action & pi_prime mismatches
if r_policies[policy_index].argmax() != search_best_indexes[policy_index]:
policy_mask[policy_index] = 0
mismatch_index.append(ind + current_index - state_index)
else:
search_values.append(0.0)
if self.env in ['DMC','Gym']:
target_policies.append([0 for _ in range(sampled_actions.shape[1])])
else:
target_policies.append([0 for _ in range(self.action_space_size)])
policy_masks[-1].append(policy_mask[policy_index])
policy_index += 1
batch_policies.append(target_policies)
reanalyzed_values.append(search_values)
if self.rank == 0 and self.config.eval.analysis_value:
new_log_index = trained_steps // 5000
if new_log_index > self.last_log_index:
self.last_log_index = new_log_index
min_idx = np.asarray(indices_lst).argmin()
r_value = reanalyzed_values[min_idx][0]
self.storage.add_log_scalar.remote({
'batch_worker/search_value': r_value
})
policy_masks = np.asarray(policy_masks)
return batch_policies, sampled_actions, best_actions, reanalyzed_values, (state_lst, value_lst, policy_lst, policy_mask), policy_masks
@torch.no_grad()
def imagine_episodes(self, pre_lst, traj_lst, transition_pos_lst, trained_steps, policy='search'):
length = 1
times = 3
# input_obs = np.concatenate([stack_obs for _ in range(times)], axis=0)
# states, values, policies = self.efficient_inference(input_obs)
states, values, policies, policy_mask = pre_lst
states = torch.cat([states for _ in range(times)], dim=0)
values = np.concatenate([values for _ in range(times)], axis=0)
policies = torch.cat([policies for _ in range(times)], dim=0)
reward_hidden = (torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda(),
torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda())
last_values_prefixes = np.zeros(len(states))
reward_lst = []
value_lst = []
temperature = self.agent.get_temperature(trained_steps=trained_steps) * np.ones((len(states), 1))
for i in range(length):
if policy == 'search':
tree = mcts.names[self.config.mcts.language](
num_actions=self.config.env.action_space_size if self.env == 'Atari' else self.config.mcts.num_top_actions,
discount=self.config.rl.discount,
**self.config.mcts, # pass mcts related params
**self.config.model, # pass the value and reward support params
)
if self.env == 'Atari':
if self.config.mcts.use_gumbel:
r_values, r_policies, best_actions, _ = tree.search(
self.model, len(states), states, values, policies,
use_gumble_noise=True, temperature=temperature
)
else:
r_values, r_policies, best_actions, _ = tree.search_ori_mcts(
self.model, len(states), states, values, policies, use_noise=True,
temperature=temperature, is_reanalyze=True
)
else:
r_values, r_policies, best_actions, sampled_actions, _, _ = tree.search_continuous(
self.model, len(states), states, values, policies,
use_gumble_noise=False, temperature=temperature)
if policy == 'search':
actions = torch.from_numpy(np.asarray(best_actions)).cuda().float()
else:
if self.env == 'Atari':
actions = F.gumbel_softmax(policies, hard=True, dim=-1, tau=1e-4)
actions = actions.argmax(dim=-1)
else:
actions = policies[:, :policies.shape[-1]//2]
actions = actions.unsqueeze(1)
with autocast():
states, value_prefixes, values, policies, reward_hidden = \
self.model.recurrent_inference(states, actions, reward_hidden)
values = values.squeeze().detach().cpu().numpy()
value_lst.append(values)
if self.value_prefix and (i + 1) % self.lstm_horizon_len == 0:
reward_hidden = (torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda(),
torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda())
true_rewards = value_prefixes.squeeze().detach().cpu().numpy()
# last_values_prefixes = np.zeros(len(states))
else:
true_rewards = value_prefixes.squeeze().detach().cpu().numpy() - last_values_prefixes
last_values_prefixes = value_prefixes.squeeze().detach().cpu().numpy()
reward_lst.append(true_rewards)
value = 0
for i, reward in enumerate(reward_lst):
value += reward * (self.config.rl.discount ** i)
value += (self.config.rl.discount ** length) * value_lst[-1]
value_reshaped = []
batch_size = len(states) // times
for i in range(times):
value_reshaped.append(value[batch_size * i:batch_size * (i+1)])
value_reshaped = np.asarray(value_reshaped).mean(0)
output_values = []
policy_index = 0
for traj, state_index in zip(traj_lst, transition_pos_lst):
imagined_values = []
for current_index in range(state_index, state_index + self.unroll_steps + 1):
traj_len = len(traj)
# assert (current_index < traj_len) == (policy_mask[policy_index])
if policy_mask[policy_index]:
imagined_values.append(value_reshaped[policy_index])
else:
imagined_values.append(0.0)
policy_index += 1
output_values.append(imagined_values)
return np.asarray(output_values)
def efficient_inference(self, obs_lst, only_value=False, value_idx=0):
batch_size = len(obs_lst)
obs_lst = np.asarray(obs_lst)
state_lst, value_lst, policy_lst = [], [], []
# split a full batch into slices of mini_infer_size
mini_batch = self.config.train.mini_batch_size
slices = np.ceil(batch_size / mini_batch).astype(np.int32)
with torch.no_grad():
for i in range(slices):
beg_index = mini_batch * i
end_index = mini_batch * (i + 1)
current_obs = obs_lst[beg_index:end_index]
current_obs = formalize_obs_lst(current_obs, self.image_based)
# obtain the statistics at current steps
with autocast():
states, values, policies = self.model.initial_inference(current_obs)
# process outputs
values = values.detach().cpu().numpy().flatten()
# concat
value_lst.append(values)
if not only_value:
state_lst.append(states)
policy_lst.append(policies)
value_lst = np.concatenate(value_lst)
if not only_value:
state_lst = torch.cat(state_lst)
policy_lst = torch.cat(policy_lst)
return state_lst, value_lst, policy_lst
# ======================================================================================================================
# batch worker
# ======================================================================================================================
def start_batch_worker(rank, agent, replay_buffer, storage, batch_storage, config):
"""
Start a GPU batch worker. Call this method remotely.
"""
worker = BatchWorker.remote(rank, agent, replay_buffer, storage, batch_storage, config)
print(f"[Batch worker GPU] Starting batch worker GPU {rank} at process {os.getpid()}.")
worker.run.remote()
def start_batch_worker_cpu(rank, agent, replay_buffer, storage, prebatch_storage, config):
worker = BatchWorker_CPU.remote(rank, agent, replay_buffer, storage, prebatch_storage, config)
print(f"[Batch worker CPU] Starting batch worker CPU {rank} at process {os.getpid()}.")
worker.run.remote()
def start_batch_worker_gpu(rank, agent, replay_buffer, storage, prebatch_storage, batch_storage, config):
worker = BatchWorker_GPU.remote(rank, agent, replay_buffer, storage, prebatch_storage, batch_storage, config)
print(f"[Batch worker GPU] Starting batch worker GPU {rank} at process {os.getpid()}.")
worker.run.remote()