update tests
This commit is contained in:
		
							parent
							
								
									8cb17de190
								
							
						
					
					
						commit
						49c750fb09
					
				@ -64,6 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy:
 | 
			
		||||
 | 
			
		||||
class TestPolicyBasics:
 | 
			
		||||
    def test_get_action(self, policy: PPOPolicy) -> None:
 | 
			
		||||
        policy.is_eval = True
 | 
			
		||||
        sample_obs = torch.randn(obs_shape)
 | 
			
		||||
        policy.deterministic_eval = False
 | 
			
		||||
        actions = [policy.compute_action(sample_obs) for _ in range(10)]
 | 
			
		||||
 | 
			
		||||
@ -138,9 +138,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -160,9 +160,8 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -195,9 +195,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(epoch_stat)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -169,9 +169,8 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -161,7 +161,6 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    assert stop_fn(result.best_reward)
 | 
			
		||||
 | 
			
		||||
    # here we define an imitation collector with a trivial policy
 | 
			
		||||
    policy.eval()
 | 
			
		||||
    if args.task.startswith("Pendulum"):
 | 
			
		||||
        args.reward_threshold -= 50  # lower the goal
 | 
			
		||||
    il_net = Net(
 | 
			
		||||
 | 
			
		||||
@ -160,10 +160,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(epoch_stat.info_stat)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -160,9 +160,8 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,12 +4,12 @@ import pprint
 | 
			
		||||
 | 
			
		||||
import gymnasium as gym
 | 
			
		||||
import numpy as np
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
from gymnasium.spaces import Box
 | 
			
		||||
from torch.utils.tensorboard import SummaryWriter
 | 
			
		||||
 | 
			
		||||
from tianshou.data import Collector, VectorReplayBuffer
 | 
			
		||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
 | 
			
		||||
from tianshou.policy import A2CPolicy, ImitationPolicy
 | 
			
		||||
from tianshou.policy.base import BasePolicy
 | 
			
		||||
from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer
 | 
			
		||||
@ -25,7 +25,7 @@ except ImportError:
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--buffer-size", type=int, default=20000)
 | 
			
		||||
@ -60,29 +60,35 @@ def get_args() -> argparse.Namespace:
 | 
			
		||||
    return parser.parse_known_args()[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
 | 
			
		||||
def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    # if you want to use python vector env, please refer to other test scripts
 | 
			
		||||
    train_envs = env = envpool.make(
 | 
			
		||||
        args.task,
 | 
			
		||||
        env_type="gymnasium",
 | 
			
		||||
        num_envs=args.training_num,
 | 
			
		||||
        seed=args.seed,
 | 
			
		||||
    )
 | 
			
		||||
    test_envs = envpool.make(
 | 
			
		||||
        args.task,
 | 
			
		||||
        env_type="gymnasium",
 | 
			
		||||
        num_envs=args.test_num,
 | 
			
		||||
        seed=args.seed,
 | 
			
		||||
    )
 | 
			
		||||
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
			
		||||
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
 | 
			
		||||
    # seed
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
    torch.manual_seed(args.seed)
 | 
			
		||||
 | 
			
		||||
    if envpool is not None:
 | 
			
		||||
        train_envs = env = envpool.make(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env_type="gymnasium",
 | 
			
		||||
            num_envs=args.training_num,
 | 
			
		||||
            seed=args.seed,
 | 
			
		||||
        )
 | 
			
		||||
        test_envs = envpool.make(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env_type="gymnasium",
 | 
			
		||||
            num_envs=args.test_num,
 | 
			
		||||
            seed=args.seed,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)])
 | 
			
		||||
        test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
 | 
			
		||||
        train_envs.seed(args.seed)
 | 
			
		||||
        test_envs.seed(args.seed)
 | 
			
		||||
    args.state_shape = env.observation_space.shape or env.observation_space.n
 | 
			
		||||
    args.action_shape = env.action_space.shape or env.action_space.n
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
 | 
			
		||||
    # model
 | 
			
		||||
    net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
 | 
			
		||||
    actor = Actor(net, args.action_shape, device=args.device).to(args.device)
 | 
			
		||||
@ -145,14 +151,13 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
    policy.eval()
 | 
			
		||||
    # here we define an imitation collector with a trivial policy
 | 
			
		||||
    # if args.task == 'CartPole-v0':
 | 
			
		||||
    # if args.task == 'CartPole-v1':
 | 
			
		||||
    #     env.spec.reward_threshold = 190  # lower the goal
 | 
			
		||||
    net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
 | 
			
		||||
    actor = Actor(net, args.action_shape, device=args.device).to(args.device)
 | 
			
		||||
@ -162,9 +167,23 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        optim=optim,
 | 
			
		||||
        action_space=env.action_space,
 | 
			
		||||
    )
 | 
			
		||||
    if envpool is not None:
 | 
			
		||||
        il_env = envpool.make(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env_type="gymnasium",
 | 
			
		||||
            num_envs=args.test_num,
 | 
			
		||||
            seed=args.seed,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        il_env = SubprocVectorEnv(
 | 
			
		||||
            [lambda: gym.make(args.task) for _ in range(args.test_num)],
 | 
			
		||||
            context="fork",
 | 
			
		||||
        )
 | 
			
		||||
        il_env.seed(args.seed)
 | 
			
		||||
 | 
			
		||||
    il_test_collector = Collector(
 | 
			
		||||
        il_policy,
 | 
			
		||||
        envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed),
 | 
			
		||||
        il_env,
 | 
			
		||||
    )
 | 
			
		||||
    train_collector.reset()
 | 
			
		||||
    result = OffpolicyTrainer(
 | 
			
		||||
@ -186,9 +205,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        il_policy.eval()
 | 
			
		||||
        collector = Collector(il_policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -148,11 +148,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    if __name__ == "__main__":
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        test_envs.seed(args.seed)
 | 
			
		||||
        test_collector.reset()
 | 
			
		||||
        collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
 | 
			
		||||
        collector_stats = test_collector.collect(
 | 
			
		||||
            n_episode=args.test_num,
 | 
			
		||||
            render=args.render,
 | 
			
		||||
            is_eval=True,
 | 
			
		||||
        )
 | 
			
		||||
        collector_stats.pprint_asdict()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -68,7 +68,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -206,10 +206,10 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -62,7 +62,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -159,10 +159,10 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -55,7 +55,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -136,9 +136,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -67,7 +67,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -176,10 +176,10 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=0)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -67,7 +67,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -172,10 +172,10 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--buffer-size", type=int, default=20000)
 | 
			
		||||
@ -51,7 +51,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -129,9 +129,9 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,7 +23,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--buffer-size", type=int, default=20000)
 | 
			
		||||
@ -64,7 +64,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -156,9 +156,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -60,10 +60,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
 | 
			
		||||
    if args.task == "CartPole-v0" and env.spec:
 | 
			
		||||
    if args.task == "CartPole-v1" and env.spec:
 | 
			
		||||
        env.spec.reward_threshold = 190  # lower the goal
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -161,10 +161,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -69,7 +69,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -223,10 +223,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--buffer-size", type=int, default=20000)
 | 
			
		||||
@ -60,7 +60,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 170}  # lower the goal
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 170}  # lower the goal
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -147,9 +147,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector.reset()
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@ from tianshou.highlevel.env import (
 | 
			
		||||
class DiscreteTestEnvFactory(EnvFactoryRegistered):
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__(
 | 
			
		||||
            task="CartPole-v0",
 | 
			
		||||
            task="CartPole-v1",
 | 
			
		||||
            train_seed=42,
 | 
			
		||||
            test_seed=1337,
 | 
			
		||||
            venv_type=VectorEnvType.DUMMY,
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -79,7 +79,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -202,10 +202,9 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--buffer-size", type=int, default=20000)
 | 
			
		||||
@ -83,7 +83,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 195}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 195}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -194,9 +194,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -120,10 +120,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    if __name__ == "__main__":
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        test_envs.seed(args.seed)
 | 
			
		||||
        test_collector.reset()
 | 
			
		||||
        stats = test_collector.collect(n_episode=args.test_num, render=args.render)
 | 
			
		||||
        stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
 | 
			
		||||
        stats.pprint_asdict()
 | 
			
		||||
    elif env.spec.reward_threshold:
 | 
			
		||||
        assert result.best_reward >= env.spec.reward_threshold
 | 
			
		||||
 | 
			
		||||
@ -19,12 +19,12 @@ from tianshou.utils.space_info import SpaceInfo
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def expert_file_name() -> str:
 | 
			
		||||
    return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl")
 | 
			
		||||
    return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.05)
 | 
			
		||||
@ -67,7 +67,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 190}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 190}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
 | 
			
		||||
    policy.set_eps(0.2)
 | 
			
		||||
    collector = Collector(policy, test_envs, buf, exploration_noise=True)
 | 
			
		||||
    collector.reset()
 | 
			
		||||
    collector_stats = collector.collect(n_step=args.buffer_size)
 | 
			
		||||
    collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True)
 | 
			
		||||
    if args.save_buffer_name.endswith(".hdf5"):
 | 
			
		||||
        buf.save_hdf5(args.save_buffer_name)
 | 
			
		||||
    else:
 | 
			
		||||
 | 
			
		||||
