Tests: fixed typing issues by declaring union types and no longer reusing var names
This commit is contained in:
parent
4b619c51ba
commit
69f07a8f12
@ -2,6 +2,7 @@ import os
|
||||
import pickle
|
||||
import tempfile
|
||||
from test.base.env import MoveToRightEnv, MyGoalEnv
|
||||
from typing import cast
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
@ -381,25 +382,25 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
|
||||
# Check that goals are the same for the episode (only 1 ep in buffer)
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf[tmp_indices].obs
|
||||
obs_next = buf[tmp_indices].obs_next
|
||||
rew = buf[tmp_indices].rew
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
obs_in_buf = cast(Batch, buf[tmp_indices].obs)
|
||||
obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
|
||||
rew_in_buf = buf[tmp_indices].rew
|
||||
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == g[0])
|
||||
assert np.all(g_next == g_next[0])
|
||||
assert np.all(rew == (ag_next == g).astype(np.float32))
|
||||
assert np.all(rew_in_buf == (ag_next == g).astype(np.float32))
|
||||
tmp_indices = buf.next(tmp_indices)
|
||||
|
||||
# Check that goals are correctly restored
|
||||
buf._restore_cache()
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf[tmp_indices].obs
|
||||
obs_next = buf[tmp_indices].obs_next
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
obs_in_buf = cast(Batch, buf[tmp_indices].obs)
|
||||
obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
|
||||
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == env_size)
|
||||
assert np.all(g_next == g_next[0])
|
||||
assert np.all(g == g[0])
|
||||
@ -411,24 +412,24 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
|
||||
# Check that goals are the same for the episode (only 1 ep in buffer)
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf2[tmp_indices].obs
|
||||
obs_next = buf2[tmp_indices].obs_next
|
||||
rew = buf2[tmp_indices].rew
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
|
||||
obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
|
||||
rew_buf = buf2[tmp_indices].rew
|
||||
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == g_next)
|
||||
assert np.all(rew == (ag_next == g).astype(np.float32))
|
||||
assert np.all(rew_buf == (ag_next == g).astype(np.float32))
|
||||
tmp_indices = buf2.next(tmp_indices)
|
||||
|
||||
# Check that goals are correctly restored
|
||||
buf2._restore_cache()
|
||||
tmp_indices = indices.copy()
|
||||
for _ in range(2 * env_size):
|
||||
obs = buf2[tmp_indices].obs
|
||||
obs_next = buf2[tmp_indices].obs_next
|
||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
|
||||
obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
|
||||
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||
assert np.all(g == env_size)
|
||||
assert np.all(g_next == g_next[0])
|
||||
assert np.all(g == g[0])
|
||||
@ -442,7 +443,6 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
|
||||
buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8)
|
||||
buf._index = 5 # shifted start index
|
||||
buf.future_p = 1
|
||||
action_list = [1] * 10
|
||||
for ep_len in [5, 10]:
|
||||
obs, _ = env.reset()
|
||||
for i in range(ep_len):
|
||||
@ -1030,6 +1030,7 @@ def test_multibuf_stack() -> None:
|
||||
size,
|
||||
)
|
||||
obs, info = env.reset(options={"state": 1})
|
||||
obs = cast(np.ndarray, obs)
|
||||
for i in range(18):
|
||||
obs_next, rew, terminated, truncated, info = env.step(1)
|
||||
done = terminated or truncated
|
||||
@ -1057,7 +1058,8 @@ def test_multibuf_stack() -> None:
|
||||
assert np.all(buf4.truncated == buf5.truncated)
|
||||
obs = obs_next
|
||||
if done:
|
||||
obs, info = env.reset(options={"state": 1})
|
||||
# obs is an array, but the env is malformed, so we can't properly type it
|
||||
obs, info = env.reset(options={"state": 1}) # type: ignore[assignment]
|
||||
# check the `add` order is correct
|
||||
assert np.allclose(
|
||||
buf4.obs.reshape(-1),
|
||||
|
@ -218,11 +218,11 @@ def test_collector() -> None:
|
||||
c_dummy_venv_4_envs.collect(n_episode=4, random=True)
|
||||
|
||||
# test corner case
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
c_dummy_venv_4_envs.collect()
|
||||
|
||||
def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]:
|
||||
@ -260,7 +260,7 @@ class TestAsyncCollector:
|
||||
async_collector_and_env_lens: tuple[AsyncCollector, list[int]],
|
||||
) -> None:
|
||||
c1, env_lens = async_collector_and_env_lens
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
c1.collect()
|
||||
|
||||
def test_collect_one_episode_async(
|
||||
|
@ -134,7 +134,7 @@ def test_lr_schedulers() -> None:
|
||||
)
|
||||
|
||||
|
||||
def test_in_eval_mode():
|
||||
def test_in_eval_mode() -> None:
|
||||
module = nn.Linear(3, 4)
|
||||
module.train()
|
||||
with in_eval_mode(module):
|
||||
@ -142,9 +142,9 @@ def test_in_eval_mode():
|
||||
assert module.training
|
||||
|
||||
|
||||
def test_in_train_mode():
|
||||
def test_in_train_mode() -> None:
|
||||
module = nn.Linear(3, 4)
|
||||
module.eval()
|
||||
with in_train_mode(module):
|
||||
assert module.training
|
||||
assert not module.training
|
||||
assert not module.training
|
||||
|
@ -8,7 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy, DiscreteBCQPolicy
|
||||
from tianshou.trainer import OfflineTrainer
|
||||
@ -96,6 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
imitation_logits_penalty=args.imitation_logits_penalty,
|
||||
)
|
||||
# buffer
|
||||
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
if args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
|
@ -8,7 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy, DiscreteCQLPolicy
|
||||
from tianshou.trainer import OfflineTrainer
|
||||
@ -85,6 +85,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
||||
min_q_weight=args.min_q_weight,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
if args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
|
@ -8,7 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.policy import BasePolicy, DiscreteCRRPolicy
|
||||
from tianshou.trainer import OfflineTrainer
|
||||
@ -89,6 +89,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
||||
target_update_freq=args.target_update_freq,
|
||||
).to(args.device)
|
||||
# buffer
|
||||
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
|
||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||
if args.load_buffer_name.endswith(".hdf5"):
|
||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||
|
@ -1,10 +1,11 @@
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_eval_mode(module: nn.Module) -> None:
|
||||
def in_eval_mode(module: nn.Module) -> Iterator[None]:
|
||||
"""Temporarily switch to evaluation mode."""
|
||||
train = module.training
|
||||
try:
|
||||
@ -15,7 +16,7 @@ def in_eval_mode(module: nn.Module) -> None:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_train_mode(module: nn.Module) -> None:
|
||||
def in_train_mode(module: nn.Module) -> Iterator[None]:
|
||||
"""Temporarily switch to training mode."""
|
||||
train = module.training
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user