Tianshou/tianshou/data/buffer.py

129 lines
4.0 KiB
Python
Raw Normal View History

2020-03-11 09:09:56 +08:00
import numpy as np
from tianshou.data.batch import Batch
class ReplayBuffer(object):
"""docstring for ReplayBuffer"""
2020-03-13 17:49:22 +08:00
2020-03-11 09:09:56 +08:00
def __init__(self, size):
super().__init__()
self._maxsize = size
2020-03-11 17:28:51 +08:00
self.reset()
2020-03-11 09:09:56 +08:00
2020-03-18 21:45:41 +08:00
def __del__(self):
for k in list(self.__dict__.keys()):
del self.__dict__[k]
2020-03-11 09:09:56 +08:00
def __len__(self):
return self._size
def _add_to_buffer(self, name, inst):
if inst is None:
return
if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
elif isinstance(inst, dict):
2020-03-13 17:49:22 +08:00
self.__dict__[name] = np.array(
[{} for _ in range(self._maxsize)])
2020-03-11 09:38:14 +08:00
else: # assume `inst` is a number
2020-03-11 09:09:56 +08:00
self.__dict__[name] = np.zeros([self._maxsize])
2020-03-18 21:45:41 +08:00
if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape:
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
2020-03-11 09:09:56 +08:00
self.__dict__[name][self._index] = inst
2020-03-14 21:48:31 +08:00
def update(self, buffer):
2020-03-16 11:11:29 +08:00
i = begin = buffer._index % len(buffer)
while True:
2020-03-14 21:48:31 +08:00
self.add(
buffer.obs[i], buffer.act[i], buffer.rew[i],
buffer.done[i], buffer.obs_next[i], buffer.info[i])
2020-03-16 11:11:29 +08:00
i = (i + 1) % len(buffer)
if i == begin:
break
2020-03-14 21:48:31 +08:00
2020-03-11 09:09:56 +08:00
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=None):
'''
weight: importance weights, disabled here
'''
assert isinstance(info, dict), \
2020-03-13 17:49:22 +08:00
'You should return a dict in the last argument of env.step().'
2020-03-11 09:09:56 +08:00
self._add_to_buffer('obs', obs)
self._add_to_buffer('act', act)
self._add_to_buffer('rew', rew)
self._add_to_buffer('done', done)
self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info)
2020-03-28 15:14:41 +08:00
if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize
else:
self._size = self._index = self._index + 1
2020-03-11 09:09:56 +08:00
def reset(self):
self._index = self._size = 0
2020-03-12 22:20:33 +08:00
self.indice = []
2020-03-11 09:09:56 +08:00
2020-03-13 17:49:22 +08:00
def sample(self, batch_size):
2020-03-12 22:20:33 +08:00
if batch_size > 0:
2020-03-13 17:49:22 +08:00
indice = np.random.choice(self._size, batch_size)
2020-03-12 22:20:33 +08:00
else:
2020-03-17 11:37:31 +08:00
indice = np.concatenate([
np.arange(self._index, self._size),
np.arange(0, self._index),
])
2020-03-11 09:38:14 +08:00
return Batch(
obs=self.obs[indice],
act=self.act[indice],
rew=self.rew[indice],
done=self.done[indice],
obs_next=self.obs_next[indice],
info=self.info[indice]
2020-03-13 17:49:22 +08:00
), indice
2020-03-11 09:09:56 +08:00
2020-03-15 17:41:00 +08:00
def __getitem__(self, index):
return Batch(
obs=self.obs[index],
act=self.act[index],
rew=self.rew[index],
done=self.done[index],
obs_next=self.obs_next[index],
info=self.info[index]
)
2020-03-11 09:09:56 +08:00
2020-03-28 15:14:41 +08:00
class ListReplayBuffer(ReplayBuffer):
"""docstring for ListReplayBuffer"""
def __init__(self):
super().__init__(size=0)
def _add_to_buffer(self, name, inst):
if inst is None:
return
if self.__dict__.get(name, None) is None:
self.__dict__[name] = []
self.__dict__[name].append(inst)
def reset(self):
self._index = self._size = 0
for k in list(self.__dict__.keys()):
if not k.startswith('_'):
self.__dict__[k] = []
2020-03-11 09:09:56 +08:00
class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer"""
2020-03-13 17:49:22 +08:00
2020-03-11 09:09:56 +08:00
def __init__(self, size):
super().__init__(size)
2020-03-11 09:38:14 +08:00
2020-03-11 09:09:56 +08:00
def add(self, obs, act, rew, done, obs_next, info={}, weight=None):
raise NotImplementedError
def sample_indice(self, batch_size):
raise NotImplementedError
def sample(self, batch_size):
raise NotImplementedError