add clear() for replay_buffer
This commit is contained in:
parent
00d4cb0fca
commit
a791916fc4
@ -26,7 +26,7 @@ class DataBufferBase(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
def clear(self):
|
||||
"""Empties the data buffer, usually used in batch set but not in replay buffer."""
|
||||
"""Empties the data buffer, usually used in batch set, also supported in replay buffer."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def sample(self, batch_size):
|
||||
|
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import numpy as np
|
||||
import gc
|
||||
|
||||
from .replay_buffer_base import ReplayBufferBase
|
||||
|
||||
@ -129,3 +130,19 @@ class VanillaReplayBuffer(ReplayBufferBase):
|
||||
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||
|
||||
return sampled_index
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Empties the replay buffer and prepares to add new data.
|
||||
"""
|
||||
del self.data
|
||||
del self.index
|
||||
del self.index_lengths
|
||||
|
||||
gc.collect()
|
||||
|
||||
self.data = [[]]
|
||||
self.index = [[]]
|
||||
self.candidate_index = 0
|
||||
self.size = 0
|
||||
self.index_lengths = [0]
|
||||
|
@ -22,8 +22,12 @@ for i in range(capacity):
|
||||
print('Now buffer with size {}:'.format(buffer.size))
|
||||
print(buffer.index)
|
||||
print(buffer.data)
|
||||
buffer.clear()
|
||||
print('Cleared buffer with size {}:'.format(buffer.size))
|
||||
print(buffer.index)
|
||||
print(buffer.data)
|
||||
|
||||
for i in range(5):
|
||||
for i in range(20):
|
||||
s = np.random.randint(10)
|
||||
a = np.random.randint(3)
|
||||
r = np.random.randint(5)
|
||||
|
Loading…
x
Reference in New Issue
Block a user