2018-03-03 20:42:34 +08:00
|
|
|
|
|
|
|
|
2018-03-09 15:07:14 +08:00
|
|
|
class DataBufferBase(object):
|
2018-03-03 20:42:34 +08:00
|
|
|
"""
|
2018-03-09 15:07:14 +08:00
|
|
|
base class for data buffer, including replay buffer as in DQN and batched dataset as in on-policy algos
|
2018-03-03 20:42:34 +08:00
|
|
|
"""
|
|
|
|
def add(self, frame):
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
2018-03-09 15:07:14 +08:00
|
|
|
def clear(self):
|
2018-03-03 20:42:34 +08:00
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
def sample(self, batch_size):
|
2018-03-10 17:30:11 +08:00
|
|
|
prob_episode = np.array(self.index_lengths) * 1. / self.size
|
|
|
|
num_episodes = len(self.index)
|
|
|
|
sampled_index = [[] for _ in range(num_episodes)]
|
|
|
|
|
|
|
|
for _ in range(batch_size):
|
|
|
|
# sample which episode
|
|
|
|
sampled_episode_i = int(np.random.choice(num_episodes, p=prob_episode))
|
|
|
|
|
|
|
|
# sample which data point within the sampled episode
|
|
|
|
sampled_frame_i = int(np.random.randint(self.index_lengths[sampled_episode_i]))
|
|
|
|
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
|
|
|
|
|
|
|
return sampled_index
|