230 lines
9.2 KiB
Python
230 lines
9.2 KiB
Python
# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
|
|
#
|
|
# This source code is licensed under the GNU License, Version 3.0
|
|
# found in the LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import time
|
|
import numpy as np
|
|
import ray
|
|
import pickle
|
|
|
|
@ray.remote
|
|
class ReplayBuffer:
|
|
def __init__(self, **kwargs):
|
|
self.batch_size = kwargs.get('batch_size')
|
|
self.buffer_size = kwargs.get('buffer_size')
|
|
self.top_transitions = kwargs.get('top_transitions')
|
|
self.use_priority = kwargs.get('use_priority')
|
|
self.env = kwargs.get('env')
|
|
self.total_transitions = kwargs.get('total_transitions')
|
|
|
|
self.base_idx = 0
|
|
self.clear_time = 0
|
|
self.buffer = []
|
|
self.priorities = []
|
|
self.snapshots = []
|
|
self.transition_idx_look_up = []
|
|
|
|
def save_pools(self, traj_pool, priorities):
|
|
# save a list of game histories
|
|
for traj in traj_pool:
|
|
if len(traj) > 0:
|
|
self.save_trajectory(traj, priorities)
|
|
|
|
def save_trajectory(self, traj, priorities):
|
|
traj_len = len(traj)
|
|
if priorities is None:
|
|
max_prio = self.priorities.max() if self.buffer else 1
|
|
self.priorities = np.concatenate((self.priorities, [max_prio for _ in range(traj_len)]))
|
|
else:
|
|
assert len(traj) == len(priorities), " priorities should be of same length as the game steps"
|
|
priorities = priorities.copy().reshape(-1)
|
|
max_prio = self.priorities.max() if self.buffer else 1
|
|
self.priorities = np.concatenate((self.priorities, [max(max_prio, priorities.max()) for i in range(traj_len)]))
|
|
|
|
for snapshot in traj.snapshot_lst:
|
|
self.snapshots.append(snapshot)
|
|
|
|
self.buffer.append(traj)
|
|
self.transition_idx_look_up += [(self.base_idx + len(self.buffer) - 1, step_pos) for step_pos in range(traj_len)]
|
|
|
|
|
|
def get_item(self, idx):
|
|
traj_idx, state_index = self.transition_idx_look_up[idx]
|
|
traj_idx -= self.base_idx
|
|
traj = self.buffer[traj_idx]
|
|
|
|
return traj, state_index
|
|
|
|
def prepare_batch_context(self, batch_size, alpha, beta, rank, cnt):
|
|
|
|
batch_context = self._prepare_batch_context(batch_size, alpha, beta)
|
|
batch_context = (batch_context, False)
|
|
|
|
return batch_context
|
|
|
|
def _prepare_batch_context_supervised(self, batch_size, alpha=None, beta=None, is_validation=False, force_uniform=False):
|
|
transition_num = self.get_transition_num()
|
|
if is_validation:
|
|
validation_set = np.arange(int(transition_num * 0.95), transition_num)
|
|
indices_lst = np.random.choice(validation_set, batch_size, replace=False)
|
|
weights_lst = (1 / batch_size) * np.ones_like(indices_lst)
|
|
else:
|
|
# sample data
|
|
if self.use_priority:
|
|
probs = self.priorities ** alpha
|
|
else:
|
|
probs = np.ones_like(self.priorities)
|
|
probs = probs[:int(0.95 * transition_num)]
|
|
probs = probs / probs.sum()
|
|
|
|
training_set = np.arange(int(transition_num * 0.95))
|
|
if force_uniform:
|
|
indices_lst = np.random.choice(training_set, batch_size, replace=False)
|
|
weights_lst = (1 / batch_size) * np.ones_like(indices_lst)
|
|
else:
|
|
indices_lst = np.random.choice(training_set, batch_size, p=probs, replace=False)
|
|
weights_lst = (transition_num * probs[indices_lst]) ** (-beta)
|
|
weights_lst = weights_lst / weights_lst.max()
|
|
|
|
traj_lst, transition_pos_lst = [], []
|
|
# obtain the
|
|
for idx in indices_lst:
|
|
traj, state_index = self.get_item(idx)
|
|
traj_lst.append(traj)
|
|
transition_pos_lst.append(state_index)
|
|
|
|
make_time_lst = [time.time() for _ in range(len(indices_lst))]
|
|
context = [self.split_trajs(traj_lst), transition_pos_lst, indices_lst, weights_lst, make_time_lst,
|
|
transition_num, self.priorities[indices_lst]]
|
|
return context
|
|
|
|
|
|
def _prepare_batch_context(self, batch_size, alpha, beta):
|
|
|
|
transition_num = self.get_transition_num()
|
|
|
|
# sample data
|
|
if self.use_priority:
|
|
probs = self.priorities ** alpha
|
|
else:
|
|
probs = np.ones_like(self.priorities)
|
|
|
|
# sample the top transitions of the current buffer
|
|
if self.env in ['DMC', 'Gym'] and len(self.priorities) > self.top_transitions:
|
|
idx = int(len(self.priorities) - self.top_transitions)
|
|
probs[:idx] = 0
|
|
self.priorities[:idx] = 0
|
|
|
|
probs = probs / probs.sum()
|
|
|
|
indices_lst = np.random.choice(transition_num, batch_size, p=probs, replace=False)
|
|
|
|
# weight
|
|
weights_lst = (transition_num * probs[indices_lst]) ** (-beta)
|
|
weights_lst = weights_lst / weights_lst.max()
|
|
weights_lst = weights_lst.clip(0.1, 1) # TODO: try weights clip, prev 0.1
|
|
|
|
traj_lst, transition_pos_lst = [], []
|
|
# obtain the
|
|
for idx in indices_lst:
|
|
traj, state_index = self.get_item(idx)
|
|
traj_lst.append(traj)
|
|
transition_pos_lst.append(state_index)
|
|
|
|
make_time_lst = [time.time() for _ in range(len(indices_lst))]
|
|
|
|
context = [self.split_trajs(traj_lst), transition_pos_lst, indices_lst, weights_lst, make_time_lst, transition_num, self.priorities[indices_lst]]
|
|
return context
|
|
|
|
def split_trajs(self, traj_lst):
|
|
obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, \
|
|
bootstrapped_value_lsts, snapshot_lsts = [], [], [], [], [], [], [], []
|
|
for traj in traj_lst:
|
|
obs_lsts.append(traj.obs_lst)
|
|
reward_lsts.append(traj.reward_lst)
|
|
policy_lsts.append(traj.policy_lst)
|
|
action_lsts.append(traj.action_lst)
|
|
pred_value_lsts.append(traj.pred_value_lst)
|
|
search_value_lsts.append(traj.search_value_lst)
|
|
bootstrapped_value_lsts.append(traj.bootstrapped_value_lst)
|
|
snapshot_lsts.append(traj.snapshot_lst)
|
|
return [obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, bootstrapped_value_lsts,
|
|
# snapshot_lsts
|
|
]
|
|
|
|
def update_root_values(self, batch_indices, search_values, transition_positions, unroll_steps):
|
|
val_idx = 0
|
|
for idx, pos in zip(batch_indices, transition_positions):
|
|
traj_idx, state_index = self.transition_idx_look_up[idx]
|
|
traj_idx -= self.base_idx
|
|
for i in range(unroll_steps + 1):
|
|
self.buffer[traj_idx].search_value_lst.setflags(write=True)
|
|
if pos + i < len(self.buffer[traj_idx].search_value_lst):
|
|
self.buffer[traj_idx].search_value_lst[pos + i] = search_values[val_idx][i]
|
|
val_idx += 1
|
|
|
|
def update_priorities(self, batch_indices, batch_priorities, make_time, mask=None):
|
|
# update the priorities for data still in replay buffer
|
|
if mask is None:
|
|
mask = np.ones(len(batch_indices))
|
|
for i in range(len(batch_indices)):
|
|
# if make_time[i] > self.clear_time:
|
|
assert make_time[i] > self.clear_time
|
|
idx, prio = batch_indices[i], batch_priorities[i]
|
|
if mask[i] == 1:
|
|
self.priorities[idx] = prio
|
|
|
|
def get_priorities(self):
|
|
return self.priorities
|
|
|
|
def get_snapshots(self, indices_lst):
|
|
selected_snapshots = []
|
|
for idx in indices_lst:
|
|
selected_snapshots.append(self.snapshots[idx])
|
|
return selected_snapshots
|
|
|
|
def get_traj_num(self):
|
|
return len(self.buffer)
|
|
|
|
def get_transition_num(self):
|
|
assert len(self.transition_idx_look_up) == len(self.priorities)
|
|
assert len(self.priorities) == len(self.snapshots)
|
|
return len(self.transition_idx_look_up)
|
|
|
|
def save_buffer(self):
|
|
path = '/workspace/EZ-Codebase/buffer/'
|
|
f_buffer = open(path + 'buffer.b', 'wb')
|
|
pickle.dump(self.buffer, f_buffer)
|
|
f_buffer.close()
|
|
f_priorities = open(path + 'priorities.b', 'wb')
|
|
pickle.dump(self.priorities, f_priorities)
|
|
f_priorities.close()
|
|
f_lookup = open(path + 'lookup.b', 'wb')
|
|
pickle.dump(self.transition_idx_look_up, f_lookup)
|
|
f_lookup.close()
|
|
f_snapshot = open(path + 'snapshots.b', 'wb')
|
|
pickle.dump(self.snapshots, f_snapshot)
|
|
f_snapshot.close()
|
|
return True
|
|
|
|
def load_buffer(self):
|
|
path = '/workspace/EZ-Codebase/buffer/'
|
|
f = open(path + 'buffer.b', 'rb')
|
|
self.buffer = pickle.load(f)
|
|
f.close()
|
|
f = open(path + 'priorities.b', 'rb')
|
|
self.priorities = pickle.load(f)
|
|
f.close()
|
|
f = open(path + 'lookup.b', 'rb')
|
|
self.transition_idx_look_up = pickle.load(f)
|
|
f.close()
|
|
f = open(path + 'snapshots.b', 'rb')
|
|
self.snapshots = pickle.load(f)
|
|
f.close()
|
|
return True
|
|
|
|
# ======================================================================================================================
|
|
# replay buffer server
|
|
# ====================================================================================================================== |