From c19876179a52acdc4d2e7dcb10dfbb4557adca86 Mon Sep 17 00:00:00 2001 From: n+e Date: Mon, 5 Jul 2021 09:50:39 +0800 Subject: [PATCH] add env_id in preprocess fn (#391) --- docs/tutorials/cheatsheet.rst | 6 +++--- test/base/test_collector.py | 4 ++-- test/multiagent/tic_tac_toe.py | 4 ++-- tianshou/data/collector.py | 21 +++++++++++++-------- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/docs/tutorials/cheatsheet.rst b/docs/tutorials/cheatsheet.rst index 7f8095e..8dac44a 100644 --- a/docs/tutorials/cheatsheet.rst +++ b/docs/tutorials/cheatsheet.rst @@ -129,7 +129,7 @@ This is related to `Issue 42 `_. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. -It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. +It will receive with "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy", "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation. @@ -149,8 +149,8 @@ For example, you can write your hook as: def preprocess_fn(**kwargs): """change reward to zero mean""" - # if only obs exist -> reset - # if obs_next/act/rew/done/policy exist -> normal step + # if obs && env_id exist -> reset + # if obs_next/act/rew/done/policy/env_id exist -> normal step if 'rew' not in kwargs: # means that it is called after env.reset(), it can only process the obs return Batch() # none of the variables are needed to be updated diff --git a/test/base/test_collector.py b/test/base/test_collector.py index b275526..79d1430 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -50,8 +50,8 @@ class Logger: def preprocess_fn(self, **kwargs): # modify info before adding into the buffer, and recorded into tfb - # if only obs exist -> reset - # if obs_next/rew/done/info exist -> normal step + # if obs && env_id exist -> reset + # if obs_next/rew/done/info/env_id exist -> normal step if 'rew' in kwargs: info = kwargs['info'] info.rew = kwargs['rew'] diff --git a/test/multiagent/tic_tac_toe.py b/test/multiagent/tic_tac_toe.py index 3e92838..dc4a443 100644 --- a/test/multiagent/tic_tac_toe.py +++ b/test/multiagent/tic_tac_toe.py @@ -128,7 +128,7 @@ def train_agent( policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True) - test_collector = Collector(policy, test_envs) + test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) train_collector.collect(n_step=args.batch_size * args.training_num) # log @@ -180,7 +180,7 @@ def watch( args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy.eval() policy.policies[args.agent_id - 1].set_eps(args.eps_test) - collector = Collector(policy, env) + collector = Collector(policy, env, exploration_noise=True) result = collector.collect(n_episode=1, render=args.render) rews, lens = result["rews"], result["lens"] print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index f9fede5..192213a 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -35,11 +35,11 @@ class Collector(object): exploration noise into action. Default to False. The "preprocess_fn" is a function called before the data has been added to the - buffer with batch format. It will receive with only "obs" when the collector resets - the environment, and will receive five keys "obs_next", "rew", "done", "info", and - "policy" in a normal env step. It returns either a dict or a - :class:`~tianshou.data.Batch` with the modified keys and values. Examples are in - "test/base/test_collector.py". + buffer with batch format. It will receive only "obs" and "env_id" when the + collector resets the environment, and will receive six keys "obs_next", "rew", + "done", "info", "policy" and "env_id" in a normal env step. It returns either a + dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples + are in "test/base/test_collector.py". .. note:: @@ -115,7 +115,8 @@ class Collector(object): """Reset all of the environments.""" obs = self.env.reset() if self.preprocess_fn: - obs = self.preprocess_fn(obs=obs).get("obs", obs) + obs = self.preprocess_fn( + obs=obs, env_id=np.arange(self.env_num)).get("obs", obs) self.data.obs = obs def _reset_state(self, id: Union[int, List[int]]) -> None: @@ -235,6 +236,7 @@ class Collector(object): done=self.data.done, info=self.data.info, policy=self.data.policy, + env_id=ready_env_ids, )) if render: @@ -260,7 +262,8 @@ class Collector(object): # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: - obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) + obs_reset = self.preprocess_fn( + obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i) @@ -442,6 +445,7 @@ class AsyncCollector(Collector): rew=self.data.rew, done=self.data.done, info=self.data.info, + env_id=ready_env_ids, )) if render: @@ -467,7 +471,8 @@ class AsyncCollector(Collector): # finished episodes, we have to reset finished envs first. obs_reset = self.env.reset(env_ind_global) if self.preprocess_fn: - obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) + obs_reset = self.preprocess_fn( + obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset) self.data.obs_next[env_ind_local] = obs_reset for i in env_ind_local: self._reset_state(i)