fix a bug of storing batch over batch data into buffer
This commit is contained in:
parent
b32b96cd3e
commit
f1951780ab
@ -3,7 +3,7 @@ Cheat Sheet
|
|||||||
|
|
||||||
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
|
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
|
||||||
|
|
||||||
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It can be a solution in the policy-environment interaction.
|
By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction.
|
||||||
|
|
||||||
.. _network_api:
|
.. _network_api:
|
||||||
|
|
||||||
@ -48,6 +48,54 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte
|
|||||||
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]]
|
||||||
venv = SubprocVectorEnv(env_fns)
|
venv = SubprocVectorEnv(env_fns)
|
||||||
|
|
||||||
|
.. _preprocess_fn:
|
||||||
|
|
||||||
|
Handle Batched Data Stream in Collector
|
||||||
|
---------------------------------------
|
||||||
|
|
||||||
|
This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
|
||||||
|
|
||||||
|
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
|
||||||
|
|
||||||
|
This function receives typically 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a dict or a Batch. For example, you can write your hook as:
|
||||||
|
::
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
class MyProcessor:
|
||||||
|
def __init__(self, size=100):
|
||||||
|
self.episode_log = None
|
||||||
|
self.main_log = deque(maxlen=size)
|
||||||
|
self.main_log.append(0)
|
||||||
|
self.baseline = 0
|
||||||
|
def preprocess_fn(**kwargs):
|
||||||
|
"""change reward to zero mean"""
|
||||||
|
if 'rew' not in kwargs:
|
||||||
|
# means that it is called after env.reset(), it can only process the obs
|
||||||
|
return {} # none of the variables are needed to be updated
|
||||||
|
else:
|
||||||
|
n = len(kwargs['rew']) # the number of envs in collector
|
||||||
|
if self.episode_log is None:
|
||||||
|
self.episode_log = [[] for i in range(n)]
|
||||||
|
for i in range(n):
|
||||||
|
self.episode_log[i].append(kwargs['rew'][i])
|
||||||
|
kwargs['rew'][i] -= self.baseline
|
||||||
|
for i in range(n):
|
||||||
|
if kwargs['done']:
|
||||||
|
self.main_log.append(np.mean(self.episode_log[i]))
|
||||||
|
self.episode_log[i] = []
|
||||||
|
self.baseline = np.mean(self.main_log)
|
||||||
|
return Batch(rew=kwargs['rew'])
|
||||||
|
# you can also return with {'rew': kwargs['rew']}
|
||||||
|
|
||||||
|
And finally,
|
||||||
|
::
|
||||||
|
|
||||||
|
test_processor = MyProcessor(size=100)
|
||||||
|
collector = Collector(policy, env, buffer, test_processor.preprocess_fn)
|
||||||
|
|
||||||
|
Some examples are in `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_.
|
||||||
|
|
||||||
.. _rnn_training:
|
.. _rnn_training:
|
||||||
|
|
||||||
RNN-style Training
|
RNN-style Training
|
||||||
@ -86,6 +134,10 @@ First of all, your self-defined environment must follow the Gym's API, some of t
|
|||||||
|
|
||||||
- close() -> None
|
- close() -> None
|
||||||
|
|
||||||
|
- observation_space
|
||||||
|
|
||||||
|
- action_space
|
||||||
|
|
||||||
The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example:
|
The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example:
|
||||||
::
|
::
|
||||||
|
|
||||||
|
@ -53,12 +53,7 @@ class Logger(object):
|
|||||||
def test_collector():
|
def test_collector():
|
||||||
writer = SummaryWriter('log/collector')
|
writer = SummaryWriter('log/collector')
|
||||||
logger = Logger(writer)
|
logger = Logger(writer)
|
||||||
env_fns = [
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||||
lambda: MyTestEnv(size=2, sleep=0),
|
|
||||||
lambda: MyTestEnv(size=3, sleep=0),
|
|
||||||
lambda: MyTestEnv(size=4, sleep=0),
|
|
||||||
lambda: MyTestEnv(size=5, sleep=0),
|
|
||||||
]
|
|
||||||
|
|
||||||
venv = SubprocVectorEnv(env_fns)
|
venv = SubprocVectorEnv(env_fns)
|
||||||
policy = MyPolicy()
|
policy = MyPolicy()
|
||||||
@ -100,12 +95,8 @@ def test_collector_with_dict_state():
|
|||||||
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
|
c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn)
|
||||||
c0.collect(n_step=3)
|
c0.collect(n_step=3)
|
||||||
c0.collect(n_episode=3)
|
c0.collect(n_episode=3)
|
||||||
env_fns = [
|
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True)
|
||||||
lambda: MyTestEnv(size=2, sleep=0, dict_state=True),
|
for i in [2, 3, 4, 5]]
|
||||||
lambda: MyTestEnv(size=3, sleep=0, dict_state=True),
|
|
||||||
lambda: MyTestEnv(size=4, sleep=0, dict_state=True),
|
|
||||||
lambda: MyTestEnv(size=5, sleep=0, dict_state=True),
|
|
||||||
]
|
|
||||||
envs = VectorEnv(env_fns)
|
envs = VectorEnv(env_fns)
|
||||||
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
|
c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn)
|
||||||
c1.collect(n_step=10)
|
c1.collect(n_step=10)
|
||||||
|
@ -133,7 +133,10 @@ class ReplayBuffer(object):
|
|||||||
d = {}
|
d = {}
|
||||||
for k_ in self._meta[key]:
|
for k_ in self._meta[key]:
|
||||||
k__ = '_' + key + '@' + k_
|
k__ = '_' + key + '@' + k_
|
||||||
|
if k__ in self.__dict__:
|
||||||
d[k_] = self.__dict__[k__]
|
d[k_] = self.__dict__[k__]
|
||||||
|
else:
|
||||||
|
d[k_] = self.__getattr__(k__)
|
||||||
return Batch(**d)
|
return Batch(**d)
|
||||||
|
|
||||||
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
def _add_to_buffer(self, name: str, inst: Any) -> None:
|
||||||
|
@ -179,7 +179,7 @@ class Collector(object):
|
|||||||
if hasattr(self.env, 'close'):
|
if hasattr(self.env, 'close'):
|
||||||
self.env.close()
|
self.env.close()
|
||||||
|
|
||||||
def _make_batch(self, data: Any) -> Union[Any, np.ndarray]:
|
def _make_batch(self, data: Any) -> np.ndarray:
|
||||||
"""Return [data]."""
|
"""Return [data]."""
|
||||||
if isinstance(data, np.ndarray):
|
if isinstance(data, np.ndarray):
|
||||||
return data[None]
|
return data[None]
|
||||||
|
@ -51,6 +51,8 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
|
|
||||||
def __init__(self, **kwargs) -> None:
|
def __init__(self, **kwargs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.observation_space = kwargs.get('observation_space')
|
||||||
|
self.action_space = kwargs.get('action_space')
|
||||||
|
|
||||||
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
def process_fn(self, batch: Batch, buffer: ReplayBuffer,
|
||||||
indice: np.ndarray) -> Batch:
|
indice: np.ndarray) -> Batch:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user