fix a bug of storing batch over batch data into buffer
This commit is contained in:
parent
b32b96cd3e
commit
f1951780ab
@ -3,7 +3,7 @@ Cheat Sheet
|
||||
|
||||
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
|
||||
|
||||
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It can be a solution in the policy-environment interaction.
|
||||
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction.
|
||||
|
||||
.. _network_api:
|
||||
|
||||
@ -48,6 +48,54 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
|
||||
.. _preprocess_fn:
|
||||
|
||||
Handle Batched Data Stream in Collector
|
||||
---------------------------------------
|
||||
|
||||
This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
|
||||
|
||||
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
|
||||
|
||||
This function receives typically 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a dict or a Batch. For example, you can write your hook as:
|
||||
::
|
||||
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
class MyProcessor:
|
||||
def __init__(self, size=100):
|
||||
self.episode_log = None
|
||||
self.main_log = deque(maxlen=size)
|
||||
self.main_log.append(0)
|
||||
self.baseline = 0
|
||||
def preprocess_fn(**kwargs):
|
||||
"""change reward to zero mean"""
|
||||
if 'rew' not in kwargs:
|
||||
# means that it is called after env.reset(), it can only process the obs
|
||||
return {} # none of the variables are needed to be updated
|
||||
else:
|
||||
n = len(kwargs['rew']) # the number of envs in collector
|
||||
if self.episode_log is None:
|
||||
self.episode_log = [[] for i in range(n)]
|
||||
for i in range(n):
|
||||
self.episode_log[i].append(kwargs['rew'][i])
|
||||
kwargs['rew'][i] -= self.baseline
|
||||
for i in range(n):
|
||||
if kwargs['done']:
|
||||
self.main_log.append(np.mean(self.episode_log[i]))
|
||||
self.episode_log[i] = []
|
||||
self.baseline = np.mean(self.main_log)
|
||||
return Batch(rew=kwargs['rew'])
|
||||
# you can also return with {'rew': kwargs['rew']}
|
||||
|
||||
And finally,
|
||||
::
|
||||
|
||||
test_processor = MyProcessor(size=100)
|
||||
collector = Collector(policy, env, buffer, test_processor.preprocess_fn)
|
||||
|
||||
Some examples are in `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_.
|
||||
|
||||
.. _rnn_training:
|
||||
|
||||
RNN-style Training
|
||||
@ -86,6 +134,10 @@ First of all, your self-defined environment must follow the Gym's API, some of t
|
||||
|
||||
- close() -> None
|
||||
|
||||
- observation_space
|
||||
|
||||
- action_space
|
||||
|
||||
The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example:
|
||||
::
|
||||
|
||||
|
@ -53,12 +53,7 @@ class Logger(object):
|
||||
def test_collector():
|
||||
writer = SummaryWriter('log/collector')
|
||||
logger = Logger(writer)
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=2, sleep=0),
|
||||
lambda: MyTestEnv(size=3, sleep=0),
|
||||
lambda: MyTestEnv(size=4, sleep=0),
|
||||
lambda: MyTestEnv(size=5, sleep=0),
|
||||
]
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
@ -100,12 +95,8 @@ def test_collector_with_dict_state():
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
|
||||
c0.collect(n_step=3)
|
||||
c0.collect(n_episode=3)
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
|
||||
lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
|
||||
lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
|
||||
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
|
||||
]
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
||||
for i in [2, 3, 4, 5]]
|
||||
envs = VectorEnv(env_fns)
|
||||
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
|
||||
c1.collect(n_step=10)
|
||||
|
@ -133,7 +133,10 @@ class ReplayBuffer(object):
|
||||
d = {}
|
||||
for k_ in self._meta[key]:
|
||||
k__ = '_' + key + '@' + k_
|
||||
d[k_] = self.__dict__[k__]
|
||||
if k__ in self.__dict__:
|
||||
d[k_] = self.__dict__[k__]
|
||||
else:
|
||||
d[k_] = self.__getattr__(k__)
|
||||
return Batch(**d)
|
||||
|
||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||
|
@ -179,7 +179,7 @@ class Collector(object):
|
||||
if hasattr(self.env, 'close'):
|
||||
self.env.close()
|
||||
|
||||
def _make_batch(self, data: Any) -> Union[Any, np.ndarray]:
|
||||
def _make_batch(self, data: Any) -> np.ndarray:
|
||||
"""Return [data]."""
|
||||
if isinstance(data, np.ndarray):
|
||||
return data[None]
|
||||
|
@ -51,6 +51,8 @@ class BasePolicy(ABC, nn.Module):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__()
|
||||
self.observation_space = kwargs.get('observation_space')
|
||||
self.action_space = kwargs.get('action_space')
|
||||
|
||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||
indice: np.ndarray) -> Batch:
|
||||
|
Loading…
x
Reference in New Issue
Block a user