add ListReplayBuffer
This commit is contained in:
parent
eb7fb37806
commit
f23b0dfac9
@ -1,10 +1,12 @@
|
|||||||
from tianshou.data.batch import Batch
|
from tianshou.data.batch import Batch
|
||||||
from tianshou.data.buffer import ReplayBuffer, PrioritizedReplayBuffer
|
from tianshou.data.buffer import ReplayBuffer, \
|
||||||
|
ListReplayBuffer, PrioritizedReplayBuffer
|
||||||
from tianshou.data.collector import Collector
|
from tianshou.data.collector import Collector
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Batch',
|
'Batch',
|
||||||
'ReplayBuffer',
|
'ReplayBuffer',
|
||||||
|
'ListReplayBuffer',
|
||||||
'PrioritizedReplayBuffer',
|
'PrioritizedReplayBuffer',
|
||||||
'Collector'
|
'Collector',
|
||||||
]
|
]
|
||||||
|
|||||||
@ -55,8 +55,11 @@ class ReplayBuffer(object):
|
|||||||
self._add_to_buffer('done', done)
|
self._add_to_buffer('done', done)
|
||||||
self._add_to_buffer('obs_next', obs_next)
|
self._add_to_buffer('obs_next', obs_next)
|
||||||
self._add_to_buffer('info', info)
|
self._add_to_buffer('info', info)
|
||||||
self._size = min(self._size + 1, self._maxsize)
|
if self._maxsize > 0:
|
||||||
self._index = (self._index + 1) % self._maxsize
|
self._size = min(self._size + 1, self._maxsize)
|
||||||
|
self._index = (self._index + 1) % self._maxsize
|
||||||
|
else:
|
||||||
|
self._size = self._index = self._index + 1
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._index = self._size = 0
|
self._index = self._size = 0
|
||||||
@ -90,6 +93,25 @@ class ReplayBuffer(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ListReplayBuffer(ReplayBuffer):
|
||||||
|
"""docstring for ListReplayBuffer"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(size=0)
|
||||||
|
|
||||||
|
def _add_to_buffer(self, name, inst):
|
||||||
|
if inst is None:
|
||||||
|
return
|
||||||
|
if self.__dict__.get(name, None) is None:
|
||||||
|
self.__dict__[name] = []
|
||||||
|
self.__dict__[name].append(inst)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._index = self._size = 0
|
||||||
|
for k in list(self.__dict__.keys()):
|
||||||
|
if not k.startswith('_'):
|
||||||
|
self.__dict__[k] = []
|
||||||
|
|
||||||
|
|
||||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||||
"""docstring for PrioritizedReplayBuffer"""
|
"""docstring for PrioritizedReplayBuffer"""
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
from copy import deepcopy
|
|
||||||
import warnings
|
import warnings
|
||||||
|
import numpy as np
|
||||||
from tianshou.env import BaseVectorEnv
|
from tianshou.env import BaseVectorEnv
|
||||||
from tianshou.data import Batch, ReplayBuffer
|
from tianshou.data import Batch, ReplayBuffer,\
|
||||||
|
ListReplayBuffer
|
||||||
from tianshou.utils import MovAvg
|
from tianshou.utils import MovAvg
|
||||||
|
|
||||||
|
|
||||||
@ -37,7 +37,7 @@ class Collector(object):
|
|||||||
self._multi_buf = True
|
self._multi_buf = True
|
||||||
elif isinstance(self.buffer, ReplayBuffer):
|
elif isinstance(self.buffer, ReplayBuffer):
|
||||||
self._cached_buf = [
|
self._cached_buf = [
|
||||||
deepcopy(self.buffer) for _ in range(self.env_num)]
|
ListReplayBuffer() for _ in range(self.env_num)]
|
||||||
else:
|
else:
|
||||||
raise TypeError('The buffer in data collector is invalid!')
|
raise TypeError('The buffer in data collector is invalid!')
|
||||||
self.reset_env()
|
self.reset_env()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user