Tests: removed all instances of if __name__ == ...
in tests
A test is not a script and should not be used as such Also marked pistonball test as skipped since it doesn't actually test anything
This commit is contained in:
parent
7d59302095
commit
12d4262f80
@ -48,10 +48,3 @@ def test_shmem_vec_env_action_space() -> None:
|
||||
action2 = [ac_space.sample() for ac_space in envs.action_space]
|
||||
|
||||
assert action1 == action2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gym_env_action_space()
|
||||
test_dummy_vec_env_action_space()
|
||||
test_subproc_vec_env_action_space()
|
||||
test_shmem_vec_env_action_space()
|
||||
|
@ -749,16 +749,3 @@ class TestToTorch:
|
||||
assert id_batch == id(batch)
|
||||
assert isinstance(batch.b, torch.Tensor)
|
||||
assert isinstance(batch.c.d, torch.Tensor)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_batch()
|
||||
test_batch_over_batch()
|
||||
test_batch_over_batch_to_torch()
|
||||
test_utils_to_torch_numpy()
|
||||
test_batch_pickle()
|
||||
test_batch_from_to_numpy_without_copy()
|
||||
test_batch_standard_compatibility()
|
||||
test_batch_cat_and_stack()
|
||||
test_batch_copy()
|
||||
test_batch_empty()
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from timeit import timeit
|
||||
from test.base.env import MoveToRightEnv, MyGoalEnv
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@ -22,11 +22,6 @@ from tianshou.data import (
|
||||
)
|
||||
from tianshou.data.utils.converter import to_hdf5
|
||||
|
||||
if __name__ == "__main__":
|
||||
from env import MoveToRightEnv, MyGoalEnv
|
||||
else: # pytest
|
||||
from test.base.env import MoveToRightEnv, MyGoalEnv
|
||||
|
||||
|
||||
def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None:
|
||||
env = MoveToRightEnv(size)
|
||||
@ -607,24 +602,6 @@ def test_segtree() -> None:
|
||||
index = tree.get_prefix_sum_idx(scalar)
|
||||
assert naive[:index].sum() <= scalar <= naive[: index + 1].sum()
|
||||
|
||||
# profile
|
||||
if __name__ == "__main__":
|
||||
size = 100000
|
||||
bsz = 64
|
||||
naive = np.random.rand(size)
|
||||
tree = SegmentTree(size)
|
||||
tree[np.arange(size)] = naive
|
||||
|
||||
def sample_npbuf() -> np.ndarray:
|
||||
return np.random.choice(size, bsz, p=naive / naive.sum())
|
||||
|
||||
def sample_tree() -> int | np.ndarray:
|
||||
scalar = np.random.rand(bsz) * tree.reduce()
|
||||
return tree.get_prefix_sum_idx(scalar)
|
||||
|
||||
print("npbuf", timeit(sample_npbuf, setup=sample_npbuf, number=1000))
|
||||
print("tree", timeit(sample_tree, setup=sample_tree, number=1000))
|
||||
|
||||
|
||||
def test_pickle() -> None:
|
||||
size = 100
|
||||
@ -1401,21 +1378,3 @@ def test_custom_key() -> None:
|
||||
):
|
||||
assert batch.__dict__[key].is_empty()
|
||||
assert sampled_batch.__dict__[key].is_empty()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_replaybuffer()
|
||||
test_ignore_obs_next()
|
||||
test_stack()
|
||||
test_segtree()
|
||||
test_priortized_replaybuffer()
|
||||
test_update()
|
||||
test_pickle()
|
||||
test_hdf5()
|
||||
test_replaybuffermanager()
|
||||
test_cachedbuffer()
|
||||
test_multibuf_stack()
|
||||
test_multibuf_hdf5()
|
||||
test_from_data()
|
||||
test_herreplaybuffer()
|
||||
test_custom_key()
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from test.base.env import MoveToRightEnv, NXEnv
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
@ -25,11 +26,6 @@ try:
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
if __name__ == "__main__":
|
||||
from env import MoveToRightEnv, NXEnv
|
||||
else: # pytest
|
||||
from test.base.env import MoveToRightEnv, NXEnv
|
||||
|
||||
|
||||
class MaxActionPolicy(BasePolicy):
|
||||
def __init__(
|
||||
@ -963,13 +959,3 @@ def test_async_collector_with_vector_env() -> None:
|
||||
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens)
|
||||
c2r = c1.collect(n_step=20)
|
||||
assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_collector()
|
||||
test_collector_with_dict_state()
|
||||
test_collector_with_multi_agent()
|
||||
test_collector_with_atari_setting()
|
||||
test_collector_envpool_gym_reset_return_info()
|
||||
test_collector_with_vector_env()
|
||||
test_async_collector_with_vector_env()
|
||||
|
@ -1,6 +1,7 @@
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from test.base.env import MoveToRightEnv, NXEnv
|
||||
from typing import Any, Literal
|
||||
|
||||
import gymnasium as gym
|
||||
@ -22,11 +23,6 @@ from tianshou.env.gym_wrappers import TruncatedAsTerminated
|
||||
from tianshou.env.venvs import BaseVectorEnv
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
if __name__ == "__main__":
|
||||
from env import MoveToRightEnv, NXEnv
|
||||
else: # pytest
|
||||
from test.base.env import MoveToRightEnv, NXEnv
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
@ -190,19 +186,6 @@ def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None:
|
||||
for info in infos:
|
||||
assert recurse_comp(infos[0], info)
|
||||
|
||||
if __name__ == "__main__":
|
||||
t = [0.0] * len(venv)
|
||||
for i, e in enumerate(venv):
|
||||
t[i] = time.time()
|
||||
e.reset()
|
||||
for a in action_list:
|
||||
done = e.step(np.array([a] * num))[2]
|
||||
if sum(done) > 0:
|
||||
e.reset(np.where(done)[0])
|
||||
t[i] = time.time() - t[i]
|
||||
for i, v in enumerate(venv):
|
||||
print(f"{type(v)}: {t[i]:.6f}s")
|
||||
|
||||
def assert_get(v: BaseVectorEnv, expected: list) -> None:
|
||||
assert v.get_env_attr("size") == expected
|
||||
assert v.get_env_attr("size", id=0) == [expected[0]]
|
||||
@ -437,17 +420,3 @@ def test_venv_wrapper_envpool_gym_reset_return_info() -> None:
|
||||
for _, v in _info.items():
|
||||
if not isinstance(v, dict):
|
||||
assert v.shape[0] == num_envs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_venv_norm_obs()
|
||||
test_venv_wrapper_gym()
|
||||
test_venv_wrapper_envpool()
|
||||
test_venv_wrapper_envpool_gym_reset_return_info()
|
||||
test_env_obs_dtype()
|
||||
test_vecenv()
|
||||
test_attr_unwrapped()
|
||||
test_async_env()
|
||||
test_async_check_id()
|
||||
test_env_reset_optional_kwargs()
|
||||
test_gym_wrappers()
|
||||
|
@ -268,8 +268,3 @@ def test_finite_subproc_vector_env() -> None:
|
||||
test_collector.collect(n_step=10**18)
|
||||
except StopIteration:
|
||||
envs.tracker.validate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_finite_dummy_vector_env()
|
||||
test_finite_subproc_vector_env()
|
||||
|
@ -1,10 +1,7 @@
|
||||
from timeit import timeit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy
|
||||
from tianshou.data.types import BatchWithReturnsProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
|
||||
@ -142,28 +139,6 @@ def test_episodic_returns(size: int = 2560) -> None:
|
||||
)
|
||||
assert np.allclose(returns, ground_truth)
|
||||
|
||||
if __name__ == "__main__":
|
||||
buf = ReplayBuffer(size)
|
||||
batch = Batch(
|
||||
terminated=np.random.randint(100, size=size) == 0,
|
||||
truncated=np.zeros(size),
|
||||
rew=np.random.random(size),
|
||||
)
|
||||
for b in iter(batch):
|
||||
b.obs = b.act = 1
|
||||
buf.add(b)
|
||||
indices = buf.sample_indices(0)
|
||||
|
||||
def vanilla() -> Batch:
|
||||
return compute_episodic_return_base(batch, gamma=0.1)
|
||||
|
||||
def optimized() -> tuple[np.ndarray, np.ndarray]:
|
||||
return fn(batch, buf, indices, gamma=0.1, gae_lambda=1.0)
|
||||
|
||||
cnt = 3000
|
||||
print("GAE vanilla", timeit(vanilla, setup=vanilla, number=cnt))
|
||||
print("GAE optim ", timeit(optimized, setup=optimized, number=cnt))
|
||||
|
||||
|
||||
def target_q_fn(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
# return the next reward
|
||||
@ -356,41 +331,3 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None:
|
||||
).pop("returns"),
|
||||
)
|
||||
assert np.allclose(returns_multidim, returns[:, np.newaxis])
|
||||
|
||||
if __name__ == "__main__":
|
||||
buf = ReplayBuffer(size)
|
||||
for i in range(int(size * 1.5)):
|
||||
buf.add(
|
||||
Batch(
|
||||
obs=0,
|
||||
act=0,
|
||||
rew=i + 1,
|
||||
terminated=np.random.randint(3) == 0,
|
||||
truncated=i % 33 == 0,
|
||||
info={},
|
||||
),
|
||||
)
|
||||
batch, indices = buf.sample(256)
|
||||
|
||||
def vanilla() -> np.ndarray:
|
||||
return compute_nstep_return_base(3, 0.1, buf, indices)
|
||||
|
||||
def optimized() -> BatchWithReturnsProtocol:
|
||||
return BasePolicy.compute_nstep_return(
|
||||
batch,
|
||||
buf,
|
||||
indices,
|
||||
target_q_fn,
|
||||
gamma=0.1,
|
||||
n_step=3,
|
||||
)
|
||||
|
||||
cnt = 3000
|
||||
print("nstep vanilla", timeit(vanilla, setup=vanilla, number=cnt))
|
||||
print("nstep optim ", timeit(optimized, setup=optimized, number=cnt))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_nstep_returns()
|
||||
test_nstep_returns_with_timelimit()
|
||||
test_episodic_returns()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -133,15 +132,3 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ddpg()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -155,15 +154,3 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_npg()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -191,19 +190,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
assert stop_fn(epoch_stat.info_stat.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(epoch_stat)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_ppo_resume(args: argparse.Namespace = get_args()) -> None:
|
||||
args.resume = True
|
||||
test_ppo(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ppo()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -164,15 +163,3 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_redq()
|
||||
|
@ -204,7 +204,3 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_sac_with_il()
|
||||
|
@ -155,16 +155,3 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
||||
# print(info)
|
||||
|
||||
assert stop_fn(epoch_stat.info_stat.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(epoch_stat.info_stat)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_td3()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -155,15 +154,3 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_trpo()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -147,15 +146,6 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
# here we define an imitation collector with a trivial policy
|
||||
# if args.task == 'CartPole-v1':
|
||||
# env.spec.reward_threshold = 190 # lower the goal
|
||||
@ -200,16 +190,3 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(il_policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_a2c_with_il()
|
||||
|
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -129,7 +128,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
return mean_rewards >= args.reward_threshold
|
||||
|
||||
# trainer
|
||||
result = OffpolicyTrainer(
|
||||
OffpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
@ -143,21 +142,3 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
test_fn=test_fn,
|
||||
stop_fn=stop_fn,
|
||||
).run()
|
||||
|
||||
# assert stop_fn(result.best_reward)
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
policy.set_eps(args.eps_test)
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
collector_stats = test_collector.collect(
|
||||
n_episode=args.test_num,
|
||||
render=args.render,
|
||||
is_eval=True,
|
||||
)
|
||||
collector_stats.pprint_asdict()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bdq(get_args())
|
||||
|
@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -202,16 +201,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_c51_resume(args: argparse.Namespace = get_args()) -> None:
|
||||
args.resume = True
|
||||
@ -223,7 +212,3 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None:
|
||||
args.gamma = 0.95
|
||||
args.seed = 1
|
||||
test_c51(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_c51(get_args())
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -155,23 +154,9 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_pdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
args.prioritized_replay = True
|
||||
args.gamma = 0.95
|
||||
args.seed = 1
|
||||
test_dqn(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dqn(get_args())
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -131,16 +130,3 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_drqn(get_args())
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -172,22 +171,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_pfqf(args: argparse.Namespace = get_args()) -> None:
|
||||
args.prioritized_replay = True
|
||||
args.gamma = 0.95
|
||||
test_fqf(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_fqf(get_args())
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -168,22 +167,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_piqn(args: argparse.Namespace = get_args()) -> None:
|
||||
args.prioritized_replay = True
|
||||
args.gamma = 0.95
|
||||
test_iqn(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_iqn(get_args())
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -124,16 +123,3 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pg()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -151,16 +150,3 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ppo()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -157,22 +156,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_pqrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
args.prioritized_replay = True
|
||||
args.gamma = 0.95
|
||||
test_qrdqn(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pqrdqn(get_args())
|
||||
|
@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -219,16 +218,6 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None:
|
||||
args.resume = True
|
||||
@ -240,7 +229,3 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
args.gamma = 0.95
|
||||
args.seed = 1
|
||||
test_rainbow(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rainbow(get_args())
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -142,16 +141,3 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
|
||||
test_in_train=False,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_sac()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -197,16 +196,3 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_dqn_icm(get_args())
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -189,15 +188,3 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_ppo()
|
||||
|
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -116,17 +115,4 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
|
||||
logger=logger,
|
||||
test_in_train=False,
|
||||
).run()
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
|
||||
stats.pprint_asdict()
|
||||
elif env.spec.reward_threshold:
|
||||
assert result.best_reward >= env.spec.reward_threshold
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_psrl()
|
||||
assert result.best_reward >= env.spec.reward_threshold
|
||||
|
@ -2,7 +2,7 @@ import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -19,11 +19,6 @@ from tianshou.utils.net.common import MLP, Net
|
||||
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
|
||||
from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_pendulum_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -207,15 +202,3 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
show_progress=args.show_progress,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
# Let's watch its performance!
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bcq()
|
||||
|
@ -3,6 +3,7 @@ import datetime
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -19,11 +20,6 @@ from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_pendulum_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -205,18 +201,3 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
# print(info)
|
||||
|
||||
assert stop_fn(epoch_stat.info_stat.best_reward)
|
||||
|
||||
# Let's watch its performance!
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(epoch_stat.info_stat)
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
if collector_result.returns_stat and collector_result.lens_stat:
|
||||
print(
|
||||
f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_cql()
|
||||
|
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
from test.offline.gather_cartpole_data import expert_file_name, gather_data
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -17,11 +17,6 @@ from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_cartpole_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_cartpole_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -165,21 +160,8 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None:
|
||||
test_discrete_bcq()
|
||||
args.resume = True
|
||||
test_discrete_bcq(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_bcq(get_args())
|
||||
|
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
from test.offline.gather_cartpole_data import expert_file_name, gather_data
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -16,11 +16,6 @@ from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_cartpole_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_cartpole_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -126,16 +121,3 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
policy.set_eps(args.eps_test)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_cql(get_args())
|
||||
|
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
from test.offline.gather_cartpole_data import expert_file_name, gather_data
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -17,11 +17,6 @@ from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_cartpole_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_cartpole_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -130,15 +125,3 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
||||
).run()
|
||||
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_discrete_crr(get_args())
|
||||
|
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -18,11 +18,6 @@ from tianshou.utils.net.common import ActorCritic, Net
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_pendulum_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -226,15 +221,3 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
save_checkpoint_fn=save_checkpoint_fn,
|
||||
).run()
|
||||
assert stop_fn(result.best_reward)
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gail()
|
||||
|
@ -2,7 +2,7 @@ import argparse
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
import pprint
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -20,11 +20,6 @@ from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
from tianshou.utils.space_info import SpaceInfo
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gather_pendulum_data import expert_file_name, gather_data
|
||||
else: # pytest
|
||||
from test.offline.gather_pendulum_data import expert_file_name, gather_data
|
||||
|
||||
|
||||
def get_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -193,15 +188,3 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None:
|
||||
# print(info)
|
||||
|
||||
assert stop_fn(epoch_stat.info_stat.best_reward)
|
||||
|
||||
# Let's watch its performance!
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(epoch_stat.info_stat)
|
||||
env = gym.make(args.task)
|
||||
collector = Collector(policy, env)
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
|
||||
print(collector_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_td3_bc()
|
||||
|
@ -1,22 +1,14 @@
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
from pistonball import get_args, train_agent, watch
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Performance bound was never tested, no point in running this for now")
|
||||
def test_piston_ball(args: argparse.Namespace = get_args()) -> None:
|
||||
if args.watch:
|
||||
watch(args)
|
||||
return
|
||||
|
||||
result, agent = train_agent(args)
|
||||
train_agent(args)
|
||||
# assert result.best_reward >= args.win_rate
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
watch(args, agent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_piston_ball(get_args())
|
||||
|
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
from pistonball_continuous import get_args, train_agent, watch
|
||||
@ -13,12 +12,3 @@ def test_piston_ball_continuous(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
result, agent = train_agent(args)
|
||||
# assert result.best_reward >= 30.0
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
watch(args, agent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_piston_ball_continuous(get_args())
|
||||
|
@ -1,5 +1,4 @@
|
||||
import argparse
|
||||
import pprint
|
||||
|
||||
from tic_tac_toe import get_args, train_agent, watch
|
||||
|
||||
@ -11,12 +10,3 @@ def test_tic_tac_toe(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
result, agent = train_agent(args)
|
||||
assert result.best_reward >= args.win_rate
|
||||
|
||||
if __name__ == "__main__":
|
||||
pprint.pprint(result)
|
||||
# Let's watch its performance!
|
||||
watch(args, agent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_tic_tac_toe(get_args())
|
||||
|
Loading…
x
Reference in New Issue
Block a user