# Copyright (c) EVAR Lab, IIIS, Tsinghua University. # # This source code is licensed under the GNU License, Version 3.0 # found in the LICENSE file in the root directory of this source tree. import os import time import ray import torch import copy import gym import imageio from PIL import Image, ImageDraw import numpy as np from torch.cuda.amp import autocast as autocast import torch.nn.functional as F from .base import Worker from ez import mcts from ez.envs import make_envs from ez.utils.distribution import SquashedNormal, TruncatedNormal, ContDist from ez.utils.format import formalize_obs_lst, DiscreteSupport, LinearSchedule, prepare_obs_lst, allocate_gpu, profile, symexp from ez.data.trajectory import GameTrajectory from ez.mcts.cy_mcts import Gumbel_MCTS @ray.remote(num_gpus=0.03) # @ray.remote(num_gpus=0.14) class BatchWorker(Worker): def __init__(self, rank, agent, replay_buffer, storage, batch_storage, config): super().__init__(rank, agent, replay_buffer, storage, config) self.model_update_interval = config.train.reanalyze_update_interval self.batch_storage = batch_storage self.beta_schedule = LinearSchedule(self.total_steps, initial_p=config.priority.priority_prob_beta, final_p=1.0) self.total_transitions = self.config.data.total_transitions self.auto_td_steps = self.config.rl.auto_td_steps self.td_steps = self.config.rl.td_steps self.unroll_steps = self.config.rl.unroll_steps self.n_stack = self.config.env.n_stack self.discount = self.config.rl.discount self.value_support = self.config.model.value_support self.action_space_size = self.config.env.action_space_size self.batch_size = self.config.train.batch_size self.PER_alpha = self.config.priority.priority_prob_alpha self.env = self.config.env.env self.image_based = self.config.env.image_based self.reanalyze_ratio = self.config.train.reanalyze_ratio self.value_target = self.config.train.value_target self.value_target_type = self.config.model.value_target self.GAE_max_steps = self.config.model.GAE_max_steps self.episodic = self.config.env.episodic self.value_prefix = self.config.model.value_prefix self.lstm_horizon_len = self.config.model.lstm_horizon_len self.training_steps = self.config.train.training_steps self.td_lambda = self.config.rl.td_lambda self.gray_scale = self.config.env.gray_scale self.obs_shape = self.config.env.obs_shape self.trajectory_size = self.config.data.trajectory_size self.mixed_value_threshold = self.config.train.mixed_value_threshold self.lstm_hidden_size = self.config.model.lstm_hidden_size self.cnt = 0 def concat_trajs(self, items): obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, \ bootstrapped_value_lsts = items traj_lst = [] for obs_lst, reward_lst, policy_lst, action_lst, pred_value_lst, search_value_lst, bootstrapped_value_lst in \ zip(obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, bootstrapped_value_lsts): # traj = GameTrajectory(**self.config.env, **self.config.rl, **self.config.model) traj = GameTrajectory( n_stack=self.n_stack, discount=self.discount, gray_scale=self.gray_scale, unroll_steps=self.unroll_steps, td_steps=self.td_steps, td_lambda=self.td_lambda, obs_shape=self.obs_shape, max_size=self.trajectory_size, image_based=self.image_based, episodic=self.episodic, GAE_max_steps=self.GAE_max_steps ) traj.obs_lst = obs_lst traj.reward_lst = reward_lst traj.policy_lst = policy_lst traj.action_lst = action_lst traj.pred_value_lst = pred_value_lst traj.search_value_lst = search_value_lst traj.bootstrapped_value_lst = bootstrapped_value_lst traj_lst.append(traj) return traj_lst def run(self): trained_steps = 0 # create the model for self-play data collection self.model = self.agent.build_model() self.latest_model = self.agent.build_model() if self.config.eval.analysis_value: weights = torch.load(self.config.eval.model_path) self.model.load_state_dict(weights) print('analysis begin') self.model.cuda() self.latest_model.cuda() if int(torch.__version__[0]) == 2: self.model = torch.compile(self.model) self.latest_model = torch.compile(self.latest_model) self.model.eval() self.latest_model.eval() self.resume_model() # wait for starting to train while not ray.get(self.storage.get_start_signal.remote()): time.sleep(0.5) # begin to make batch prev_trained_steps = -10 while not self.is_finished(trained_steps): trained_steps = ray.get(self.storage.get_counter.remote()) if self.config.ray.single_process: if trained_steps <= prev_trained_steps: time.sleep(0.1) continue prev_trained_steps = trained_steps print(f'reanalyze[{self.rank}] makes batch at step {trained_steps}') # get the fresh model weights self.get_recent_model(trained_steps, 'reanalyze') self.get_latest_model(trained_steps, 'latest') # if self.config.model.noisy_net: # self.model.value_policy_model.reset_noise() start_time = time.time() ray_time = self.make_batch(trained_steps, self.cnt) self.cnt += 1 end_time = time.time() # if self.cnt % 100 == 0: # print(f'make batch time={end_time-start_time:.3f}s, ray get time={ray_time:.3f}s') # @torch.no_grad() # @profile def make_batch(self, trained_steps, cnt, real_time=False): beta = self.beta_schedule.value(trained_steps) batch_size = self.batch_size # obtain the batch context from replay buffer x = time.time() batch_context = ray.get( self.replay_buffer.prepare_batch_context.remote(batch_size=batch_size, alpha=self.PER_alpha, beta=beta, rank=self.rank, cnt=cnt) ) batch_context, validation_flag = batch_context ray_time = time.time() - x traj_lst, transition_pos_lst, indices_lst, weights_lst, make_time_lst, transition_num, prior_lst = batch_context traj_lst = self.concat_trajs(traj_lst) # part of policy will be reanalyzed reanalyze_batch_size = batch_size if self.env in ['DMC', 'Gym'] \ else int(batch_size * self.config.train.reanalyze_ratio) assert 0 <= reanalyze_batch_size <= batch_size # ============================================================================================================== # make inputs # ============================================================================================================== collected_transitions = ray.get(self.replay_buffer.get_transition_num.remote()) # make observations, actions and masks (if unrolled steps are out of trajectory) obs_lst, action_lst, mask_lst = [], [], [] top_new_masks = [] # prepare the inputs of a batch for i in range(batch_size): traj = traj_lst[i] state_index = transition_pos_lst[i] sample_idx = indices_lst[i] top_new_masks.append(int(sample_idx > collected_transitions - self.mixed_value_threshold)) if self.env in ['DMC', 'Gym']: _actions = traj.action_lst[state_index:state_index + self.unroll_steps] _unroll_actions = traj.action_lst[state_index + 1:state_index + 1 + self.unroll_steps] # _unroll_actions = traj.action_lst[state_index:state_index + self.unroll_steps] _mask = [1. for _ in range(_unroll_actions.shape[0])] _mask += [0. for _ in range(self.unroll_steps - len(_mask))] _rand_actions = np.zeros((self.unroll_steps - _actions.shape[0], self.action_space_size)) _actions = np.concatenate((_actions, _rand_actions), axis=0) else: _actions = traj.action_lst[state_index:state_index + self.unroll_steps].tolist() _mask = [1. for _ in range(len(_actions))] _mask += [0. for _ in range(self.unroll_steps - len(_mask))] _actions += [np.random.randint(0, self.action_space_size) for _ in range(self.unroll_steps - len(_actions))] # obtain the input observations obs_lst.append(traj.get_index_stacked_obs(state_index, padding=True)) action_lst.append(_actions) mask_lst.append(_mask) obs_lst = prepare_obs_lst(obs_lst, self.image_based) inputs_batch = [obs_lst, action_lst, mask_lst, indices_lst, weights_lst, make_time_lst, prior_lst] for i in range(len(inputs_batch)): inputs_batch[i] = np.asarray(inputs_batch[i]) # ============================================================================================================== # make targets # ============================================================================================================== if self.value_target in ['sarsa', 'mixed', 'max']: if self.value_target_type == 'GAE': prepare_func = self.prepare_reward_value_gae elif self.value_target_type == 'bootstrapped': prepare_func = self.prepare_reward_value else: raise NotImplementedError elif self.value_target == 'search': prepare_func = self.prepare_reward else: raise NotImplementedError # obtain the value prefix (reward), and the value batch_value_prefixes, batch_values, td_steps, pre_calc, value_masks = \ prepare_func(traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps) # obtain the re policy if reanalyze_batch_size > 0: batch_policies_re, sampled_actions, best_actions, reanalyzed_values, pre_lst, policy_masks = \ self.prepare_policy_reanalyze( trained_steps, traj_lst[:reanalyze_batch_size], transition_pos_lst[:reanalyze_batch_size], indices_lst[:reanalyze_batch_size], state_lst=pre_calc[0], value_lst=pre_calc[1], policy_lst=pre_calc[2], policy_mask=pre_calc[3] ) else: batch_policies_re = [] # obtain the non-re policy if batch_size - reanalyze_batch_size > 0: batch_policies_non_re = self.prepare_policy_non_reanalyze(traj_lst[reanalyze_batch_size:], transition_pos_lst[reanalyze_batch_size:]) else: batch_policies_non_re = [] # concat target policy batch_policies = batch_policies_re if self.env in ['DMC', 'Gym']: batch_best_actions = best_actions.reshape(batch_size, self.unroll_steps + 1, self.action_space_size) else: batch_best_actions = np.asarray(best_actions).reshape(batch_size, self.unroll_steps + 1) # target value prefix (reward), value, policy if self.env not in ['DMC', 'Gym']: batch_actions = np.ones_like(batch_policies) else: batch_actions = sampled_actions.reshape( batch_size, self.unroll_steps + 1, -1, self.action_space_size ) targets_batch = [batch_value_prefixes, batch_values, batch_actions, batch_policies, batch_best_actions, top_new_masks, policy_masks, reanalyzed_values] for i in range(len(targets_batch)): targets_batch[i] = np.asarray(targets_batch[i]) # ============================================================================================================== # push batch into batch queue # ============================================================================================================== # full batch data: [obs_lst, other stuffs, target stuffs] # batch = [inputs_batch[0], inputs_batch[1:], targets_batch] batch = [inputs_batch, targets_batch] # log self.storage.add_log_scalar.remote({ 'batch_worker/td_step': np.mean(td_steps) }) if real_time: return batch else: # push into batch storage self.batch_storage.push(batch) return ray_time # @profile def prepare_reward_value_gae_faster(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps): # value prefix (or reward), value batch_value_prefixes, batch_values = [], [] extra = max(0, min(int(1 / (1 - self.td_lambda)), self.GAE_max_steps) - self.unroll_steps - 1) # init value_obs_lst, td_steps_lst, value_mask, policy_mask = [], [], [], [] # mask: 0 -> out of traj # policy_obs_lst, policy_mask = [], [] zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False) # get obs_{t+k} td_steps = 1 for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst): traj_len = len(traj) # prepare the corresponding observations for bootstrapped values o_{t+k} traj_obs = traj.get_index_stacked_obs(state_index, extra=extra + td_steps) for current_index in range(state_index, state_index + self.unroll_steps + 1 + extra + td_steps): bootstrap_index = current_index if not self.episodic: if bootstrap_index <= traj_len: beg_index = bootstrap_index - state_index end_index = beg_index + self.n_stack obs = traj_obs[beg_index:end_index] value_mask.append(1) if bootstrap_index < traj_len: policy_mask.append(1) else: policy_mask.append(0) else: value_mask.append(0) policy_mask.append(0) obs = np.asarray(zero_obs) else: if bootstrap_index < traj_len: beg_index = bootstrap_index - state_index end_index = beg_index + self.n_stack obs = traj_obs[beg_index:end_index] value_mask.append(1) policy_mask.append(1) else: value_mask.append(0) policy_mask.append(0) obs = np.asarray(zero_obs) value_obs_lst.append(obs) td_steps_lst.append(td_steps) # reanalyze the bootstrapped value v_{t+k} state_lst, value_lst, policy_lst = self.efficient_inference(value_obs_lst, only_value=False) # v_{t+k} batch_size = len(value_lst) value_lst = value_lst.reshape(-1) * (np.array([self.discount for _ in range(batch_size)]) ** td_steps_lst) value_lst = value_lst * np.array(value_mask) # value_lst = np.zeros_like(value_lst) # for unit test, remove if training td_value_lst = copy.deepcopy(value_lst) value_lst = value_lst.tolist() td_value_lst = td_value_lst.tolist() re_state_lst, re_value_lst, re_policy_lst, re_policy_mask = [], [], [], [] for i in range(len(state_lst)): if i % (self.unroll_steps + extra + 1 + td_steps) < self.unroll_steps + 1: re_state_lst.append(state_lst[i].unsqueeze(0)) re_value_lst.append(value_lst[i]) re_policy_lst.append(policy_lst[i].unsqueeze(0)) re_policy_mask.append(policy_mask[i]) re_state_lst = torch.cat(re_state_lst, dim=0) re_value_lst = np.asarray(re_value_lst) re_policy_lst = torch.cat(re_policy_lst, dim=0) # v_{t} = r + ... + gamma ^ k * v_{t+k} value_index = 0 td_lambdas = [] for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst): traj_len = len(traj) target_values = [] target_value_prefixs = [] delta_lambda = 0.1 * (collected_transitions - idx) / self.auto_td_steps if self.value_target in ['mixed', 'max']: delta_lambda = 0.0 td_lambda = self.td_lambda - delta_lambda td_lambda = np.clip(td_lambda, 0.65, self.td_lambda) td_lambdas.append(td_lambda) delta = np.zeros(self.unroll_steps + 1 + extra) advantage = np.zeros(self.unroll_steps + 1 + extra + 1) index = self.unroll_steps + extra for current_index in reversed(range(state_index, state_index + self.unroll_steps + 1 + extra)): bootstrap_index = current_index + td_steps_lst[value_index + index] for i, reward in enumerate(traj.reward_lst[current_index:bootstrap_index]): td_value_lst[value_index + index + td_steps] += reward * self.discount ** i delta[index] = td_value_lst[value_index + index + td_steps] - value_lst[value_index + index] advantage[index] = delta[index] + self.discount * td_lambda * advantage[index + 1] index -= 1 target_values_tmp = advantage[:self.unroll_steps + 1] + np.asarray(value_lst)[value_index:value_index + self.unroll_steps + 1] horizon_id = 0 value_prefix = 0.0 for i, current_index in enumerate(range(state_index, state_index + self.unroll_steps + 1)): # reset every lstm_horizon_len if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix: value_prefix = 0.0 horizon_id += 1 if current_index < traj_len: # Since the horizon is small and the discount is close to 1. # Compute the reward sum to approximate the value prefix for simplification if self.value_prefix: value_prefix += traj.reward_lst[current_index] else: value_prefix = traj.reward_lst[current_index] target_value_prefixs.append(value_prefix) else: target_value_prefixs.append(value_prefix) if self.episodic: if current_index < traj_len: target_values.append(target_values_tmp[i]) else: target_values.append(0) else: if current_index <= traj_len: target_values.append(target_values_tmp[i]) else: target_values.append(0) value_index += (self.unroll_steps + 1 + extra + td_steps) batch_value_prefixes.append(target_value_prefixs) batch_values.append(target_values) if self.rank == 0 and self.cnt % 20 == 0: print(f'--------------- lambda={np.asarray(td_lambdas).mean():.3f} -------------------') self.storage.add_log_scalar.remote({ 'batch_worker/td_lambda': np.asarray(td_lambdas).mean() }) value_index = 0 value_masks, policy_masks = [], [] for i, idx in enumerate(indices_lst): value_masks.append(int(idx > collected_transitions - self.mixed_value_threshold)) value_index += (self.unroll_steps + 1 + extra) value_masks = np.asarray(value_masks) return np.asarray(batch_value_prefixes), np.asarray(batch_values), np.asarray(td_steps_lst).flatten(), \ (re_state_lst, re_value_lst, re_policy_lst, re_policy_mask), value_masks # @profile def prepare_reward_value_gae(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps): # value prefix (or reward), value batch_value_prefixes, batch_values = [], [] extra = max(0, min(int(1 / (1 - self.td_lambda)), self.GAE_max_steps) - self.unroll_steps - 1) # init value_obs_lst, td_steps_lst, value_mask = [], [], [] # mask: 0 -> out of traj policy_obs_lst, policy_mask = [], [] zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False) # get obs_{t+k} for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst): traj_len = len(traj) td_steps = 1 # prepare the corresponding observations for bootstrapped values o_{t+k} traj_obs = traj.get_index_stacked_obs(state_index + td_steps, extra=extra) game_obs = traj.get_index_stacked_obs(state_index, extra=extra) for current_index in range(state_index, state_index + self.unroll_steps + 1 + extra): bootstrap_index = current_index + td_steps if not self.episodic: if bootstrap_index <= traj_len: value_mask.append(1) beg_index = bootstrap_index - (state_index + td_steps) end_index = beg_index + self.n_stack obs = traj_obs[beg_index:end_index] else: value_mask.append(0) obs = np.asarray(zero_obs) else: if bootstrap_index < traj_len: value_mask.append(1) beg_index = bootstrap_index - (state_index + td_steps) end_index = beg_index + self.n_stack obs = traj_obs[beg_index:end_index] else: value_mask.append(0) obs = np.asarray(zero_obs) value_obs_lst.append(obs) td_steps_lst.append(td_steps) if current_index < traj_len: policy_mask.append(1) beg_index = current_index - state_index end_index = beg_index + self.n_stack obs = game_obs[beg_index:end_index] else: policy_mask.append(0) obs = np.asarray(zero_obs) policy_obs_lst.append(obs) # reanalyze the bootstrapped value v_{t+k} _, value_lst, _ = self.efficient_inference(value_obs_lst, only_value=True) state_lst, ori_cur_value_lst, policy_lst = self.efficient_inference(policy_obs_lst, only_value=False) # v_{t+k} batch_size = len(value_lst) value_lst = value_lst.reshape(-1) * (np.array([self.discount for _ in range(batch_size)]) ** td_steps_lst) value_lst = value_lst * np.array(value_mask) # value_lst = np.zeros_like(value_lst) # for unit test, remove if training value_lst = value_lst.tolist() cur_value_lst = ori_cur_value_lst.reshape(-1) * np.array(policy_mask) # cur_value_lst = np.zeros_like(cur_value_lst) # for unit test, remove if training cur_value_lst = cur_value_lst.tolist() state_lst_cut, ori_cur_value_lst_cut, policy_lst_cut, policy_mask_cut = [], [], [], [] for i in range(len(state_lst)): if i % (self.unroll_steps + extra + 1) < self.unroll_steps + 1: state_lst_cut.append(state_lst[i].unsqueeze(0)) ori_cur_value_lst_cut.append(ori_cur_value_lst[i]) policy_lst_cut.append(policy_lst[i].unsqueeze(0)) policy_mask_cut.append(policy_mask[i]) state_lst_cut = torch.cat(state_lst_cut, dim=0) ori_cur_value_lst_cut = np.asarray(ori_cur_value_lst_cut) policy_lst_cut = torch.cat(policy_lst_cut, dim=0) # v_{t} = r + ... + gamma ^ k * v_{t+k} value_index = 0 td_lambdas = [] for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst): traj_len = len(traj) target_values = [] target_value_prefixs = [] delta_lambda = 0.1 * (collected_transitions - idx) / self.auto_td_steps if self.value_target in ['mixed', 'max']: delta_lambda = 0.0 td_lambda = self.td_lambda - delta_lambda td_lambda = np.clip(td_lambda, 0.65, self.td_lambda) td_lambdas.append(td_lambda) delta = np.zeros(self.unroll_steps + 1 + extra) advantage = np.zeros(self.unroll_steps + 1 + extra + 1) index = self.unroll_steps + extra for current_index in reversed(range(state_index, state_index + self.unroll_steps + 1 + extra)): bootstrap_index = current_index + td_steps_lst[value_index + index] for i, reward in enumerate(traj.reward_lst[current_index:bootstrap_index]): value_lst[value_index + index] += reward * self.discount ** i delta[index] = value_lst[value_index + index] - cur_value_lst[value_index + index] advantage[index] = delta[index] + self.discount * td_lambda * advantage[index + 1] index -= 1 target_values_tmp = advantage[:self.unroll_steps + 1] + np.asarray(cur_value_lst)[value_index:value_index + self.unroll_steps + 1] horizon_id = 0 value_prefix = 0.0 for i, current_index in enumerate(range(state_index, state_index + self.unroll_steps + 1)): # reset every lstm_horizon_len if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix: value_prefix = 0.0 horizon_id += 1 if current_index < traj_len: # Since the horizon is small and the discount is close to 1. # Compute the reward sum to approximate the value prefix for simplification if self.value_prefix: value_prefix += traj.reward_lst[current_index] else: value_prefix = traj.reward_lst[current_index] target_value_prefixs.append(value_prefix) else: target_value_prefixs.append(value_prefix) if self.episodic: if current_index < traj_len: target_values.append(target_values_tmp[i]) else: target_values.append(0) else: if current_index <= traj_len: target_values.append(target_values_tmp[i]) else: target_values.append(0) value_index += (self.unroll_steps + 1 + extra) batch_value_prefixes.append(target_value_prefixs) batch_values.append(target_values) if self.rank == 0 and self.cnt % 20 == 0: print(f'--------------- lambda={np.asarray(td_lambdas).mean():.3f} -------------------') self.storage.add_log_scalar.remote({ 'batch_worker/td_lambda': np.asarray(td_lambdas).mean() }) value_index = 0 value_masks, policy_masks = [], [] for i, idx in enumerate(indices_lst): value_masks.append(int(idx > collected_transitions - self.mixed_value_threshold)) value_index += (self.unroll_steps + 1 + extra) value_masks = np.asarray(value_masks) return np.asarray(batch_value_prefixes), np.asarray(batch_values), np.asarray(td_steps_lst).flatten(), \ (state_lst_cut, ori_cur_value_lst_cut, policy_lst_cut, policy_mask_cut), value_masks def prepare_reward(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps): # value prefix (or reward), value batch_value_prefixes = [] # v_{t} = r + ... + gamma ^ k * v_{t+k} value_index = 0 top_value_masks = [] for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst): traj_len = len(traj) target_value_prefixs = [] horizon_id = 0 value_prefix = 0.0 top_value_masks.append(int(idx > collected_transitions - self.config.train.start_use_mix_training_steps)) for current_index in range(state_index, state_index + self.unroll_steps + 1): # reset every lstm_horizon_len if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix: value_prefix = 0.0 horizon_id += 1 if current_index < traj_len: # Since the horizon is small and the discount is close to 1. # Compute the reward sum to approximate the value prefix for simplification if self.value_prefix: value_prefix += traj.reward_lst[current_index] else: value_prefix = traj.reward_lst[current_index] target_value_prefixs.append(value_prefix) else: target_value_prefixs.append(value_prefix) value_index += 1 batch_value_prefixes.append(target_value_prefixs) value_masks = np.asarray(top_value_masks) batch_value_prefixes = np.asarray(batch_value_prefixes) batch_values = np.zeros_like(batch_value_prefixes) td_steps_lst = np.ones_like(batch_value_prefixes) return batch_value_prefixes, np.asarray(batch_values), td_steps_lst.flatten(), \ (None, None, None, None), value_masks def prepare_reward_value(self, traj_lst, transition_pos_lst, indices_lst, collected_transitions, trained_steps): # value prefix (or reward), value batch_value_prefixes, batch_values = [], [] # search_values = [] # init value_obs_lst, td_steps_lst, value_mask = [], [], [] # mask: 0 -> out of traj zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False) # get obs_{t+k} for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst): traj_len = len(traj) # off-policy correction: shorter horizon of td steps delta_td = (collected_transitions - idx) // self.auto_td_steps if self.value_target in ['mixed', 'max']: delta_td = 0 td_steps = self.td_steps - delta_td # td_steps = self.td_steps # for test off-policy issue if not self.episodic: td_steps = min(traj_len - state_index, td_steps) td_steps = np.clip(td_steps, 1, self.td_steps).astype(np.int32) obs_idx = state_index + td_steps # prepare the corresponding observations for bootstrapped values o_{t+k} traj_obs = traj.get_index_stacked_obs(state_index + td_steps) for current_index in range(state_index, state_index + self.unroll_steps + 1): if not self.episodic: td_steps = min(traj_len - current_index, td_steps) td_steps = max(td_steps, 1) bootstrap_index = current_index + td_steps if not self.episodic: if bootstrap_index <= traj_len: value_mask.append(1) beg_index = bootstrap_index - obs_idx end_index = beg_index + self.n_stack obs = traj_obs[beg_index:end_index] else: value_mask.append(0) obs = zero_obs else: if bootstrap_index < traj_len: value_mask.append(1) beg_index = bootstrap_index - (state_index + td_steps) end_index = beg_index + self.n_stack obs = traj_obs[beg_index:end_index] else: value_mask.append(0) obs = zero_obs value_obs_lst.append(obs) td_steps_lst.append(td_steps) # reanalyze the bootstrapped value v_{t+k} state_lst, value_lst, policy_lst = self.efficient_inference(value_obs_lst, only_value=True) batch_size = len(value_lst) value_lst = value_lst.reshape(-1) * (np.array([self.discount for _ in range(batch_size)]) ** td_steps_lst) value_lst = value_lst * np.array(value_mask) # value_lst = np.zeros_like(value_lst) # for unit test, remove if training value_lst = value_lst.tolist() # v_{t} = r + ... + gamma ^ k * v_{t+k} value_index = 0 top_value_masks = [] for traj, state_index, idx in zip(traj_lst, transition_pos_lst, indices_lst): traj_len = len(traj) target_values = [] target_value_prefixs = [] horizon_id = 0 value_prefix = 0.0 top_value_masks.append(int(idx > collected_transitions - self.mixed_value_threshold)) for current_index in range(state_index, state_index + self.unroll_steps + 1): bootstrap_index = current_index + td_steps_lst[value_index] for i, reward in enumerate(traj.reward_lst[current_index:bootstrap_index]): value_lst[value_index] += reward * self.discount ** i # reset every lstm_horizon_len if horizon_id % self.lstm_horizon_len == 0 and self.value_prefix: value_prefix = 0.0 horizon_id += 1 if current_index < traj_len: # Since the horizon is small and the discount is close to 1. # Compute the reward sum to approximate the value prefix for simplification if self.value_prefix: value_prefix += traj.reward_lst[current_index] else: value_prefix = traj.reward_lst[current_index] target_value_prefixs.append(value_prefix) else: target_value_prefixs.append(value_prefix) if self.episodic: if current_index < traj_len: target_values.append(value_lst[value_index]) else: target_values.append(0) else: if current_index <= traj_len: target_values.append(value_lst[value_index]) else: target_values.append(0) value_index += 1 batch_value_prefixes.append(target_value_prefixs) batch_values.append(target_values) value_masks = np.asarray(top_value_masks) return np.asarray(batch_value_prefixes), np.asarray(batch_values), np.asarray(td_steps_lst).flatten(), \ (None, None, None, None), value_masks def prepare_policy_non_reanalyze(self, traj_lst, transition_pos_lst): # policy batch_policies = [] # load searched policy in self-play for traj, state_index in zip(traj_lst, transition_pos_lst): traj_len = len(traj) target_policies = [] for current_index in range(state_index, state_index + self.unroll_steps + 1): if current_index < traj_len: target_policies.append(traj.policy_lst[current_index]) else: target_policies.append([0 for _ in range(self.action_space_size)]) batch_policies.append(target_policies) return batch_policies def prepare_policy_reanalyze(self, trained_steps, traj_lst, transition_pos_lst, indices_lst, state_lst=None, value_lst=None, policy_lst=None, policy_mask=None): # policy reanalyzed_values = [] batch_policies = [] # init if value_lst is None: policy_obs_lst, policy_mask = [], [] # mask: 0 -> out of traj zero_obs = traj_lst[0].get_zero_obs(self.n_stack, channel_first=False) # get obs_{t} instead of obs_{t+k} for traj, state_index in zip(traj_lst, transition_pos_lst): traj_len = len(traj) game_obs = traj.get_index_stacked_obs(state_index) for current_index in range(state_index, state_index + self.unroll_steps + 1): if current_index < traj_len: policy_mask.append(1) beg_index = current_index - state_index end_index = beg_index + self.n_stack obs = game_obs[beg_index:end_index] else: policy_mask.append(0) obs = np.asarray(zero_obs) policy_obs_lst.append(obs) # reanalyze the search policy pi_{t} state_lst, value_lst, policy_lst = self.efficient_inference(policy_obs_lst, only_value=False) # tree search for policies batch_size = len(state_lst) # temperature temperature = self.agent.get_temperature(trained_steps=trained_steps) #* np.ones((batch_size, 1)) tree = mcts.names[self.config.mcts.language]( num_actions=self.action_space_size if self.env == 'Atari' else self.config.mcts.num_sampled_actions, discount=self.config.rl.discount, env=self.env, **self.config.mcts, # pass mcts related params **self.config.model, # pass the value and reward support params ) if self.env == 'Atari': if self.config.mcts.use_gumbel: r_values, r_policies, best_actions, _ = tree.search( self.model, # self.latest_model, batch_size, state_lst, value_lst, policy_lst, use_gumble_noise=True, temperature=temperature ) else: r_values, r_policies, best_actions, _ = tree.search_ori_mcts( self.model, batch_size, state_lst, value_lst, policy_lst, use_noise=True, temperature=temperature, is_reanalyze=True ) sampled_actions = best_actions search_best_indexes = best_actions else: r_values, r_policies, best_actions, sampled_actions, search_best_indexes, _ = tree.search_continuous( self.model, batch_size, state_lst, value_lst, policy_lst, temperature=temperature, ) if self.config.train.optimal_Q: r_values = self.efficient_recurrent(state_lst, policy_lst) r_values = r_values.reshape(-1) * np.array(policy_mask) r_values = r_values.tolist() # concat policy policy_index = 0 policy_masks = [] mismatch_index = [] for traj, state_index, ind in zip(traj_lst, transition_pos_lst, indices_lst): target_policies = [] search_values = [] policy_masks.append([]) for current_index in range(state_index, state_index + self.unroll_steps + 1): traj_len = len(traj) assert (current_index < traj_len) == (policy_mask[policy_index]) if policy_mask[policy_index]: target_policies.append(r_policies[policy_index]) search_values.append(r_values[policy_index]) # mask best-action & pi_prime mismatches if r_policies[policy_index].argmax() != search_best_indexes[policy_index]: policy_mask[policy_index] = 0 mismatch_index.append(ind + current_index - state_index) else: search_values.append(0.0) if self.env in ['DMC','Gym']: target_policies.append([0 for _ in range(sampled_actions.shape[1])]) else: target_policies.append([0 for _ in range(self.action_space_size)]) policy_masks[-1].append(policy_mask[policy_index]) policy_index += 1 batch_policies.append(target_policies) reanalyzed_values.append(search_values) if self.rank == 0 and self.config.eval.analysis_value: new_log_index = trained_steps // 5000 if new_log_index > self.last_log_index: self.last_log_index = new_log_index min_idx = np.asarray(indices_lst).argmin() r_value = reanalyzed_values[min_idx][0] self.storage.add_log_scalar.remote({ 'batch_worker/search_value': r_value }) policy_masks = np.asarray(policy_masks) return batch_policies, sampled_actions, best_actions, reanalyzed_values, (state_lst, value_lst, policy_lst, policy_mask), policy_masks @torch.no_grad() def imagine_episodes(self, pre_lst, traj_lst, transition_pos_lst, trained_steps, policy='search'): length = 1 times = 3 # input_obs = np.concatenate([stack_obs for _ in range(times)], axis=0) # states, values, policies = self.efficient_inference(input_obs) states, values, policies, policy_mask = pre_lst states = torch.cat([states for _ in range(times)], dim=0) values = np.concatenate([values for _ in range(times)], axis=0) policies = torch.cat([policies for _ in range(times)], dim=0) reward_hidden = (torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda(), torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda()) last_values_prefixes = np.zeros(len(states)) reward_lst = [] value_lst = [] temperature = self.agent.get_temperature(trained_steps=trained_steps) * np.ones((len(states), 1)) for i in range(length): if policy == 'search': tree = mcts.names[self.config.mcts.language]( num_actions=self.config.env.action_space_size if self.env == 'Atari' else self.config.mcts.num_top_actions, discount=self.config.rl.discount, **self.config.mcts, # pass mcts related params **self.config.model, # pass the value and reward support params ) if self.env == 'Atari': if self.config.mcts.use_gumbel: r_values, r_policies, best_actions, _ = tree.search( self.model, len(states), states, values, policies, use_gumble_noise=True, temperature=temperature ) else: r_values, r_policies, best_actions, _ = tree.search_ori_mcts( self.model, len(states), states, values, policies, use_noise=True, temperature=temperature, is_reanalyze=True ) else: r_values, r_policies, best_actions, sampled_actions, _, _ = tree.search_continuous( self.model, len(states), states, values, policies, use_gumble_noise=False, temperature=temperature) if policy == 'search': actions = torch.from_numpy(np.asarray(best_actions)).cuda().float() else: if self.env == 'Atari': actions = F.gumbel_softmax(policies, hard=True, dim=-1, tau=1e-4) actions = actions.argmax(dim=-1) else: actions = policies[:, :policies.shape[-1]//2] actions = actions.unsqueeze(1) with autocast(): states, value_prefixes, values, policies, reward_hidden = \ self.model.recurrent_inference(states, actions, reward_hidden) values = values.squeeze().detach().cpu().numpy() value_lst.append(values) if self.value_prefix and (i + 1) % self.lstm_horizon_len == 0: reward_hidden = (torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda(), torch.zeros(1, len(states), self.config.model.lstm_hidden_size).cuda()) true_rewards = value_prefixes.squeeze().detach().cpu().numpy() # last_values_prefixes = np.zeros(len(states)) else: true_rewards = value_prefixes.squeeze().detach().cpu().numpy() - last_values_prefixes last_values_prefixes = value_prefixes.squeeze().detach().cpu().numpy() reward_lst.append(true_rewards) value = 0 for i, reward in enumerate(reward_lst): value += reward * (self.config.rl.discount ** i) value += (self.config.rl.discount ** length) * value_lst[-1] value_reshaped = [] batch_size = len(states) // times for i in range(times): value_reshaped.append(value[batch_size * i:batch_size * (i+1)]) value_reshaped = np.asarray(value_reshaped).mean(0) output_values = [] policy_index = 0 for traj, state_index in zip(traj_lst, transition_pos_lst): imagined_values = [] for current_index in range(state_index, state_index + self.unroll_steps + 1): traj_len = len(traj) # assert (current_index < traj_len) == (policy_mask[policy_index]) if policy_mask[policy_index]: imagined_values.append(value_reshaped[policy_index]) else: imagined_values.append(0.0) policy_index += 1 output_values.append(imagined_values) return np.asarray(output_values) def efficient_inference(self, obs_lst, only_value=False, value_idx=0): batch_size = len(obs_lst) obs_lst = np.asarray(obs_lst) state_lst, value_lst, policy_lst = [], [], [] # split a full batch into slices of mini_infer_size mini_batch = self.config.train.mini_batch_size slices = np.ceil(batch_size / mini_batch).astype(np.int32) with torch.no_grad(): for i in range(slices): beg_index = mini_batch * i end_index = mini_batch * (i + 1) current_obs = obs_lst[beg_index:end_index] current_obs = formalize_obs_lst(current_obs, self.image_based) # obtain the statistics at current steps with autocast(): states, values, policies = self.model.initial_inference(current_obs) # process outputs values = values.detach().cpu().numpy().flatten() # concat value_lst.append(values) if not only_value: state_lst.append(states) policy_lst.append(policies) value_lst = np.concatenate(value_lst) if not only_value: state_lst = torch.cat(state_lst) policy_lst = torch.cat(policy_lst) return state_lst, value_lst, policy_lst # ====================================================================================================================== # batch worker # ====================================================================================================================== def start_batch_worker(rank, agent, replay_buffer, storage, batch_storage, config): """ Start a GPU batch worker. Call this method remotely. """ worker = BatchWorker.remote(rank, agent, replay_buffer, storage, batch_storage, config) print(f"[Batch worker GPU] Starting batch worker GPU {rank} at process {os.getpid()}.") worker.run.remote() def start_batch_worker_cpu(rank, agent, replay_buffer, storage, prebatch_storage, config): worker = BatchWorker_CPU.remote(rank, agent, replay_buffer, storage, prebatch_storage, config) print(f"[Batch worker CPU] Starting batch worker CPU {rank} at process {os.getpid()}.") worker.run.remote() def start_batch_worker_gpu(rank, agent, replay_buffer, storage, prebatch_storage, batch_storage, config): worker = BatchWorker_GPU.remote(rank, agent, replay_buffer, storage, prebatch_storage, batch_storage, config) print(f"[Batch worker GPU] Starting batch worker GPU {rank} at process {os.getpid()}.") worker.run.remote()