diff --git a/examples/atari/atari_c51.py b/examples/atari/atari_c51.py index fc04a21..16694af 100644 --- a/examples/atari/atari_c51.py +++ b/examples/atari/atari_c51.py @@ -162,7 +162,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -175,14 +174,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_dqn.py b/examples/atari/atari_dqn.py index f237c5a..aed46a2 100644 --- a/examples/atari/atari_dqn.py +++ b/examples/atari/atari_dqn.py @@ -204,7 +204,6 @@ def main(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -217,14 +216,18 @@ def main(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_fqf.py b/examples/atari/atari_fqf.py index 127a14b..92a6140 100644 --- a/examples/atari/atari_fqf.py +++ b/examples/atari/atari_fqf.py @@ -175,7 +175,6 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -188,14 +187,18 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_iqn.py b/examples/atari/atari_iqn.py index 8b16252..3c4a695 100644 --- a/examples/atari/atari_iqn.py +++ b/examples/atari/atari_iqn.py @@ -172,7 +172,6 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -185,14 +184,18 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_ppo.py b/examples/atari/atari_ppo.py index f1a89ef..3c2c68d 100644 --- a/examples/atari/atari_ppo.py +++ b/examples/atari/atari_ppo.py @@ -229,7 +229,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -241,14 +240,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_qrdqn.py b/examples/atari/atari_qrdqn.py index dfb9641..c821fad 100644 --- a/examples/atari/atari_qrdqn.py +++ b/examples/atari/atari_qrdqn.py @@ -166,7 +166,6 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -179,14 +178,18 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_rainbow.py b/examples/atari/atari_rainbow.py index 7b341c0..0d14f5d 100644 --- a/examples/atari/atari_rainbow.py +++ b/examples/atari/atari_rainbow.py @@ -200,7 +200,6 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -215,14 +214,18 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: beta=args.beta, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/atari/atari_sac.py b/examples/atari/atari_sac.py index f06964c..a8d759a 100644 --- a/examples/atari/atari_sac.py +++ b/examples/atari/atari_sac.py @@ -216,7 +216,6 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -228,14 +227,18 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/box2d/acrobot_dualdqn.py b/examples/box2d/acrobot_dualdqn.py index 96e61b6..ed6a65b 100644 --- a/examples/box2d/acrobot_dualdqn.py +++ b/examples/box2d/acrobot_dualdqn.py @@ -144,11 +144,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/bipedal_bdq.py b/examples/box2d/bipedal_bdq.py index 8b1e8ca..da425e6 100644 --- a/examples/box2d/bipedal_bdq.py +++ b/examples/box2d/bipedal_bdq.py @@ -162,11 +162,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/bipedal_hardcore_sac.py b/examples/box2d/bipedal_hardcore_sac.py index 2c071bc..8dfbed6 100644 --- a/examples/box2d/bipedal_hardcore_sac.py +++ b/examples/box2d/bipedal_hardcore_sac.py @@ -207,10 +207,13 @@ def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/lunarlander_dqn.py b/examples/box2d/lunarlander_dqn.py index 9e5db58..4bd2896 100644 --- a/examples/box2d/lunarlander_dqn.py +++ b/examples/box2d/lunarlander_dqn.py @@ -141,11 +141,14 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/box2d/mcc_sac.py b/examples/box2d/mcc_sac.py index 5c09309..2576411 100644 --- a/examples/box2d/mcc_sac.py +++ b/examples/box2d/mcc_sac.py @@ -153,10 +153,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/discrete/discrete_dqn.py b/examples/discrete/discrete_dqn.py index 4f1a82b..3ba22a4 100644 --- a/examples/discrete/discrete_dqn.py +++ b/examples/discrete/discrete_dqn.py @@ -80,7 +80,6 @@ def main() -> None: print(f"Finished training in {result.timing.total_time} seconds") # watch performance - policy.eval() policy.set_eps(eps_test) collector = ts.data.Collector(policy, env, exploration_noise=True) collector.collect(n_episode=100, render=1 / 35) diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 2d013a0..3c10359 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -264,10 +264,13 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index cd6ceec..01a89b6 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -238,10 +238,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) collector_stats.pprint_asdict() diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index ea6ab8f..0195702 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -219,10 +219,13 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py index b2a4087..98852dc 100755 --- a/examples/mujoco/mujoco_ddpg.py +++ b/examples/mujoco/mujoco_ddpg.py @@ -168,10 +168,13 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 8a379da..c2a193a 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -216,10 +216,13 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index 0004288..9e20c46 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -224,10 +224,13 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_redq.py b/examples/mujoco/mujoco_redq.py index ae46b22..791005c 100755 --- a/examples/mujoco/mujoco_redq.py +++ b/examples/mujoco/mujoco_redq.py @@ -196,10 +196,13 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index 109f1cc..8912ee6 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -196,10 +196,13 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py index a0bd567..6c0bbdc 100755 --- a/examples/mujoco/mujoco_sac.py +++ b/examples/mujoco/mujoco_sac.py @@ -190,10 +190,13 @@ def test_sac(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py index 6b6dfdc..8e0c315 100755 --- a/examples/mujoco/mujoco_td3.py +++ b/examples/mujoco/mujoco_td3.py @@ -188,10 +188,13 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index 2195933..96da0cf 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -221,10 +221,13 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/atari_bcq.py b/examples/offline/atari_bcq.py index 1fc0dc7..80687d1 100644 --- a/examples/offline/atari_bcq.py +++ b/examples/offline/atari_bcq.py @@ -187,12 +187,11 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_cql.py b/examples/offline/atari_cql.py index 40d91c1..6107145 100644 --- a/examples/offline/atari_cql.py +++ b/examples/offline/atari_cql.py @@ -171,12 +171,11 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_crr.py b/examples/offline/atari_crr.py index a4b31c4..ad5c231 100644 --- a/examples/offline/atari_crr.py +++ b/examples/offline/atari_crr.py @@ -188,11 +188,10 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/atari_il.py b/examples/offline/atari_il.py index bb7822e..495b630 100644 --- a/examples/offline/atari_il.py +++ b/examples/offline/atari_il.py @@ -145,11 +145,10 @@ def test_il(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) result.pprint_asdict() if args.watch: diff --git a/examples/offline/d4rl_bcq.py b/examples/offline/d4rl_bcq.py index 80b233c..9834247 100644 --- a/examples/offline/d4rl_bcq.py +++ b/examples/offline/d4rl_bcq.py @@ -206,9 +206,8 @@ def test_bcq() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -229,10 +228,13 @@ def test_bcq() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/d4rl_cql.py b/examples/offline/d4rl_cql.py index 7ca8ae2..4e6a127 100644 --- a/examples/offline/d4rl_cql.py +++ b/examples/offline/d4rl_cql.py @@ -344,9 +344,8 @@ def test_cql() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -367,10 +366,13 @@ def test_cql() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/d4rl_il.py b/examples/offline/d4rl_il.py index c2152a7..a8a6a66 100644 --- a/examples/offline/d4rl_il.py +++ b/examples/offline/d4rl_il.py @@ -142,9 +142,8 @@ def test_il() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -165,10 +164,13 @@ def test_il() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/offline/d4rl_td3_bc.py b/examples/offline/d4rl_td3_bc.py index 4d6159f..24138a4 100644 --- a/examples/offline/d4rl_td3_bc.py +++ b/examples/offline/d4rl_td3_bc.py @@ -191,9 +191,8 @@ def test_td3_bc() -> None: args.resume_path = os.path.join(log_path, "policy.pth") policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu"))) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) if not args.watch: replay_buffer = load_buffer_d4rl(args.expert_data_task) @@ -217,10 +216,13 @@ def test_td3_bc() -> None: watch() # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) print(collector_stats) diff --git a/examples/vizdoom/vizdoom_c51.py b/examples/vizdoom/vizdoom_c51.py index 4211585..3566051 100644 --- a/examples/vizdoom/vizdoom_c51.py +++ b/examples/vizdoom/vizdoom_c51.py @@ -168,7 +168,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) if args.save_buffer_name: @@ -181,14 +180,18 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: diff --git a/examples/vizdoom/vizdoom_ppo.py b/examples/vizdoom/vizdoom_ppo.py index f5abf0b..907b0eb 100644 --- a/examples/vizdoom/vizdoom_ppo.py +++ b/examples/vizdoom/vizdoom_ppo.py @@ -231,7 +231,6 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # watch agent's performance def watch() -> None: print("Setup test envs ...") - policy.eval() test_envs.seed(args.seed) if args.save_buffer_name: print(f"Generate buffer with size {args.buffer_size}") @@ -243,14 +242,18 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: stack_num=args.frames_stack, ) collector = Collector(policy, test_envs, buffer, exploration_noise=True) - result = collector.collect(n_step=args.buffer_size) + result = collector.collect(n_step=args.buffer_size, is_eval=True) print(f"Save buffer into {args.save_buffer_name}") # Unfortunately, pickle will cause oom with 1M buffer size buffer.save_hdf5(args.save_buffer_name) else: print("Testing agent ...") test_collector.reset() - result = test_collector.collect(n_episode=args.test_num, render=args.render) + result = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) result.pprint_asdict() if args.watch: