Tests: fixed typing issues by declaring union types and no longer reusing var names

This commit is contained in:
Michael Panchenko 2024-04-26 17:37:12 +02:00
parent 4b619c51ba
commit 69f07a8f12
7 changed files with 42 additions and 36 deletions

View File

@ -2,6 +2,7 @@ import os
import pickle import pickle
import tempfile import tempfile
from test.base.env import MoveToRightEnv, MyGoalEnv from test.base.env import MoveToRightEnv, MyGoalEnv
from typing import cast
import h5py import h5py
import numpy as np import numpy as np
@ -381,25 +382,25 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
# Check that goals are the same for the episode (only 1 ep in buffer) # Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy() tmp_indices = indices.copy()
for _ in range(2 * env_size): for _ in range(2 * env_size):
obs = buf[tmp_indices].obs obs_in_buf = cast(Batch, buf[tmp_indices].obs)
obs_next = buf[tmp_indices].obs_next obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
rew = buf[tmp_indices].rew rew_in_buf = buf[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g[0]) assert np.all(g == g[0])
assert np.all(g_next == g_next[0]) assert np.all(g_next == g_next[0])
assert np.all(rew == (ag_next == g).astype(np.float32)) assert np.all(rew_in_buf == (ag_next == g).astype(np.float32))
tmp_indices = buf.next(tmp_indices) tmp_indices = buf.next(tmp_indices)
# Check that goals are correctly restored # Check that goals are correctly restored
buf._restore_cache() buf._restore_cache()
tmp_indices = indices.copy() tmp_indices = indices.copy()
for _ in range(2 * env_size): for _ in range(2 * env_size):
obs = buf[tmp_indices].obs obs_in_buf = cast(Batch, buf[tmp_indices].obs)
obs_next = buf[tmp_indices].obs_next obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size) assert np.all(g == env_size)
assert np.all(g_next == g_next[0]) assert np.all(g_next == g_next[0])
assert np.all(g == g[0]) assert np.all(g == g[0])
@ -411,24 +412,24 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
# Check that goals are the same for the episode (only 1 ep in buffer) # Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy() tmp_indices = indices.copy()
for _ in range(2 * env_size): for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
obs_next = buf2[tmp_indices].obs_next obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
rew = buf2[tmp_indices].rew rew_buf = buf2[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0] ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g_next) assert np.all(g == g_next)
assert np.all(rew == (ag_next == g).astype(np.float32)) assert np.all(rew_buf == (ag_next == g).astype(np.float32))
tmp_indices = buf2.next(tmp_indices) tmp_indices = buf2.next(tmp_indices)
# Check that goals are correctly restored # Check that goals are correctly restored
buf2._restore_cache() buf2._restore_cache()
tmp_indices = indices.copy() tmp_indices = indices.copy()
for _ in range(2 * env_size): for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
obs_next = buf2[tmp_indices].obs_next obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0] g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0] g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size) assert np.all(g == env_size)
assert np.all(g_next == g_next[0]) assert np.all(g_next == g_next[0])
assert np.all(g == g[0]) assert np.all(g == g[0])
@ -442,7 +443,6 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8) buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8)
buf._index = 5 # shifted start index buf._index = 5 # shifted start index
buf.future_p = 1 buf.future_p = 1
action_list = [1] * 10
for ep_len in [5, 10]: for ep_len in [5, 10]:
obs, _ = env.reset() obs, _ = env.reset()
for i in range(ep_len): for i in range(ep_len):
@ -1030,6 +1030,7 @@ def test_multibuf_stack() -> None:
size, size,
) )
obs, info = env.reset(options={"state": 1}) obs, info = env.reset(options={"state": 1})
obs = cast(np.ndarray, obs)
for i in range(18): for i in range(18):
obs_next, rew, terminated, truncated, info = env.step(1) obs_next, rew, terminated, truncated, info = env.step(1)
done = terminated or truncated done = terminated or truncated
@ -1057,7 +1058,8 @@ def test_multibuf_stack() -> None:
assert np.all(buf4.truncated == buf5.truncated) assert np.all(buf4.truncated == buf5.truncated)
obs = obs_next obs = obs_next
if done: if done:
obs, info = env.reset(options={"state": 1}) # obs is an array, but the env is malformed, so we can't properly type it
obs, info = env.reset(options={"state": 1}) # type: ignore[assignment]
# check the `add` order is correct # check the `add` order is correct
assert np.allclose( assert np.allclose(
buf4.obs.reshape(-1), buf4.obs.reshape(-1),

View File

@ -218,11 +218,11 @@ def test_collector() -> None:
c_dummy_venv_4_envs.collect(n_episode=4, random=True) c_dummy_venv_4_envs.collect(n_episode=4, random=True)
# test corner case # test corner case
with pytest.raises(TypeError): with pytest.raises(ValueError):
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
with pytest.raises(TypeError): with pytest.raises(ValueError):
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError): with pytest.raises(ValueError):
c_dummy_venv_4_envs.collect() c_dummy_venv_4_envs.collect()
def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]: def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]:
@ -260,7 +260,7 @@ class TestAsyncCollector:
async_collector_and_env_lens: tuple[AsyncCollector, list[int]], async_collector_and_env_lens: tuple[AsyncCollector, list[int]],
) -> None: ) -> None:
c1, env_lens = async_collector_and_env_lens c1, env_lens = async_collector_and_env_lens
with pytest.raises(TypeError): with pytest.raises(ValueError):
c1.collect() c1.collect()
def test_collect_one_episode_async( def test_collect_one_episode_async(

View File

@ -134,7 +134,7 @@ def test_lr_schedulers() -> None:
) )
def test_in_eval_mode(): def test_in_eval_mode() -> None:
module = nn.Linear(3, 4) module = nn.Linear(3, 4)
module.train() module.train()
with in_eval_mode(module): with in_eval_mode(module):
@ -142,9 +142,9 @@ def test_in_eval_mode():
assert module.training assert module.training
def test_in_train_mode(): def test_in_train_mode() -> None:
module = nn.Linear(3, 4) module = nn.Linear(3, 4)
module.eval() module.eval()
with in_train_mode(module): with in_train_mode(module):
assert module.training assert module.training
assert not module.training assert not module.training

