add env_id in preprocess fn (#391)

This commit is contained in:
n+e 2021-07-05 09:50:39 +08:00 committed by GitHub
parent ebaca6f8da
commit c19876179a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 15 deletions

View File

@ -129,7 +129,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
It will receive with "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy", "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
@ -149,8 +149,8 @@ For example, you can write your hook as:
def preprocess_fn(**kwargs):
"""change reward to zero mean"""
# if only obs exist -> reset
# if obs_next/act/rew/done/policy exist -> normal step
# if obs && env_id exist -> reset
# if obs_next/act/rew/done/policy/env_id exist -> normal step
if 'rew' not in kwargs:
# means that it is called after env.reset(), it can only process the obs
return Batch() # none of the variables are needed to be updated

View File

@ -50,8 +50,8 @@ class Logger:
def preprocess_fn(self, **kwargs):
# modify info before adding into the buffer, and recorded into tfb
# if only obs exist -> reset
# if obs_next/rew/done/info exist -> normal step
# if obs && env_id exist -> reset
# if obs_next/rew/done/info/env_id exist -> normal step
if 'rew' in kwargs:
info = kwargs['info']
info.rew = kwargs['rew']

View File

@ -128,7 +128,7 @@ def train_agent(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
test_collector = Collector(policy, test_envs)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
@ -180,7 +180,7 @@ def watch(
args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.eval()
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
collector = Collector(policy, env)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")

View File

@ -35,11 +35,11 @@ class Collector(object):
exploration noise into action. Default to False.
The "preprocess_fn" is a function called before the data has been added to the
buffer with batch format. It will receive with only "obs" when the collector resets
the environment, and will receive five keys "obs_next", "rew", "done", "info", and
"policy" in a normal env step. It returns either a dict or a
:class:`~tianshou.data.Batch` with the modified keys and values. Examples are in
"test/base/test_collector.py".
buffer with batch format. It will receive only "obs" and "env_id" when the
collector resets the environment, and will receive six keys "obs_next", "rew",
"done", "info", "policy" and "env_id" in a normal env step. It returns either a
dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples
are in "test/base/test_collector.py".
.. note::
@ -115,7 +115,8 @@ class Collector(object):
"""Reset all of the environments."""
obs = self.env.reset()
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs).get("obs", obs)
obs = self.preprocess_fn(
obs=obs, env_id=np.arange(self.env_num)).get("obs", obs)
self.data.obs = obs
def _reset_state(self, id: Union[int, List[int]]) -> None:
@ -235,6 +236,7 @@ class Collector(object):
done=self.data.done,
info=self.data.info,
policy=self.data.policy,
env_id=ready_env_ids,
))
if render:
@ -260,7 +262,8 @@ class Collector(object):
# finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local:
self._reset_state(i)
@ -442,6 +445,7 @@ class AsyncCollector(Collector):
rew=self.data.rew,
done=self.data.done,
info=self.data.info,
env_id=ready_env_ids,
))
if render:
@ -467,7 +471,8 @@ class AsyncCollector(Collector):
# finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local:
self._reset_state(i)