add env_id in preprocess fn (#391)
This commit is contained in:
parent
ebaca6f8da
commit
c19876179a
@ -129,7 +129,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
|
||||
|
||||
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
|
||||
|
||||
It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
|
||||
It will receive with "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy", "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
|
||||
|
||||
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
|
||||
|
||||
@ -149,8 +149,8 @@ For example, you can write your hook as:
|
||||
|
||||
def preprocess_fn(**kwargs):
|
||||
"""change reward to zero mean"""
|
||||
# if only obs exist -> reset
|
||||
# if obs_next/act/rew/done/policy exist -> normal step
|
||||
# if obs && env_id exist -> reset
|
||||
# if obs_next/act/rew/done/policy/env_id exist -> normal step
|
||||
if 'rew' not in kwargs:
|
||||
# means that it is called after env.reset(), it can only process the obs
|
||||
return Batch() # none of the variables are needed to be updated
|
||||
|
@ -50,8 +50,8 @@ class Logger:
|
||||
|
||||
def preprocess_fn(self, **kwargs):
|
||||
# modify info before adding into the buffer, and recorded into tfb
|
||||
# if only obs exist -> reset
|
||||
# if obs_next/rew/done/info exist -> normal step
|
||||
# if obs && env_id exist -> reset
|
||||
# if obs_next/rew/done/info/env_id exist -> normal step
|
||||
if 'rew' in kwargs:
|
||||
info = kwargs['info']
|
||||
info.rew = kwargs['rew']
|
||||
|
@ -128,7 +128,7 @@ def train_agent(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# log
|
||||
@ -180,7 +180,7 @@ def watch(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent)
|
||||
policy.eval()
|
||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
rews, lens = result["rews"], result["lens"]
|
||||
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
||||
|
@ -35,11 +35,11 @@ class Collector(object):
|
||||
exploration noise into action. Default to False.
|
||||
|
||||
The "preprocess_fn" is a function called before the data has been added to the
|
||||
buffer with batch format. It will receive with only "obs" when the collector resets
|
||||
the environment, and will receive five keys "obs_next", "rew", "done", "info", and
|
||||
"policy" in a normal env step. It returns either a dict or a
|
||||
:class:`~tianshou.data.Batch` with the modified keys and values. Examples are in
|
||||
"test/base/test_collector.py".
|
||||
buffer with batch format. It will receive only "obs" and "env_id" when the
|
||||
collector resets the environment, and will receive six keys "obs_next", "rew",
|
||||
"done", "info", "policy" and "env_id" in a normal env step. It returns either a
|
||||
dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples
|
||||
are in "test/base/test_collector.py".
|
||||
|
||||
.. note::
|
||||
|
||||
@ -115,7 +115,8 @@ class Collector(object):
|
||||
"""Reset all of the environments."""
|
||||
obs = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
||||
obs = self.preprocess_fn(
|
||||
obs=obs, env_id=np.arange(self.env_num)).get("obs", obs)
|
||||
self.data.obs = obs
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
@ -235,6 +236,7 @@ class Collector(object):
|
||||
done=self.data.done,
|
||||
info=self.data.info,
|
||||
policy=self.data.policy,
|
||||
env_id=ready_env_ids,
|
||||
))
|
||||
|
||||
if render:
|
||||
@ -260,7 +262,8 @@ class Collector(object):
|
||||
# finished episodes, we have to reset finished envs first.
|
||||
obs_reset = self.env.reset(env_ind_global)
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
||||
obs_reset = self.preprocess_fn(
|
||||
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
|
||||
self.data.obs_next[env_ind_local] = obs_reset
|
||||
for i in env_ind_local:
|
||||
self._reset_state(i)
|
||||
@ -442,6 +445,7 @@ class AsyncCollector(Collector):
|
||||
rew=self.data.rew,
|
||||
done=self.data.done,
|
||||
info=self.data.info,
|
||||
env_id=ready_env_ids,
|
||||
))
|
||||
|
||||
if render:
|
||||
@ -467,7 +471,8 @@ class AsyncCollector(Collector):
|
||||
# finished episodes, we have to reset finished envs first.
|
||||
obs_reset = self.env.reset(env_ind_global)
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
||||
obs_reset = self.preprocess_fn(
|
||||
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
|
||||
self.data.obs_next[env_ind_local] = obs_reset
|
||||
for i in env_ind_local:
|
||||
self._reset_state(i)
|
||||
|
Loading…
x
Reference in New Issue
Block a user