update tests
This commit is contained in:
parent
8cb17de190
commit
49c750fb09
@ -64,6 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy:
|
||||
|
||||
class TestPolicyBasics:
|
||||
def test_get_action(self, policy: PPOPolicy) -> None:
|
||||
policy.is_eval = True
|
||||
sample_obs = torch.randn(obs_shape)
|
||||
policy.deterministic_eval = False
|
||||
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
||||
|
@ -138,9 +138,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -160,9 +160,8 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -195,9 +195,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(epoch_stat)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -169,9 +169,8 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -161,7 +161,6 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
# here we define an imitation collector with a trivial policy
|
||||
policy.eval()
|
||||
if args.task.startswith("Pendulum"):
|
||||
args.reward_threshold -= 50 # lower the goal
|
||||
il_net = Net(
|
||||
|
@ -160,10 +160,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(epoch_stat.info_stat)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -160,9 +160,8 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -4,12 +4,12 @@ import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from gymnasium.spaces import Box
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
from tianshou.policy import A2CPolicy, ImitationPolicy
|
||||
from tianshou.policy.base import BasePolicy
|
||||
from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer
|
||||
@ -25,7 +25,7 @@ except ImportError:
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--buffer-size", type=int, default=20000)
|
||||
@ -60,29 +60,35 @@ def get_args() -> argparse.Namespace:
|
||||
return parser.parse_known_args()[0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
# if you want to use python vector env, please refer to other test scripts
|
||||
train_envs = env = envpool.make(
|
||||
args.task,
|
||||
env_type="gymnasium",
|
||||
num_envs=args.training_num,
|
||||
seed=args.seed,
|
||||
)
|
||||
test_envs = envpool.make(
|
||||
args.task,
|
||||
env_type="gymnasium",
|
||||
num_envs=args.test_num,
|
||||
seed=args.seed,
|
||||
)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
if envpool is not None:
|
||||
train_envs = env = envpool.make(
|
||||
args.task,
|
||||
env_type="gymnasium",
|
||||
num_envs=args.training_num,
|
||||
seed=args.seed,
|
||||
)
|
||||
test_envs = envpool.make(
|
||||
args.task,
|
||||
env_type="gymnasium",
|
||||
num_envs=args.test_num,
|
||||
seed=args.seed,
|
||||
)
|
||||
else:
|
||||
env = gym.make(args.task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
|
||||
# model
|
||||
net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
@ -145,14 +151,13 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
policy.eval()
|
||||
# here we define an imitation collector with a trivial policy
|
||||
# if args.task == 'CartPole-v0':
|
||||
# if args.task == 'CartPole-v1':
|
||||
# env.spec.reward_threshold = 190 # lower the goal
|
||||
net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
|
||||
@ -162,9 +167,23 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
optim=optim,
|
||||
action_space=env.action_space,
|
||||
)
|
||||
if envpool is not None:
|
||||
il_env = envpool.make(
|
||||
args.task,
|
||||
env_type="gymnasium",
|
||||
num_envs=args.test_num,
|
||||
seed=args.seed,
|
||||
)
|
||||
else:
|
||||
il_env = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
context="fork",
|
||||
)
|
||||
il_env.seed(args.seed)
|
||||
|
||||
il_test_collector = Collector(
|
||||
il_policy,
|
||||
envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed),
|
||||
il_env,
|
||||
)
|
||||
train_collector.reset()
|
||||
result = OffpolicyTrainer(
|
||||
@ -186,9 +205,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
il_policy.eval()
|
||||
collector = Collector(il_policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -148,11 +148,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
collector_stats.pprint_asdict()
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -68,7 +68,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -206,10 +206,10 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -62,7 +62,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -159,10 +159,10 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -19,7 +19,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -55,7 +55,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -136,9 +136,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -67,7 +67,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -176,10 +176,10 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -67,7 +67,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -172,10 +172,10 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--buffer-size", type=int, default=20000)
|
||||
@ -51,7 +51,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -129,9 +129,9 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -23,7 +23,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--buffer-size", type=int, default=20000)
|
||||
@ -64,7 +64,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -156,9 +156,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -60,10 +60,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
|
||||
if args.task == "CartPole-v0" and env.spec:
|
||||
if args.task == "CartPole-v1" and env.spec:
|
||||
env.spec.reward_threshold = 190 # lower the goal
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -161,10 +161,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -69,7 +69,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -223,10 +223,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--buffer-size", type=int, default=20000)
|
||||
@ -60,7 +60,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 170} # lower the goal
|
||||
default_reward_threshold = {"CartPole-v1": 170} # lower the goal
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -147,9 +147,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -7,7 +7,7 @@ from tianshou.highlevel.env import (
|
||||
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
task="CartPole-v0",
|
||||
task="CartPole-v1",
|
||||
train_seed=42,
|
||||
test_seed=1337,
|
||||
venv_type=VectorEnvType.DUMMY,
|
||||
|
@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -79,7 +79,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -202,10 +202,9 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--buffer-size", type=int, default=20000)
|
||||
@ -83,7 +83,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 195}
|
||||
default_reward_threshold = {"CartPole-v1": 195}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -194,9 +194,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -120,10 +120,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
stats = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
stats.pprint_asdict()
|
||||
elif env.spec.reward_threshold:
|
||||
assert result.best_reward >= env.spec.reward_threshold
|
||||
|
@ -19,12 +19,12 @@ from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
|
||||
def expert_file_name() -> str:
|
||||
return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl")
|
||||
return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl")
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1)
|
||||
parser.add_argument("--eps-test", type=float, default=0.05)
|
||||
@ -67,7 +67,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 190}
|
||||
default_reward_threshold = {"CartPole-v1": 190}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
|
||||
policy.set_eps(0.2)
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_step=args.buffer_size)
|
||||
collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True)
|
||||
if args.save_buffer_name.endswith(".hdf5"):
|
||||
buf.save_hdf5(args.save_buffer_name)
|
||||
else:
|
||||
|
@ -189,9 +189,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
policy.load_state_dict(
|
||||
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
|
||||
)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.collect(n_episode=1, render=1 / 35)
|
||||
collector.collect(n_episode=1, render=1 / 35, is_eval=True)
|
||||
|
||||
# trainer
|
||||
result = OfflineTrainer(
|
||||
@ -213,9 +212,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -210,9 +210,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(epoch_stat.info_stat)
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
if collector_result.returns_stat and collector_result.lens_stat:
|
||||
print(
|
||||
f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}",
|
||||
|
@ -25,7 +25,7 @@ else: # pytest
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||
@ -61,7 +61,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 185}
|
||||
default_reward_threshold = {"CartPole-v1": 185}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -169,10 +169,9 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@ else: # pytest
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--eps-test", type=float, default=0.001)
|
||||
@ -58,7 +58,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 170}
|
||||
default_reward_threshold = {"CartPole-v1": 170}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -131,10 +131,9 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -25,7 +25,7 @@ else: # pytest
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="CartPole-v0")
|
||||
parser.add_argument("--task", type=str, default="CartPole-v1")
|
||||
parser.add_argument("--reward-threshold", type=float, default=None)
|
||||
parser.add_argument("--seed", type=int, default=1626)
|
||||
parser.add_argument("--lr", type=float, default=7e-4)
|
||||
@ -56,7 +56,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
||||
args.state_shape = space_info.observation_info.obs_shape
|
||||
args.action_shape = space_info.action_info.action_shape
|
||||
if args.reward_threshold is None:
|
||||
default_reward_threshold = {"CartPole-v0": 180}
|
||||
default_reward_threshold = {"CartPole-v1": 180}
|
||||
args.reward_threshold = default_reward_threshold.get(
|
||||
args.task,
|
||||
env.spec.reward_threshold if env.spec else None,
|
||||
@ -135,9 +135,8 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -231,9 +231,8 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -198,9 +198,8 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None:
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(epoch_stat.info_stat)
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
|
@ -188,8 +188,7 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
|
||||
"watching random agents, as loading pre-trained policies is currently not supported",
|
||||
)
|
||||
policy, _, _ = get_agents(args)
|
||||
policy.eval()
|
||||
[agent.set_eps(args.eps_test) for agent in policy.policies.values()]
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
result.pprint_asdict()
|
||||
|
@ -284,7 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
|
||||
"watching random agents, as loading pre-trained policies is currently not supported",
|
||||
)
|
||||
policy, _, _ = get_agents(args)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
collector_result.pprint_asdict()
|
||||
|
@ -228,8 +228,7 @@ def watch(
|
||||
) -> None:
|
||||
env = DummyVectorEnv([partial(get_env, render_mode="human")])
|
||||
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
|
||||
policy.eval()
|
||||
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
|
||||
collector = Collector(policy, env, exploration_noise=True)
|
||||
result = collector.collect(n_episode=1, render=args.render)
|
||||
result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
result.pprint_asdict()
|
||||
|
Loading…
x
Reference in New Issue
Block a user