parent
959955fa2a
commit
80d661907e
15
README.md
15
README.md
@ -5,9 +5,9 @@
|
|||||||
---
|
---
|
||||||
|
|
||||||
[](https://pypi.org/project/tianshou/)
|
[](https://pypi.org/project/tianshou/)
|
||||||
|
[](https://tianshou.readthedocs.io)
|
||||||
[](https://github.com/thu-ml/tianshou/actions)
|
[](https://github.com/thu-ml/tianshou/actions)
|
||||||
[](https://codecov.io/gh/thu-ml/tianshou)
|
[](https://codecov.io/gh/thu-ml/tianshou)
|
||||||
[](https://tianshou.readthedocs.io)
|
|
||||||
[](https://github.com/thu-ml/tianshou/issues)
|
[](https://github.com/thu-ml/tianshou/issues)
|
||||||
[](https://github.com/thu-ml/tianshou/stargazers)
|
[](https://github.com/thu-ml/tianshou/stargazers)
|
||||||
[](https://github.com/thu-ml/tianshou/network)
|
[](https://github.com/thu-ml/tianshou/network)
|
||||||
@ -40,7 +40,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is
|
|||||||
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install tianshou -U
|
pip3 install tianshou
|
||||||
```
|
```
|
||||||
|
|
||||||
You can also install with the newest version through GitHub:
|
You can also install with the newest version through GitHub:
|
||||||
@ -49,6 +49,17 @@ You can also install with the newest version through GitHub:
|
|||||||
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# create a new virtualenv and install pip, change the env name if you like
|
||||||
|
conda create -n myenv pip
|
||||||
|
# activate the environment
|
||||||
|
conda activate myenv
|
||||||
|
# install tianshou
|
||||||
|
pip install tianshou
|
||||||
|
```
|
||||||
|
|
||||||
After installation, open your python console and type
|
After installation, open your python console and type
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -30,13 +30,23 @@ Installation
|
|||||||
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
|
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
|
||||||
::
|
::
|
||||||
|
|
||||||
pip3 install tianshou -U
|
pip3 install tianshou
|
||||||
|
|
||||||
You can also install with the newest version through GitHub:
|
You can also install with the newest version through GitHub:
|
||||||
::
|
::
|
||||||
|
|
||||||
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
pip3 install git+https://github.com/thu-ml/tianshou.git@master
|
||||||
|
|
||||||
|
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
|
||||||
|
::
|
||||||
|
|
||||||
|
# create a new virtualenv and install pip, change the env name if you like
|
||||||
|
conda create -n myenv pip
|
||||||
|
# activate the environment
|
||||||
|
conda activate myenv
|
||||||
|
# install tianshou
|
||||||
|
pip install tianshou
|
||||||
|
|
||||||
After installation, open your python console and type
|
After installation, open your python console and type
|
||||||
::
|
::
|
||||||
|
|
||||||
|
@ -3,15 +3,16 @@ import time
|
|||||||
|
|
||||||
|
|
||||||
class MyTestEnv(gym.Env):
|
class MyTestEnv(gym.Env):
|
||||||
def __init__(self, size, sleep=0):
|
def __init__(self, size, sleep=0, dict_state=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.sleep = sleep
|
self.sleep = sleep
|
||||||
|
self.dict_state = dict_state
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self, state=0):
|
def reset(self, state=0):
|
||||||
self.done = False
|
self.done = False
|
||||||
self.index = state
|
self.index = state
|
||||||
return self.index
|
return {'index': self.index} if self.dict_state else self.index
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
if self.done:
|
if self.done:
|
||||||
@ -20,11 +21,21 @@ class MyTestEnv(gym.Env):
|
|||||||
time.sleep(self.sleep)
|
time.sleep(self.sleep)
|
||||||
if self.index == self.size:
|
if self.index == self.size:
|
||||||
self.done = True
|
self.done = True
|
||||||
return self.index, 0, True, {}
|
if self.dict_state:
|
||||||
|
return {'index': self.index}, 0, True, {}
|
||||||
|
else:
|
||||||
|
return self.index, 0, True, {}
|
||||||
if action == 0:
|
if action == 0:
|
||||||
self.index = max(self.index - 1, 0)
|
self.index = max(self.index - 1, 0)
|
||||||
return self.index, 0, False, {}
|
if self.dict_state:
|
||||||
|
return {'index': self.index}, 0, False, {}
|
||||||
|
else:
|
||||||
|
return self.index, 0, False, {}
|
||||||
elif action == 1:
|
elif action == 1:
|
||||||
self.index += 1
|
self.index += 1
|
||||||
self.done = self.index == self.size
|
self.done = self.index == self.size
|
||||||
return self.index, int(self.done), self.done, {'key': 1}
|
if self.dict_state:
|
||||||
|
return {'index': self.index}, int(self.done), self.done, \
|
||||||
|
{'key': 1}
|
||||||
|
else:
|
||||||
|
return self.index, int(self.done), self.done, {'key': 1}
|
||||||
|
@ -15,7 +15,7 @@ def test_batch():
|
|||||||
with pytest.raises(IndexError):
|
with pytest.raises(IndexError):
|
||||||
batch[2]
|
batch[2]
|
||||||
batch.obs = np.arange(5)
|
batch.obs = np.arange(5)
|
||||||
for i, b in enumerate(batch.split(1, permute=False)):
|
for i, b in enumerate(batch.split(1, shuffle=False)):
|
||||||
assert b.obs == batch[i].obs
|
assert b.obs == batch[i].obs
|
||||||
print(batch)
|
print(batch)
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import numpy as np
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.env import SubprocVectorEnv
|
from tianshou.env import VectorEnv, SubprocVectorEnv
|
||||||
from tianshou.data import Collector, Batch, ReplayBuffer
|
from tianshou.data import Collector, Batch, ReplayBuffer
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -12,10 +12,13 @@ else: # pytest
|
|||||||
|
|
||||||
|
|
||||||
class MyPolicy(BasePolicy):
|
class MyPolicy(BasePolicy):
|
||||||
def __init__(self):
|
def __init__(self, dict_state=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.dict_state = dict_state
|
||||||
|
|
||||||
def forward(self, batch, state=None):
|
def forward(self, batch, state=None):
|
||||||
|
if self.dict_state:
|
||||||
|
return Batch(act=np.ones(batch.obs['index'].shape[0]))
|
||||||
return Batch(act=np.ones(batch.obs.shape[0]))
|
return Batch(act=np.ones(batch.obs.shape[0]))
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
@ -75,5 +78,24 @@ def test_collector():
|
|||||||
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
|
||||||
|
|
||||||
|
|
||||||
|
def test_collector_with_dict_state():
|
||||||
|
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
||||||
|
policy = MyPolicy(dict_state=True)
|
||||||
|
c0 = Collector(policy, env, ReplayBuffer(size=100))
|
||||||
|
c0.collect(n_step=3)
|
||||||
|
c0.collect(n_episode=3)
|
||||||
|
env_fns = [
|
||||||
|
lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
|
||||||
|
lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
|
||||||
|
lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
|
||||||
|
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
|
||||||
|
]
|
||||||
|
envs = VectorEnv(env_fns)
|
||||||
|
c1 = Collector(policy, envs, ReplayBuffer(size=100))
|
||||||
|
c1.collect(n_step=10)
|
||||||
|
c1.collect(n_episode=[2, 1, 1, 2])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_collector()
|
test_collector()
|
||||||
|
test_collector_with_dict_state()
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import pprint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@ -23,7 +24,7 @@ class Batch(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
In short, you can define a :class:`Batch` with any key-value pair. The
|
In short, you can define a :class:`Batch` with any key-value pair. The
|
||||||
current implementation of Tianshou typically use 6 keys in
|
current implementation of Tianshou typically use 6 reserved keys in
|
||||||
:class:`~tianshou.data.Batch`:
|
:class:`~tianshou.data.Batch`:
|
||||||
|
|
||||||
* ``obs`` the observation of step :math:`t` ;
|
* ``obs`` the observation of step :math:`t` ;
|
||||||
@ -56,7 +57,7 @@ class Batch(object):
|
|||||||
array([0, 11, 22, 0, 11, 22])
|
array([0, 11, 22, 0, 11, 22])
|
||||||
|
|
||||||
>>> # split whole data into multiple small batch
|
>>> # split whole data into multiple small batch
|
||||||
>>> for d in data.split(size=2, permute=False):
|
>>> for d in data.split(size=2, shuffle=False):
|
||||||
... print(d.obs, d.rew)
|
... print(d.obs, d.rew)
|
||||||
[ 0 11] [6 6]
|
[ 0 11] [6 6]
|
||||||
[22 0] [6 6]
|
[22 0] [6 6]
|
||||||
@ -65,24 +66,56 @@ class Batch(object):
|
|||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__dict__.update(kwargs)
|
self._meta = {}
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if (isinstance(v, list) or isinstance(v, np.ndarray)) \
|
||||||
|
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
|
||||||
|
self._meta[k] = list(v[0].keys())
|
||||||
|
for k_ in v[0].keys():
|
||||||
|
k__ = '_' + k + '@' + k_
|
||||||
|
self.__dict__[k__] = np.array([
|
||||||
|
v[i][k_] for i in range(len(v))
|
||||||
|
])
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
self._meta[k] = list(v.keys())
|
||||||
|
for k_ in v.keys():
|
||||||
|
k__ = '_' + k + '@' + k_
|
||||||
|
self.__dict__[k__] = v[k_]
|
||||||
|
else:
|
||||||
|
self.__dict__[k] = kwargs[k]
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""Return self[index]."""
|
"""Return self[index]."""
|
||||||
|
if isinstance(index, str):
|
||||||
|
return self.__getattr__(index)
|
||||||
b = Batch()
|
b = Batch()
|
||||||
for k in self.__dict__.keys():
|
for k in self.__dict__.keys():
|
||||||
if self.__dict__[k] is not None:
|
if k != '_meta' and self.__dict__[k] is not None:
|
||||||
b.__dict__.update(**{k: self.__dict__[k][index]})
|
b.__dict__.update(**{k: self.__dict__[k][index]})
|
||||||
|
b._meta = self._meta
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
"""Return self.key"""
|
||||||
|
if key not in self._meta.keys():
|
||||||
|
if key not in self.__dict__.keys():
|
||||||
|
raise AttributeError(key)
|
||||||
|
return self.__dict__[key]
|
||||||
|
d = {}
|
||||||
|
for k_ in self._meta[key]:
|
||||||
|
k__ = '_' + key + '@' + k_
|
||||||
|
d[k_] = self.__dict__[k__]
|
||||||
|
return d
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
"""Return str(self)."""
|
"""Return str(self)."""
|
||||||
s = self.__class__.__name__ + '(\n'
|
s = self.__class__.__name__ + '(\n'
|
||||||
flag = False
|
flag = False
|
||||||
for k in sorted(self.__dict__.keys()):
|
for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
|
||||||
if k[0] != '_' and self.__dict__[k] is not None:
|
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
||||||
|
k in self._meta.keys()):
|
||||||
rpl = '\n' + ' ' * (6 + len(k))
|
rpl = '\n' + ' ' * (6 + len(k))
|
||||||
obj = str(self.__dict__[k]).replace('\n', rpl)
|
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||||
s += f' {k}: {obj},\n'
|
s += f' {k}: {obj},\n'
|
||||||
flag = True
|
flag = True
|
||||||
if flag:
|
if flag:
|
||||||
@ -91,10 +124,18 @@ class Batch(object):
|
|||||||
s = self.__class__.__name__ + '()\n'
|
s = self.__class__.__name__ + '()\n'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
"""Return self.keys()."""
|
||||||
|
return sorted([i for i in self.__dict__.keys() if i[0] != '_'] +
|
||||||
|
list(self._meta.keys()))
|
||||||
|
|
||||||
def append(self, batch):
|
def append(self, batch):
|
||||||
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
"""Append a :class:`~tianshou.data.Batch` object to current batch."""
|
||||||
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
assert isinstance(batch, Batch), 'Only append Batch is allowed!'
|
||||||
for k in batch.__dict__.keys():
|
for k in batch.__dict__.keys():
|
||||||
|
if k == '_meta':
|
||||||
|
self._meta.update(batch.__dict__[k])
|
||||||
|
continue
|
||||||
if batch.__dict__[k] is None:
|
if batch.__dict__[k] is None:
|
||||||
continue
|
continue
|
||||||
if not hasattr(self, k) or self.__dict__[k] is None:
|
if not hasattr(self, k) or self.__dict__[k] is None:
|
||||||
@ -117,22 +158,22 @@ class Batch(object):
|
|||||||
"""Return len(self)."""
|
"""Return len(self)."""
|
||||||
return min([
|
return min([
|
||||||
len(self.__dict__[k]) for k in self.__dict__.keys()
|
len(self.__dict__[k]) for k in self.__dict__.keys()
|
||||||
if self.__dict__[k] is not None])
|
if k != '_meta' and self.__dict__[k] is not None])
|
||||||
|
|
||||||
def split(self, size=None, permute=True):
|
def split(self, size=None, shuffle=True):
|
||||||
"""Split whole data into multiple small batch.
|
"""Split whole data into multiple small batch.
|
||||||
|
|
||||||
:param int size: if it is ``None``, it does not split the data batch;
|
:param int size: if it is ``None``, it does not split the data batch;
|
||||||
otherwise it will divide the data batch with the given size.
|
otherwise it will divide the data batch with the given size.
|
||||||
Default to ``None``.
|
Default to ``None``.
|
||||||
:param bool permute: randomly shuffle the entire data batch if it is
|
:param bool shuffle: randomly shuffle the entire data batch if it is
|
||||||
``True``, otherwise remain in the same. Default to ``True``.
|
``True``, otherwise remain in the same. Default to ``True``.
|
||||||
"""
|
"""
|
||||||
length = len(self)
|
length = len(self)
|
||||||
if size is None:
|
if size is None:
|
||||||
size = length
|
size = length
|
||||||
temp = 0
|
temp = 0
|
||||||
if permute:
|
if shuffle:
|
||||||
index = np.random.permutation(length)
|
index = np.random.permutation(length)
|
||||||
else:
|
else:
|
||||||
index = np.arange(length)
|
index = np.arange(length)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import pprint
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tianshou.data.batch import Batch
|
from tianshou.data.batch import Batch
|
||||||
|
|
||||||
@ -92,6 +93,7 @@ class ReplayBuffer(object):
|
|||||||
self._maxsize = size
|
self._maxsize = size
|
||||||
self._stack = stack_num
|
self._stack = stack_num
|
||||||
self._save_s_ = not ignore_obs_next
|
self._save_s_ = not ignore_obs_next
|
||||||
|
self._meta = {}
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -102,10 +104,11 @@ class ReplayBuffer(object):
|
|||||||
"""Return str(self)."""
|
"""Return str(self)."""
|
||||||
s = self.__class__.__name__ + '(\n'
|
s = self.__class__.__name__ + '(\n'
|
||||||
flag = False
|
flag = False
|
||||||
for k in self.__dict__.keys():
|
for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
|
||||||
if k[0] != '_' and self.__dict__[k] is not None:
|
if k[0] != '_' and (self.__dict__.get(k, None) is not None or
|
||||||
|
k in self._meta.keys()):
|
||||||
rpl = '\n' + ' ' * (6 + len(k))
|
rpl = '\n' + ' ' * (6 + len(k))
|
||||||
obj = str(self.__dict__[k]).replace('\n', rpl)
|
obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
|
||||||
s += f' {k}: {obj},\n'
|
s += f' {k}: {obj},\n'
|
||||||
flag = True
|
flag = True
|
||||||
if flag:
|
if flag:
|
||||||
@ -114,23 +117,51 @@ class ReplayBuffer(object):
|
|||||||
s = self.__class__.__name__ + '()\n'
|
s = self.__class__.__name__ + '()\n'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
"""Return self.key"""
|
||||||
|
if key not in self._meta.keys():
|
||||||
|
if key not in self.__dict__.keys():
|
||||||
|
raise AttributeError(key)
|
||||||
|
return self.__dict__[key]
|
||||||
|
d = {}
|
||||||
|
for k_ in self._meta[key]:
|
||||||
|
k__ = '_' + key + '@' + k_
|
||||||
|
d[k_] = self.__dict__[k__]
|
||||||
|
return d
|
||||||
|
|
||||||
def _add_to_buffer(self, name, inst):
|
def _add_to_buffer(self, name, inst):
|
||||||
if inst is None:
|
if inst is None:
|
||||||
if getattr(self, name, None) is None:
|
if getattr(self, name, None) is None:
|
||||||
self.__dict__[name] = None
|
self.__dict__[name] = None
|
||||||
return
|
return
|
||||||
|
if name in self._meta.keys():
|
||||||
|
for k in inst.keys():
|
||||||
|
self._add_to_buffer('_' + name + '@' + k, inst[k])
|
||||||
|
return
|
||||||
if self.__dict__.get(name, None) is None:
|
if self.__dict__.get(name, None) is None:
|
||||||
if isinstance(inst, np.ndarray):
|
if isinstance(inst, np.ndarray):
|
||||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
||||||
elif isinstance(inst, dict):
|
elif isinstance(inst, dict):
|
||||||
self.__dict__[name] = np.array(
|
if name == 'info':
|
||||||
[{} for _ in range(self._maxsize)])
|
self.__dict__[name] = np.array(
|
||||||
|
[{} for _ in range(self._maxsize)])
|
||||||
|
else:
|
||||||
|
if self._meta.get(name, None) is None:
|
||||||
|
self._meta[name] = [
|
||||||
|
'_' + name + '@' + k for k in inst.keys()]
|
||||||
|
for k in inst.keys():
|
||||||
|
k_ = '_' + name + '@' + k
|
||||||
|
self._add_to_buffer(k_, inst[k])
|
||||||
else: # assume `inst` is a number
|
else: # assume `inst` is a number
|
||||||
self.__dict__[name] = np.zeros([self._maxsize])
|
self.__dict__[name] = np.zeros([self._maxsize])
|
||||||
if isinstance(inst, np.ndarray) and \
|
if isinstance(inst, np.ndarray) and \
|
||||||
self.__dict__[name].shape[1:] != inst.shape:
|
self.__dict__[name].shape[1:] != inst.shape:
|
||||||
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
|
raise ValueError(
|
||||||
self.__dict__[name][self._index] = inst
|
"Cannot add data to a buffer with different shape, "
|
||||||
|
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
|
||||||
|
f"given shape: {inst.shape}.")
|
||||||
|
if name not in self._meta.keys():
|
||||||
|
self.__dict__[name][self._index] = inst
|
||||||
|
|
||||||
def update(self, buffer):
|
def update(self, buffer):
|
||||||
"""Move the data from the given buffer to self."""
|
"""Move the data from the given buffer to self."""
|
||||||
@ -144,7 +175,8 @@ class ReplayBuffer(object):
|
|||||||
if i == begin:
|
if i == begin:
|
||||||
break
|
break
|
||||||
|
|
||||||
def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None):
|
def add(self, obs, act, rew, done, obs_next=None, info={}, policy={},
|
||||||
|
**kwargs):
|
||||||
"""Add a batch of data into replay buffer."""
|
"""Add a batch of data into replay buffer."""
|
||||||
assert isinstance(info, dict), \
|
assert isinstance(info, dict), \
|
||||||
'You should return a dict in the last argument of env.step().'
|
'You should return a dict in the last argument of env.step().'
|
||||||
@ -155,6 +187,7 @@ class ReplayBuffer(object):
|
|||||||
if self._save_s_:
|
if self._save_s_:
|
||||||
self._add_to_buffer('obs_next', obs_next)
|
self._add_to_buffer('obs_next', obs_next)
|
||||||
self._add_to_buffer('info', info)
|
self._add_to_buffer('info', info)
|
||||||
|
self._add_to_buffer('policy', policy)
|
||||||
if self._maxsize > 0:
|
if self._maxsize > 0:
|
||||||
self._size = min(self._size + 1, self._maxsize)
|
self._size = min(self._size + 1, self._maxsize)
|
||||||
self._index = (self._index + 1) % self._maxsize
|
self._index = (self._index + 1) % self._maxsize
|
||||||
@ -180,11 +213,13 @@ class ReplayBuffer(object):
|
|||||||
])
|
])
|
||||||
return self[indice], indice
|
return self[indice], indice
|
||||||
|
|
||||||
def get(self, indice, key):
|
def get(self, indice, key, stack_num=None):
|
||||||
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
|
||||||
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
where s is self.key, t is indice. The stack_num (here equals to 4) is
|
||||||
given from buffer initialization procedure.
|
given from buffer initialization procedure.
|
||||||
"""
|
"""
|
||||||
|
if stack_num is None:
|
||||||
|
stack_num = self._stack
|
||||||
if not isinstance(indice, np.ndarray):
|
if not isinstance(indice, np.ndarray):
|
||||||
if np.isscalar(indice):
|
if np.isscalar(indice):
|
||||||
indice = np.array(indice)
|
indice = np.array(indice)
|
||||||
@ -200,18 +235,37 @@ class ReplayBuffer(object):
|
|||||||
indice += 1 - self.done[indice].astype(np.int)
|
indice += 1 - self.done[indice].astype(np.int)
|
||||||
indice[indice == self._size] = 0
|
indice[indice == self._size] = 0
|
||||||
key = 'obs'
|
key = 'obs'
|
||||||
if self._stack == 0:
|
if stack_num == 0:
|
||||||
self.done[last_index] = last_done
|
self.done[last_index] = last_done
|
||||||
return self.__dict__[key][indice]
|
if key in self._meta:
|
||||||
stack = []
|
return {k.split('@')[-1]: self.__dict__[k][indice]
|
||||||
for i in range(self._stack):
|
for k in self._meta[key]}
|
||||||
stack = [self.__dict__[key][indice]] + stack
|
else:
|
||||||
|
return self.__dict__[key][indice]
|
||||||
|
if key in self._meta:
|
||||||
|
many_keys = self._meta[key]
|
||||||
|
stack = {k.split('@')[-1]: [] for k in self._meta[key]}
|
||||||
|
else:
|
||||||
|
stack = []
|
||||||
|
many_keys = None
|
||||||
|
for i in range(stack_num):
|
||||||
|
if many_keys is not None:
|
||||||
|
for k_ in many_keys:
|
||||||
|
k = k_.split('@')[-1]
|
||||||
|
stack[k] = [self.__dict__[k_][indice]] + stack[k]
|
||||||
|
else:
|
||||||
|
stack = [self.__dict__[key][indice]] + stack
|
||||||
pre_indice = indice - 1
|
pre_indice = indice - 1
|
||||||
pre_indice[pre_indice == -1] = self._size - 1
|
pre_indice[pre_indice == -1] = self._size - 1
|
||||||
indice = pre_indice + self.done[pre_indice].astype(np.int)
|
indice = pre_indice + self.done[pre_indice].astype(np.int)
|
||||||
indice[indice == self._size] = 0
|
indice[indice == self._size] = 0
|
||||||
self.done[last_index] = last_done
|
self.done[last_index] = last_done
|
||||||
return np.stack(stack, axis=1)
|
if many_keys is not None:
|
||||||
|
for k in stack:
|
||||||
|
stack[k] = np.stack(stack[k], axis=1)
|
||||||
|
else:
|
||||||
|
stack = np.stack(stack, axis=1)
|
||||||
|
return stack
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""Return a data batch: self[index]. If stack_num is set to be > 0,
|
"""Return a data batch: self[index]. If stack_num is set to be > 0,
|
||||||
@ -223,7 +277,8 @@ class ReplayBuffer(object):
|
|||||||
rew=self.rew[index],
|
rew=self.rew[index],
|
||||||
done=self.done[index],
|
done=self.done[index],
|
||||||
obs_next=self.get(index, 'obs_next'),
|
obs_next=self.get(index, 'obs_next'),
|
||||||
info=self.info[index]
|
info=self.info[index],
|
||||||
|
policy=self.get(index, 'policy'),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -234,7 +289,7 @@ class ListReplayBuffer(ReplayBuffer):
|
|||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
Please refer to :class:`~tianshou.data.ListReplayBuffer` for more
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
|
||||||
detailed explanation.
|
detailed explanation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -256,7 +311,13 @@ class ListReplayBuffer(ReplayBuffer):
|
|||||||
|
|
||||||
|
|
||||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||||
"""docstring for PrioritizedReplayBuffer"""
|
"""Prioritized replay buffer implementation.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
|
||||||
|
detailed explanation.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, size, alpha: float, beta: float,
|
def __init__(self, size, alpha: float, beta: float,
|
||||||
mode: str = 'weight', **kwargs):
|
mode: str = 'weight', **kwargs):
|
||||||
@ -270,17 +331,18 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
self._amortization_freq = 50
|
self._amortization_freq = 50
|
||||||
self._amortization_counter = 0
|
self._amortization_counter = 0
|
||||||
|
|
||||||
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=1.0):
|
def add(self, obs, act, rew, done, obs_next=0, info={}, policy={},
|
||||||
|
weight=1.0):
|
||||||
"""Add a batch of data into replay buffer."""
|
"""Add a batch of data into replay buffer."""
|
||||||
self._weight_sum += np.abs(weight)**self._alpha - \
|
self._weight_sum += np.abs(weight)**self._alpha - \
|
||||||
self.weight[self._index]
|
self.weight[self._index]
|
||||||
# we have to sacrifice some convenience for speed :(
|
# we have to sacrifice some convenience for speed :(
|
||||||
self._add_to_buffer('weight', np.abs(weight)**self._alpha)
|
self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
|
||||||
super().add(obs, act, rew, done, obs_next, info)
|
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||||
self._check_weight_sum()
|
self._check_weight_sum()
|
||||||
|
|
||||||
def sample(self, batch_size: int = 0, importance_sample: bool = True):
|
def sample(self, batch_size: int = 0, importance_sample: bool = True):
|
||||||
""" Get a random sample from buffer with priority probability. \
|
"""Get a random sample from buffer with priority probability. \
|
||||||
Return all the data in the buffer if batch_size is ``0``.
|
Return all the data in the buffer if batch_size is ``0``.
|
||||||
|
|
||||||
:return: Sample data and its corresponding index inside the buffer.
|
:return: Sample data and its corresponding index inside the buffer.
|
||||||
@ -290,7 +352,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
# will cause weight update conflict
|
# will cause weight update conflict
|
||||||
indice = np.random.choice(
|
indice = np.random.choice(
|
||||||
self._size, batch_size,
|
self._size, batch_size,
|
||||||
p=(self.weight/self.weight.sum())[:self._size], replace=False)
|
p=(self.weight / self.weight.sum())[:self._size],
|
||||||
|
replace=False)
|
||||||
# self._weight_sum is not work for the accuracy issue
|
# self._weight_sum is not work for the accuracy issue
|
||||||
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
||||||
elif batch_size == 0:
|
elif batch_size == 0:
|
||||||
@ -305,8 +368,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
batch = self[indice]
|
batch = self[indice]
|
||||||
if importance_sample:
|
if importance_sample:
|
||||||
impt_weight = Batch(
|
impt_weight = Batch(
|
||||||
impt_weight=1/np.power(
|
impt_weight=1 / np.power(
|
||||||
self._size*(batch.weight/self._weight_sum), self._beta))
|
self._size * (batch.weight / self._weight_sum),
|
||||||
|
self._beta))
|
||||||
batch.append(impt_weight)
|
batch.append(impt_weight)
|
||||||
self._check_weight_sum()
|
self._check_weight_sum()
|
||||||
return batch, indice
|
return batch, indice
|
||||||
@ -316,7 +380,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
super().reset()
|
super().reset()
|
||||||
|
|
||||||
def update_weight(self, indice, new_weight: np.ndarray):
|
def update_weight(self, indice, new_weight: np.ndarray):
|
||||||
"""update priority weight by indice in this buffer
|
"""Update priority weight by indice in this buffer.
|
||||||
|
|
||||||
:param indice: indice you want to update weight
|
:param indice: indice you want to update weight
|
||||||
:param new_weight: new priority weight you wangt to update
|
:param new_weight: new priority weight you wangt to update
|
||||||
@ -333,7 +397,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
done=self.done[index],
|
done=self.done[index],
|
||||||
obs_next=self.get(index, 'obs_next'),
|
obs_next=self.get(index, 'obs_next'),
|
||||||
info=self.info[index],
|
info=self.info[index],
|
||||||
weight=self.weight[index]
|
weight=self.weight[index],
|
||||||
|
policy=self.get(index, 'policy'),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_weight_sum(self):
|
def _check_weight_sum(self):
|
||||||
|
@ -54,7 +54,7 @@ class A2CPolicy(PGPolicy):
|
|||||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
v_ = []
|
v_ = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch, permute=False):
|
for b in batch.split(self._batch, shuffle=False):
|
||||||
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
v_.append(self.critic(b.obs_next).detach().cpu().numpy())
|
||||||
v_ = np.concatenate(v_, axis=0)
|
v_ = np.concatenate(v_, axis=0)
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
|
@ -74,7 +74,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
batch, None, gamma=self._gamma, gae_lambda=self._lambda)
|
||||||
v_ = []
|
v_ = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(self._batch, permute=False):
|
for b in batch.split(self._batch, shuffle=False):
|
||||||
v_.append(self.critic(b.obs_next))
|
v_.append(self.critic(b.obs_next))
|
||||||
v_ = torch.cat(v_, dim=0).cpu().numpy()
|
v_ = torch.cat(v_, dim=0).cpu().numpy()
|
||||||
return self.compute_episodic_return(
|
return self.compute_episodic_return(
|
||||||
@ -111,7 +111,7 @@ class PPOPolicy(PGPolicy):
|
|||||||
v = []
|
v = []
|
||||||
old_log_prob = []
|
old_log_prob = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for b in batch.split(batch_size, permute=False):
|
for b in batch.split(batch_size, shuffle=False):
|
||||||
v.append(self.critic(b.obs))
|
v.append(self.critic(b.obs))
|
||||||
old_log_prob.append(self(b).dist.log_prob(
|
old_log_prob.append(self(b).dist.log_prob(
|
||||||
torch.tensor(b.act, device=v[0].device)))
|
torch.tensor(b.act, device=v[0].device)))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user