EfficientZeroV2/ez/data/replay_buffer.py
“Shengjiewang-Jason” 1367bca203 first commit
2024-06-07 16:02:01 +08:00

230 lines
9.2 KiB
Python

# Copyright (c) EVAR Lab, IIIS, Tsinghua University.
#
# This source code is licensed under the GNU License, Version 3.0
# found in the LICENSE file in the root directory of this source tree.
import os
import time
import numpy as np
import ray
import pickle
@ray.remote
class ReplayBuffer:
def __init__(self, **kwargs):
self.batch_size = kwargs.get('batch_size')
self.buffer_size = kwargs.get('buffer_size')
self.top_transitions = kwargs.get('top_transitions')
self.use_priority = kwargs.get('use_priority')
self.env = kwargs.get('env')
self.total_transitions = kwargs.get('total_transitions')
self.base_idx = 0
self.clear_time = 0
self.buffer = []
self.priorities = []
self.snapshots = []
self.transition_idx_look_up = []
def save_pools(self, traj_pool, priorities):
# save a list of game histories
for traj in traj_pool:
if len(traj) > 0:
self.save_trajectory(traj, priorities)
def save_trajectory(self, traj, priorities):
traj_len = len(traj)
if priorities is None:
max_prio = self.priorities.max() if self.buffer else 1
self.priorities = np.concatenate((self.priorities, [max_prio for _ in range(traj_len)]))
else:
assert len(traj) == len(priorities), " priorities should be of same length as the game steps"
priorities = priorities.copy().reshape(-1)
max_prio = self.priorities.max() if self.buffer else 1
self.priorities = np.concatenate((self.priorities, [max(max_prio, priorities.max()) for i in range(traj_len)]))
for snapshot in traj.snapshot_lst:
self.snapshots.append(snapshot)
self.buffer.append(traj)
self.transition_idx_look_up += [(self.base_idx + len(self.buffer) - 1, step_pos) for step_pos in range(traj_len)]
def get_item(self, idx):
traj_idx, state_index = self.transition_idx_look_up[idx]
traj_idx -= self.base_idx
traj = self.buffer[traj_idx]
return traj, state_index
def prepare_batch_context(self, batch_size, alpha, beta, rank, cnt):
batch_context = self._prepare_batch_context(batch_size, alpha, beta)
batch_context = (batch_context, False)
return batch_context
def _prepare_batch_context_supervised(self, batch_size, alpha=None, beta=None, is_validation=False, force_uniform=False):
transition_num = self.get_transition_num()
if is_validation:
validation_set = np.arange(int(transition_num * 0.95), transition_num)
indices_lst = np.random.choice(validation_set, batch_size, replace=False)
weights_lst = (1 / batch_size) * np.ones_like(indices_lst)
else:
# sample data
if self.use_priority:
probs = self.priorities ** alpha
else:
probs = np.ones_like(self.priorities)
probs = probs[:int(0.95 * transition_num)]
probs = probs / probs.sum()
training_set = np.arange(int(transition_num * 0.95))
if force_uniform:
indices_lst = np.random.choice(training_set, batch_size, replace=False)
weights_lst = (1 / batch_size) * np.ones_like(indices_lst)
else:
indices_lst = np.random.choice(training_set, batch_size, p=probs, replace=False)
weights_lst = (transition_num * probs[indices_lst]) ** (-beta)
weights_lst = weights_lst / weights_lst.max()
traj_lst, transition_pos_lst = [], []
# obtain the
for idx in indices_lst:
traj, state_index = self.get_item(idx)
traj_lst.append(traj)
transition_pos_lst.append(state_index)
make_time_lst = [time.time() for _ in range(len(indices_lst))]
context = [self.split_trajs(traj_lst), transition_pos_lst, indices_lst, weights_lst, make_time_lst,
transition_num, self.priorities[indices_lst]]
return context
def _prepare_batch_context(self, batch_size, alpha, beta):
transition_num = self.get_transition_num()
# sample data
if self.use_priority:
probs = self.priorities ** alpha
else:
probs = np.ones_like(self.priorities)
# sample the top transitions of the current buffer
if self.env in ['DMC', 'Gym'] and len(self.priorities) > self.top_transitions:
idx = int(len(self.priorities) - self.top_transitions)
probs[:idx] = 0
self.priorities[:idx] = 0
probs = probs / probs.sum()
indices_lst = np.random.choice(transition_num, batch_size, p=probs, replace=False)
# weight
weights_lst = (transition_num * probs[indices_lst]) ** (-beta)
weights_lst = weights_lst / weights_lst.max()
weights_lst = weights_lst.clip(0.1, 1) # TODO: try weights clip, prev 0.1
traj_lst, transition_pos_lst = [], []
# obtain the
for idx in indices_lst:
traj, state_index = self.get_item(idx)
traj_lst.append(traj)
transition_pos_lst.append(state_index)
make_time_lst = [time.time() for _ in range(len(indices_lst))]
context = [self.split_trajs(traj_lst), transition_pos_lst, indices_lst, weights_lst, make_time_lst, transition_num, self.priorities[indices_lst]]
return context
def split_trajs(self, traj_lst):
obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, \
bootstrapped_value_lsts, snapshot_lsts = [], [], [], [], [], [], [], []
for traj in traj_lst:
obs_lsts.append(traj.obs_lst)
reward_lsts.append(traj.reward_lst)
policy_lsts.append(traj.policy_lst)
action_lsts.append(traj.action_lst)
pred_value_lsts.append(traj.pred_value_lst)
search_value_lsts.append(traj.search_value_lst)
bootstrapped_value_lsts.append(traj.bootstrapped_value_lst)
snapshot_lsts.append(traj.snapshot_lst)
return [obs_lsts, reward_lsts, policy_lsts, action_lsts, pred_value_lsts, search_value_lsts, bootstrapped_value_lsts,
# snapshot_lsts
]
def update_root_values(self, batch_indices, search_values, transition_positions, unroll_steps):
val_idx = 0
for idx, pos in zip(batch_indices, transition_positions):
traj_idx, state_index = self.transition_idx_look_up[idx]
traj_idx -= self.base_idx
for i in range(unroll_steps + 1):
self.buffer[traj_idx].search_value_lst.setflags(write=True)
if pos + i < len(self.buffer[traj_idx].search_value_lst):
self.buffer[traj_idx].search_value_lst[pos + i] = search_values[val_idx][i]
val_idx += 1
def update_priorities(self, batch_indices, batch_priorities, make_time, mask=None):
# update the priorities for data still in replay buffer
if mask is None:
mask = np.ones(len(batch_indices))
for i in range(len(batch_indices)):
# if make_time[i] > self.clear_time:
assert make_time[i] > self.clear_time
idx, prio = batch_indices[i], batch_priorities[i]
if mask[i] == 1:
self.priorities[idx] = prio
def get_priorities(self):
return self.priorities
def get_snapshots(self, indices_lst):
selected_snapshots = []
for idx in indices_lst:
selected_snapshots.append(self.snapshots[idx])
return selected_snapshots
def get_traj_num(self):
return len(self.buffer)
def get_transition_num(self):
assert len(self.transition_idx_look_up) == len(self.priorities)
assert len(self.priorities) == len(self.snapshots)
return len(self.transition_idx_look_up)
def save_buffer(self):
path = '/workspace/EZ-Codebase/buffer/'
f_buffer = open(path + 'buffer.b', 'wb')
pickle.dump(self.buffer, f_buffer)
f_buffer.close()
f_priorities = open(path + 'priorities.b', 'wb')
pickle.dump(self.priorities, f_priorities)
f_priorities.close()
f_lookup = open(path + 'lookup.b', 'wb')
pickle.dump(self.transition_idx_look_up, f_lookup)
f_lookup.close()
f_snapshot = open(path + 'snapshots.b', 'wb')
pickle.dump(self.snapshots, f_snapshot)
f_snapshot.close()
return True
def load_buffer(self):
path = '/workspace/EZ-Codebase/buffer/'
f = open(path + 'buffer.b', 'rb')
self.buffer = pickle.load(f)
f.close()
f = open(path + 'priorities.b', 'rb')
self.priorities = pickle.load(f)
f.close()
f = open(path + 'lookup.b', 'rb')
self.transition_idx_look_up = pickle.load(f)
f.close()
f = open(path + 'snapshots.b', 'rb')
self.snapshots = pickle.load(f)
f.close()
return True
# ======================================================================================================================
# replay buffer server
# ======================================================================================================================