add ListReplayBuffer

This commit is contained in:
Trinkle23897 2020-03-28 15:14:41 +08:00
parent eb7fb37806
commit f23b0dfac9
3 changed files with 32 additions and 8 deletions

View File

@ -1,10 +1,12 @@
from tianshou.data.batch import Batch
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.buffer import ReplayBuffer, \
ListReplayBuffer, PrioritizedReplayBuffer
from tianshou.data.collector import Collector
__all__ = [
'Batch',
'ReplayBuffer',
'ListReplayBuffer',
'PrioritizedReplayBuffer',
'Collector'
'Collector',
]

View File

@ -55,8 +55,11 @@ class ReplayBuffer(object):
self._add_to_buffer('done', done)
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info)
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
else:
self._size = self._index = self._index + 1
def reset(self):
self._index = self._size = 0
@ -90,6 +93,25 @@ class ReplayBuffer(object):
)
class ListReplayBuffer(ReplayBuffer):
"""docstring for ListReplayBuffer"""
def __init__(self):
super().__init__(size=0)
def _add_to_buffer(self, name, inst):
if inst is None:
return
if self.__dict__.get(name, None) is None:
self.__dict__[name] = []
self.__dict__[name].append(inst)
def reset(self):
self._index = self._size = 0
for k in list(self.__dict__.keys()):
if not k.startswith('_'):
self.__dict__[k] = []
class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer"""

View File

@ -1,10 +1,10 @@
import time
import torch
import numpy as np
from copy import deepcopy
import warnings
import numpy as np
from tianshou.env import BaseVectorEnv
from tianshou.data import Batch, ReplayBuffer
from tianshou.data import Batch, ReplayBuffer,\
ListReplayBuffer
from tianshou.utils import MovAvg
@ -37,7 +37,7 @@ class Collector(object):
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer):
self._cached_buf = [
deepcopy(self.buffer) for _ in range(self.env_num)]
ListReplayBuffer() for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self.reset_env()