add clear() for replay_buffer
This commit is contained in:
parent
00d4cb0fca
commit
a791916fc4
@ -26,7 +26,7 @@ class DataBufferBase(object):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""Empties the data buffer, usually used in batch set but not in replay buffer."""
|
"""Empties the data buffer, usually used in batch set, also supported in replay buffer."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def sample(self, batch_size):
|
def sample(self, batch_size):
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import gc
|
||||||
|
|
||||||
from .replay_buffer_base import ReplayBufferBase
|
from .replay_buffer_base import ReplayBufferBase
|
||||||
|
|
||||||
@ -129,3 +130,19 @@ class VanillaReplayBuffer(ReplayBufferBase):
|
|||||||
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
sampled_index[sampled_episode_i].append(sampled_frame_i)
|
||||||
|
|
||||||
return sampled_index
|
return sampled_index
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""
|
||||||
|
Empties the replay buffer and prepares to add new data.
|
||||||
|
"""
|
||||||
|
del self.data
|
||||||
|
del self.index
|
||||||
|
del self.index_lengths
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.data = [[]]
|
||||||
|
self.index = [[]]
|
||||||
|
self.candidate_index = 0
|
||||||
|
self.size = 0
|
||||||
|
self.index_lengths = [0]
|
||||||
|
@ -22,8 +22,12 @@ for i in range(capacity):
|
|||||||
print('Now buffer with size {}:'.format(buffer.size))
|
print('Now buffer with size {}:'.format(buffer.size))
|
||||||
print(buffer.index)
|
print(buffer.index)
|
||||||
print(buffer.data)
|
print(buffer.data)
|
||||||
|
buffer.clear()
|
||||||
|
print('Cleared buffer with size {}:'.format(buffer.size))
|
||||||
|
print(buffer.index)
|
||||||
|
print(buffer.data)
|
||||||
|
|
||||||
for i in range(5):
|
for i in range(20):
|
||||||
s = np.random.randint(10)
|
s = np.random.randint(10)
|
||||||
a = np.random.randint(3)
|
a = np.random.randint(3)
|
||||||
r = np.random.randint(5)
|
r = np.random.randint(5)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user