From 69f07a8f12bce7003aad1e897476f422763da247 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Fri, 26 Apr 2024 17:37:12 +0200 Subject: [PATCH] Tests: fixed typing issues by declaring union types and no longer reusing var names --- test/base/test_buffer.py | 50 ++++++++++++++++--------------- test/base/test_collector.py | 8 ++--- test/base/test_utils.py | 6 ++-- test/offline/test_discrete_bcq.py | 3 +- test/offline/test_discrete_cql.py | 3 +- test/offline/test_discrete_crr.py | 3 +- tianshou/utils/torch_utils.py | 5 ++-- 7 files changed, 42 insertions(+), 36 deletions(-) diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 40f450c..1b3593d 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -2,6 +2,7 @@ import os import pickle import tempfile from test.base.env import MoveToRightEnv, MyGoalEnv +from typing import cast import h5py import numpy as np @@ -381,25 +382,25 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf[tmp_indices].obs - obs_next = buf[tmp_indices].obs_next - rew = buf[tmp_indices].rew - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf[tmp_indices].obs) + obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) + rew_in_buf = buf[tmp_indices].rew + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g[0]) assert np.all(g_next == g_next[0]) - assert np.all(rew == (ag_next == g).astype(np.float32)) + assert np.all(rew_in_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf.next(tmp_indices) # Check that goals are correctly restored buf._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf[tmp_indices].obs - obs_next = buf[tmp_indices].obs_next - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf[tmp_indices].obs) + obs_next_buf = cast(Batch, buf[tmp_indices].obs_next) + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -411,24 +412,24 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4) # Check that goals are the same for the episode (only 1 ep in buffer) tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf2[tmp_indices].obs - obs_next = buf2[tmp_indices].obs_next - rew = buf2[tmp_indices].rew - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf2[tmp_indices].obs) + obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) + rew_buf = buf2[tmp_indices].rew + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == g_next) - assert np.all(rew == (ag_next == g).astype(np.float32)) + assert np.all(rew_buf == (ag_next == g).astype(np.float32)) tmp_indices = buf2.next(tmp_indices) # Check that goals are correctly restored buf2._restore_cache() tmp_indices = indices.copy() for _ in range(2 * env_size): - obs = buf2[tmp_indices].obs - obs_next = buf2[tmp_indices].obs_next - g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] - g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] + obs_in_buf = cast(Batch, buf2[tmp_indices].obs) + obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next) + g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0] + g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0] assert np.all(g == env_size) assert np.all(g_next == g_next[0]) assert np.all(g == g[0]) @@ -442,7 +443,6 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf._index = 5 # shifted start index buf.future_p = 1 - action_list = [1] * 10 for ep_len in [5, 10]: obs, _ = env.reset() for i in range(ep_len): @@ -1030,6 +1030,7 @@ def test_multibuf_stack() -> None: size, ) obs, info = env.reset(options={"state": 1}) + obs = cast(np.ndarray, obs) for i in range(18): obs_next, rew, terminated, truncated, info = env.step(1) done = terminated or truncated @@ -1057,7 +1058,8 @@ def test_multibuf_stack() -> None: assert np.all(buf4.truncated == buf5.truncated) obs = obs_next if done: - obs, info = env.reset(options={"state": 1}) + # obs is an array, but the env is malformed, so we can't properly type it + obs, info = env.reset(options={"state": 1}) # type: ignore[assignment] # check the `add` order is correct assert np.allclose( buf4.obs.reshape(-1), diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 08483bb..d03a54d 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -218,11 +218,11 @@ def test_collector() -> None: c_dummy_venv_4_envs.collect(n_episode=4, random=True) # test corner case - with pytest.raises(TypeError): + with pytest.raises(ValueError): Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) - with pytest.raises(TypeError): + with pytest.raises(ValueError): c_dummy_venv_4_envs.collect() def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: @@ -260,7 +260,7 @@ class TestAsyncCollector: async_collector_and_env_lens: tuple[AsyncCollector, list[int]], ) -> None: c1, env_lens = async_collector_and_env_lens - with pytest.raises(TypeError): + with pytest.raises(ValueError): c1.collect() def test_collect_one_episode_async( diff --git a/test/base/test_utils.py b/test/base/test_utils.py index 1d8e37c..ac3b2fa 100644 --- a/test/base/test_utils.py +++ b/test/base/test_utils.py @@ -134,7 +134,7 @@ def test_lr_schedulers() -> None: ) -def test_in_eval_mode(): +def test_in_eval_mode() -> None: module = nn.Linear(3, 4) module.train() with in_eval_mode(module): @@ -142,9 +142,9 @@ def test_in_eval_mode(): assert module.training -def test_in_train_mode(): +def test_in_train_mode() -> None: module = nn.Linear(3, 4) module.eval() with in_train_mode(module): assert module.training - assert not module.training \ No newline at end of file + assert not module.training diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 5898099..6e8e878 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -8,7 +8,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteBCQPolicy from tianshou.trainer import OfflineTrainer @@ -96,6 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: imitation_logits_penalty=args.imitation_logits_penalty, ) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) diff --git a/test/offline/test_discrete_cql.py b/test/offline/test_discrete_cql.py index b62fdcc..f2a60e0 100644 --- a/test/offline/test_discrete_cql.py +++ b/test/offline/test_discrete_cql.py @@ -8,7 +8,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCQLPolicy from tianshou.trainer import OfflineTrainer @@ -85,6 +85,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None: min_q_weight=args.min_q_weight, ).to(args.device) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) diff --git a/test/offline/test_discrete_crr.py b/test/offline/test_discrete_crr.py index eee8725..bc54dd9 100644 --- a/test/offline/test_discrete_crr.py +++ b/test/offline/test_discrete_crr.py @@ -8,7 +8,7 @@ import numpy as np import torch from torch.utils.tensorboard import SummaryWriter -from tianshou.data import Collector, VectorReplayBuffer +from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer from tianshou.env import DummyVectorEnv from tianshou.policy import BasePolicy, DiscreteCRRPolicy from tianshou.trainer import OfflineTrainer @@ -89,6 +89,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None: target_update_freq=args.target_update_freq, ).to(args.device) # buffer + buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if args.load_buffer_name.endswith(".hdf5"): buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) diff --git a/tianshou/utils/torch_utils.py b/tianshou/utils/torch_utils.py index 1676b52..2fb70da 100644 --- a/tianshou/utils/torch_utils.py +++ b/tianshou/utils/torch_utils.py @@ -1,10 +1,11 @@ +from collections.abc import Iterator from contextlib import contextmanager from torch import nn @contextmanager -def in_eval_mode(module: nn.Module) -> None: +def in_eval_mode(module: nn.Module) -> Iterator[None]: """Temporarily switch to evaluation mode.""" train = module.training try: @@ -15,7 +16,7 @@ def in_eval_mode(module: nn.Module) -> None: @contextmanager -def in_train_mode(module: nn.Module) -> None: +def in_train_mode(module: nn.Module) -> Iterator[None]: """Temporarily switch to training mode.""" train = module.training try: