Tests: removed all instances of if __name__ == ... in tests

A test is not a script and should not be used as such

Also marked pistonball test as skipped since it doesn't actually test anything
This commit is contained in:
Michael Panchenko 2024-04-26 14:58:58 +02:00
parent 7d59302095
commit 12d4262f80
39 changed files with 15 additions and 651 deletions

View File

@ -48,10 +48,3 @@ def test_shmem_vec_env_action_space() -> None:
action2 = [ac_space.sample() for ac_space in envs.action_space]
assert action1 == action2
if __name__ == "__main__":
test_gym_env_action_space()
test_dummy_vec_env_action_space()
test_subproc_vec_env_action_space()
test_shmem_vec_env_action_space()

View File

@ -749,16 +749,3 @@ class TestToTorch:
assert id_batch == id(batch)
assert isinstance(batch.b, torch.Tensor)
assert isinstance(batch.c.d, torch.Tensor)
if __name__ == "__main__":
test_batch()
test_batch_over_batch()
test_batch_over_batch_to_torch()
test_utils_to_torch_numpy()
test_batch_pickle()
test_batch_from_to_numpy_without_copy()
test_batch_standard_compatibility()
test_batch_cat_and_stack()
test_batch_copy()
test_batch_empty()

View File

@ -1,7 +1,7 @@
import os
import pickle
import tempfile
from timeit import timeit
from test.base.env import MoveToRightEnv, MyGoalEnv
import h5py
import numpy as np
@ -22,11 +22,6 @@ from tianshou.data import (
)
from tianshou.data.utils.converter import to_hdf5
if __name__ == "__main__":
from env import MoveToRightEnv, MyGoalEnv
else: # pytest
from test.base.env import MoveToRightEnv, MyGoalEnv
def test_replaybuffer(size: int = 10, bufsize: int = 20) -> None:
env = MoveToRightEnv(size)
@ -607,24 +602,6 @@ def test_segtree() -> None:
index = tree.get_prefix_sum_idx(scalar)
assert naive[:index].sum() <= scalar <= naive[: index + 1].sum()
# profile
if __name__ == "__main__":
size = 100000
bsz = 64
naive = np.random.rand(size)
tree = SegmentTree(size)
tree[np.arange(size)] = naive
def sample_npbuf() -> np.ndarray:
return np.random.choice(size, bsz, p=naive / naive.sum())
def sample_tree() -> int | np.ndarray:
scalar = np.random.rand(bsz) * tree.reduce()
return tree.get_prefix_sum_idx(scalar)
print("npbuf", timeit(sample_npbuf, setup=sample_npbuf, number=1000))
print("tree", timeit(sample_tree, setup=sample_tree, number=1000))
def test_pickle() -> None:
size = 100
@ -1401,21 +1378,3 @@ def test_custom_key() -> None:
):
assert batch.__dict__[key].is_empty()
assert sampled_batch.__dict__[key].is_empty()
if __name__ == "__main__":
test_replaybuffer()
test_ignore_obs_next()
test_stack()
test_segtree()
test_priortized_replaybuffer()
test_update()
test_pickle()
test_hdf5()
test_replaybuffermanager()
test_cachedbuffer()
test_multibuf_stack()
test_multibuf_hdf5()
test_from_data()
test_herreplaybuffer()
test_custom_key()

View File

@ -1,4 +1,5 @@
from collections.abc import Callable, Sequence
from test.base.env import MoveToRightEnv, NXEnv
from typing import Any
import gymnasium as gym
@ -25,11 +26,6 @@ try:
except ImportError:
envpool = None
if __name__ == "__main__":
from env import MoveToRightEnv, NXEnv
else: # pytest
from test.base.env import MoveToRightEnv, NXEnv
class MaxActionPolicy(BasePolicy):
def __init__(
@ -963,13 +959,3 @@ def test_async_collector_with_vector_env() -> None:
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens)
c2r = c1.collect(n_step=20)
assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens)
if __name__ == "__main__":
test_collector()
test_collector_with_dict_state()
test_collector_with_multi_agent()
test_collector_with_atari_setting()
test_collector_envpool_gym_reset_return_info()
test_collector_with_vector_env()
test_async_collector_with_vector_env()

View File

@ -1,6 +1,7 @@
import sys
import time
from collections.abc import Callable
from test.base.env import MoveToRightEnv, NXEnv
from typing import Any, Literal
import gymnasium as gym
@ -22,11 +23,6 @@ from tianshou.env.gym_wrappers import TruncatedAsTerminated
from tianshou.env.venvs import BaseVectorEnv
from tianshou.utils import RunningMeanStd
if __name__ == "__main__":
from env import MoveToRightEnv, NXEnv
else: # pytest
from test.base.env import MoveToRightEnv, NXEnv
try:
import envpool
except ImportError:
@ -190,19 +186,6 @@ def test_vecenv(size: int = 10, num: int = 8, sleep: float = 0.001) -> None:
for info in infos:
assert recurse_comp(infos[0], info)
if __name__ == "__main__":
t = [0.0] * len(venv)
for i, e in enumerate(venv):
t[i] = time.time()
e.reset()
for a in action_list:
done = e.step(np.array([a] * num))[2]
if sum(done) > 0:
e.reset(np.where(done)[0])
t[i] = time.time() - t[i]
for i, v in enumerate(venv):
print(f"{type(v)}: {t[i]:.6f}s")
def assert_get(v: BaseVectorEnv, expected: list) -> None:
assert v.get_env_attr("size") == expected
assert v.get_env_attr("size", id=0) == [expected[0]]
@ -437,17 +420,3 @@ def test_venv_wrapper_envpool_gym_reset_return_info() -> None:
for _, v in _info.items():
if not isinstance(v, dict):
assert v.shape[0] == num_envs
if __name__ == "__main__":
test_venv_norm_obs()
test_venv_wrapper_gym()
test_venv_wrapper_envpool()
test_venv_wrapper_envpool_gym_reset_return_info()
test_env_obs_dtype()
test_vecenv()
test_attr_unwrapped()
test_async_env()
test_async_check_id()
test_env_reset_optional_kwargs()
test_gym_wrappers()

View File

@ -268,8 +268,3 @@ def test_finite_subproc_vector_env() -> None:
test_collector.collect(n_step=10**18)
except StopIteration:
envs.tracker.validate()
if __name__ == "__main__":
test_finite_dummy_vector_env()
test_finite_subproc_vector_env()

View File

@ -1,10 +1,7 @@
from timeit import timeit
import numpy as np
import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy
from tianshou.data.types import BatchWithReturnsProtocol
from tianshou.policy import BasePolicy
@ -142,28 +139,6 @@ def test_episodic_returns(size: int = 2560) -> None:
)
assert np.allclose(returns, ground_truth)
if __name__ == "__main__":
buf = ReplayBuffer(size)
batch = Batch(
terminated=np.random.randint(100, size=size) == 0,
truncated=np.zeros(size),
rew=np.random.random(size),
)
for b in iter(batch):
b.obs = b.act = 1
buf.add(b)
indices = buf.sample_indices(0)
def vanilla() -> Batch:
return compute_episodic_return_base(batch, gamma=0.1)
def optimized() -> tuple[np.ndarray, np.ndarray]:
return fn(batch, buf, indices, gamma=0.1, gae_lambda=1.0)
cnt = 3000
print("GAE vanilla", timeit(vanilla, setup=vanilla, number=cnt))
print("GAE optim ", timeit(optimized, setup=optimized, number=cnt))
def target_q_fn(buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
# return the next reward
@ -356,41 +331,3 @@ def test_nstep_returns_with_timelimit(size: int = 10000) -> None:
).pop("returns"),
)
assert np.allclose(returns_multidim, returns[:, np.newaxis])
if __name__ == "__main__":
buf = ReplayBuffer(size)
for i in range(int(size * 1.5)):
buf.add(
Batch(
obs=0,
act=0,
rew=i + 1,
terminated=np.random.randint(3) == 0,
truncated=i % 33 == 0,
info={},
),
)
batch, indices = buf.sample(256)
def vanilla() -> np.ndarray:
return compute_nstep_return_base(3, 0.1, buf, indices)
def optimized() -> BatchWithReturnsProtocol:
return BasePolicy.compute_nstep_return(
batch,
buf,
indices,
target_q_fn,
gamma=0.1,
n_step=3,
)
cnt = 3000
print("nstep vanilla", timeit(vanilla, setup=vanilla, number=cnt))
print("nstep optim ", timeit(optimized, setup=optimized, number=cnt))
if __name__ == "__main__":
test_nstep_returns()
test_nstep_returns_with_timelimit()
test_episodic_returns()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -133,15 +132,3 @@ def test_ddpg(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_ddpg()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -155,15 +154,3 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_npg()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -191,19 +190,7 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
assert stop_fn(epoch_stat.info_stat.best_reward)
if __name__ == "__main__":
pprint.pprint(epoch_stat)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_ppo_resume(args: argparse.Namespace = get_args()) -> None:
args.resume = True
test_ppo(args)
if __name__ == "__main__":
test_ppo()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -164,15 +163,3 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_redq()

View File

@ -204,7 +204,3 @@ def test_sac_with_il(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
test_sac_with_il()

View File

@ -155,16 +155,3 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
# print(info)
assert stop_fn(epoch_stat.info_stat.best_reward)
if __name__ == "__main__":
pprint.pprint(epoch_stat.info_stat)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_td3()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -155,15 +154,3 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_trpo()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -147,15 +146,6 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
# here we define an imitation collector with a trivial policy
# if args.task == 'CartPole-v1':
# env.spec.reward_threshold = 190 # lower the goal
@ -200,16 +190,3 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(il_policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_a2c_with_il()

View File

@ -1,5 +1,4 @@
import argparse
import pprint
import gymnasium as gym
import numpy as np
@ -129,7 +128,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
return mean_rewards >= args.reward_threshold
# trainer
result = OffpolicyTrainer(
OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
@ -143,21 +142,3 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
test_fn=test_fn,
stop_fn=stop_fn,
).run()
# assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
is_eval=True,
)
collector_stats.pprint_asdict()
if __name__ == "__main__":
test_bdq(get_args())

View File

@ -1,7 +1,6 @@
import argparse
import os
import pickle
import pprint
import gymnasium as gym
import numpy as np
@ -202,16 +201,6 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_c51_resume(args: argparse.Namespace = get_args()) -> None:
args.resume = True
@ -223,7 +212,3 @@ def test_pc51(args: argparse.Namespace = get_args()) -> None:
args.gamma = 0.95
args.seed = 1
test_c51(args)
if __name__ == "__main__":
test_c51(get_args())

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -155,23 +154,9 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_pdqn(args: argparse.Namespace = get_args()) -> None:
args.prioritized_replay = True
args.gamma = 0.95
args.seed = 1
test_dqn(args)
if __name__ == "__main__":
test_dqn(get_args())

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -131,16 +130,3 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_drqn(get_args())

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -172,22 +171,8 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_pfqf(args: argparse.Namespace = get_args()) -> None:
args.prioritized_replay = True
args.gamma = 0.95
test_fqf(args)
if __name__ == "__main__":
test_fqf(get_args())

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -168,22 +167,8 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_piqn(args: argparse.Namespace = get_args()) -> None:
args.prioritized_replay = True
args.gamma = 0.95
test_iqn(args)
if __name__ == "__main__":
test_iqn(get_args())

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -124,16 +123,3 @@ def test_pg(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_pg()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -151,16 +150,3 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_ppo()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -157,22 +156,8 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_pqrdqn(args: argparse.Namespace = get_args()) -> None:
args.prioritized_replay = True
args.gamma = 0.95
test_qrdqn(args)
if __name__ == "__main__":
test_pqrdqn(get_args())

View File

@ -1,7 +1,6 @@
import argparse
import os
import pickle
import pprint
import gymnasium as gym
import numpy as np
@ -219,16 +218,6 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_rainbow_resume(args: argparse.Namespace = get_args()) -> None:
args.resume = True
@ -240,7 +229,3 @@ def test_prainbow(args: argparse.Namespace = get_args()) -> None:
args.gamma = 0.95
args.seed = 1
test_rainbow(args)
if __name__ == "__main__":
test_rainbow(get_args())

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -142,16 +141,3 @@ def test_discrete_sac(args: argparse.Namespace = get_args()) -> None:
test_in_train=False,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_discrete_sac()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -197,16 +196,3 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_dqn_icm(get_args())

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import gymnasium as gym
import numpy as np
@ -189,15 +188,3 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
logger=logger,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_ppo()

View File

@ -1,6 +1,5 @@
import argparse
import os
import pprint
import numpy as np
import pytest
@ -116,17 +115,4 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
logger=logger,
test_in_train=False,
).run()
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
stats = test_collector.collect(n_episode=args.test_num, render=args.render, is_eval=True)
stats.pprint_asdict()
elif env.spec.reward_threshold:
assert result.best_reward >= env.spec.reward_threshold
if __name__ == "__main__":
test_psrl()
assert result.best_reward >= env.spec.reward_threshold

View File

@ -2,7 +2,7 @@ import argparse
import datetime
import os
import pickle
import pprint
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
import numpy as np
@ -19,11 +19,6 @@ from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
from tianshou.utils.space_info import SpaceInfo
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
@ -207,15 +202,3 @@ def test_bcq(args: argparse.Namespace = get_args()) -> None:
show_progress=args.show_progress,
).run()
assert stop_fn(result.best_reward)
# Let's watch its performance!
if __name__ == "__main__":
pprint.pprint(result)
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_bcq()

View File

@ -3,6 +3,7 @@ import datetime
import os
import pickle
import pprint
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
import numpy as np
@ -19,11 +20,6 @@ from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.space_info import SpaceInfo
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
@ -205,18 +201,3 @@ def test_cql(args: argparse.Namespace = get_args()) -> None:
# print(info)
assert stop_fn(epoch_stat.info_stat.best_reward)
# Let's watch its performance!
if __name__ == "__main__":
pprint.pprint(epoch_stat.info_stat)
env = gym.make(args.task)
collector = Collector(policy, env)
collector_result = collector.collect(n_episode=1, render=args.render, is_eval=True)
if collector_result.returns_stat and collector_result.lens_stat:
print(
f"Final reward: {collector_result.returns_stat.mean}, length: {collector_result.lens_stat.mean}",
)
if __name__ == "__main__":
test_cql()

View File

@ -1,7 +1,7 @@
import argparse
import os
import pickle
import pprint
from test.offline.gather_cartpole_data import expert_file_name, gather_data
import gymnasium as gym
import numpy as np
@ -17,11 +17,6 @@ from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor
from tianshou.utils.space_info import SpaceInfo
if __name__ == "__main__":
from gather_cartpole_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_cartpole_data import expert_file_name, gather_data
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
@ -165,21 +160,8 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None:
test_discrete_bcq()
args.resume = True
test_discrete_bcq(args)
if __name__ == "__main__":
test_discrete_bcq(get_args())

View File

@ -1,7 +1,7 @@
import argparse
import os
import pickle
import pprint
from test.offline.gather_cartpole_data import expert_file_name, gather_data
import gymnasium as gym
import numpy as np
@ -16,11 +16,6 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.space_info import SpaceInfo
if __name__ == "__main__":
from gather_cartpole_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_cartpole_data import expert_file_name, gather_data
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
@ -126,16 +121,3 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_discrete_cql(get_args())

View File

@ -1,7 +1,7 @@
import argparse
import os
import pickle
import pprint
from test.offline.gather_cartpole_data import expert_file_name, gather_data
import gymnasium as gym
import numpy as np
@ -17,11 +17,6 @@ from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic
from tianshou.utils.space_info import SpaceInfo
if __name__ == "__main__":
from gather_cartpole_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_cartpole_data import expert_file_name, gather_data
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
@ -130,15 +125,3 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_discrete_crr(get_args())

View File

@ -1,7 +1,7 @@
import argparse
import os
import pickle
import pprint
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
import numpy as np
@ -18,11 +18,6 @@ from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.space_info import SpaceInfo
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
@ -226,15 +221,3 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
save_checkpoint_fn=save_checkpoint_fn,
).run()
assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_gail()

View File

@ -2,7 +2,7 @@ import argparse
import datetime
import os
import pickle
import pprint
from test.offline.gather_pendulum_data import expert_file_name, gather_data
import gymnasium as gym
import numpy as np
@ -20,11 +20,6 @@ from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor, Critic
from tianshou.utils.space_info import SpaceInfo
if __name__ == "__main__":
from gather_pendulum_data import expert_file_name, gather_data
else: # pytest
from test.offline.gather_pendulum_data import expert_file_name, gather_data
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
@ -193,15 +188,3 @@ def test_td3_bc(args: argparse.Namespace = get_args()) -> None:
# print(info)
assert stop_fn(epoch_stat.info_stat.best_reward)
# Let's watch its performance!
if __name__ == "__main__":
pprint.pprint(epoch_stat.info_stat)
env = gym.make(args.task)
collector = Collector(policy, env)
collector_stats = collector.collect(n_episode=1, render=args.render, is_eval=True)
print(collector_stats)
if __name__ == "__main__":
test_td3_bc()

View File

@ -1,22 +1,14 @@
import argparse
import pprint
import pytest
from pistonball import get_args, train_agent, watch
@pytest.mark.skip(reason="Performance bound was never tested, no point in running this for now")
def test_piston_ball(args: argparse.Namespace = get_args()) -> None:
if args.watch:
watch(args)
return
result, agent = train_agent(args)
train_agent(args)
# assert result.best_reward >= args.win_rate
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == "__main__":
test_piston_ball(get_args())

View File

@ -1,5 +1,4 @@
import argparse
import pprint
import pytest
from pistonball_continuous import get_args, train_agent, watch
@ -13,12 +12,3 @@ def test_piston_ball_continuous(args: argparse.Namespace = get_args()) -> None:
result, agent = train_agent(args)
# assert result.best_reward >= 30.0
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == "__main__":
test_piston_ball_continuous(get_args())

View File

@ -1,5 +1,4 @@
import argparse
import pprint
from tic_tac_toe import get_args, train_agent, watch
@ -11,12 +10,3 @@ def test_tic_tac_toe(args: argparse.Namespace = get_args()) -> None:
result, agent = train_agent(args)
assert result.best_reward >= args.win_rate
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
watch(args, agent)
if __name__ == "__main__":
test_tic_tac_toe(get_args())