@ -189,9 +189,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        policy.load_state_dict(
 | 
			
		||||
            torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
 | 
			
		||||
        )
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector.collect(n_episode=1, render=1 / 35)
 | 
			
		||||
        collector.collect(n_episode=1, render=1 / 35, is_eval=True)
 | 
			
		||||
 | 
			
		||||
    # trainer
 | 
			
		||||
    result = OfflineTrainer(
 | 
			
		||||
@ -213,9 +212,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    if __name__ == "__main__":
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -210,9 +210,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    if __name__ == "__main__":
 | 
			
		||||
        pprint.pprint(epoch_stat.info_stat)
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_result = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        if collector_result.returns_stat and collector_result.lens_stat:
 | 
			
		||||
            print(
 | 
			
		||||
                f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}",
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ else:  # pytest
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.001)
 | 
			
		||||
@ -61,7 +61,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 185}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 185}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -169,10 +169,9 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -24,7 +24,7 @@ else:  # pytest
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--eps-test", type=float, default=0.001)
 | 
			
		||||
@ -58,7 +58,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 170}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 170}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -131,10 +131,9 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        policy.set_eps(args.eps_test)
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -25,7 +25,7 @@ else:  # pytest
 | 
			
		||||
 | 
			
		||||
def get_args() -> argparse.Namespace:
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v0")
 | 
			
		||||
    parser.add_argument("--task", type=str, default="CartPole-v1")
 | 
			
		||||
    parser.add_argument("--reward-threshold", type=float, default=None)
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=1626)
 | 
			
		||||
    parser.add_argument("--lr", type=float, default=7e-4)
 | 
			
		||||
@ -56,7 +56,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    args.state_shape = space_info.observation_info.obs_shape
 | 
			
		||||
    args.action_shape = space_info.action_info.action_shape
 | 
			
		||||
    if args.reward_threshold is None:
 | 
			
		||||
        default_reward_threshold = {"CartPole-v0": 180}
 | 
			
		||||
        default_reward_threshold = {"CartPole-v1": 180}
 | 
			
		||||
        args.reward_threshold = default_reward_threshold.get(
 | 
			
		||||
            args.task,
 | 
			
		||||
            env.spec.reward_threshold if env.spec else None,
 | 
			
		||||
@ -135,9 +135,8 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -231,9 +231,8 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
        pprint.pprint(result)
 | 
			
		||||
        # Let's watch its performance!
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -198,9 +198,8 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None:
 | 
			
		||||
    if __name__ == "__main__":
 | 
			
		||||
        pprint.pprint(epoch_stat.info_stat)
 | 
			
		||||
        env = gym.make(args.task)
 | 
			
		||||
        policy.eval()
 | 
			
		||||
        collector = Collector(policy, env)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
        collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
        print(collector_stats)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -188,8 +188,7 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
 | 
			
		||||
            "watching random agents, as loading pre-trained policies is currently not supported",
 | 
			
		||||
        )
 | 
			
		||||
        policy, _, _ = get_agents(args)
 | 
			
		||||
    policy.eval()
 | 
			
		||||
    [agent.set_eps(args.eps_test) for agent in policy.policies.values()]
 | 
			
		||||
    collector = Collector(policy, env, exploration_noise=True)
 | 
			
		||||
    result = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
    result = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
    result.pprint_asdict()
 | 
			
		||||
 | 
			
		||||
@ -284,7 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
 | 
			
		||||
            "watching random agents, as loading pre-trained policies is currently not supported",
 | 
			
		||||
        )
 | 
			
		||||
        policy, _, _ = get_agents(args)
 | 
			
		||||
    policy.eval()
 | 
			
		||||
    collector = Collector(policy, env)
 | 
			
		||||
    collector_result = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
    collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
    collector_result.pprint_asdict()
 | 
			
		||||
 | 
			
		||||
@ -228,8 +228,7 @@ def watch(
 | 
			
		||||
) -> None:
 | 
			
		||||
    env = DummyVectorEnv([partial(get_env, render_mode="human")])
 | 
			
		||||
    policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
 | 
			
		||||
    policy.eval()
 | 
			
		||||
    policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
 | 
			
		||||
    collector = Collector(policy, env, exploration_noise=True)
 | 
			
		||||
    result = collector.collect(n_episode=1, render=args.render)
 | 
			
		||||
    result = collector.collect(n_episode=1, render=args.render, is_eval=True)
 | 
			
		||||
    result.pprint_asdict()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user