From 4b619c51baecf85a99f5ad199f402fe6e006db90 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 16:46:03 +0200 Subject: [PATCH] Collector: extracted interface BaseCollector, minor simplifications Renamed is_eval kwarg --- examples/atari/atari_c51.py | 4 +- examples/atari/atari_dqn.py | 4 +- examples/atari/atari_fqf.py | 4 +- examples/atari/atari_iqn.py | 4 +- examples/atari/atari_ppo.py | 4 +- examples/atari/atari_qrdqn.py | 4 +- examples/atari/atari_rainbow.py | 4 +- examples/atari/atari_sac.py | 4 +- examples/box2d/acrobot_dualdqn.py | 2 +- examples/box2d/bipedal_bdq.py | 2 +- examples/box2d/bipedal_hardcore_sac.py | 2 +- examples/box2d/lunarlander_dqn.py | 2 +- examples/box2d/mcc_sac.py | 2 +- examples/inverse/irl_gail.py | 2 +- examples/mujoco/fetch_her_ddpg.py | 2 +- examples/mujoco/mujoco_a2c.py | 2 +- examples/mujoco/mujoco_ddpg.py | 2 +- examples/mujoco/mujoco_npg.py | 2 +- examples/mujoco/mujoco_ppo.py | 2 +- examples/mujoco/mujoco_redq.py | 2 +- examples/mujoco/mujoco_reinforce.py | 2 +- examples/mujoco/mujoco_sac.py | 2 +- examples/mujoco/mujoco_td3.py | 2 +- examples/mujoco/mujoco_trpo.py | 2 +- examples/offline/atari_bcq.py | 2 +- examples/offline/atari_cql.py | 2 +- examples/offline/atari_crr.py | 2 +- examples/offline/atari_il.py | 2 +- examples/offline/d4rl_bcq.py | 4 +- examples/offline/d4rl_cql.py | 4 +- examples/offline/d4rl_il.py | 4 +- examples/offline/d4rl_td3_bc.py | 4 +- examples/vizdoom/vizdoom_c51.py | 4 +- examples/vizdoom/vizdoom_ppo.py | 4 +- test/offline/gather_cartpole_data.py | 2 +- test/offline/test_bcq.py | 2 +- test/pettingzoo/pistonball.py | 2 +- test/pettingzoo/pistonball_continuous.py | 2 +- test/pettingzoo/tic_tac_toe.py | 2 +- tianshou/data/collector.py | 470 ++++++++++++----------- tianshou/highlevel/experiment.py | 2 +- 41 files changed, 303 insertions(+), 275 deletions(-) diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 16694af..98d1d08 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -174,7 +174,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -184,7 +184,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index aed46a2..7d60654 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -216,7 +216,7 @@ def main(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -226,7 +226,7 @@ def main(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 92a6140..185cff1 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -187,7 +187,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -197,7 +197,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 3c4a695..5216d7c 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -184,7 +184,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -194,7 +194,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 3c2c68d..969d00a 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -240,7 +240,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -250,7 +250,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index c821fad..5b02581 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -178,7 +178,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -188,7 +188,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 0d14f5d..5bb69a3 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -214,7 +214,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: beta=args.beta, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -224,7 +224,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index a8d759a..be13884 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -227,7 +227,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -237,7 +237,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ed6a65b..ac2c800 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -150,7 +150,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index da425e6..7c89cd2 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -168,7 +168,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 8dfbed6..2147213 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -212,7 +212,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 4bd2896..8280e23 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -147,7 +147,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 2576411..97f7149 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -158,7 +158,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 3c10359..e30b3f6 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -269,7 +269,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 01a89b6..8b3439a 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -243,7 +243,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 0195702..834e0ce 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -224,7 +224,7 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 98852dc..066c7d4 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -173,7 +173,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index c2a193a..40cc443 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -221,7 +221,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 9e20c46..eb6817e 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -229,7 +229,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 791005c..9101f86 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -201,7 +201,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 8912ee6..cbd4d3e 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -201,7 +201,7 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 6c0bbdc..8d9d61e 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -195,7 +195,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 8e0c315..d1b5e69 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -193,7 +193,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 96da0cf..b5a91c0 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -226,7 +226,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 80687d1..d4e985e 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -191,7 +191,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 6107145..69ae433 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -175,7 +175,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index ad5c231..4e0771c 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -191,7 +191,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 495b630..928112c 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -148,7 +148,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 9834247..308a872 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -207,7 +207,7 @@ def test_bcq() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -233,7 +233,7 @@ def test_bcq() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 4e6a127..9d95d29 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -345,7 +345,7 @@ def test_cql() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -371,7 +371,7 @@ def test_cql() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index a8a6a66..d64cfe9 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -143,7 +143,7 @@ def test_il() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -169,7 +169,7 @@ def test_il() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 24138a4..5719eba 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -192,7 +192,7 @@ def test_td3_bc() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -221,7 +221,7 @@ def test_td3_bc() -> None: collector_stats = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) print(collector_stats) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 3566051..8a52e5f 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -180,7 +180,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -190,7 +190,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index 907b0eb..adbfb05 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -242,7 +242,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, is_eval=True) + result = collector.collect(n_step=args.buffer_size, eval_mode=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) @@ -252,7 +252,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: result = test_collector.collect( n_episode=args.test_num, render=args.render, - is_eval=True, + eval_mode=True, ) result.pprint_asdict() diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 91ee284..7d6aba1 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) collector.reset() - collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True) + collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index e28368e..20e4dd6 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -185,7 +185,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, is_eval=True) + collector.collect(n_episode=1, render=1 / 35, eval_mode=True) # trainer result = OfflineTrainer( diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 7fc9134..27392f7 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non policy, _, _ = get_agents(args) [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, is_eval=True) + result = collector.collect(n_episode=1, render=args.render, eval_mode=True) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 0100470..ed08522 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non ) policy, _, _ = get_agents(args) collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) + collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True) collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index b63636a..fc46e2a 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -230,5 +230,5 @@ def watch( policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, is_eval=True) + result = collector.collect(n_episode=1, render=args.render, eval_mode=True) result.pprint_asdict() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 1a60bd0..0b9fd6d 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,5 +1,7 @@ +import logging import time import warnings +from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass from typing import Any, Self, TypeVar, cast @@ -7,11 +9,11 @@ from typing import Any, Self, TypeVar, cast import gymnasium as gym import numpy as np import torch +from overrides import override from tianshou.data import ( Batch, CachedReplayBuffer, - PrioritizedReplayBuffer, ReplayBuffer, ReplayBufferManager, SequenceSummaryStats, @@ -25,6 +27,9 @@ from tianshou.data.types import ( from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin +from tianshou.utils.torch_utils import in_eval_mode, in_train_mode + +log = logging.getLogger(__name__) @dataclass(kw_only=True) @@ -122,23 +127,12 @@ def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: return result_batch_parent.info -class Collector: - """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. - - :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. - :param env: a ``gym.Env`` environment or an instance of the - :class:`~tianshou.env.BaseVectorEnv` class. - :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` - as the default buffer. - :param exploration_noise: determine whether the action needs to be modified - with the corresponding policy's exploration noise. If so, "policy. - exploration_noise(act, batch)" will be called automatically to add the - exploration noise into action. Default to False. +class BaseCollector(ABC): + """Used to collect data from a vector environment into a buffer using a given policy. .. note:: - Please make sure the given environment has a time limitation if using n_episode + Please make sure the given environment has a time limitation if using `n_episode` collect option. .. note:: @@ -150,72 +144,70 @@ class Collector: def __init__( self, policy: BasePolicy, - env: gym.Env | BaseVectorEnv, + env: BaseVectorEnv | gym.Env, buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> None: - super().__init__() if isinstance(env, gym.Env) and not hasattr(env, "__len__"): warnings.warn("Single environment detected, wrap to DummyVectorEnv.") # Unfortunately, mypy seems to ignore the isinstance in lambda, maybe a bug in mypy - self.env = DummyVectorEnv([lambda: env]) - else: - self.env = env # type: ignore - self.env_num = len(self.env) - self.exploration_noise = exploration_noise - self.buffer = self._assign_buffer(buffer) + env = DummyVectorEnv([lambda: env]) # type: ignore + + if buffer is None: + buffer = VectorReplayBuffer(len(env), len(env)) + + self.buffer: ReplayBuffer = buffer self.policy = policy - self._action_space = self.env.action_space - - self._pre_collect_obs_RO: np.ndarray | None = None - self._pre_collect_info_R: np.ndarray | None = None - self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None - - self._is_closed = False + self.env = cast(BaseVectorEnv, env) + self.exploration_noise = exploration_noise self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + self._action_space = self.env.action_space + self._is_closed = False + + self._validate_buffer() + + def _validate_buffer(self) -> None: + buf = self.buffer + # TODO: a bit weird but true - all VectorReplayBuffers inherit from ReplayBufferManager. + # We should probably rename the manager + if isinstance(buf, ReplayBufferManager) and buf.buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.buffer_num} buffers, but at least {self.env_num=} are needed.", + ) + if isinstance(buf, CachedReplayBuffer) and buf.cached_buffer_num < self.env_num: + raise ValueError( + f"Buffer has only {buf.cached_buffer_num} cached buffers, but at least {self.env_num=} are needed.", + ) + # Non-VectorReplayBuffer. TODO: probably shouldn't rely on isinstance + if not isinstance(buf, ReplayBufferManager): + if buf.maxsize == 0: + raise ValueError("Buffer maxsize should be greater than 0.") + if self.env_num > 1: + raise ValueError( + f"Cannot use {type(buf).__name__} to collect from multiple envs ({self.env_num=}). " + f"Please use the corresponding VectorReplayBuffer instead.", + ) + + @property + def env_num(self) -> int: + return len(self.env) + + @property + def action_space(self) -> gym.spaces.Space: + return self._action_space + def close(self) -> None: """Close the collector and the environment.""" self.env.close() - self._pre_collect_obs_RO = None - self._pre_collect_info_R = None self._is_closed = True - @property - def is_closed(self) -> bool: - """Return True if the collector is closed.""" - return self._is_closed - - def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: - """Check if the buffer matches the constraint.""" - if buffer is None: - buffer = VectorReplayBuffer(self.env_num, self.env_num) - elif isinstance(buffer, ReplayBufferManager): - assert buffer.buffer_num >= self.env_num - if isinstance(buffer, CachedReplayBuffer): - assert buffer.cached_buffer_num >= self.env_num - else: # ReplayBuffer or PrioritizedReplayBuffer - assert buffer.maxsize > 0 - if self.env_num > 1: - if isinstance(buffer, ReplayBuffer): - buffer_type = "ReplayBuffer" - vector_type = "VectorReplayBuffer" - if isinstance(buffer, PrioritizedReplayBuffer): - buffer_type = "PrioritizedReplayBuffer" - vector_type = "PrioritizedVectorReplayBuffer" - raise TypeError( - f"Cannot use {buffer_type}(size={buffer.maxsize}, ...) to collect " - f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" - f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.", - ) - return buffer - def reset( self, reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the environment, statistics, and data needed to start the collection. :param reset_buffer: if true, reset the replay buffer attached @@ -224,12 +216,13 @@ class Collector: :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) """ - self.reset_env(gym_reset_kwargs=gym_reset_kwargs) + obs_NO, info_N = self.reset_env(gym_reset_kwargs=gym_reset_kwargs) if reset_buffer: self.reset_buffer() if reset_stats: self.reset_stat() self._is_closed = False + return obs_NO, info_N def reset_stat(self) -> None: """Reset the statistic variables.""" @@ -242,18 +235,168 @@ class Collector: def reset_env( self, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the environments and the initial obs, info, and hidden state of the collector.""" gym_reset_kwargs = gym_reset_kwargs or {} - self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs) + obs_NO, info_N = self.env.reset(**gym_reset_kwargs) # TODO: hack, wrap envpool envs such that they don't return a dict - if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable] + if isinstance(info_N, dict): # type: ignore[unreachable] # this can happen if the env is an envpool env. Then the thing returned by reset is a dict # with array entries instead of an array of dicts # We use Batch to turn it into an array of dicts - self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable] + info_N = _dict_of_arr_to_arr_of_dicts(info_N) # type: ignore[unreachable] + return obs_NO, info_N + @abstractmethod + def _collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> CollectStats: + pass + + def collect( + self, + n_step: int | None = None, + n_episode: int | None = None, + random: bool = False, + render: float | None = None, + no_grad: bool = True, + reset_before_collect: bool = False, + gym_reset_kwargs: dict[str, Any] | None = None, + eval_mode: bool = False, + ) -> CollectStats: + """Collect a specified number of steps or episodes. + + To ensure an unbiased sampling result with the n_episode option, this function will + first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` + episodes, they will be collected evenly from each env. + + :param n_step: how many steps you want to collect. + :param n_episode: how many episodes you want to collect. + :param random: whether to use random policy for collecting data. + :param render: the sleep time between rendering consecutive frames. + :param no_grad: whether to retain gradient in policy.forward(). + :param reset_before_collect: whether to reset the environment before collecting data. + (The collector needs the initial obs and info to function properly.) + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Only used if reset_before_collect is True. + :param eval_mode: whether to collect data in evaluation mode. Will + set the policy to training mode otherwise. + + .. note:: + + One and only one collection number specification is permitted, either + ``n_step`` or ``n_episode``. + + :return: The collected stats + """ + # check that exactly one of n_step or n_episode is set and that the other is larger than 0 + self._validate_n_step_n_episode(n_episode, n_step) + + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + policy_mode_context = in_eval_mode if eval_mode else in_train_mode + with policy_mode_context(self.policy): + return self._collect( + n_step=n_step, + n_episode=n_episode, + random=random, + render=render, + no_grad=no_grad, + gym_reset_kwargs=gym_reset_kwargs, + ) + + def _validate_n_step_n_episode(self, n_episode: int | None, n_step: int | None) -> None: + if not n_step and not n_episode: + raise ValueError( + f"Only one of n_step and n_episode should be set to a value larger than zero " + f"but got {n_step=}, {n_episode=}.", + ) + if n_step is None and n_episode is None: + raise ValueError( + "Exactly one of n_step and n_episode should be set but got None for both.", + ) + if n_step and n_step % self.env_num != 0: + warnings.warn( + f"{n_step=} is not a multiple of ({self.env_num=}), " + "which may cause extra transitions being collected into the buffer.", + ) + if n_episode and self.env_num > n_episode: + warnings.warn( + f"{n_episode=} should be larger than {self.env_num=} to " + f"collect at least one trajectory in each environment.", + ) + + +class Collector(BaseCollector): + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # N - number of envs, always fixed and >= R. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + + # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy + # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on + # policy.deterministic_eval) + + def __init__( + self, + policy: BasePolicy, + env: gym.Env | BaseVectorEnv, + buffer: ReplayBuffer | None = None, + exploration_noise: bool = False, + ) -> None: + """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class. + :param env: a ``gym.Env`` environment or an instance of the + :class:`~tianshou.env.BaseVectorEnv` class. + :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. + :param exploration_noise: determine whether the action needs to be modified + with the corresponding policy's exploration noise. If so, "policy. + exploration_noise(act, batch)" will be called automatically to add the + exploration noise into action. Default to False. + """ + super().__init__(policy, env, buffer, exploration_noise=exploration_noise) + self._pre_collect_obs_RO: np.ndarray | None = None + self._pre_collect_info_R: np.ndarray | None = None + self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + + self._is_closed = False + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + def close(self) -> None: + super().close() + self._pre_collect_obs_RO = None + self._pre_collect_info_R = None + + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Reset the environments and the initial obs, info, and hidden state of the collector.""" + obs_NO, info_N = super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + # We assume that R = N when reset is called. + # TODO: there is currently no mechanism that ensures this and it's a public method! + self._pre_collect_obs_RO = obs_NO + self._pre_collect_info_R = info_N self._pre_collect_hidden_state_RH = None + return obs_NO, info_N def _compute_action_policy_hidden( self, @@ -309,98 +452,30 @@ class Collector: return act_RA, act_normalized_RA, policy_R, hidden_state_RH # TODO: reduce complexity, remove the noqa - def collect( + def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, no_grad: bool = True, - reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, - is_eval: bool = False, ) -> CollectStats: - """Collect a specified number of steps or episodes. - - To ensure an unbiased sampling result with the n_episode option, this function will - first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` - episodes, they will be collected evenly from each env. - - :param n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. - :param render: the sleep time between rendering consecutive frames. - :param no_grad: whether to retain gradient in policy.forward(). - :param reset_before_collect: whether to reset the environment before collecting data. - (The collector needs the initial obs and info to function properly.) - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Only used if reset_before_collect is True. - :param is_eval: whether to collect data in evaluation mode. - - .. note:: - - One and only one collection number specification is permitted, either - ``n_step`` or ``n_episode``. - - :return: The collected stats - """ - # NAMING CONVENTION (mostly suffixes): - # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, - # the corresponding env is either reset or removed from the ready envs. - # R - number ready env ids. Note that this might change when envs get idle. - # This can only happen in n_episode case, see explanation in the corresponding block. - # For n_step, we always use all envs to collect the data, while for n_episode, - # R will be at most n_episode at the beginning, but can decrease during the collection. - # O - dimension(s) of observations - # A - dimension(s) of actions - # H - dimension(s) of hidden state - # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. - # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. - # Only used in n_episode case. Then, R becomes R-S. - - # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy - # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on - # policy.deterministic_eval) - self.policy.eval() - pre_collect_is_eval = self.policy.is_eval - self.policy.is_eval = is_eval - - use_grad = not no_grad - gym_reset_kwargs = gym_reset_kwargs or {} - - # Input validation - assert not self.env.is_async, "Please use AsyncCollector if using async venv." - if n_step is not None: - assert n_episode is None, ( - f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got {n_step=}, {n_episode=}." + # TODO: can't do it init since AsyncCollector is currently a subclass of Collector + if self.env.is_async: + raise ValueError( + f"Please use {AsyncCollector.__name__} for asynchronous environments. " + f"Env class: {self.env.__class__.__name__}.", ) - assert n_step > 0 - if n_step % self.env_num != 0: - warnings.warn( - f"{n_step=} is not a multiple of ({self.env_num=}), " - "which may cause extra transitions being collected into the buffer.", - ) + + if n_step is not None: ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: - assert n_episode > 0 - if self.env_num > n_episode: - warnings.warn( - f"{n_episode=} should be larger than {self.env_num=} to " - f"collect at least one trajectory in each environment.", - ) ready_env_ids_R = np.arange(min(self.env_num, n_episode)) - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) + + use_grad = not no_grad start_time = time.time() - - if reset_before_collect: - self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: raise ValueError( "Initial obs and info should not be None. " @@ -491,7 +566,8 @@ class Collector: step_count += len(ready_env_ids_R) # preparing for the next iteration - # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + # obs_next, info and hidden_state will be modified inplace in the code below, + # so we copy to not affect the data in the buffer last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) last_hidden_state_RH = copy(hidden_state_RH) @@ -509,6 +585,7 @@ class Collector: # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. + gym_reset_kwargs = gym_reset_kwargs or {} obs_reset_DO, info_reset_D = self.env.reset( env_id=env_ind_global_D, **gym_reset_kwargs, @@ -577,9 +654,6 @@ class Collector: # reset envs and the _pre_collect fields self.reset_env(gym_reset_kwargs) # todo still necessary? - # set the policy back to pre collect mode - self.policy.is_eval = pre_collect_is_eval - return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), @@ -608,8 +682,7 @@ class Collector: class AsyncCollector(Collector): """Async Collector handles async vector environment. - The arguments are exactly the same as :class:`~tianshou.data.Collector`, please - refer to :class:`~tianshou.data.Collector` for more detailed explanation. + Please refer to :class:`~tianshou.data.Collector` for a more detailed explanation. """ def __init__( @@ -619,6 +692,12 @@ class AsyncCollector(Collector): buffer: ReplayBuffer | None = None, exploration_noise: bool = False, ) -> None: + if not env.is_async: + # TODO: raise an exception? + log.error( + f"Please use {Collector.__name__} if not using async venv. " + f"Env class: {env.__class__.__name__}", + ) # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") super().__init__( @@ -644,7 +723,7 @@ class AsyncCollector(Collector): reset_buffer: bool = True, reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, - ) -> None: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the environment, statistics, and data needed to start the collection. :param reset_buffer: if true, reset the replay buffer attached @@ -654,7 +733,7 @@ class AsyncCollector(Collector): reset function. Defaults to None (extra keyword arguments) """ # This sets the _pre_collect attrs - super().reset( + result = super().reset( reset_buffer=reset_buffer, reset_stats=reset_stats, gym_reset_kwargs=gym_reset_kwargs, @@ -667,78 +746,29 @@ class AsyncCollector(Collector): self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) self._current_action_in_all_envs_EA = np.empty(self.env_num) self._current_policy_in_all_envs_E = None + return result - def collect( + @override + def reset_env( + self, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + # we need to step through the envs and wait until they are ready to be able to interact with them + if self.env.waiting_id: + self.env.step(None, id=self.env.waiting_id) + return super().reset_env(gym_reset_kwargs=gym_reset_kwargs) + + @override + def _collect( self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, no_grad: bool = True, - reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, - is_eval: bool = False, ) -> CollectStats: - """Collect a specified number of steps or episodes with async env setting. - - This function does not collect an exact number of transitions specified by n_step or - n_episode. Instead, to support the asynchronous setting, it may collect more transitions - than requested by n_step or n_episode and save them into the buffer. - - :param n_step: how many steps you want to collect. - :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy_R for collecting data. Default - to False. - :param render: the sleep time between rendering consecutive frames. - Default to None (no rendering). - :param no_grad: whether to retain gradient in policy_R.forward(). Default to - True (no gradient retaining). - :param reset_before_collect: whether to reset the environment before - collecting data. It has only an effect if n_episode is not None, i.e. - if one wants to collect a fixed number of episodes. - (The collector needs the initial obs and info to function properly.) - :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) - :param is_eval: whether to collect data in evaluation mode. - - .. note:: - - One and only one collection number specification is permitted, either - ``n_step`` or ``n_episode``. - - :return: A dataclass object - """ - # set the policy's modules to eval mode (this affects modules like dropout and batchnorm) and the policy - # evaluation mode (this affects the policy's behavior, i.e., whether to sample or use the mode depending on - # policy.deterministic_eval) - self.policy.eval() - pre_collect_is_eval = self.policy.is_eval - self.policy.is_eval = is_eval - use_grad = not no_grad - gym_reset_kwargs = gym_reset_kwargs or {} - - # collect at least n_step or n_episode - if n_step is not None: - assert n_episode is None, ( - "Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}." - ) - assert n_step > 0 - elif n_episode is not None: - assert n_episode > 0 - else: - raise TypeError( - "Please specify at least one (either n_step or n_episode) " - "in AsyncCollector.collect().", - ) - - if reset_before_collect: - # first we need to step all envs to be able to interact with them - if self.env.waiting_id: - self.env.step(None, id=self.env.waiting_id) - self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - start_time = time.time() step_count = 0 @@ -868,12 +898,12 @@ class AsyncCollector(Collector): num_collected_episodes += num_episodes_done_this_iter # preparing for the next iteration - # todo do we need the copy stuff (tests pass also without) # todo seem we can get rid of this last_sth stuff altogether last_obs_RO = copy(obs_next_RO) last_info_R = copy(info_R) - last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] - + last_hidden_state_RH = copy( + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R], # type: ignore[index] + ) if num_episodes_done_this_iter: env_ind_local_D = np.where(done_R)[0] env_ind_global_D = ready_env_ids_R[env_ind_local_D] @@ -883,6 +913,7 @@ class AsyncCollector(Collector): # now we copy obs_next_RO to obs, but since there might be # finished episodes, we have to reset finished envs first. + gym_reset_kwargs = gym_reset_kwargs or {} obs_reset_DO, info_reset_D = self.env.reset( env_id=env_ind_global_D, **gym_reset_kwargs, @@ -923,9 +954,6 @@ class AsyncCollector(Collector): # persist for future collect iterations self._ready_env_ids_R = ready_env_ids_R - # set the policy back to pre collect mode - self.policy.is_eval = pre_collect_is_eval - return CollectStats.with_autogenerated_stats( returns=np.array(episode_returns), lens=np.array(episode_lens), diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 71b8159..df87aca 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -337,7 +337,7 @@ class Experiment(ToStringMixin): ) -> None: collector = Collector(policy, env) collector.reset() - result = collector.collect(n_episode=num_episodes, render=render, is_eval=True) + result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info(