diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 7c3aacc..f286156 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -64,6 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy: class TestPolicyBasics: def test_get_action(self, policy: PPOPolicy) -> None: + policy.is_eval = True sample_obs = torch.randn(obs_shape) policy.deterministic_eval = False actions = [policy.compute_action(sample_obs) for _ in range(10)] diff --git a/test/continuous/test_ddpg.py b/test/continuous/test_ddpg.py index a17c3b5..4b776ce 100644 --- a/test/continuous/test_ddpg.py +++ b/test/continuous/test_ddpg.py @@ -138,9 +138,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index 8e0a50d..c60aa2b 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -160,9 +160,8 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index 5a522de..33bcc55 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -195,9 +195,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: pprint.pprint(epoch_stat) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 697b59e..a97b443 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -169,9 +169,8 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_sac_with_il.py b/test/continuous/test_sac_with_il.py index d13b03d..0bf80e5 100644 --- a/test/continuous/test_sac_with_il.py +++ b/test/continuous/test_sac_with_il.py @@ -161,7 +161,6 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None: assert stop_fn(result.best_reward) # here we define an imitation collector with a trivial policy - policy.eval() if args.task.startswith("Pendulum"): args.reward_threshold -= 50 # lower the goal il_net = Net( diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index ea55da0..d5efc57 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -160,10 +160,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: pprint.pprint(epoch_stat.info_stat) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) collector.reset() - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index ae788d1..c0debf7 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -160,9 +160,8 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index f60857e..a2b9fb4 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -4,12 +4,12 @@ import pprint import gymnasium as gym import numpy as np -import pytest import torch from gymnasium.spaces import Box from torch.utils.tensorboard import SummaryWriter from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy.base import BasePolicy from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer @@ -25,7 +25,7 @@ except ImportError: def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -60,29 +60,35 @@ def get_args() -> argparse.Namespace: return parser.parse_known_args()[0] -@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: - # if you want to use python vector env, please refer to other test scripts - train_envs = env = envpool.make( - args.task, - env_type="gymnasium", - num_envs=args.training_num, - seed=args.seed, - ) - test_envs = envpool.make( - args.task, - env_type="gymnasium", - num_envs=args.test_num, - seed=args.seed, - ) - args.state_shape = env.observation_space.shape or env.observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n - if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} - args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # seed np.random.seed(args.seed) torch.manual_seed(args.seed) + + if envpool is not None: + train_envs = env = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.training_num, + seed=args.seed, + ) + test_envs = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.test_num, + seed=args.seed, + ) + else: + env = gym.make(args.task) + train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)]) + test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)]) + train_envs.seed(args.seed) + test_envs.seed(args.seed) + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + if args.reward_threshold is None: + default_reward_threshold = {"CartPole-v1": 195} + args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold) # model net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) @@ -145,14 +151,13 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) - policy.eval() # here we define an imitation collector with a trivial policy - # if args.task == 'CartPole-v0': + # if args.task == 'CartPole-v1': # env.spec.reward_threshold = 190 # lower the goal net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device) @@ -162,9 +167,23 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: optim=optim, action_space=env.action_space, ) + if envpool is not None: + il_env = envpool.make( + args.task, + env_type="gymnasium", + num_envs=args.test_num, + seed=args.seed, + ) + else: + il_env = SubprocVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)], + context="fork", + ) + il_env.seed(args.seed) + il_test_collector = Collector( il_policy, - envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed), + il_env, ) train_collector.reset() result = OffpolicyTrainer( @@ -186,9 +205,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - il_policy.eval() collector = Collector(il_policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 1089d4b..6c93b02 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -148,11 +148,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() policy.set_eps(args.eps_test) test_envs.seed(args.seed) test_collector.reset() - collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) + collector_stats = test_collector.collect( + n_episode=args.test_num, + render=args.render, + is_eval=True, + ) collector_stats.pprint_asdict() diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 4d25d43..3abd45b 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -68,7 +68,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -206,10 +206,10 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index b62a93c..cd8b3f5 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -24,7 +24,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -62,7 +62,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -159,10 +159,10 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 5c24518..bae341f 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -19,7 +19,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -55,7 +55,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -136,9 +136,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 8ff9eeb..399672e 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +67,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -176,10 +176,10 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 765bbf9..f11ea46 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +67,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -172,10 +172,10 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_pg.py b/test/discrete/test_pg.py index 95db43c..51142eb 100644 --- a/test/discrete/test_pg.py +++ b/test/discrete/test_pg.py @@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -51,7 +51,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -129,9 +129,9 @@ def test_pg(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 132cbea..03f6f8d 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -23,7 +23,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer-size", type=int, default=20000) @@ -64,7 +64,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -156,9 +156,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index 6485637..6d39d9a 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -60,10 +60,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape - if args.task == "CartPole-v0" and env.spec: + if args.task == "CartPole-v1" and env.spec: env.spec.reward_threshold = 190 # lower the goal if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -161,10 +161,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index ff4ef1c..0dabd9e 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -69,7 +69,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -223,10 +223,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/discrete/test_sac.py b/test/discrete/test_sac.py index b2f466f..cdc53d7 100644 --- a/test/discrete/test_sac.py +++ b/test/discrete/test_sac.py @@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--buffer-size", type=int, default=20000) @@ -60,7 +60,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 170} # lower the goal + default_reward_threshold = {"CartPole-v1": 170} # lower the goal args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -147,9 +147,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector.reset() + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/highlevel/env_factory.py b/test/highlevel/env_factory.py index ddfce7b..4a131e5 100644 --- a/test/highlevel/env_factory.py +++ b/test/highlevel/env_factory.py @@ -7,7 +7,7 @@ from tianshou.highlevel.env import ( class DiscreteTestEnvFactory(EnvFactoryRegistered): def __init__(self) -> None: super().__init__( - task="CartPole-v0", + task="CartPole-v1", train_seed=42, test_seed=1337, venv_type=VectorEnvType.DUMMY, diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 9a4206e..01e6572 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.05) @@ -79,7 +79,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -202,10 +202,9 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/modelbased/test_ppo_icm.py b/test/modelbased/test_ppo_icm.py index ebf93cd..fcb541e 100644 --- a/test/modelbased/test_ppo_icm.py +++ b/test/modelbased/test_ppo_icm.py @@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--buffer-size", type=int, default=20000) @@ -83,7 +83,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 195} + default_reward_threshold = {"CartPole-v1": 195} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -194,9 +194,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 72742b7..ca7a388 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -120,10 +120,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) # Let's watch its performance! - policy.eval() test_envs.seed(args.seed) test_collector.reset() - stats = test_collector.collect(n_episode=args.test_num, render=args.render) + stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True) stats.pprint_asdict() elif env.spec.reward_threshold: assert result.best_reward >= env.spec.reward_threshold diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 9387794..91ee284 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -19,12 +19,12 @@ from tianshou.utils.space_info import SpaceInfo def expert_file_name() -> str: - return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") + return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl") def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1) parser.add_argument("--eps-test", type=float, default=0.05) @@ -67,7 +67,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 190} + default_reward_threshold = {"CartPole-v1": 190} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) collector.reset() - collector_stats = collector.collect(n_step=args.buffer_size) + collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) else: diff --git a/test/offline/test_bcq.py b/test/offline/test_bcq.py index 1839d86..660c607 100644 --- a/test/offline/test_bcq.py +++ b/test/offline/test_bcq.py @@ -189,9 +189,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: policy.load_state_dict( torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), ) - policy.eval() collector = Collector(policy, env) - collector.collect(n_episode=1, render=1 / 35) + collector.collect(n_episode=1, render=1 / 35, is_eval=True) # trainer result = OfflineTrainer( @@ -213,9 +212,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(result) env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_cql.py b/test/offline/test_cql.py index 1e31b1f..53aaf1e 100644 --- a/test/offline/test_cql.py +++ b/test/offline/test_cql.py @@ -210,9 +210,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render) + collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) if collector_result.returns_stat and collector_result.lens_stat: print( f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 7779080..e151633 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -25,7 +25,7 @@ else: # pytest def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -61,7 +61,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 185} + default_reward_threshold = {"CartPole-v1": 185} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -169,10 +169,9 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index 7323eac..309f3b6 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -24,7 +24,7 @@ else: # pytest def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--eps-test", type=float, default=0.001) @@ -58,7 +58,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 170} + default_reward_threshold = {"CartPole-v1": 170} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -131,10 +131,9 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() policy.set_eps(args.eps_test) collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index b3cb646..c3f2549 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -25,7 +25,7 @@ else: # pytest def get_args() -> argparse.Namespace: parser = argparse.ArgumentParser() - parser.add_argument("--task", type=str, default="CartPole-v0") + parser.add_argument("--task", type=str, default="CartPole-v1") parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--lr", type=float, default=7e-4) @@ -56,7 +56,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: args.state_shape = space_info.observation_info.obs_shape args.action_shape = space_info.action_info.action_shape if args.reward_threshold is None: - default_reward_threshold = {"CartPole-v0": 180} + default_reward_threshold = {"CartPole-v1": 180} args.reward_threshold = default_reward_threshold.get( args.task, env.spec.reward_threshold if env.spec else None, @@ -135,9 +135,8 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 256140c..df0f7bf 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -231,9 +231,8 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: pprint.pprint(result) # Let's watch its performance! env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/offline/test_td3_bc.py b/test/offline/test_td3_bc.py index 1877856..43762dd 100644 --- a/test/offline/test_td3_bc.py +++ b/test/offline/test_td3_bc.py @@ -198,9 +198,8 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None: if __name__ == "__main__": pprint.pprint(epoch_stat.info_stat) env = gym.make(args.task) - policy.eval() collector = Collector(policy, env) - collector_stats = collector.collect(n_episode=1, render=args.render) + collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True) print(collector_stats) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index abd0c88..7fc9134 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -188,8 +188,7 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - policy.eval() [agent.set_eps(args.eps_test) for agent in policy.policies.values()] collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render) + result = collector.collect(n_episode=1, render=args.render, is_eval=True) result.pprint_asdict() diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 54d6066..0100470 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -284,7 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non "watching random agents, as loading pre-trained policies is currently not supported", ) policy, _, _ = get_agents(args) - policy.eval() collector = Collector(policy, env) - collector_result = collector.collect(n_episode=1, render=args.render) + collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True) collector_result.pprint_asdict() diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 7ed6319..b63636a 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -228,8 +228,7 @@ def watch( ) -> None: env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) - policy.eval() policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) collector = Collector(policy, env, exploration_noise=True) - result = collector.collect(n_episode=1, render=args.render) + result = collector.collect(n_episode=1, render=args.render, is_eval=True) result.pprint_asdict()