Tests: fixed typing issues by declaring union types and no longer reusing var names

This commit is contained in:
Michael Panchenko 2024-04-26 17:37:12 +02:00
parent 4b619c51ba
commit 69f07a8f12
7 changed files with 42 additions and 36 deletions

View File

@ -2,6 +2,7 @@ import os
import pickle
import tempfile
from test.base.env import MoveToRightEnv, MyGoalEnv
from typing import cast
import h5py
import numpy as np
@ -381,25 +382,25 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
# Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf[tmp_indices].obs
obs_next = buf[tmp_indices].obs_next
rew = buf[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
obs_in_buf = cast(Batch, buf[tmp_indices].obs)
obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
rew_in_buf = buf[tmp_indices].rew
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g[0])
assert np.all(g_next == g_next[0])
assert np.all(rew == (ag_next == g).astype(np.float32))
assert np.all(rew_in_buf == (ag_next == g).astype(np.float32))
tmp_indices = buf.next(tmp_indices)
# Check that goals are correctly restored
buf._restore_cache()
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf[tmp_indices].obs
obs_next = buf[tmp_indices].obs_next
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
obs_in_buf = cast(Batch, buf[tmp_indices].obs)
obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size)
assert np.all(g_next == g_next[0])
assert np.all(g == g[0])
@ -411,24 +412,24 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
# Check that goals are the same for the episode (only 1 ep in buffer)
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs
obs_next = buf2[tmp_indices].obs_next
rew = buf2[tmp_indices].rew
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
rew_buf = buf2[tmp_indices].rew
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == g_next)
assert np.all(rew == (ag_next == g).astype(np.float32))
assert np.all(rew_buf == (ag_next == g).astype(np.float32))
tmp_indices = buf2.next(tmp_indices)
# Check that goals are correctly restored
buf2._restore_cache()
tmp_indices = indices.copy()
for _ in range(2 * env_size):
obs = buf2[tmp_indices].obs
obs_next = buf2[tmp_indices].obs_next
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
assert np.all(g == env_size)
assert np.all(g_next == g_next[0])
assert np.all(g == g[0])
@ -442,7 +443,6 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8)
buf._index = 5 # shifted start index
buf.future_p = 1
action_list = [1] * 10
for ep_len in [5, 10]:
obs, _ = env.reset()
for i in range(ep_len):
@ -1030,6 +1030,7 @@ def test_multibuf_stack() -> None:
size,
)
obs, info = env.reset(options={"state": 1})
obs = cast(np.ndarray, obs)
for i in range(18):
obs_next, rew, terminated, truncated, info = env.step(1)
done = terminated or truncated
@ -1057,7 +1058,8 @@ def test_multibuf_stack() -> None:
assert np.all(buf4.truncated == buf5.truncated)
obs = obs_next
if done:
obs, info = env.reset(options={"state": 1})
# obs is an array, but the env is malformed, so we can't properly type it
obs, info = env.reset(options={"state": 1}) # type: ignore[assignment]
# check the `add` order is correct
assert np.allclose(
buf4.obs.reshape(-1),

View File

@ -218,11 +218,11 @@ def test_collector() -> None:
c_dummy_venv_4_envs.collect(n_episode=4, random=True)
# test corner case
with pytest.raises(TypeError):
with pytest.raises(ValueError):
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
with pytest.raises(TypeError):
with pytest.raises(ValueError):
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError):
with pytest.raises(ValueError):
c_dummy_venv_4_envs.collect()
def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]:
@ -260,7 +260,7 @@ class TestAsyncCollector:
async_collector_and_env_lens: tuple[AsyncCollector, list[int]],
) -> None:
c1, env_lens = async_collector_and_env_lens
with pytest.raises(TypeError):
with pytest.raises(ValueError):
c1.collect()
def test_collect_one_episode_async(

View File

@ -134,7 +134,7 @@ def test_lr_schedulers() -> None:
)
def test_in_eval_mode():
def test_in_eval_mode() -> None:
module = nn.Linear(3, 4)
module.train()
with in_eval_mode(module):
@ -142,9 +142,9 @@ def test_in_eval_mode():
assert module.training
def test_in_train_mode():
def test_in_train_mode() -> None:
module = nn.Linear(3, 4)
module.eval()
with in_train_mode(module):
assert module.training
assert not module.training
assert not module.training

View File

@ -8,7 +8,7 @@ import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import BasePolicy, DiscreteBCQPolicy
from tianshou.trainer import OfflineTrainer
@ -96,6 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
imitation_logits_penalty=args.imitation_logits_penalty,
)
# buffer
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)

View File

@ -8,7 +8,7 @@ import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import BasePolicy, DiscreteCQLPolicy
from tianshou.trainer import OfflineTrainer
@ -85,6 +85,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
min_q_weight=args.min_q_weight,
).to(args.device)
# buffer
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)

View File

@ -8,7 +8,7 @@ import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import BasePolicy, DiscreteCRRPolicy
from tianshou.trainer import OfflineTrainer
@ -89,6 +89,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
target_update_freq=args.target_update_freq,
).to(args.device)
# buffer
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
if args.load_buffer_name.endswith(".hdf5"):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)

View File

@ -1,10 +1,11 @@
from collections.abc import Iterator
from contextlib import contextmanager
from torch import nn
@contextmanager
def in_eval_mode(module: nn.Module) -> None:
def in_eval_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to evaluation mode."""
train = module.training
try:
@ -15,7 +16,7 @@ def in_eval_mode(module: nn.Module) -> None:
@contextmanager
def in_train_mode(module: nn.Module) -> None:
def in_train_mode(module: nn.Module) -> Iterator[None]:
"""Temporarily switch to training mode."""
train = module.training
try: