add env_id in preprocess fn (#391)
This commit is contained in:
parent
ebaca6f8da
commit
c19876179a
@ -129,7 +129,7 @@ This is related to `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_.
|
|||||||
|
|
||||||
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
|
If you want to get log stat from data stream / pre-process batch-image / modify the reward with given env info, use ``preproces_fn`` in :class:`~tianshou.data.Collector`. This is a hook which will be called before the data adding into the buffer.
|
||||||
|
|
||||||
It will receive with only "obs" when the collector resets the environment, and will receive five keys "obs_next", "rew", "done", "info", "policy" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
|
It will receive with "obs" and "env_id" when the collector resets the environment, and will receive six keys "obs_next", "rew", "done", "info", "policy", "env_id" in a normal env step. It returns either a dict or a :class:`~tianshou.data.Batch` with the modified keys and values.
|
||||||
|
|
||||||
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
|
These variables are intended to gather all the information requires to keep track of a simulation step, namely the (observation, action, reward, done flag, next observation, info, intermediate result of the policy) at time t, for the whole duration of the simulation.
|
||||||
|
|
||||||
@ -149,8 +149,8 @@ For example, you can write your hook as:
|
|||||||
|
|
||||||
def preprocess_fn(**kwargs):
|
def preprocess_fn(**kwargs):
|
||||||
"""change reward to zero mean"""
|
"""change reward to zero mean"""
|
||||||
# if only obs exist -> reset
|
# if obs && env_id exist -> reset
|
||||||
# if obs_next/act/rew/done/policy exist -> normal step
|
# if obs_next/act/rew/done/policy/env_id exist -> normal step
|
||||||
if 'rew' not in kwargs:
|
if 'rew' not in kwargs:
|
||||||
# means that it is called after env.reset(), it can only process the obs
|
# means that it is called after env.reset(), it can only process the obs
|
||||||
return Batch() # none of the variables are needed to be updated
|
return Batch() # none of the variables are needed to be updated
|
||||||
|
@ -50,8 +50,8 @@ class Logger:
|
|||||||
|
|
||||||
def preprocess_fn(self, **kwargs):
|
def preprocess_fn(self, **kwargs):
|
||||||
# modify info before adding into the buffer, and recorded into tfb
|
# modify info before adding into the buffer, and recorded into tfb
|
||||||
# if only obs exist -> reset
|
# if obs && env_id exist -> reset
|
||||||
# if obs_next/rew/done/info exist -> normal step
|
# if obs_next/rew/done/info/env_id exist -> normal step
|
||||||
if 'rew' in kwargs:
|
if 'rew' in kwargs:
|
||||||
info = kwargs['info']
|
info = kwargs['info']
|
||||||
info.rew = kwargs['rew']
|
info.rew = kwargs['rew']
|
||||||
|
@ -128,7 +128,7 @@ def train_agent(
|
|||||||
policy, train_envs,
|
policy, train_envs,
|
||||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||||
exploration_noise=True)
|
exploration_noise=True)
|
||||||
test_collector = Collector(policy, test_envs)
|
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||||
# policy.set_eps(1)
|
# policy.set_eps(1)
|
||||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||||
# log
|
# log
|
||||||
@ -180,7 +180,7 @@ def watch(
|
|||||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent)
|
args, agent_learn=agent_learn, agent_opponent=agent_opponent)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
policy.policies[args.agent_id - 1].set_eps(args.eps_test)
|
||||||
collector = Collector(policy, env)
|
collector = Collector(policy, env, exploration_noise=True)
|
||||||
result = collector.collect(n_episode=1, render=args.render)
|
result = collector.collect(n_episode=1, render=args.render)
|
||||||
rews, lens = result["rews"], result["lens"]
|
rews, lens = result["rews"], result["lens"]
|
||||||
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
print(f"Final reward: {rews[:, args.agent_id - 1].mean()}, length: {lens.mean()}")
|
||||||
|
@ -35,11 +35,11 @@ class Collector(object):
|
|||||||
exploration noise into action. Default to False.
|
exploration noise into action. Default to False.
|
||||||
|
|
||||||
The "preprocess_fn" is a function called before the data has been added to the
|
The "preprocess_fn" is a function called before the data has been added to the
|
||||||
buffer with batch format. It will receive with only "obs" when the collector resets
|
buffer with batch format. It will receive only "obs" and "env_id" when the
|
||||||
the environment, and will receive five keys "obs_next", "rew", "done", "info", and
|
collector resets the environment, and will receive six keys "obs_next", "rew",
|
||||||
"policy" in a normal env step. It returns either a dict or a
|
"done", "info", "policy" and "env_id" in a normal env step. It returns either a
|
||||||
:class:`~tianshou.data.Batch` with the modified keys and values. Examples are in
|
dict or a :class:`~tianshou.data.Batch` with the modified keys and values. Examples
|
||||||
"test/base/test_collector.py".
|
are in "test/base/test_collector.py".
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
@ -115,7 +115,8 @@ class Collector(object):
|
|||||||
"""Reset all of the environments."""
|
"""Reset all of the environments."""
|
||||||
obs = self.env.reset()
|
obs = self.env.reset()
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs = self.preprocess_fn(obs=obs).get("obs", obs)
|
obs = self.preprocess_fn(
|
||||||
|
obs=obs, env_id=np.arange(self.env_num)).get("obs", obs)
|
||||||
self.data.obs = obs
|
self.data.obs = obs
|
||||||
|
|
||||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||||
@ -235,6 +236,7 @@ class Collector(object):
|
|||||||
done=self.data.done,
|
done=self.data.done,
|
||||||
info=self.data.info,
|
info=self.data.info,
|
||||||
policy=self.data.policy,
|
policy=self.data.policy,
|
||||||
|
env_id=ready_env_ids,
|
||||||
))
|
))
|
||||||
|
|
||||||
if render:
|
if render:
|
||||||
@ -260,7 +262,8 @@ class Collector(object):
|
|||||||
# finished episodes, we have to reset finished envs first.
|
# finished episodes, we have to reset finished envs first.
|
||||||
obs_reset = self.env.reset(env_ind_global)
|
obs_reset = self.env.reset(env_ind_global)
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
obs_reset = self.preprocess_fn(
|
||||||
|
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
|
||||||
self.data.obs_next[env_ind_local] = obs_reset
|
self.data.obs_next[env_ind_local] = obs_reset
|
||||||
for i in env_ind_local:
|
for i in env_ind_local:
|
||||||
self._reset_state(i)
|
self._reset_state(i)
|
||||||
@ -442,6 +445,7 @@ class AsyncCollector(Collector):
|
|||||||
rew=self.data.rew,
|
rew=self.data.rew,
|
||||||
done=self.data.done,
|
done=self.data.done,
|
||||||
info=self.data.info,
|
info=self.data.info,
|
||||||
|
env_id=ready_env_ids,
|
||||||
))
|
))
|
||||||
|
|
||||||
if render:
|
if render:
|
||||||
@ -467,7 +471,8 @@ class AsyncCollector(Collector):
|
|||||||
# finished episodes, we have to reset finished envs first.
|
# finished episodes, we have to reset finished envs first.
|
||||||
obs_reset = self.env.reset(env_ind_global)
|
obs_reset = self.env.reset(env_ind_global)
|
||||||
if self.preprocess_fn:
|
if self.preprocess_fn:
|
||||||
obs_reset = self.preprocess_fn(obs=obs_reset).get("obs", obs_reset)
|
obs_reset = self.preprocess_fn(
|
||||||
|
obs=obs_reset, env_id=env_ind_global).get("obs", obs_reset)
|
||||||
self.data.obs_next[env_ind_local] = obs_reset
|
self.data.obs_next[env_ind_local] = obs_reset
|
||||||
for i in env_ind_local:
|
for i in env_ind_local:
|
||||||
self._reset_state(i)
|
self._reset_state(i)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user