add ListReplayBuffer

This commit is contained in:
Trinkle23897 2020-03-28 15:14:41 +08:00
parent eb7fb37806
commit f23b0dfac9
3 changed files with 32 additions and 8 deletions

View File

@ -1,10 +1,12 @@
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector from tianshou.data.collector import Collector
__all__ = [ __all__ = [
'Batch', 'Batch',
'ReplayBuffer', 'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer', 'PrioritizedReplayBuffer',
'Collector' 'Collector',
] ]

View File

@ -55,8 +55,11 @@ class ReplayBuffer(object):
self._add_to_buffer('done', done) self._add_to_buffer('done', done)
self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info) self._add_to_buffer('info', info)
self._size = min(self._size + 1, self._maxsize) if self._maxsize > 0:
self._index = (self._index + 1) % self._maxsize self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
else:
self._size = self._index = self._index + 1
def reset(self): def reset(self):
self._index = self._size = 0 self._index = self._size = 0
@ -90,6 +93,25 @@ class ReplayBuffer(object):
) )
class ListReplayBuffer(ReplayBuffer):
"""docstring for ListReplayBuffer"""
def __init__(self):
super().__init__(size=0)
def _add_to_buffer(self, name, inst):
if inst is None:
return
if self.__dict__.get(name, None) is None:
self.__dict__[name] = []
self.__dict__[name].append(inst)
def reset(self):
self._index = self._size = 0
for k in list(self.__dict__.keys()):
if not k.startswith('_'):
self.__dict__[k] = []
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer""" """docstring for PrioritizedReplayBuffer"""

View File

@ -1,10 +1,10 @@
import time import time
import torch import torch
import numpy as np
from copy import deepcopy
import warnings import warnings
import numpy as np
from tianshou.env import BaseVectorEnv from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer,\
ListReplayBuffer
from tianshou.utils import MovAvg from tianshou.utils import MovAvg
@ -37,7 +37,7 @@ class Collector(object):
self._multi_buf = True self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer): elif isinstance(self.buffer, ReplayBuffer):
self._cached_buf = [ self._cached_buf = [
deepcopy(self.buffer) for _ in range(self.env_num)] ListReplayBuffer() for _ in range(self.env_num)]
else: else:
raise TypeError('The buffer in data collector is invalid!') raise TypeError('The buffer in data collector is invalid!')
self.reset_env() self.reset_env()