fix a bug of storing batch over batch data into buffer

This commit is contained in:
Trinkle23897 2020-06-09 18:46:14 +08:00
parent b32b96cd3e
commit f1951780ab
5 changed files with 63 additions and 15 deletions

View File

@ -3,7 +3,7 @@ Cheat Sheet
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It can be a solution in the policy-environment interaction.
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction.
.. _network_api:
@ -48,6 +48,54 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)
.. _preprocess_fn:
Handle Batched Data Stream in Collector
---------------------------------------
This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
This function receives typically 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a dict or a Batch. For example, you can write your hook as:
::
import numpy as np
from collections import deque
class MyProcessor:
def __init__(self, size=100):
self.episode_log = None
self.main_log = deque(maxlen=size)
self.main_log.append(0)
self.baseline = 0
def preprocess_fn(**kwargs):
"""change reward to zero mean"""
if 'rew' not in kwargs:
# means that it is called after env.reset(), it can only process the obs
return {} # none of the variables are needed to be updated
else:
n = len(kwargs['rew']) # the number of envs in collector
if self.episode_log is None:
self.episode_log = [[] for i in range(n)]
for i in range(n):
self.episode_log[i].append(kwargs['rew'][i])
kwargs['rew'][i] -= self.baseline
for i in range(n):
if kwargs['done']:
self.main_log.append(np.mean(self.episode_log[i]))
self.episode_log[i] = []
self.baseline = np.mean(self.main_log)
return Batch(rew=kwargs['rew'])
# you can also return with {'rew': kwargs['rew']}
And finally,
::
test_processor = MyProcessor(size=100)
collector = Collector(policy, env, buffer, test_processor.preprocess_fn)
Some examples are in `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_.
.. _rnn_training:
RNN-style Training
@ -86,6 +134,10 @@ First of all, your self-defined environment must follow the Gym's API, some of t
- close() -> None
- observation_space
- action_space
The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example:
::

View File

@ -53,12 +53,7 @@ class Logger(object):
def test_collector():
writer = SummaryWriter('log/collector')
logger = Logger(writer)
env_fns = [
lambda: MyTestEnv(size=2, sleep=0),
lambda: MyTestEnv(size=3, sleep=0),
lambda: MyTestEnv(size=4, sleep=0),
lambda: MyTestEnv(size=5, sleep=0),
]
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)
policy = MyPolicy()
@ -100,12 +95,8 @@ def test_collector_with_dict_state():
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
c0.collect(n_step=3)
c0.collect(n_episode=3)
env_fns = [
lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
]
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
for i in [2, 3, 4, 5]]
envs = VectorEnv(env_fns)
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
c1.collect(n_step=10)

View File

@ -133,7 +133,10 @@ class ReplayBuffer(object):
d = {}
for k_ in self._meta[key]:
k__ = '_' + key + '@' + k_
d[k_] = self.__dict__[k__]
if k__ in self.__dict__:
d[k_] = self.__dict__[k__]
else:
d[k_] = self.__getattr__(k__)
return Batch(**d)
def _add_to_buffer(self, name: str, inst: Any) -> None:

View File

@ -179,7 +179,7 @@ class Collector(object):
if hasattr(self.env, 'close'):
self.env.close()
def _make_batch(self, data: Any) -> Union[Any, np.ndarray]:
def _make_batch(self, data: Any) -> np.ndarray:
"""Return [data]."""
if isinstance(data, np.ndarray):
return data[None]

View File

@ -51,6 +51,8 @@ class BasePolicy(ABC, nn.Module):
def __init__(self, **kwargs) -> None:
super().__init__()
self.observation_space = kwargs.get('observation_space')
self.action_space = kwargs.get('action_space')
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
indice: np.ndarray) -> Batch: