add env_id in preprocess fn (#391)

This commit is contained in:
n+e 2021-07-05 09:50:39 +08:00 committed by GitHub
parent ebaca6f8da
commit c19876179a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 15 deletions

View File

@ -129,7 +129,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer. If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values. It will receive with "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy", "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation. These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
@ -149,8 +149,8 @@ For example, you can write your hook as:
def preprocess_fn(**kwargs): def preprocess_fn(**kwargs):
"""change reward to zero mean""" """change reward to zero mean"""
# if only obs exist -> reset # if obs && env_id exist -> reset
# if obs_next/act/rew/done/policy exist -> normal step # if obs_next/act/rew/done/policy/env_id exist -> normal step
if 'rew' not in kwargs: if 'rew' not in kwargs:
# means that it is called after env.reset(), it can only process the obs # means that it is called after env.reset(), it can only process the obs
return Batch() # none of the variables are needed to be updated return Batch() # none of the variables are needed to be updated

View File

@ -50,8 +50,8 @@ class Logger:
def preprocess_fn(self, **kwargs): def preprocess_fn(self, **kwargs):
# modify info before adding into the buffer, and recorded into tfb # modify info before adding into the buffer, and recorded into tfb
# if only obs exist -> reset # if obs && env_id exist -> reset
# if obs_next/rew/done/info exist -> normal step # if obs_next/rew/done/info/env_id exist -> normal step
if 'rew' in kwargs: if 'rew' in kwargs:
info = kwargs['info'] info = kwargs['info']
info.rew = kwargs['rew'] info.rew = kwargs['rew']

View File

@ -128,7 +128,7 @@ def train_agent(
policy, train_envs, policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)), VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True) exploration_noise=True)
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num)
# log # log
@ -180,7 +180,7 @@ def watch(
args, agent_learn=agent_learn, agent_opponent=agent_opponent) args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.eval() policy.eval()
policy.policies[args.agent_id - 1].set_eps(args.eps_test) policy.policies[args.agent_id - 1].set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render) result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"] rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}") print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")

View File

@ -35,11 +35,11 @@ class Collector(object):
exploration noise into action. Default to False. exploration noise into action. Default to False.
The "preprocess_fn" is a function called before the data has been added to the The "preprocess_fn" is a function called before the data has been added to the
buffer with batch format. It will receive with only "obs" when the collector resets buffer with batch format. It will receive only "obs" and "env_id" when the
the environment, and will receive five keys "obs_next", "rew", "done", "info", and collector resets the environment, and will receive six keys "obs_next", "rew",
"policy" in a normal env step. It returns either a dict or a "done", "info", "policy" and "env_id" in a normal env step. It returns either a
:class:`~tianshou.data.Batch` with the modified keys and values. Examples are in dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples
"test/base/test_collector.py". are in "test/base/test_collector.py".
.. note:: .. note::
@ -115,7 +115,8 @@ class Collector(object):
"""Reset all of the environments.""" """Reset all of the environments."""
obs = self.env.reset() obs = self.env.reset()
if self.preprocess_fn: if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs).get("obs", obs) obs = self.preprocess_fn(
obs=obs, env_id=np.arange(self.env_num)).get("obs", obs)
self.data.obs = obs self.data.obs = obs
def _reset_state(self, id: Union[int, List[int]]) -> None: def _reset_state(self, id: Union[int, List[int]]) -> None:
@ -235,6 +236,7 @@ class Collector(object):
done=self.data.done, done=self.data.done,
info=self.data.info, info=self.data.info,
policy=self.data.policy, policy=self.data.policy,
env_id=ready_env_ids,
)) ))
if render: if render:
@ -260,7 +262,8 @@ class Collector(object):
# finished episodes, we have to reset finished envs first. # finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global) obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn: if self.preprocess_fn:
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local: for i in env_ind_local:
self._reset_state(i) self._reset_state(i)
@ -442,6 +445,7 @@ class AsyncCollector(Collector):
rew=self.data.rew, rew=self.data.rew,
done=self.data.done, done=self.data.done,
info=self.data.info, info=self.data.info,
env_id=ready_env_ids,
)) ))
if render: if render:
@ -467,7 +471,8 @@ class AsyncCollector(Collector):
# finished episodes, we have to reset finished envs first. # finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global) obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn: if self.preprocess_fn:
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset) obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local: for i in env_ind_local:
self._reset_state(i) self._reset_state(i)