View File

@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import BasePolicy, DiscreteBCQPolicy from tianshou.policy import BasePolicy, DiscreteBCQPolicy
from tianshou.trainer import OfflineTrainer from tianshou.trainer import OfflineTrainer
@ -96,6 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
imitation_logits_penalty=args.imitation_logits_penalty, imitation_logits_penalty=args.imitation_logits_penalty,
) )
# buffer # buffer
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"): if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)

View File

@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import BasePolicy, DiscreteCQLPolicy from tianshou.policy import BasePolicy, DiscreteCQLPolicy
from tianshou.trainer import OfflineTrainer from tianshou.trainer import OfflineTrainer
@ -85,6 +85,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
min_q_weight=args.min_q_weight, min_q_weight=args.min_q_weight,
).to(args.device) ).to(args.device)
# buffer # buffer
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"): if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)

View File

@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv from tianshou.env import DummyVectorEnv
from tianshou.policy import BasePolicy, DiscreteCRRPolicy from tianshou.policy import BasePolicy, DiscreteCRRPolicy
from tianshou.trainer import OfflineTrainer from tianshou.trainer import OfflineTrainer
@ -89,6 +89,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
target_update_freq=args.target_update_freq, target_update_freq=args.target_update_freq,
).to(args.device) ).to(args.device)
# buffer # buffer
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name): if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"): if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name) buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)

View File

@ -1,10 +1,11 @@
from collections.abc import Iterator
from contextlib import contextmanager from contextlib import contextmanager
from torch import nn from torch import nn
@contextmanager @contextmanager
def in_eval_mode(module: nn.Module) -> None: def in_eval_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to evaluation mode.""" """Temporarily switch to evaluation mode."""
train = module.training train = module.training
try: try:
@ -15,7 +16,7 @@ def in_eval_mode(module: nn.Module) -> None:
@contextmanager @contextmanager
def in_train_mode(module: nn.Module) -> None: def in_train_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to training mode.""" """Temporarily switch to training mode."""
train = module.training train = module.training
try: try: