add ListReplayBuffer
This commit is contained in:
parent
eb7fb37806
commit
f23b0dfac9
@ -1,10 +1,12 @@
|
||||
from tianshou.data.batch import Batch
|
||||
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.buffer import ReplayBuffer, \
|
||||
ListReplayBuffer, PrioritizedReplayBuffer
|
||||
from tianshou.data.collector import Collector
|
||||
|
||||
__all__ = [
|
||||
'Batch',
|
||||
'ReplayBuffer',
|
||||
'ListReplayBuffer',
|
||||
'PrioritizedReplayBuffer',
|
||||
'Collector'
|
||||
'Collector',
|
||||
]
|
||||
|
@ -55,8 +55,11 @@ class ReplayBuffer(object):
|
||||
self._add_to_buffer('done', done)
|
||||
self._add_to_buffer('obs_next', obs_next)
|
||||
self._add_to_buffer('info', info)
|
||||
self._size = min(self._size + 1, self._maxsize)
|
||||
self._index = (self._index + 1) % self._maxsize
|
||||
if self._maxsize > 0:
|
||||
self._size = min(self._size + 1, self._maxsize)
|
||||
self._index = (self._index + 1) % self._maxsize
|
||||
else:
|
||||
self._size = self._index = self._index + 1
|
||||
|
||||
def reset(self):
|
||||
self._index = self._size = 0
|
||||
@ -90,6 +93,25 @@ class ReplayBuffer(object):
|
||||
)
|
||||
|
||||
|
||||
class ListReplayBuffer(ReplayBuffer):
|
||||
"""docstring for ListReplayBuffer"""
|
||||
def __init__(self):
|
||||
super().__init__(size=0)
|
||||
|
||||
def _add_to_buffer(self, name, inst):
|
||||
if inst is None:
|
||||
return
|
||||
if self.__dict__.get(name, None) is None:
|
||||
self.__dict__[name] = []
|
||||
self.__dict__[name].append(inst)
|
||||
|
||||
def reset(self):
|
||||
self._index = self._size = 0
|
||||
for k in list(self.__dict__.keys()):
|
||||
if not k.startswith('_'):
|
||||
self.__dict__[k] = []
|
||||
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""docstring for PrioritizedReplayBuffer"""
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import warnings
|
||||
import numpy as np
|
||||
from tianshou.env import BaseVectorEnv
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data import Batch, ReplayBuffer,\
|
||||
ListReplayBuffer
|
||||
from tianshou.utils import MovAvg
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ class Collector(object):
|
||||
self._multi_buf = True
|
||||
elif isinstance(self.buffer, ReplayBuffer):
|
||||
self._cached_buf = [
|
||||
deepcopy(self.buffer) for _ in range(self.env_num)]
|
||||
ListReplayBuffer() for _ in range(self.env_num)]
|
||||
else:
|
||||
raise TypeError('The buffer in data collector is invalid!')
|
||||
self.reset_env()
|
||||
|
Loading…
x
Reference in New Issue
Block a user