Tianshou/test/base/test_policy.py
Daniel Plop eb0215cf76
Refactoring/mypy issues test (#1017)
Improves typing in examples and tests, towards mypy passing there.

Introduces the SpaceInfo utility
2024-02-06 14:24:30 +01:00

72 lines
2.3 KiB
Python

import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor
obs_shape = (5,)
def _to_hashable(x: np.ndarray | int):
return x if isinstance(x, int) else tuple(x.tolist())
@pytest.fixture(params=["continuous", "discrete"])
def policy(request):
action_type = request.param
if action_type == "continuous":
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
actor = ActorProb(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
action_shape=action_space.n,
)
dist_fn = lambda logits: Categorical(logits=logits)
else:
raise ValueError(f"Unknown action type: {action_type}")
critic = Critic(
Net(obs_shape, hidden_sizes=[64, 64]),
)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
policy: PPOPolicy = PPOPolicy(
actor=actor,
critic=critic,
dist_fn=dist_fn,
optim=optim,
action_space=action_space,
action_scaling=False,
)
policy.eval()
return policy
class TestPolicyBasics:
def test_get_action(self, policy) -> None:
sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)]
assert all(policy.action_space.contains(a) for a in actions)
# check that the actions are different in non-deterministic mode
assert len(set(map(_to_hashable, actions))) > 1
policy.deterministic_eval = True
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1