Multimodal obs (#38, #27, #25)

This commit is contained in:
Trinkle23897 2020-04-28 20:56:02 +08:00
parent 959955fa2a
commit 80d661907e
9 changed files with 212 additions and 52 deletions

View File

@ -5,9 +5,9 @@
--- ---
[![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/) [![PyPI](https://img.shields.io/pypi/v/tianshou)](https://pypi.org/project/tianshou/)
[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io)
[![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions) [![Unittest](https://github.com/thu-ml/tianshou/workflows/Unittest/badge.svg?branch=master)](https://github.com/thu-ml/tianshou/actions)
[![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou) [![codecov](https://img.shields.io/codecov/c/gh/thu-ml/tianshou)](https://codecov.io/gh/thu-ml/tianshou)
[![Documentation Status](https://readthedocs.org/projects/tianshou/badge/?version=latest)](https://tianshou.readthedocs.io)
[![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues) [![GitHub issues](https://img.shields.io/github/issues/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/issues)
[![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers)
[![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network) [![GitHub forks](https://img.shields.io/github/forks/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/network)
@ -40,7 +40,7 @@ In Chinese, Tianshou means the innate talent, not taught by others. Tianshou is
Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command: Tianshou is currently hosted on [PyPI](https://pypi.org/project/tianshou/). It requires Python >= 3.6. You can simply install Tianshou with the following command:
```bash ```bash
pip3 install tianshou -U pip3 install tianshou
``` ```
You can also install with the newest version through GitHub: You can also install with the newest version through GitHub:
@ -49,6 +49,17 @@ You can also install with the newest version through GitHub:
pip3 install git+https://github.com/thu-ml/tianshou.git@master pip3 install git+https://github.com/thu-ml/tianshou.git@master
``` ```
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
```bash
# create a new virtualenv and install pip, change the env name if you like
conda create -n myenv pip
# activate the environment
conda activate myenv
# install tianshou
pip install tianshou
```
After installation, open your python console and type After installation, open your python console and type
```python ```python

View File

@ -30,13 +30,23 @@ Installation
Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command: Tianshou is currently hosted on `PyPI <https://pypi.org/project/tianshou/>`_. You can simply install Tianshou with the following command:
:: ::
pip3 install tianshou -U pip3 install tianshou
You can also install with the newest version through GitHub: You can also install with the newest version through GitHub:
:: ::
pip3 install git+https://github.com/thu-ml/tianshou.git@master pip3 install git+https://github.com/thu-ml/tianshou.git@master
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
::
# create a new virtualenv and install pip, change the env name if you like
conda create -n myenv pip
# activate the environment
conda activate myenv
# install tianshou
pip install tianshou
After installation, open your python console and type After installation, open your python console and type
:: ::

View File

@ -3,15 +3,16 @@ import time
class MyTestEnv(gym.Env): class MyTestEnv(gym.Env):
def __init__(self, size, sleep=0): def __init__(self, size, sleep=0, dict_state=False):
self.size = size self.size = size
self.sleep = sleep self.sleep = sleep
self.dict_state = dict_state
self.reset() self.reset()
def reset(self, state=0): def reset(self, state=0):
self.done = False self.done = False
self.index = state self.index = state
return self.index return {'index': self.index} if self.dict_state else self.index
def step(self, action): def step(self, action):
if self.done: if self.done:
@ -20,11 +21,21 @@ class MyTestEnv(gym.Env):
time.sleep(self.sleep) time.sleep(self.sleep)
if self.index == self.size: if self.index == self.size:
self.done = True self.done = True
if self.dict_state:
return {'index': self.index}, 0, True, {}
else:
return self.index, 0, True, {} return self.index, 0, True, {}
if action == 0: if action == 0:
self.index = max(self.index - 1, 0) self.index = max(self.index - 1, 0)
if self.dict_state:
return {'index': self.index}, 0, False, {}
else:
return self.index, 0, False, {} return self.index, 0, False, {}
elif action == 1: elif action == 1:
self.index += 1 self.index += 1
self.done = self.index == self.size self.done = self.index == self.size
if self.dict_state:
return {'index': self.index}, int(self.done), self.done, \
{'key': 1}
else:
return self.index, int(self.done), self.done, {'key': 1} return self.index, int(self.done), self.done, {'key': 1}

View File

@ -15,7 +15,7 @@ def test_batch():
with pytest.raises(IndexError): with pytest.raises(IndexError):
batch[2] batch[2]
batch.obs = np.arange(5) batch.obs = np.arange(5)
for i, b in enumerate(batch.split(1, permute=False)): for i, b in enumerate(batch.split(1, shuffle=False)):
assert b.obs == batch[i].obs assert b.obs == batch[i].obs
print(batch) print(batch)

View File

@ -2,7 +2,7 @@ import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.env import SubprocVectorEnv from tianshou.env import VectorEnv, SubprocVectorEnv
from tianshou.data import Collector, Batch, ReplayBuffer from tianshou.data import Collector, Batch, ReplayBuffer
if __name__ == '__main__': if __name__ == '__main__':
@ -12,10 +12,13 @@ else: # pytest
class MyPolicy(BasePolicy): class MyPolicy(BasePolicy):
def __init__(self): def __init__(self, dict_state=False):
super().__init__() super().__init__()
self.dict_state = dict_state
def forward(self, batch, state=None): def forward(self, batch, state=None):
if self.dict_state:
return Batch(act=np.ones(batch.obs['index'].shape[0]))
return Batch(act=np.ones(batch.obs.shape[0])) return Batch(act=np.ones(batch.obs.shape[0]))
def learn(self): def learn(self):
@ -75,5 +78,24 @@ def test_collector():
1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]) 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5])
def test_collector_with_dict_state():
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100))
c0.collect(n_step=3)
c0.collect(n_episode=3)
env_fns = [
lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100))
c1.collect(n_step=10)
c1.collect(n_episode=[2, 1, 1, 2])
if __name__ == '__main__': if __name__ == '__main__':
test_collector() test_collector()
test_collector_with_dict_state()

View File

@ -1,4 +1,5 @@
import torch import torch
import pprint
import numpy as np import numpy as np
@ -23,7 +24,7 @@ class Batch(object):
) )
In short, you can define a :class:`Batch` with any key-value pair. The In short, you can define a :class:`Batch` with any key-value pair. The
current implementation of Tianshou typically use 6 keys in current implementation of Tianshou typically use 6 reserved keys in
:class:`~tianshou.data.Batch`: :class:`~tianshou.data.Batch`:
* ``obs`` the observation of step :math:`t` ; * ``obs`` the observation of step :math:`t` ;
@ -56,7 +57,7 @@ class Batch(object):
array([0, 11, 22, 0, 11, 22]) array([0, 11, 22, 0, 11, 22])
>>> # split whole data into multiple small batch >>> # split whole data into multiple small batch
>>> for d in data.split(size=2, permute=False): >>> for d in data.split(size=2, shuffle=False):
... print(d.obs, d.rew) ... print(d.obs, d.rew)
[ 0 11] [6 6] [ 0 11] [6 6]
[22 0] [6 6] [22 0] [6 6]
@ -65,24 +66,56 @@ class Batch(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__() super().__init__()
self.__dict__.update(kwargs) self._meta = {}
for k, v in kwargs.items():
if (isinstance(v, list) or isinstance(v, np.ndarray)) \
and len(v) > 0 and isinstance(v[0], dict) and k != 'info':
self._meta[k] = list(v[0].keys())
for k_ in v[0].keys():
k__ = '_' + k + '@' + k_
self.__dict__[k__] = np.array([
v[i][k_] for i in range(len(v))
])
elif isinstance(v, dict):
self._meta[k] = list(v.keys())
for k_ in v.keys():
k__ = '_' + k + '@' + k_
self.__dict__[k__] = v[k_]
else:
self.__dict__[k] = kwargs[k]
def __getitem__(self, index): def __getitem__(self, index):
"""Return self[index].""" """Return self[index]."""
if isinstance(index, str):
return self.__getattr__(index)
b = Batch() b = Batch()
for k in self.__dict__.keys(): for k in self.__dict__.keys():
if self.__dict__[k] is not None: if k != '_meta' and self.__dict__[k] is not None:
b.__dict__.update(**{k: self.__dict__[k][index]}) b.__dict__.update(**{k: self.__dict__[k][index]})
b._meta = self._meta
return b return b
def __getattr__(self, key):
"""Return self.key"""
if key not in self._meta.keys():
if key not in self.__dict__.keys():
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return d
def __repr__(self): def __repr__(self):
"""Return str(self).""" """Return str(self)."""
s = self.__class__.__name__ + '(\n' s = self.__class__.__name__ + '(\n'
flag = False flag = False
for k in sorted(self.__dict__.keys()): for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
if k[0] != '_' and self.__dict__[k] is not None: if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta.keys()):
rpl = '\n' + ' ' * (6 + len(k)) rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl) obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n' s += f' {k}: {obj},\n'
flag = True flag = True
if flag: if flag:
@ -91,10 +124,18 @@ class Batch(object):
s = self.__class__.__name__ + '()\n' s = self.__class__.__name__ + '()\n'
return s return s
def keys(self):
"""Return self.keys()."""
return sorted([i for i in self.__dict__.keys() if i[0] != '_'] +
list(self._meta.keys()))
def append(self, batch): def append(self, batch):
"""Append a :class:`~tianshou.data.Batch` object to current batch.""" """Append a :class:`~tianshou.data.Batch` object to current batch."""
assert isinstance(batch, Batch), 'Only append Batch is allowed!' assert isinstance(batch, Batch), 'Only append Batch is allowed!'
for k in batch.__dict__.keys(): for k in batch.__dict__.keys():
if k == '_meta':
self._meta.update(batch.__dict__[k])
continue
if batch.__dict__[k] is None: if batch.__dict__[k] is None:
continue continue
if not hasattr(self, k) or self.__dict__[k] is None: if not hasattr(self, k) or self.__dict__[k] is None:
@ -117,22 +158,22 @@ class Batch(object):
"""Return len(self).""" """Return len(self)."""
return min([ return min([
len(self.__dict__[k]) for k in self.__dict__.keys() len(self.__dict__[k]) for k in self.__dict__.keys()
if self.__dict__[k] is not None]) if k != '_meta' and self.__dict__[k] is not None])
def split(self, size=None, permute=True): def split(self, size=None, shuffle=True):
"""Split whole data into multiple small batch. """Split whole data into multiple small batch.
:param int size: if it is ``None``, it does not split the data batch; :param int size: if it is ``None``, it does not split the data batch;
otherwise it will divide the data batch with the given size. otherwise it will divide the data batch with the given size.
Default to ``None``. Default to ``None``.
:param bool permute: randomly shuffle the entire data batch if it is :param bool shuffle: randomly shuffle the entire data batch if it is
``True``, otherwise remain in the same. Default to ``True``. ``True``, otherwise remain in the same. Default to ``True``.
""" """
length = len(self) length = len(self)
if size is None: if size is None:
size = length size = length
temp = 0 temp = 0
if permute: if shuffle:
index = np.random.permutation(length) index = np.random.permutation(length)
else: else:
index = np.arange(length) index = np.arange(length)

View File

@ -1,3 +1,4 @@
import pprint
import numpy as np import numpy as np
from tianshou.data.batch import Batch from tianshou.data.batch import Batch
@ -92,6 +93,7 @@ class ReplayBuffer(object):
self._maxsize = size self._maxsize = size
self._stack = stack_num self._stack = stack_num
self._save_s_ = not ignore_obs_next self._save_s_ = not ignore_obs_next
self._meta = {}
self.reset() self.reset()
def __len__(self): def __len__(self):
@ -102,10 +104,11 @@ class ReplayBuffer(object):
"""Return str(self).""" """Return str(self)."""
s = self.__class__.__name__ + '(\n' s = self.__class__.__name__ + '(\n'
flag = False flag = False
for k in self.__dict__.keys(): for k in sorted(list(self.__dict__.keys()) + list(self._meta.keys())):
if k[0] != '_' and self.__dict__[k] is not None: if k[0] != '_' and (self.__dict__.get(k, None) is not None or
k in self._meta.keys()):
rpl = '\n' + ' ' * (6 + len(k)) rpl = '\n' + ' ' * (6 + len(k))
obj = str(self.__dict__[k]).replace('\n', rpl) obj = pprint.pformat(self.__getattr__(k)).replace('\n', rpl)
s += f' {k}: {obj},\n' s += f' {k}: {obj},\n'
flag = True flag = True
if flag: if flag:
@ -114,22 +117,50 @@ class ReplayBuffer(object):
s = self.__class__.__name__ + '()\n' s = self.__class__.__name__ + '()\n'
return s return s
def __getattr__(self, key):
"""Return self.key"""
if key not in self._meta.keys():
if key not in self.__dict__.keys():
raise AttributeError(key)
return self.__dict__[key]
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
return d
def _add_to_buffer(self, name, inst): def _add_to_buffer(self, name, inst):
if inst is None: if inst is None:
if getattr(self, name, None) is None: if getattr(self, name, None) is None:
self.__dict__[name] = None self.__dict__[name] = None
return return
if name in self._meta.keys():
for k in inst.keys():
self._add_to_buffer('_' + name + '@' + k, inst[k])
return
if self.__dict__.get(name, None) is None: if self.__dict__.get(name, None) is None:
if isinstance(inst, np.ndarray): if isinstance(inst, np.ndarray):
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) self.__dict__[name] = np.zeros([self._maxsize, *inst.shape])
elif isinstance(inst, dict): elif isinstance(inst, dict):
if name == 'info':
self.__dict__[name] = np.array( self.__dict__[name] = np.array(
[{} for _ in range(self._maxsize)]) [{} for _ in range(self._maxsize)])
else:
if self._meta.get(name, None) is None:
self._meta[name] = [
'_' + name + '@' + k for k in inst.keys()]
for k in inst.keys():
k_ = '_' + name + '@' + k
self._add_to_buffer(k_, inst[k])
else: # assume `inst` is a number else: # assume `inst` is a number
self.__dict__[name] = np.zeros([self._maxsize]) self.__dict__[name] = np.zeros([self._maxsize])
if isinstance(inst, np.ndarray) and \ if isinstance(inst, np.ndarray) and \
self.__dict__[name].shape[1:] != inst.shape: self.__dict__[name].shape[1:] != inst.shape:
self.__dict__[name] = np.zeros([self._maxsize, *inst.shape]) raise ValueError(
"Cannot add data to a buffer with different shape, "
f"key: {name}, expect shape: {self.__dict__[name].shape[1:]}, "
f"given shape: {inst.shape}.")
if name not in self._meta.keys():
self.__dict__[name][self._index] = inst self.__dict__[name][self._index] = inst
def update(self, buffer): def update(self, buffer):
@ -144,7 +175,8 @@ class ReplayBuffer(object):
if i == begin: if i == begin:
break break
def add(self, obs, act, rew, done, obs_next=None, info={}, weight=None): def add(self, obs, act, rew, done, obs_next=None, info={}, policy={},
**kwargs):
"""Add a batch of data into replay buffer.""" """Add a batch of data into replay buffer."""
assert isinstance(info, dict), \ assert isinstance(info, dict), \
'You should return a dict in the last argument of env.step().' 'You should return a dict in the last argument of env.step().'
@ -155,6 +187,7 @@ class ReplayBuffer(object):
if self._save_s_: if self._save_s_:
self._add_to_buffer('obs_next', obs_next) self._add_to_buffer('obs_next', obs_next)
self._add_to_buffer('info', info) self._add_to_buffer('info', info)
self._add_to_buffer('policy', policy)
if self._maxsize > 0: if self._maxsize > 0:
self._size = min(self._size + 1, self._maxsize) self._size = min(self._size + 1, self._maxsize)
self._index = (self._index + 1) % self._maxsize self._index = (self._index + 1) % self._maxsize
@ -180,11 +213,13 @@ class ReplayBuffer(object):
]) ])
return self[indice], indice return self[indice], indice
def get(self, indice, key): def get(self, indice, key, stack_num=None):
"""Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t], """Return the stacked result, e.g. [s_{t-3}, s_{t-2}, s_{t-1}, s_t],
where s is self.key, t is indice. The stack_num (here equals to 4) is where s is self.key, t is indice. The stack_num (here equals to 4) is
given from buffer initialization procedure. given from buffer initialization procedure.
""" """
if stack_num is None:
stack_num = self._stack
if not isinstance(indice, np.ndarray): if not isinstance(indice, np.ndarray):
if np.isscalar(indice): if np.isscalar(indice):
indice = np.array(indice) indice = np.array(indice)
@ -200,18 +235,37 @@ class ReplayBuffer(object):
indice += 1 - self.done[indice].astype(np.int) indice += 1 - self.done[indice].astype(np.int)
indice[indice == self._size] = 0 indice[indice == self._size] = 0
key = 'obs' key = 'obs'
if self._stack == 0: if stack_num == 0:
self.done[last_index] = last_done self.done[last_index] = last_done
if key in self._meta:
return {k.split('@')[-1]: self.__dict__[k][indice]
for k in self._meta[key]}
else:
return self.__dict__[key][indice] return self.__dict__[key][indice]
if key in self._meta:
many_keys = self._meta[key]
stack = {k.split('@')[-1]: [] for k in self._meta[key]}
else:
stack = [] stack = []
for i in range(self._stack): many_keys = None
for i in range(stack_num):
if many_keys is not None:
for k_ in many_keys:
k = k_.split('@')[-1]
stack[k] = [self.__dict__[k_][indice]] + stack[k]
else:
stack = [self.__dict__[key][indice]] + stack stack = [self.__dict__[key][indice]] + stack
pre_indice = indice - 1 pre_indice = indice - 1
pre_indice[pre_indice == -1] = self._size - 1 pre_indice[pre_indice == -1] = self._size - 1
indice = pre_indice + self.done[pre_indice].astype(np.int) indice = pre_indice + self.done[pre_indice].astype(np.int)
indice[indice == self._size] = 0 indice[indice == self._size] = 0
self.done[last_index] = last_done self.done[last_index] = last_done
return np.stack(stack, axis=1) if many_keys is not None:
for k in stack:
stack[k] = np.stack(stack[k], axis=1)
else:
stack = np.stack(stack, axis=1)
return stack
def __getitem__(self, index): def __getitem__(self, index):
"""Return a data batch: self[index]. If stack_num is set to be > 0, """Return a data batch: self[index]. If stack_num is set to be > 0,
@ -223,7 +277,8 @@ class ReplayBuffer(object):
rew=self.rew[index], rew=self.rew[index],
done=self.done[index], done=self.done[index],
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.info[index] info=self.info[index],
policy=self.get(index, 'policy'),
) )
@ -234,7 +289,7 @@ class ListReplayBuffer(ReplayBuffer):
.. seealso:: .. seealso::
Please refer to :class:`~tianshou.data.ListReplayBuffer` for more Please refer to :class:`~tianshou.data.ReplayBuffer` for more
detailed explanation. detailed explanation.
""" """
@ -256,7 +311,13 @@ class ListReplayBuffer(ReplayBuffer):
class PrioritizedReplayBuffer(ReplayBuffer): class PrioritizedReplayBuffer(ReplayBuffer):
"""docstring for PrioritizedReplayBuffer""" """Prioritized replay buffer implementation.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for more
detailed explanation.
"""
def __init__(self, size, alpha: float, beta: float, def __init__(self, size, alpha: float, beta: float,
mode: str = 'weight', **kwargs): mode: str = 'weight', **kwargs):
@ -270,13 +331,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self._amortization_freq = 50 self._amortization_freq = 50
self._amortization_counter = 0 self._amortization_counter = 0
def add(self, obs, act, rew, done, obs_next=0, info={}, weight=1.0): def add(self, obs, act, rew, done, obs_next=0, info={}, policy={},
weight=1.0):
"""Add a batch of data into replay buffer.""" """Add a batch of data into replay buffer."""
self._weight_sum += np.abs(weight)**self._alpha - \ self._weight_sum += np.abs(weight)**self._alpha - \
self.weight[self._index] self.weight[self._index]
# we have to sacrifice some convenience for speed :( # we have to sacrifice some convenience for speed :(
self._add_to_buffer('weight', np.abs(weight) ** self._alpha) self._add_to_buffer('weight', np.abs(weight) ** self._alpha)
super().add(obs, act, rew, done, obs_next, info) super().add(obs, act, rew, done, obs_next, info, policy)
self._check_weight_sum() self._check_weight_sum()
def sample(self, batch_size: int = 0, importance_sample: bool = True): def sample(self, batch_size: int = 0, importance_sample: bool = True):
@ -290,7 +352,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
# will cause weight update conflict # will cause weight update conflict
indice = np.random.choice( indice = np.random.choice(
self._size, batch_size, self._size, batch_size,
p=(self.weight/self.weight.sum())[:self._size], replace=False) p=(self.weight / self.weight.sum())[:self._size],
replace=False)
# self._weight_sum is not work for the accuracy issue # self._weight_sum is not work for the accuracy issue
# p=(self.weight/self._weight_sum)[:self._size], replace=False) # p=(self.weight/self._weight_sum)[:self._size], replace=False)
elif batch_size == 0: elif batch_size == 0:
@ -306,7 +369,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
if importance_sample: if importance_sample:
impt_weight = Batch( impt_weight = Batch(
impt_weight=1 / np.power( impt_weight=1 / np.power(
self._size*(batch.weight/self._weight_sum), self._beta)) self._size * (batch.weight / self._weight_sum),
self._beta))
batch.append(impt_weight) batch.append(impt_weight)
self._check_weight_sum() self._check_weight_sum()
return batch, indice return batch, indice
@ -316,7 +380,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
super().reset() super().reset()
def update_weight(self, indice, new_weight: np.ndarray): def update_weight(self, indice, new_weight: np.ndarray):
"""update priority weight by indice in this buffer """Update priority weight by indice in this buffer.
:param indice: indice you want to update weight :param indice: indice you want to update weight
:param new_weight: new priority weight you wangt to update :param new_weight: new priority weight you wangt to update
@ -333,7 +397,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
done=self.done[index], done=self.done[index],
obs_next=self.get(index, 'obs_next'), obs_next=self.get(index, 'obs_next'),
info=self.info[index], info=self.info[index],
weight=self.weight[index] weight=self.weight[index],
policy=self.get(index, 'policy'),
) )
def _check_weight_sum(self): def _check_weight_sum(self):

View File

@ -54,7 +54,7 @@ class A2CPolicy(PGPolicy):
batch, None, gamma=self._gamma, gae_lambda=self._lambda) batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = [] v_ = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, permute=False): for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next).detach().cpu().numpy()) v_.append(self.critic(b.obs_next).detach().cpu().numpy())
v_ = np.concatenate(v_, axis=0) v_ = np.concatenate(v_, axis=0)
return self.compute_episodic_return( return self.compute_episodic_return(

View File

@ -74,7 +74,7 @@ class PPOPolicy(PGPolicy):
batch, None, gamma=self._gamma, gae_lambda=self._lambda) batch, None, gamma=self._gamma, gae_lambda=self._lambda)
v_ = [] v_ = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(self._batch, permute=False): for b in batch.split(self._batch, shuffle=False):
v_.append(self.critic(b.obs_next)) v_.append(self.critic(b.obs_next))
v_ = torch.cat(v_, dim=0).cpu().numpy() v_ = torch.cat(v_, dim=0).cpu().numpy()
return self.compute_episodic_return( return self.compute_episodic_return(
@ -111,7 +111,7 @@ class PPOPolicy(PGPolicy):
v = [] v = []
old_log_prob = [] old_log_prob = []
with torch.no_grad(): with torch.no_grad():
for b in batch.split(batch_size, permute=False): for b in batch.split(batch_size, shuffle=False):
v.append(self.critic(b.obs)) v.append(self.critic(b.obs))
old_log_prob.append(self(b).dist.log_prob( old_log_prob.append(self(b).dist.log_prob(
torch.tensor(b.act, device=v[0].device))) torch.tensor(b.act, device=v[0].device)))