Tests: fixed typing issues by declaring union types and no longer reusing var names
This commit is contained in:
parent
4b619c51ba
commit
69f07a8f12
@ -2,6 +2,7 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import tempfile
|
import tempfile
|
||||||
from test.base.env import MoveToRightEnv, MyGoalEnv
|
from test.base.env import MoveToRightEnv, MyGoalEnv
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -381,25 +382,25 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
|
|||||||
# Check that goals are the same for the episode (only 1 ep in buffer)
|
# Check that goals are the same for the episode (only 1 ep in buffer)
|
||||||
tmp_indices = indices.copy()
|
tmp_indices = indices.copy()
|
||||||
for _ in range(2 * env_size):
|
for _ in range(2 * env_size):
|
||||||
obs = buf[tmp_indices].obs
|
obs_in_buf = cast(Batch, buf[tmp_indices].obs)
|
||||||
obs_next = buf[tmp_indices].obs_next
|
obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
|
||||||
rew = buf[tmp_indices].rew
|
rew_in_buf = buf[tmp_indices].rew
|
||||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
assert np.all(g == g[0])
|
assert np.all(g == g[0])
|
||||||
assert np.all(g_next == g_next[0])
|
assert np.all(g_next == g_next[0])
|
||||||
assert np.all(rew == (ag_next == g).astype(np.float32))
|
assert np.all(rew_in_buf == (ag_next == g).astype(np.float32))
|
||||||
tmp_indices = buf.next(tmp_indices)
|
tmp_indices = buf.next(tmp_indices)
|
||||||
|
|
||||||
# Check that goals are correctly restored
|
# Check that goals are correctly restored
|
||||||
buf._restore_cache()
|
buf._restore_cache()
|
||||||
tmp_indices = indices.copy()
|
tmp_indices = indices.copy()
|
||||||
for _ in range(2 * env_size):
|
for _ in range(2 * env_size):
|
||||||
obs = buf[tmp_indices].obs
|
obs_in_buf = cast(Batch, buf[tmp_indices].obs)
|
||||||
obs_next = buf[tmp_indices].obs_next
|
obs_next_buf = cast(Batch, buf[tmp_indices].obs_next)
|
||||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
assert np.all(g == env_size)
|
assert np.all(g == env_size)
|
||||||
assert np.all(g_next == g_next[0])
|
assert np.all(g_next == g_next[0])
|
||||||
assert np.all(g == g[0])
|
assert np.all(g == g[0])
|
||||||
@ -411,24 +412,24 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
|
|||||||
# Check that goals are the same for the episode (only 1 ep in buffer)
|
# Check that goals are the same for the episode (only 1 ep in buffer)
|
||||||
tmp_indices = indices.copy()
|
tmp_indices = indices.copy()
|
||||||
for _ in range(2 * env_size):
|
for _ in range(2 * env_size):
|
||||||
obs = buf2[tmp_indices].obs
|
obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
|
||||||
obs_next = buf2[tmp_indices].obs_next
|
obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
|
||||||
rew = buf2[tmp_indices].rew
|
rew_buf = buf2[tmp_indices].rew
|
||||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
ag_next = obs_next.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
ag_next = obs_next_buf.achieved_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
assert np.all(g == g_next)
|
assert np.all(g == g_next)
|
||||||
assert np.all(rew == (ag_next == g).astype(np.float32))
|
assert np.all(rew_buf == (ag_next == g).astype(np.float32))
|
||||||
tmp_indices = buf2.next(tmp_indices)
|
tmp_indices = buf2.next(tmp_indices)
|
||||||
|
|
||||||
# Check that goals are correctly restored
|
# Check that goals are correctly restored
|
||||||
buf2._restore_cache()
|
buf2._restore_cache()
|
||||||
tmp_indices = indices.copy()
|
tmp_indices = indices.copy()
|
||||||
for _ in range(2 * env_size):
|
for _ in range(2 * env_size):
|
||||||
obs = buf2[tmp_indices].obs
|
obs_in_buf = cast(Batch, buf2[tmp_indices].obs)
|
||||||
obs_next = buf2[tmp_indices].obs_next
|
obs_next_buf = cast(Batch, buf2[tmp_indices].obs_next)
|
||||||
g = obs.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g = obs_in_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
g_next = obs_next.desired_goal.reshape(sample_sz, -1)[:, 0]
|
g_next = obs_next_buf.desired_goal.reshape(sample_sz, -1)[:, 0]
|
||||||
assert np.all(g == env_size)
|
assert np.all(g == env_size)
|
||||||
assert np.all(g_next == g_next[0])
|
assert np.all(g_next == g_next[0])
|
||||||
assert np.all(g == g[0])
|
assert np.all(g == g[0])
|
||||||
@ -442,7 +443,6 @@ def test_herreplaybuffer(size: int = 10, bufsize: int = 100, sample_sz: int = 4)
|
|||||||
buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8)
|
buf = HERReplayBuffer(bufsize, compute_reward_fn=compute_reward_fn, horizon=30, future_k=8)
|
||||||
buf._index = 5 # shifted start index
|
buf._index = 5 # shifted start index
|
||||||
buf.future_p = 1
|
buf.future_p = 1
|
||||||
action_list = [1] * 10
|
|
||||||
for ep_len in [5, 10]:
|
for ep_len in [5, 10]:
|
||||||
obs, _ = env.reset()
|
obs, _ = env.reset()
|
||||||
for i in range(ep_len):
|
for i in range(ep_len):
|
||||||
@ -1030,6 +1030,7 @@ def test_multibuf_stack() -> None:
|
|||||||
size,
|
size,
|
||||||
)
|
)
|
||||||
obs, info = env.reset(options={"state": 1})
|
obs, info = env.reset(options={"state": 1})
|
||||||
|
obs = cast(np.ndarray, obs)
|
||||||
for i in range(18):
|
for i in range(18):
|
||||||
obs_next, rew, terminated, truncated, info = env.step(1)
|
obs_next, rew, terminated, truncated, info = env.step(1)
|
||||||
done = terminated or truncated
|
done = terminated or truncated
|
||||||
@ -1057,7 +1058,8 @@ def test_multibuf_stack() -> None:
|
|||||||
assert np.all(buf4.truncated == buf5.truncated)
|
assert np.all(buf4.truncated == buf5.truncated)
|
||||||
obs = obs_next
|
obs = obs_next
|
||||||
if done:
|
if done:
|
||||||
obs, info = env.reset(options={"state": 1})
|
# obs is an array, but the env is malformed, so we can't properly type it
|
||||||
|
obs, info = env.reset(options={"state": 1}) # type: ignore[assignment]
|
||||||
# check the `add` order is correct
|
# check the `add` order is correct
|
||||||
assert np.allclose(
|
assert np.allclose(
|
||||||
buf4.obs.reshape(-1),
|
buf4.obs.reshape(-1),
|
||||||
|
@ -218,11 +218,11 @@ def test_collector() -> None:
|
|||||||
c_dummy_venv_4_envs.collect(n_episode=4, random=True)
|
c_dummy_venv_4_envs.collect(n_episode=4, random=True)
|
||||||
|
|
||||||
# test corner case
|
# test corner case
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(ValueError):
|
||||||
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
|
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(ValueError):
|
||||||
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
|
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(ValueError):
|
||||||
c_dummy_venv_4_envs.collect()
|
c_dummy_venv_4_envs.collect()
|
||||||
|
|
||||||
def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]:
|
def get_env_factory(i: int, t: str) -> Callable[[], NXEnv]:
|
||||||
@ -260,7 +260,7 @@ class TestAsyncCollector:
|
|||||||
async_collector_and_env_lens: tuple[AsyncCollector, list[int]],
|
async_collector_and_env_lens: tuple[AsyncCollector, list[int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
c1, env_lens = async_collector_and_env_lens
|
c1, env_lens = async_collector_and_env_lens
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(ValueError):
|
||||||
c1.collect()
|
c1.collect()
|
||||||
|
|
||||||
def test_collect_one_episode_async(
|
def test_collect_one_episode_async(
|
||||||
|
@ -134,7 +134,7 @@ def test_lr_schedulers() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_in_eval_mode():
|
def test_in_eval_mode() -> None:
|
||||||
module = nn.Linear(3, 4)
|
module = nn.Linear(3, 4)
|
||||||
module.train()
|
module.train()
|
||||||
with in_eval_mode(module):
|
with in_eval_mode(module):
|
||||||
@ -142,9 +142,9 @@ def test_in_eval_mode():
|
|||||||
assert module.training
|
assert module.training
|
||||||
|
|
||||||
|
|
||||||
def test_in_train_mode():
|
def test_in_train_mode() -> None:
|
||||||
module = nn.Linear(3, 4)
|
module = nn.Linear(3, 4)
|
||||||
module.eval()
|
module.eval()
|
||||||
with in_train_mode(module):
|
with in_train_mode(module):
|
||||||
assert module.training
|
assert module.training
|
||||||
assert not module.training
|
assert not module.training
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.policy import BasePolicy, DiscreteBCQPolicy
|
from tianshou.policy import BasePolicy, DiscreteBCQPolicy
|
||||||
from tianshou.trainer import OfflineTrainer
|
from tianshou.trainer import OfflineTrainer
|
||||||
@ -96,6 +96,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
|||||||
imitation_logits_penalty=args.imitation_logits_penalty,
|
imitation_logits_penalty=args.imitation_logits_penalty,
|
||||||
)
|
)
|
||||||
# buffer
|
# buffer
|
||||||
|
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
|
||||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
if args.load_buffer_name.endswith(".hdf5"):
|
if args.load_buffer_name.endswith(".hdf5"):
|
||||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.policy import BasePolicy, DiscreteCQLPolicy
|
from tianshou.policy import BasePolicy, DiscreteCQLPolicy
|
||||||
from tianshou.trainer import OfflineTrainer
|
from tianshou.trainer import OfflineTrainer
|
||||||
@ -85,6 +85,7 @@ def test_discrete_cql(args: argparse.Namespace = get_args()) -> None:
|
|||||||
min_q_weight=args.min_q_weight,
|
min_q_weight=args.min_q_weight,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# buffer
|
# buffer
|
||||||
|
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
|
||||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
if args.load_buffer_name.endswith(".hdf5"):
|
if args.load_buffer_name.endswith(".hdf5"):
|
||||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||||
|
@ -8,7 +8,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.data import Collector, VectorReplayBuffer
|
from tianshou.data import Collector, PrioritizedVectorReplayBuffer, VectorReplayBuffer
|
||||||
from tianshou.env import DummyVectorEnv
|
from tianshou.env import DummyVectorEnv
|
||||||
from tianshou.policy import BasePolicy, DiscreteCRRPolicy
|
from tianshou.policy import BasePolicy, DiscreteCRRPolicy
|
||||||
from tianshou.trainer import OfflineTrainer
|
from tianshou.trainer import OfflineTrainer
|
||||||
@ -89,6 +89,7 @@ def test_discrete_crr(args: argparse.Namespace = get_args()) -> None:
|
|||||||
target_update_freq=args.target_update_freq,
|
target_update_freq=args.target_update_freq,
|
||||||
).to(args.device)
|
).to(args.device)
|
||||||
# buffer
|
# buffer
|
||||||
|
buffer: VectorReplayBuffer | PrioritizedVectorReplayBuffer
|
||||||
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
if os.path.exists(args.load_buffer_name) and os.path.isfile(args.load_buffer_name):
|
||||||
if args.load_buffer_name.endswith(".hdf5"):
|
if args.load_buffer_name.endswith(".hdf5"):
|
||||||
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def in_eval_mode(module: nn.Module) -> None:
|
def in_eval_mode(module: nn.Module) -> Iterator[None]:
|
||||||
"""Temporarily switch to evaluation mode."""
|
"""Temporarily switch to evaluation mode."""
|
||||||
train = module.training
|
train = module.training
|
||||||
try:
|
try:
|
||||||
@ -15,7 +16,7 @@ def in_eval_mode(module: nn.Module) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def in_train_mode(module: nn.Module) -> None:
|
def in_train_mode(module: nn.Module) -> Iterator[None]:
|
||||||
"""Temporarily switch to training mode."""
|
"""Temporarily switch to training mode."""
|
||||||
train = module.training
|
train = module.training
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user