277 lines
11 KiB
Python
277 lines
11 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import copy
|
|
import torch
|
|
import numpy as np
|
|
import ray
|
|
|
|
from ez.utils.format import str_to_arr
|
|
|
|
|
|
class GameTrajectory:
|
|
def __init__(self, **kwargs):
|
|
# self.raw_obs_lst = []
|
|
self.obs_lst = []
|
|
self.reward_lst = []
|
|
self.policy_lst = []
|
|
self.action_lst = []
|
|
self.pred_value_lst = []
|
|
self.search_value_lst = []
|
|
self.bootstrapped_value_lst = []
|
|
self.mc_value_lst = []
|
|
self.temp_obs = []
|
|
self.snapshot_lst = []
|
|
|
|
self.n_stack = kwargs.get('n_stack')
|
|
self.discount = kwargs.get('discount')
|
|
self.obs_to_string = kwargs.get('obs_to_string')
|
|
self.gray_scale = kwargs.get('gray_scale')
|
|
self.unroll_steps = kwargs.get('unroll_steps')
|
|
self.td_steps = kwargs.get('td_steps')
|
|
self.td_lambda = kwargs.get('td_lambda')
|
|
self.obs_shape = kwargs.get('obs_shape')
|
|
self.max_size = kwargs.get('trajectory_size')
|
|
self.image_based = kwargs.get('image_based')
|
|
self.episodic = kwargs.get('episodic')
|
|
self.GAE_max_steps = kwargs.get('GAE_max_steps')
|
|
|
|
def init(self, init_frames):
|
|
assert len(init_frames) == self.n_stack
|
|
|
|
for obs in init_frames:
|
|
self.obs_lst.append(copy.deepcopy(obs))
|
|
|
|
def append(self, action, obs, reward):
|
|
assert self.__len__() <= self.max_size
|
|
|
|
# append a transition tuple
|
|
self.action_lst.append(action)
|
|
self.obs_lst.append(obs)
|
|
self.reward_lst.append(reward)
|
|
|
|
def pad_over(self, tail_obs, tail_rewards, tail_pred_values, tail_search_values, tail_policies):
|
|
"""To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next history block
|
|
, which is necessary for the bootstrapped values at the end states of this history block.
|
|
Eg: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105,
|
|
but r_101, r_102, ... are from the next history block.
|
|
Parameters
|
|
----------
|
|
tail_obs: list
|
|
tail o_t from the next trajectory block
|
|
tail_rewards: list
|
|
tail r_t from the next trajectory block
|
|
tail_pred_values: list
|
|
tail r_t from the next trajectory block (predicted by network)
|
|
tail_search_values: list
|
|
tail v_t from the next trajectory block (search by mcts)
|
|
tail_policies: list
|
|
tail pi_t from the next trajectory block
|
|
"""
|
|
assert len(tail_obs) <= self.unroll_steps
|
|
assert len(tail_policies) <= self.unroll_steps
|
|
# assert len(tail_search_values) <= self.unroll_steps + self.td_steps
|
|
# assert len(tail_rewards) <= self.unroll_steps + self.td_steps - 1
|
|
|
|
# notice: next block observation should start from (stacked_observation - 1) in next trajectory
|
|
for obs in tail_obs:
|
|
self.obs_lst.append(copy.deepcopy(obs))
|
|
|
|
for reward in tail_rewards:
|
|
self.reward_lst.append(reward)
|
|
|
|
for pred_value in tail_pred_values:
|
|
self.pred_value_lst.append(pred_value)
|
|
|
|
for search_value in tail_search_values:
|
|
self.search_value_lst.append(search_value)
|
|
|
|
for policy in tail_policies:
|
|
self.policy_lst.append(policy)
|
|
|
|
# calculate bootstrapped value
|
|
self.bootstrapped_value_lst = self.get_bootstrapped_value(value_type='prediction')
|
|
# calculate GAE value
|
|
|
|
def save_to_memory(self):
|
|
"""
|
|
post processing the data when a history block is full
|
|
"""
|
|
# convert to numpy
|
|
self.obs_lst = ray.put(np.array(self.obs_lst))
|
|
# self.obs_lst = np.array(self.obs_lst)
|
|
self.reward_lst = np.array(self.reward_lst)
|
|
self.policy_lst = np.array(self.policy_lst)
|
|
self.action_lst = np.array(self.action_lst)
|
|
self.pred_value_lst = np.array(self.pred_value_lst)
|
|
self.search_value_lst = np.array(self.search_value_lst)
|
|
self.bootstrapped_value_lst = np.array(self.bootstrapped_value_lst)
|
|
|
|
def make_target(self, index):
|
|
assert index < self.__len__()
|
|
|
|
target_obs = self.get_index_stacked_obs(index)
|
|
target_reward = self.reward_lst[index:index+self.unroll_steps+1]
|
|
target_pred_value = self.pred_value_lst[index:index+self.unroll_steps+1]
|
|
target_search_value = self.search_value_lst[index:index+self.unroll_steps+1]
|
|
target_bt_value = self.bootstrapped_value_lst[index:index+self.unroll_steps+1]
|
|
target_policy = self.policy_lst[index:index+self.unroll_steps+1]
|
|
|
|
assert len(target_reward) == len(target_pred_value) == len(target_search_value) == len(target_policy)
|
|
return np.array(target_obs), np.array(target_reward), np.array(target_pred_value), np.array(target_search_value), np.array(target_bt_value), np.array(target_policy)
|
|
|
|
def store_search_results(self, pred_value, search_value, policy, idx: int = None):
|
|
# store the visit count distributions and value of the root node after MCTS
|
|
if idx is None:
|
|
self.pred_value_lst.append(pred_value)
|
|
self.search_value_lst.append(search_value)
|
|
self.policy_lst.append(policy)
|
|
else:
|
|
self.pred_value_lst[idx] = pred_value
|
|
self.search_value_lst[idx] = search_value
|
|
self.policy_lst[idx] = policy
|
|
|
|
def get_gae_value(self, value_type='prediction', value_lst_input=None, index=None, collected_transitions=None):
|
|
td_steps = 1
|
|
assert value_type in ['prediction', 'search']
|
|
if value_type == 'prediction':
|
|
value_lst = self.pred_value_lst
|
|
else:
|
|
value_lst = self.search_value_lst
|
|
|
|
if value_lst_input is not None:
|
|
value_lst = value_lst_input
|
|
|
|
traj_len = self.__len__()
|
|
assert traj_len
|
|
|
|
if index is None:
|
|
td_lambda = self.td_lambda
|
|
else:
|
|
delta_lambda = 0.1 * (
|
|
collected_transitions - index) / 3e4
|
|
td_lambda = self.td_lambda - delta_lambda
|
|
td_lambda = np.clip(td_lambda, 0.65, self.td_lambda)
|
|
|
|
lens = self.GAE_max_steps
|
|
gae_values = np.zeros(traj_len)
|
|
for i in range(traj_len):
|
|
delta = np.zeros(lens)
|
|
advantage = np.zeros(lens + 1)
|
|
cnt = copy.deepcopy(lens) - 1
|
|
for idx in reversed(range(i, i + lens)):
|
|
if idx + td_steps < traj_len:
|
|
bt_value = (self.discount ** td_steps) * value_lst[idx + td_steps]
|
|
for n in range(td_steps):
|
|
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
|
|
else:
|
|
steps = idx + td_steps - len(value_lst) + 1
|
|
bt_value = 0 if self.episodic else (self.discount ** (td_steps - steps)) * value_lst[-1]
|
|
for n in range(td_steps - steps):
|
|
assert idx + n < len(self.reward_lst)
|
|
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
|
|
|
|
try:
|
|
delta[cnt] = bt_value - value_lst[idx]
|
|
except:
|
|
delta[cnt] = 0
|
|
advantage[cnt] = delta[cnt] + self.discount * td_lambda * advantage[cnt + 1]
|
|
cnt -= 1
|
|
|
|
value_lst_tmp = np.concatenate((np.asarray(value_lst)[i:i + lens], np.zeros(lens-np.asarray(value_lst)[i:i + lens].shape[0])))
|
|
gae_values_tmp = advantage[:lens] + value_lst_tmp
|
|
gae_values[i] = gae_values_tmp[0]
|
|
|
|
return gae_values
|
|
|
|
def get_bootstrapped_value(self, value_type='prediction', value_lst_input=None, index=None, collected_transitions=None):
|
|
assert value_type in ['prediction', 'search']
|
|
if value_type == 'prediction':
|
|
value_lst = self.pred_value_lst
|
|
else:
|
|
value_lst = self.search_value_lst
|
|
|
|
if value_lst_input is not None:
|
|
value_lst = value_lst_input
|
|
|
|
bt_values = []
|
|
traj_len = self.__len__()
|
|
assert traj_len
|
|
|
|
if index is None:
|
|
td_steps = self.td_steps
|
|
else:
|
|
delta_td = (collected_transitions - index) // 3e4
|
|
td_steps = self.td_steps - delta_td
|
|
td_steps = np.clip(td_steps, 1, self.td_steps).astype(np.int32)
|
|
|
|
for idx in range(traj_len):
|
|
if idx + td_steps < len(value_lst):
|
|
bt_value = (self.discount ** td_steps) * value_lst[idx + td_steps]
|
|
for n in range(td_steps):
|
|
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
|
|
else:
|
|
steps = idx + td_steps - len(value_lst) + 1
|
|
bt_value = 0 if self.episodic else (self.discount ** (td_steps - steps)) * value_lst[-1]
|
|
for n in range(td_steps - steps):
|
|
assert idx + n < len(self.reward_lst)
|
|
bt_value += (self.discount ** n) * self.reward_lst[idx + n]
|
|
|
|
bt_values.append(bt_value)
|
|
return bt_values
|
|
|
|
def get_zero_obs(self, n_stack, channel_first=True):
|
|
|
|
if self.image_based:
|
|
if channel_first:
|
|
return [np.ones(self.obs_shape, dtype=np.uint8) for _ in
|
|
range(n_stack)]
|
|
else:
|
|
return [np.ones((self.obs_shape[1], self.obs_shape[2], self.obs_shape[0]), dtype=np.uint8)
|
|
for _ in range(n_stack)]
|
|
else:
|
|
return [np.ones(self.obs_shape, dtype=np.float32) for _ in range(n_stack)]
|
|
|
|
def get_current_stacked_obs(self):
|
|
# return the current stacked observation of correct format for model inference
|
|
index = len(self.reward_lst)
|
|
frames = ray.get(self.obs_lst)[index:index + self.n_stack]
|
|
# frames = self.obs_lst[index:index + self.n_stack]
|
|
if self.obs_to_string:
|
|
frames = [str_to_arr(obs, self.gray_scale) for obs in frames]
|
|
return frames
|
|
|
|
def get_index_stacked_obs(self, index, padding=False, extra=0):
|
|
"""To obtain an observation of correct format: o[t, t + stack frames + extra len]
|
|
Parameters
|
|
----------
|
|
i: int
|
|
time step i
|
|
padding: bool
|
|
True -> padding frames if (t + stack frames) are out of trajectory
|
|
"""
|
|
unroll_steps = self.unroll_steps + extra
|
|
frames = ray.get(self.obs_lst)[index:index + self.n_stack + unroll_steps]
|
|
# frames = self.obs_lst[index:index + self.n_stack + unroll_steps]
|
|
if padding:
|
|
pad_len = self.n_stack + unroll_steps - len(frames)
|
|
if pad_len > 0:
|
|
pad_frames = [frames[-1] for _ in range(pad_len)]
|
|
frames = np.concatenate((frames, pad_frames))
|
|
if self.obs_to_string:
|
|
frames = [str_to_arr(obs, self.gray_scale) for obs in frames]
|
|
return frames
|
|
|
|
def set_inf_len(self):
|
|
self.max_size = 100000000
|
|
|
|
def is_full(self):
|
|
# history block is full
|
|
return self.__len__() >= self.max_size
|
|
|
|
def __len__(self):
|
|
# if self.length is not None:
|
|
# return self.length
|
|
return len(self.action_lst) |