diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index 98d1d08..d611ab1 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -174,18 +174,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index 7d60654..eeb9bcc 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -216,18 +216,14 @@ def main(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 185cff1..58aff46 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -187,18 +187,14 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 5216d7c..c609052 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -184,18 +184,14 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index 969d00a..dd75de7 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -240,18 +240,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index 5b02581..b973131 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -178,18 +178,14 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 5bb69a3..952d35f 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -214,18 +214,14 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: beta=args.beta, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index be13884..4d01a88 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -227,18 +227,14 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index ac2c800..365c073 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -147,11 +147,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 7c89cd2..c817831 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -165,11 +165,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 2147213..66e5f31 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -209,11 +209,7 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 8280e23..f9bbd6f 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -144,11 +144,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 97f7149..7617b7b 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -155,11 +155,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index e30b3f6..42e5bc2 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -266,11 +266,7 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 8b3439a..bbf68c2 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -240,11 +240,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 834e0ce..194d9b5 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -221,11 +221,7 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index 066c7d4..db90bab 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -170,11 +170,7 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 40cc443..4d8530a 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -218,11 +218,7 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index eb6817e..7c3f268 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -226,11 +226,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index 9101f86..8951b03 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -198,11 +198,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index cbd4d3e..ff7e340 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -198,11 +198,7 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index 8d9d61e..af13983 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -192,11 +192,7 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index d1b5e69..6cc8bb2 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -190,11 +190,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b5a91c0..eefdfcc 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -223,11 +223,7 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index d4e985e..3af40cc 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -191,7 +191,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 69ae433..b2c0c87 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -175,7 +175,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index 4e0771c..8b6320a 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -191,7 +191,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index 928112c..39aee31 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -148,7 +148,7 @@ def test_il(args: argparse.Namespace = get_args()) -> None: test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render, eval_mode=True) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 308a872..9ed1826 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -207,7 +207,7 @@ def test_bcq() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -230,11 +230,7 @@ def test_bcq() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 9d95d29..90d6b15 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -345,7 +345,7 @@ def test_cql() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -368,11 +368,7 @@ def test_cql() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index d64cfe9..e03deed 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -143,7 +143,7 @@ def test_il() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -166,11 +166,7 @@ def test_il() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 5719eba..6b448b3 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -192,7 +192,7 @@ def test_td3_bc() -> None: policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -218,11 +218,7 @@ def test_td3_bc() -> None: # Let's watch its performance! test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) print(collector_stats) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 8a52e5f..25ad804 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -180,18 +180,14 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index adbfb05..7fc09f6 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -242,18 +242,14 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size, eval_mode=True) + result = collector.collect(n_step=args.buffer_size) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect( - n_episode=args.test_num, - render=args.render, - eval_mode=True, - ) + result = test_collector.collect(n_episode=args.test_num, render=args.render) result.pprint_asdict() if args.watch: diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index ce8a936..35bf3e2 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -250,7 +250,7 @@ def test_finite_dummy_vector_env() -> None: envs.tracker = MetricTracker() try: # TODO: why on earth 10**18? - test_collector.collect(n_step=10**18) + test_collector.collect(n_step=10 ** 18) except StopIteration: envs.tracker.validate() @@ -265,6 +265,6 @@ def test_finite_subproc_vector_env() -> None: for _ in range(3): envs.tracker = MetricTracker() try: - test_collector.collect(n_step=10**18) + test_collector.collect(n_step=10 ** 18) except StopIteration: envs.tracker.validate() diff --git a/test/base/test_policy.py b/test/base/test_policy.py index f286156..4d26905 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -64,7 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy: class TestPolicyBasics: def test_get_action(self, policy: PPOPolicy) -> None: - policy.is_eval = True + policy.is_within_training_step = False sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 7d6aba1..19ba653 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) collector.reset() - collector_stats = collector.collect(n_step=args.buffer_size, eval_mode=True) + collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 20e4dd6..8b31c19 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -185,7 +185,7 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35, eval_mode=True) + collector.collect(n_episode=1, render=1 / 35) # trainer result = OfflineTrainer( diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 27392f7..c57522d 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -190,5 +190,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non policy, _, _ = get_agents(args) [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, eval_mode=True) + result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index ed08522..38de811 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -285,5 +285,5 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non ) policy, _, _ = get_agents(args) collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render, eval_mode=True) + collector_result = collector.collect(n_episode=1, render=args.render) collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index fc46e2a..966c9e0 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -230,5 +230,5 @@ def watch( policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render, eval_mode=True) + result = collector.collect(n_episode=1, render=args.render) result.pprint_asdict() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 10cf663..cf897a5 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -260,17 +260,8 @@ class BaseCollector(ABC): ) -> CollectStats: pass - def collect( - self, - n_step: int | None = None, - n_episode: int | None = None, - random: bool = False, - render: float | None = None, - no_grad: bool = True, - reset_before_collect: bool = False, - gym_reset_kwargs: dict[str, Any] | None = None, - eval_mode: bool = False, - ) -> CollectStats: + def collect(self, n_step: int | None = None, n_episode: int | None = None, random: bool = False, render: float | None = None, + no_grad: bool = True, reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None) -> CollectStats: """Collect a specified number of steps or episodes. To ensure an unbiased sampling result with the n_episode option, this function will @@ -286,9 +277,6 @@ class BaseCollector(ABC): (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Only used if reset_before_collect is True. - :param eval_mode: whether to collect data in evaluation mode. Will - set the policy to training mode otherwise. - .. note:: One and only one collection number specification is permitted, either @@ -302,8 +290,7 @@ class BaseCollector(ABC): if reset_before_collect: self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) - policy_mode_context = in_eval_mode if eval_mode else in_train_mode - with policy_mode_context(self.policy): + with in_eval_mode(self.policy): # safety precaution only return self._collect( n_step=n_step, n_episode=n_episode, diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index c131326..fdfc4c0 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -130,10 +130,7 @@ class AgentFactory(ABC, ToStringMixin): log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", ) - train_collector.collect( - n_step=self.sampling_config.start_timesteps, - random=self.sampling_config.start_timesteps_random, - ) + train_collector.collect(n_step=self.sampling_config.start_timesteps, random=self.sampling_config.start_timesteps_random) return train_collector, test_collector def set_policy_wrapper_factory( diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index df87aca..e7ccc9f 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -337,7 +337,7 @@ class Experiment(ToStringMixin): ) -> None: collector = Collector(policy, env) collector.reset() - result = collector.collect(n_episode=num_episodes, render=render, eval_mode=True) + result = collector.collect(n_episode=num_episodes, render=render) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index e566d0c..498a333 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -226,8 +226,18 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): self.action_scaling = action_scaling self.action_bound_method = action_bound_method self.lr_scheduler = lr_scheduler - # whether the policy is in evaluation mode - self.is_eval = False # TODO: remove in favor of kwarg in compute_action/forward? + self.is_within_training_step = False + """ + flag indicating whether we are currently within a training step, which encompasses data collection + for training and the policy update (gradient steps). + + It can be used, for example, to control whether a flag controlling deterministic evaluation should + indeed be applied, because within a training step, we typically always want to apply stochastic evaluation + (even if such a flag is enabled). + + This flag should normally remain False and should be set to True only by the algorithm which performs + training steps. + """ self._compile() @property diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index 7e731b1..8c79428 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -107,7 +107,7 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): ) -> Batch: logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Categorical(logits=logits_BA) - act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() + act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index 4792db8..3ef82be 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -197,7 +197,7 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked dist = self.dist_fn(action_dist_input_BD) - act_B = dist.mode if self.deterministic_eval and self.is_eval else dist.sample() + act_B = dist.mode if self.deterministic_eval and not self.is_within_training_step else dist.sample() # act is of dimension BA in continuous case and of dimension B in discrete result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index a216cf9..25f2997 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -153,7 +153,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): ) -> Batch: (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc_B, scale_B), 1) - if self.deterministic_eval and self.is_eval: + if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index 3dbea75..a5a05c0 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -175,7 +175,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t ) -> DistLogProbBatchProtocol: (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) - if self.deterministic_eval and self.is_eval: + if self.deterministic_eval and not self.is_within_training_step: act_B = dist.mode else: act_B = dist.rsample() diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index 825c80c..213da7b 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -3,6 +3,7 @@ import time from abc import ABC, abstractmethod from collections import defaultdict, deque from collections.abc import Callable +from contextlib import contextmanager from dataclasses import asdict from typing import Optional, Tuple @@ -407,23 +408,34 @@ class BaseTrainer(ABC): return test_stat, stop_fn_flag + @contextmanager + def _is_within_training_step_enabled(self, is_within_training_step: bool): + old_value = self.policy.is_within_training_step + try: + self.policy.is_within_training_step = is_within_training_step + yield + finally: + self.policy.is_within_training_step = old_value + def training_step(self) -> Tuple[CollectStatsBase, Optional[TrainingStats], bool]: - should_stop_training = False + with self._is_within_training_step_enabled(True): - if self.train_collector is not None: - collect_stats = self._collect_training_data() - should_stop_training = self._test_in_train(collect_stats) - else: - collect_stats = CollectStatsBase( - n_collected_episodes=len(self.buffer), - ) + should_stop_training = False - if not should_stop_training: - training_stats = self.policy_update_fn(collect_stats) - else: - training_stats = None + if self.train_collector is not None: + collect_stats = self._collect_training_data() + should_stop_training = self._test_in_train(collect_stats) + else: + collect_stats = CollectStatsBase( + n_collected_episodes=len(self.buffer), + ) - return collect_stats, training_stats, should_stop_training + if not should_stop_training: + training_stats = self.policy_update_fn(collect_stats) + else: + training_stats = None + + return collect_stats, training_stats, should_stop_training def _collect_training_data(self) -> CollectStats: """Performs training data collection @@ -434,10 +446,7 @@ class BaseTrainer(ABC): assert self.train_collector is not None if self.train_fn: self.train_fn(self.epoch, self.env_step) - collect_stats = self.train_collector.collect( - n_step=self.step_per_collect, - n_episode=self.episode_per_collect, - ) + collect_stats = self.train_collector.collect(n_step=self.step_per_collect, n_episode=self.episode_per_collect) self.env_step += collect_stats.n_collected_steps @@ -467,26 +476,28 @@ class BaseTrainer(ABC): """ should_stop_training = False - if ( - collect_stats.n_collected_episodes > 0 - and self.test_in_train - and self.stop_fn - and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore - ): - assert self.test_collector is not None - test_result = test_episode( - self.test_collector, - self.test_fn, - self.epoch, - self.episode_per_test, - self.logger, - self.env_step, - ) - assert test_result.returns_stat is not None # for mypy - if self.stop_fn(test_result.returns_stat.mean): - should_stop_training = True - self.best_reward = test_result.returns_stat.mean - self.best_reward_std = test_result.returns_stat.std + # Because we need to evaluate the policy, we need to temporarily leave the "is_training_step" semantics + with self._is_within_training_step_enabled(False): + if ( + collect_stats.n_collected_episodes > 0 + and self.test_in_train + and self.stop_fn + and self.stop_fn(collect_stats.returns_stat.mean) # type: ignore + ): + assert self.test_collector is not None + test_result = test_episode( + self.test_collector, + self.test_fn, + self.epoch, + self.episode_per_test, + self.logger, + self.env_step, + ) + assert test_result.returns_stat is not None # for mypy + if self.stop_fn(test_result.returns_stat.mean): + should_stop_training = True + self.best_reward = test_result.returns_stat.mean + self.best_reward_std = test_result.returns_stat.std return should_stop_training diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 0c2bf18..767e76d 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -27,7 +27,7 @@ def test_episode( collector.reset(reset_stats=False) if test_fn: test_fn(epoch, global_step) - result = collector.collect(n_episode=n_episode, eval_mode=True) + result = collector.collect(n_episode=n_episode) if reward_metric: # TODO: move into collector rew = reward_metric(result.returns) result.returns = rew