update tests

This commit is contained in:
Maximilian Huettenrauch 2024-04-24 17:06:59 +02:00
parent 8cb17de190
commit 49c750fb09
35 changed files with 129 additions and 127 deletions

View File

@ -64,6 +64,7 @@ def policy(request: pytest.FixtureRequest) -> PPOPolicy:
class TestPolicyBasics: class TestPolicyBasics:
def test_get_action(self, policy: PPOPolicy) -> None: def test_get_action(self, policy: PPOPolicy) -> None:
policy.is_eval = True
sample_obs = torch.randn(obs_shape) sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)] actions = [policy.compute_action(sample_obs) for _ in range(10)]

View File

@ -138,9 +138,8 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -160,9 +160,8 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -195,9 +195,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(epoch_stat) pprint.pprint(epoch_stat)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -169,9 +169,8 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -161,7 +161,6 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
assert stop_fn(result.best_reward) assert stop_fn(result.best_reward)
# here we define an imitation collector with a trivial policy # here we define an imitation collector with a trivial policy
policy.eval()
if args.task.startswith("Pendulum"): if args.task.startswith("Pendulum"):
args.reward_threshold -= 50 # lower the goal args.reward_threshold -= 50 # lower the goal
il_net = Net( il_net = Net(

View File

@ -160,10 +160,9 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(epoch_stat.info_stat) pprint.pprint(epoch_stat.info_stat)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector.reset() collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -160,9 +160,8 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -4,12 +4,12 @@ import pprint
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
import pytest
import torch import torch
from gymnasium.spaces import Box from gymnasium.spaces import Box
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import A2CPolicy, ImitationPolicy from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.policy.base import BasePolicy from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import OffpolicyTrainer, OnpolicyTrainer
@ -25,7 +25,7 @@ except ImportError:
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--buffer-size", type=int, default=20000)
@ -60,29 +60,35 @@ def get_args() -> argparse.Namespace:
return parser.parse_known_args()[0] return parser.parse_known_args()[0]
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
# if you want to use python vector env, please refer to other test scripts
train_envs = env = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.training_num,
seed=args.seed,
)
test_envs = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.test_num,
seed=args.seed,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195}
args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
# seed # seed
np.random.seed(args.seed) np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if envpool is not None:
train_envs = env = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.training_num,
seed=args.seed,
)
test_envs = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.test_num,
seed=args.seed,
)
else:
env = gym.make(args.task)
train_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.training_num)])
test_envs = DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
train_envs.seed(args.seed)
test_envs.seed(args.seed)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get(args.task, env.spec.reward_threshold)
# model # model
net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device)
@ -145,14 +151,13 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)
policy.eval()
# here we define an imitation collector with a trivial policy # here we define an imitation collector with a trivial policy
# if args.task == 'CartPole-v0': # if args.task == 'CartPole-v1':
# env.spec.reward_threshold = 190 # lower the goal # env.spec.reward_threshold = 190 # lower the goal
net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device) net = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device) actor = Actor(net, args.action_shape, device=args.device).to(args.device)
@ -162,9 +167,23 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
optim=optim, optim=optim,
action_space=env.action_space, action_space=env.action_space,
) )
if envpool is not None:
il_env = envpool.make(
args.task,
env_type="gymnasium",
num_envs=args.test_num,
seed=args.seed,
)
else:
il_env = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
context="fork",
)
il_env.seed(args.seed)
il_test_collector = Collector( il_test_collector = Collector(
il_policy, il_policy,
envpool.make(args.task, env_type="gymnasium", num_envs=args.test_num, seed=args.seed), il_env,
) )
train_collector.reset() train_collector.reset()
result = OffpolicyTrainer( result = OffpolicyTrainer(
@ -186,9 +205,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
il_policy.eval()
collector = Collector(il_policy, env) collector = Collector(il_policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -148,11 +148,14 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__": if __name__ == "__main__":
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render) collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
collector_stats.pprint_asdict() collector_stats.pprint_asdict()

View File

@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -68,7 +68,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -206,10 +206,10 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -24,7 +24,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -62,7 +62,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -159,10 +159,10 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -19,7 +19,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -55,7 +55,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -136,9 +136,9 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -67,7 +67,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -176,10 +176,10 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -25,7 +25,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -67,7 +67,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -172,10 +172,10 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--buffer-size", type=int, default=20000)
@ -51,7 +51,7 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -129,9 +129,9 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -23,7 +23,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--buffer-size", type=int, default=20000)
@ -64,7 +64,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -156,9 +156,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -20,7 +20,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -60,10 +60,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.task == "CartPole-v0" and env.spec: if args.task == "CartPole-v1" and env.spec:
env.spec.reward_threshold = 190 # lower the goal env.spec.reward_threshold = 190 # lower the goal
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -161,10 +161,10 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -69,7 +69,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -223,10 +223,10 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--buffer-size", type=int, default=20000)
@ -60,7 +60,7 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 170} # lower the goal default_reward_threshold = {"CartPole-v1": 170} # lower the goal
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -147,9 +147,9 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -7,7 +7,7 @@ from tianshou.highlevel.env import (
class DiscreteTestEnvFactory(EnvFactoryRegistered): class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__( super().__init__(
task="CartPole-v0", task="CartPole-v1",
train_seed=42, train_seed=42,
test_seed=1337, test_seed=1337,
venv_type=VectorEnvType.DUMMY, venv_type=VectorEnvType.DUMMY,

View File

@ -21,7 +21,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -79,7 +79,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -202,10 +202,9 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -22,7 +22,7 @@ from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--buffer-size", type=int, default=20000) parser.add_argument("--buffer-size", type=int, default=20000)
@ -83,7 +83,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 195} default_reward_threshold = {"CartPole-v1": 195}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -194,9 +194,8 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -120,10 +120,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__": if __name__ == "__main__":
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
policy.eval()
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset() test_collector.reset()
stats = test_collector.collect(n_episode=args.test_num, render=args.render) stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
stats.pprint_asdict() stats.pprint_asdict()
elif env.spec.reward_threshold: elif env.spec.reward_threshold:
assert result.best_reward >= env.spec.reward_threshold assert result.best_reward >= env.spec.reward_threshold

View File

@ -19,12 +19,12 @@ from tianshou.utils.space_info import SpaceInfo
def expert_file_name() -> str: def expert_file_name() -> str:
return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v0.pkl") return os.path.join(os.path.dirname(__file__), "expert_QRDQN_CartPole-v1.pkl")
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1) parser.add_argument("--seed", type=int, default=1)
parser.add_argument("--eps-test", type=float, default=0.05) parser.add_argument("--eps-test", type=float, default=0.05)
@ -67,7 +67,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 190} default_reward_threshold = {"CartPole-v1": 190}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -167,7 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
policy.set_eps(0.2) policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True) collector = Collector(policy, test_envs, buf, exploration_noise=True)
collector.reset() collector.reset()
collector_stats = collector.collect(n_step=args.buffer_size) collector_stats = collector.collect(n_step=args.buffer_size, is_eval=True)
if args.save_buffer_name.endswith(".hdf5"): if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name) buf.save_hdf5(args.save_buffer_name)
else: else:

View File

@ -189,9 +189,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
policy.load_state_dict( policy.load_state_dict(
torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")), torch.load(os.path.join(log_path, "policy.pth"), map_location=torch.device("cpu")),
) )
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35) collector.collect(n_episode=1, render=1 / 35, is_eval=True)
# trainer # trainer
result = OfflineTrainer( result = OfflineTrainer(
@ -213,9 +212,8 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__": if __name__ == "__main__":
pprint.pprint(result) pprint.pprint(result)
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -210,9 +210,8 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__": if __name__ == "__main__":
pprint.pprint(epoch_stat.info_stat) pprint.pprint(epoch_stat.info_stat)
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render) collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
if collector_result.returns_stat and collector_result.lens_stat: if collector_result.returns_stat and collector_result.lens_stat:
print( print(
f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}", f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}",

View File

@ -25,7 +25,7 @@ else: # pytest
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--eps-test", type=float, default=0.001)
@ -61,7 +61,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 185} default_reward_threshold = {"CartPole-v1": 185}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -169,10 +169,9 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -24,7 +24,7 @@ else: # pytest
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001) parser.add_argument("--eps-test", type=float, default=0.001)
@ -58,7 +58,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 170} default_reward_threshold = {"CartPole-v1": 170}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -131,10 +131,9 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test) policy.set_eps(args.eps_test)
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -25,7 +25,7 @@ else: # pytest
def get_args() -> argparse.Namespace: def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0") parser.add_argument("--task", type=str, default="CartPole-v1")
parser.add_argument("--reward-threshold", type=float, default=None) parser.add_argument("--reward-threshold", type=float, default=None)
parser.add_argument("--seed", type=int, default=1626) parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--lr", type=float, default=7e-4) parser.add_argument("--lr", type=float, default=7e-4)
@ -56,7 +56,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
args.state_shape = space_info.observation_info.obs_shape args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape args.action_shape = space_info.action_info.action_shape
if args.reward_threshold is None: if args.reward_threshold is None:
default_reward_threshold = {"CartPole-v0": 180} default_reward_threshold = {"CartPole-v1": 180}
args.reward_threshold = default_reward_threshold.get( args.reward_threshold = default_reward_threshold.get(
args.task, args.task,
env.spec.reward_threshold if env.spec else None, env.spec.reward_threshold if env.spec else None,
@ -135,9 +135,8 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -231,9 +231,8 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
pprint.pprint(result) pprint.pprint(result)
# Let's watch its performance! # Let's watch its performance!
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -198,9 +198,8 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None:
if __name__ == "__main__": if __name__ == "__main__":
pprint.pprint(epoch_stat.info_stat) pprint.pprint(epoch_stat.info_stat)
env = gym.make(args.task) env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats) print(collector_stats)

View File

@ -188,8 +188,7 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
"watching random agents, as loading pre-trained policies is currently not supported", "watching random agents, as loading pre-trained policies is currently not supported",
) )
policy, _, _ = get_agents(args) policy, _, _ = get_agents(args)
policy.eval()
[agent.set_eps(args.eps_test) for agent in policy.policies.values()] [agent.set_eps(args.eps_test) for agent in policy.policies.values()]
collector = Collector(policy, env, exploration_noise=True) collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render) result = collector.collect(n_episode=1, render=args.render, is_eval=True)
result.pprint_asdict() result.pprint_asdict()

View File

@ -284,7 +284,6 @@ def watch(args: argparse.Namespace = get_args(), policy: BasePolicy | None = Non
"watching random agents, as loading pre-trained policies is currently not supported", "watching random agents, as loading pre-trained policies is currently not supported",
) )
policy, _, _ = get_agents(args) policy, _, _ = get_agents(args)
policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render) collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
collector_result.pprint_asdict() collector_result.pprint_asdict()

View File

@ -228,8 +228,7 @@ def watch(
) -> None: ) -> None:
env = DummyVectorEnv([partial(get_env, render_mode="human")]) env = DummyVectorEnv([partial(get_env, render_mode="human")])
policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent) policy, optim, agents = get_agents(args, agent_learn=agent_learn, agent_opponent=agent_opponent)
policy.eval()
policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test) policy.policies[agents[args.agent_id - 1]].set_eps(args.eps_test)
collector = Collector(policy, env, exploration_noise=True) collector = Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=1, render=args.render) result = collector.collect(n_episode=1, render=args.render, is_eval=True)
result.pprint_asdict() result.pprint_asdict()