From f1951780ab7e5ee37eb93a06d8992f0a19d1b5d4 Mon Sep 17 00:00:00 2001 From: Trinkle23897 <463003665@qq.com> Date: Tue, 9 Jun 2020 18:46:14 +0800 Subject: [PATCH] fix a bug of storing batch over batch data into buffer --- docs/tutorials/cheatsheet.rst | 54 ++++++++++++++++++++++++++++++++++- test/base/test_collector.py | 15 ++-------- tianshou/data/buffer.py | 5 +++- tianshou/data/collector.py | 2 +- tianshou/policy/base.py | 2 ++ 5 files changed, 63 insertions(+), 15 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 178fb32..878dd59 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -3,7 +3,7 @@ Cheat Sheet This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios. -By the way, some of these issues can be resolved by using a ``gym.wrapper``. It can be a solution in the policy-environment interaction. +By the way, some of these issues can be resolved by using a ``gym.wrapper``. It could be a universal solution in the policy-environment interaction. .. _network_api: @@ -48,6 +48,54 @@ where ``env_fns`` is a list of callable env hooker. The above code can be writte env_fns = [lambda x=i: MyTestEnv(size=x) for i in [2, 3, 4, 5]] venv = SubprocVectorEnv(env_fns) +.. _preprocess_fn: + +Handle Batched Data Stream in Collector +--------------------------------------- + +This is related to `Issue 42 `_. + +If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. + +This function receives typically 7 keys, as listed in :class:`~tianshou.data.Batch`, and returns the modified part within a dict or a Batch. For example, you can write your hook as: +:: + + import numpy as np + from collections import deque + class MyProcessor: + def __init__(self, size=100): + self.episode_log = None + self.main_log = deque(maxlen=size) + self.main_log.append(0) + self.baseline = 0 + def preprocess_fn(**kwargs): + """change reward to zero mean""" + if 'rew' not in kwargs: + # means that it is called after env.reset(), it can only process the obs + return {} # none of the variables are needed to be updated + else: + n = len(kwargs['rew']) # the number of envs in collector + if self.episode_log is None: + self.episode_log = [[] for i in range(n)] + for i in range(n): + self.episode_log[i].append(kwargs['rew'][i]) + kwargs['rew'][i] -= self.baseline + for i in range(n): + if kwargs['done']: + self.main_log.append(np.mean(self.episode_log[i])) + self.episode_log[i] = [] + self.baseline = np.mean(self.main_log) + return Batch(rew=kwargs['rew']) + # you can also return with {'rew': kwargs['rew']} + +And finally, +:: + + test_processor = MyProcessor(size=100) + collector = Collector(policy, env, buffer, test_processor.preprocess_fn) + +Some examples are in `test/base/test_collector.py `_. + .. _rnn_training: RNN-style Training @@ -86,6 +134,10 @@ First of all, your self-defined environment must follow the Gym's API, some of t - close() -> None +- observation_space + +- action_space + The state can be a ``numpy.ndarray`` or a Python dictionary. Take ``FetchReach-v1`` as an example: :: diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 95d8da1..183397a 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -53,12 +53,7 @@ class Logger(object): def test_collector(): writer = SummaryWriter('log/collector') logger = Logger(writer) - env_fns = [ - lambda: MyTestEnv(size=2, sleep=0), - lambda: MyTestEnv(size=3, sleep=0), - lambda: MyTestEnv(size=4, sleep=0), - lambda: MyTestEnv(size=5, sleep=0), - ] + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] venv = SubprocVectorEnv(env_fns) policy = MyPolicy() @@ -100,12 +95,8 @@ def test_collector_with_dict_state(): c0 = Collector(policy, env, ReplayBuffer(size=100), preprocess_fn) c0.collect(n_step=3) c0.collect(n_episode=3) - env_fns = [ - lambda: MyTestEnv(size=2, sleep=0, dict_state=True), - lambda: MyTestEnv(size=3, sleep=0, dict_state=True), - lambda: MyTestEnv(size=4, sleep=0, dict_state=True), - lambda: MyTestEnv(size=5, sleep=0, dict_state=True), - ] + env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) + for i in [2, 3, 4, 5]] envs = VectorEnv(env_fns) c1 = Collector(policy, envs, ReplayBuffer(size=100), preprocess_fn) c1.collect(n_step=10) diff --git a/tianshou/data/buffer.py b/tianshou/data/buffer.py index 92e34a6..a87f53b 100644 --- a/tianshou/data/buffer.py +++ b/tianshou/data/buffer.py @@ -133,7 +133,10 @@ class ReplayBuffer(object): d = {} for k_ in self._meta[key]: k__ = '_' + key + '@' + k_ - d[k_] = self.__dict__[k__] + if k__ in self.__dict__: + d[k_] = self.__dict__[k__] + else: + d[k_] = self.__getattr__(k__) return Batch(**d) def _add_to_buffer(self, name: str, inst: Any) -> None: diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 837f4fa..92d39e9 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -179,7 +179,7 @@ class Collector(object): if hasattr(self.env, 'close'): self.env.close() - def _make_batch(self, data: Any) -> Union[Any, np.ndarray]: + def _make_batch(self, data: Any) -> np.ndarray: """Return [data].""" if isinstance(data, np.ndarray): return data[None] diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e7fa474..f60a472 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -51,6 +51,8 @@ class BasePolicy(ABC, nn.Module): def __init__(self, **kwargs) -> None: super().__init__() + self.observation_space = kwargs.get('observation_space') + self.action_space = kwargs.get('action_space') def process_fn(self, batch: Batch, buffer: ReplayBuffer, indice: np.ndarray) -> Batch: