From a791916fc45775d2b1ea194166d2c7eca31a0014 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Wed, 15 Aug 2018 09:53:46 +0800 Subject: [PATCH] add clear() for replay_buffer --- tianshou/data/data_buffer/base.py | 2 +- tianshou/data/data_buffer/vanilla.py | 17 +++++++++++++++++ tianshou/data/test_replay_buffer.py | 6 +++++- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/tianshou/data/data_buffer/base.py b/tianshou/data/data_buffer/base.py index c60dff2..7c3aa64 100644 --- a/tianshou/data/data_buffer/base.py +++ b/tianshou/data/data_buffer/base.py @@ -26,7 +26,7 @@ class DataBufferBase(object): raise NotImplementedError() def clear(self): - """Empties the data buffer, usually used in batch set but not in replay buffer.""" + """Empties the data buffer, usually used in batch set, also supported in replay buffer.""" raise NotImplementedError() def sample(self, batch_size): diff --git a/tianshou/data/data_buffer/vanilla.py b/tianshou/data/data_buffer/vanilla.py index cb83524..fdef439 100644 --- a/tianshou/data/data_buffer/vanilla.py +++ b/tianshou/data/data_buffer/vanilla.py @@ -1,5 +1,6 @@ import logging import numpy as np +import gc from .replay_buffer_base import ReplayBufferBase @@ -129,3 +130,19 @@ class VanillaReplayBuffer(ReplayBufferBase): sampled_index[sampled_episode_i].append(sampled_frame_i) return sampled_index + + def clear(self): + """ + Empties the replay buffer and prepares to add new data. + """ + del self.data + del self.index + del self.index_lengths + + gc.collect() + + self.data = [[]] + self.index = [[]] + self.candidate_index = 0 + self.size = 0 + self.index_lengths = [0] diff --git a/tianshou/data/test_replay_buffer.py b/tianshou/data/test_replay_buffer.py index b3ffd0c..cee704f 100644 --- a/tianshou/data/test_replay_buffer.py +++ b/tianshou/data/test_replay_buffer.py @@ -22,8 +22,12 @@ for i in range(capacity): print('Now buffer with size {}:'.format(buffer.size)) print(buffer.index) print(buffer.data) +buffer.clear() +print('Cleared buffer with size {}:'.format(buffer.size)) +print(buffer.index) +print(buffer.data) -for i in range(5): +for i in range(20): s = np.random.randint(10) a = np.random.randint(3) r = np.random.randint(5)