Improves typing in examples and tests, towards mypy passing there. Introduces the SpaceInfo utility
72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
import gymnasium as gym
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from torch.distributions import Categorical, Independent, Normal
|
|
|
|
from tianshou.policy import PPOPolicy
|
|
from tianshou.utils.net.common import ActorCritic, Net
|
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
|
from tianshou.utils.net.discrete import Actor
|
|
|
|
obs_shape = (5,)
|
|
|
|
|
|
def _to_hashable(x: np.ndarray | int):
|
|
return x if isinstance(x, int) else tuple(x.tolist())
|
|
|
|
|
|
@pytest.fixture(params=["continuous", "discrete"])
|
|
def policy(request):
|
|
action_type = request.param
|
|
if action_type == "continuous":
|
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
|
|
actor = ActorProb(
|
|
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
|
|
action_shape=action_space.shape,
|
|
)
|
|
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
|
|
elif action_type == "discrete":
|
|
action_space = gym.spaces.Discrete(3)
|
|
actor = Actor(
|
|
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
|
|
action_shape=action_space.n,
|
|
)
|
|
dist_fn = lambda logits: Categorical(logits=logits)
|
|
else:
|
|
raise ValueError(f"Unknown action type: {action_type}")
|
|
|
|
critic = Critic(
|
|
Net(obs_shape, hidden_sizes=[64, 64]),
|
|
)
|
|
|
|
actor_critic = ActorCritic(actor, critic)
|
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
|
|
|
|
policy: PPOPolicy = PPOPolicy(
|
|
actor=actor,
|
|
critic=critic,
|
|
dist_fn=dist_fn,
|
|
optim=optim,
|
|
action_space=action_space,
|
|
action_scaling=False,
|
|
)
|
|
policy.eval()
|
|
return policy
|
|
|
|
|
|
class TestPolicyBasics:
|
|
def test_get_action(self, policy) -> None:
|
|
sample_obs = torch.randn(obs_shape)
|
|
policy.deterministic_eval = False
|
|
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
|
assert all(policy.action_space.contains(a) for a in actions)
|
|
|
|
# check that the actions are different in non-deterministic mode
|
|
assert len(set(map(_to_hashable, actions))) > 1
|
|
|
|
policy.deterministic_eval = True
|
|
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
|
# check that the actions are the same in deterministic mode
|
|
assert len(set(map(_to_hashable, actions))) == 1
|