update tests

This commit is contained in:
Maximilian Huettenrauch 2024-04-24 17:06:59 +02:00
parent 8cb17de190
commit 49c750fb09
35 changed files with 129 additions and 127 deletions

View File

@ -64,6 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy:
class TestPolicyBasics:
def test_get_action(self, policy: PPOPolicy) -> None:
policy.is_eval = True
sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)]

View File

@ -138,9 +138,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -160,9 +160,8 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -195,9 +195,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(epoch_stat)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -169,9 +169,8 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -161,7 +161,6 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
assert stop_fn(result.best_reward)
# here we define an imitation collector with a trivial policy
policy.eval()
if args.task.startswith("Pendulum"):
args.reward_threshold -= 50 # lower the goal
il_net = Net(

View File

@ -160,10 +160,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(epoch_stat.info_stat)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -160,9 +160,8 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -4,12 +4,12 @@ import pprint
import gymnasium as gym
import numpy as np
import pytest
import torch
from gymnasium.spaces import Box
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer
@ -25,7 +25,7 @@ except ImportError:
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=20000)
@ -60,29 +60,35 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0]
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.training_num,
seed=args.seed,
)
test_envs = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.test_num,
seed=args.seed,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if envpool is not None:
train_envs = env = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.training_num,
seed=args.seed,
)
test_envs = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.test_num,
seed=args.seed,
)
else:
env = gym.make(args.task)
train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)])
test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
train_envs.seed(args.seed)
test_envs.seed(args.seed)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
# model
net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
@ -145,14 +151,13 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
policy.eval()
# here we define an imitation collector with a trivial policy
# if args.task == 'CartPole-v0':
# if args.task == 'CartPole-v1':
# env.spec.reward_threshold = 190 # lower the goal
net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
@ -162,9 +167,23 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
optim=optim,
action_space=env.action_space,
)
if envpool is not None:
il_env = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.test_num,
seed=args.seed,
)
else:
il_env = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
context="fork",
)
il_env.seed(args.seed)
il_test_collector = Collector(
il_policy,
envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed),
il_env,
)
train_collector.reset()
result = OffpolicyTrainer(
@ -186,9 +205,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
il_policy.eval()
collector = Collector(il_policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -148,11 +148,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
collector_stats.pprint_asdict()

View File

@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -68,7 +68,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -206,10 +206,10 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -24,7 +24,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -62,7 +62,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -159,10 +159,10 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -19,7 +19,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -55,7 +55,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -136,9 +136,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -67,7 +67,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -176,10 +176,10 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -67,7 +67,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -172,10 +172,10 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=20000)
@ -51,7 +51,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -129,9 +129,9 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -23,7 +23,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--buffer-size", type=int, default=20000)
@ -64,7 +64,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -156,9 +156,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -60,10 +60,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.task == "CartPole-v0" and env.spec:
if args.task == "CartPole-v1" and env.spec:
env.spec.reward_threshold = 190 # lower the goal
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -161,10 +161,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -69,7 +69,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -223,10 +223,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=20000)
@ -60,7 +60,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 170} # lower the goal
default_reward_threshold = {"CartPole-v1": 170} # lower the goal
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -147,9 +147,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -7,7 +7,7 @@ from tianshou.highlevel.env import (
class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None:
super().__init__(
task="CartPole-v0",
task="CartPole-v1",
train_seed=42,
test_seed=1337,
venv_type=VectorEnvType.DUMMY,

View File

@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -79,7 +79,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -202,10 +202,9 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--buffer-size", type=int, default=20000)
@ -83,7 +83,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -194,9 +194,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -120,10 +120,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
stats = test_collector.collect(n_episode=args.test_num, render=args.render)
stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
stats.pprint_asdict()
elif env.spec.reward_threshold:
assert result.best_reward >= env.spec.reward_threshold

View File

@ -19,12 +19,12 @@ from tianshou.utils.space_info import SpaceInfo
def expert_file_name() -> str:
return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl")
return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl")
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05)
@ -67,7 +67,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 190}
default_reward_threshold = {"CartPole-v1": 190}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
collector.reset()
collector_stats = collector.collect(n_step=args.buffer_size)
collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True)
if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name)
else:

View File

@ -189,9 +189,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
policy.load_state_dict(
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
)
policy.eval()
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35)
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
# trainer
result = OfflineTrainer(
@ -213,9 +212,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(result)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -210,9 +210,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(epoch_stat.info_stat)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render)
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
if collector_result.returns_stat and collector_result.lens_stat:
print(
f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}",

View File

@ -25,7 +25,7 @@ else: # pytest
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
@ -61,7 +61,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 185}
default_reward_threshold = {"CartPole-v1": 185}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -169,10 +169,9 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -24,7 +24,7 @@ else: # pytest
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
@ -58,7 +58,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 170}
default_reward_threshold = {"CartPole-v1": 170}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -131,10 +131,9 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -25,7 +25,7 @@ else: # pytest
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--lr", type=float, default=7e-4)
@ -56,7 +56,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 180}
default_reward_threshold = {"CartPole-v1": 180}
args.reward_threshold = default_reward_threshold.get(
args.task,
env.spec.reward_threshold if env.spec else None,
@ -135,9 +135,8 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -231,9 +231,8 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -198,9 +198,8 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__":
pprint.pprint(epoch_stat.info_stat)
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)

View File

@ -188,8 +188,7 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
"watching random agents, as loading pre-trained policies is currently not supported",
)
policy, _, _ = get_agents(args)
policy.eval()
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
result = collector.collect(n_episode=1, render=args.render, is_eval=True)
result.pprint_asdict()

View File

@ -284,7 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
"watching random agents, as loading pre-trained policies is currently not supported",
)
policy, _, _ = get_agents(args)
policy.eval()
collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render)
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
collector_result.pprint_asdict()

View File

@ -228,8 +228,7 @@ def watch(
) -> None:
env = DummyVectorEnv([partial(get_env, render_mode="human")])
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.eval()
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render)
result = collector.collect(n_episode=1, render=args.render, is_eval=True)
result.pprint_asdict()