EfficientZeroV2/ez/data/trajectory.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

277 lines
11 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import copy
import torch
import numpy as np
import ray
from ez.utils.format import str_to_arr
class GameTrajectory:
def __init__(self, **kwargs):
# self.raw_obs_lst = []
self.obs_lst = []
self.reward_lst = []
self.policy_lst = []
self.action_lst = []
self.pred_value_lst = []
self.search_value_lst = []
self.bootstrapped_value_lst = []
self.mc_value_lst = []
self.temp_obs = []
self.snapshot_lst = []
self.n_stack = kwargs.get('n_stack')
self.discount = kwargs.get('discount')
self.obs_to_string = kwargs.get('obs_to_string')
self.gray_scale = kwargs.get('gray_scale')
self.unroll_steps = kwargs.get('unroll_steps')
self.td_steps = kwargs.get('td_steps')
self.td_lambda = kwargs.get('td_lambda')
self.obs_shape = kwargs.get('obs_shape')
self.max_size = kwargs.get('trajectory_size')
self.image_based = kwargs.get('image_based')
self.episodic = kwargs.get('episodic')
self.GAE_max_steps = kwargs.get('GAE_max_steps')
def init(self, init_frames):
assert len(init_frames) == self.n_stack
for obs in init_frames:
self.obs_lst.append(copy.deepcopy(obs))
def append(self, action, obs, reward):
assert self.__len__() <= self.max_size
# append a transition tuple
self.action_lst.append(action)
self.obs_lst.append(obs)
self.reward_lst.append(reward)
def pad_over(self, tail_obs, tail_rewards, tail_pred_values, tail_search_values, tail_policies):
"""To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next history block
, which is necessary for the bootstrapped values at the end states of this history block.
Eg: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105,
but r_101, r_102, ... are from the next history block.
Parameters
----------
tail_obs: list
tail o_t from the next trajectory block
tail_rewards: list
tail r_t from the next trajectory block
tail_pred_values: list
tail r_t from the next trajectory block (predicted by network)
tail_search_values: list
tail v_t from the next trajectory block (search by mcts)
tail_policies: list
tail pi_t from the next trajectory block
"""
assert len(tail_obs) <= self.unroll_steps
assert len(tail_policies) <= self.unroll_steps
# assert len(tail_search_values) <= self.unroll_steps + self.td_steps
# assert len(tail_rewards) <= self.unroll_steps + self.td_steps - 1
# notice: next block observation should start from (stacked_observation - 1) in next trajectory
for obs in tail_obs:
self.obs_lst.append(copy.deepcopy(obs))
for reward in tail_rewards:
self.reward_lst.append(reward)
for pred_value in tail_pred_values:
self.pred_value_lst.append(pred_value)
for search_value in tail_search_values:
self.search_value_lst.append(search_value)
for policy in tail_policies:
self.policy_lst.append(policy)
# calculate bootstrapped value
self.bootstrapped_value_lst = self.get_bootstrapped_value(value_type='prediction')
# calculate GAE value
def save_to_memory(self):
"""
post processing the data when a history block is full
"""
# convert to numpy
self.obs_lst = ray.put(np.array(self.obs_lst))
# self.obs_lst = np.array(self.obs_lst)
self.reward_lst = np.array(self.reward_lst)
self.policy_lst = np.array(self.policy_lst)
self.action_lst = np.array(self.action_lst)
self.pred_value_lst = np.array(self.pred_value_lst)
self.search_value_lst = np.array(self.search_value_lst)
self.bootstrapped_value_lst = np.array(self.bootstrapped_value_lst)
def make_target(self, index):
assert index < self.__len__()
target_obs = self.get_index_stacked_obs(index)
target_reward = self.reward_lst[index:index+self.unroll_steps+1]
target_pred_value = self.pred_value_lst[index:index+self.unroll_steps+1]
target_search_value = self.search_value_lst[index:index+self.unroll_steps+1]
target_bt_value = self.bootstrapped_value_lst[index:index+self.unroll_steps+1]
target_policy = self.policy_lst[index:index+self.unroll_steps+1]
assert len(target_reward) == len(target_pred_value) == len(target_search_value) == len(target_policy)
return np.array(target_obs), np.array(target_reward), np.array(target_pred_value), np.array(target_search_value), np.array(target_bt_value), np.array(target_policy)
def store_search_results(self, pred_value, search_value, policy, idx: int = None):
# store the visit count distributions and value of the root node after MCTS
if idx is None:
self.pred_value_lst.append(pred_value)
self.search_value_lst.append(search_value)
self.policy_lst.append(policy)
else:
self.pred_value_lst[idx] = pred_value
self.search_value_lst[idx] = search_value
self.policy_lst[idx] = policy
def get_gae_value(self, value_type='prediction', value_lst_input=None, index=None, collected_transitions=None):
td_steps = 1
assert value_type in ['prediction', 'search']
if value_type == 'prediction':
value_lst = self.pred_value_lst
else:
value_lst = self.search_value_lst
if value_lst_input is not None:
value_lst = value_lst_input
traj_len = self.__len__()
assert traj_len
if index is None:
td_lambda = self.td_lambda
else:
delta_lambda = 0.1 * (
collected_transitions - index) / 3e4
td_lambda = self.td_lambda - delta_lambda
td_lambda = np.clip(td_lambda, 0.65, self.td_lambda)
lens = self.GAE_max_steps
gae_values = np.zeros(traj_len)
for i in range(traj_len):
delta = np.zeros(lens)
advantage = np.zeros(lens + 1)
cnt = copy.deepcopy(lens) - 1
for idx in reversed(range(i, i + lens)):
if idx + td_steps < traj_len:
bt_value = (self.discount ** td_steps) * value_lst[idx + td_steps]
for n in range(td_steps):
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
else:
steps = idx + td_steps - len(value_lst) + 1
bt_value = 0 if self.episodic else (self.discount ** (td_steps - steps)) * value_lst[-1]
for n in range(td_steps - steps):
assert idx + n < len(self.reward_lst)
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
try:
delta[cnt] = bt_value - value_lst[idx]
except:
delta[cnt] = 0
advantage[cnt] = delta[cnt] + self.discount * td_lambda * advantage[cnt + 1]
cnt -= 1
value_lst_tmp = np.concatenate((np.asarray(value_lst)[i:i + lens], np.zeros(lens-np.asarray(value_lst)[i:i + lens].shape[0])))
gae_values_tmp = advantage[:lens] + value_lst_tmp
gae_values[i] = gae_values_tmp[0]
return gae_values
def get_bootstrapped_value(self, value_type='prediction', value_lst_input=None, index=None, collected_transitions=None):
assert value_type in ['prediction', 'search']
if value_type == 'prediction':
value_lst = self.pred_value_lst
else:
value_lst = self.search_value_lst
if value_lst_input is not None:
value_lst = value_lst_input
bt_values = []
traj_len = self.__len__()
assert traj_len
if index is None:
td_steps = self.td_steps
else:
delta_td = (collected_transitions - index) // 3e4
td_steps = self.td_steps - delta_td
td_steps = np.clip(td_steps, 1, self.td_steps).astype(np.int32)
for idx in range(traj_len):
if idx + td_steps < len(value_lst):
bt_value = (self.discount ** td_steps) * value_lst[idx + td_steps]
for n in range(td_steps):
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
else:
steps = idx + td_steps - len(value_lst) + 1
bt_value = 0 if self.episodic else (self.discount ** (td_steps - steps)) * value_lst[-1]
for n in range(td_steps - steps):
assert idx + n < len(self.reward_lst)
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
bt_values.append(bt_value)
return bt_values
def get_zero_obs(self, n_stack, channel_first=True):
if self.image_based:
if channel_first:
return [np.ones(self.obs_shape, dtype=np.uint8) for _ in
range(n_stack)]
else:
return [np.ones((self.obs_shape[1], self.obs_shape[2], self.obs_shape[0]), dtype=np.uint8)
for _ in range(n_stack)]
else:
return [np.ones(self.obs_shape, dtype=np.float32) for _ in range(n_stack)]
def get_current_stacked_obs(self):
# return the current stacked observation of correct format for model inference
index = len(self.reward_lst)
frames = ray.get(self.obs_lst)[index:index + self.n_stack]
# frames = self.obs_lst[index:index + self.n_stack]
if self.obs_to_string:
frames = [str_to_arr(obs, self.gray_scale) for obs in frames]
return frames
def get_index_stacked_obs(self, index, padding=False, extra=0):
"""To obtain an observation of correct format: o[t, t + stack frames + extra len]
Parameters
----------
i: int
time step i
padding: bool
True -> padding frames if (t + stack frames) are out of trajectory
"""
unroll_steps = self.unroll_steps + extra
frames = ray.get(self.obs_lst)[index:index + self.n_stack + unroll_steps]
# frames = self.obs_lst[index:index + self.n_stack + unroll_steps]
if padding:
pad_len = self.n_stack + unroll_steps - len(frames)
if pad_len > 0:
pad_frames = [frames[-1] for _ in range(pad_len)]
frames = np.concatenate((frames, pad_frames))
if self.obs_to_string:
frames = [str_to_arr(obs, self.gray_scale) for obs in frames]
return frames
def set_inf_len(self):
self.max_size = 100000000
def is_full(self):
# history block is full
return self.__len__() >= self.max_size
def __len__(self):
# if self.length is not None:
# return self.length
return len(self.action_lst)