add clear() for replay_buffer

This commit is contained in:
haoshengzou 2018-08-15 09:53:46 +08:00
parent 00d4cb0fca
commit a791916fc4
3 changed files with 23 additions and 2 deletions

View File

@ -26,7 +26,7 @@ class DataBufferBase(object):
raise NotImplementedError()
def clear(self):
"""Empties the data buffer, usually used in batch set but not in replay buffer."""
"""Empties the data buffer, usually used in batch set, also supported in replay buffer."""
raise NotImplementedError()
def sample(self, batch_size):

View File

@ -1,5 +1,6 @@
import logging
import numpy as np
import gc
from .replay_buffer_base import ReplayBufferBase
@ -129,3 +130,19 @@ class VanillaReplayBuffer(ReplayBufferBase):
sampled_index[sampled_episode_i].append(sampled_frame_i)
return sampled_index
def clear(self):
"""
Empties the replay buffer and prepares to add new data.
"""
del self.data
del self.index
del self.index_lengths
gc.collect()
self.data = [[]]
self.index = [[]]
self.candidate_index = 0
self.size = 0
self.index_lengths = [0]

View File

@ -22,8 +22,12 @@ for i in range(capacity):
print('Now buffer with size {}:'.format(buffer.size))
print(buffer.index)
print(buffer.data)
buffer.clear()
print('Cleared buffer with size {}:'.format(buffer.size))
print(buffer.index)
print(buffer.data)
for i in range(5):
for i in range(20):
s = np.random.randint(10)
a = np.random.randint(3)
r = np.random.randint(5)