Tianshou/test/base/test_policy.py
Dominik Jain ca69e79b4a Change the way in which deterministic evaluation is controlled:
* Remove flag `eval_mode` from Collector.collect
  * Replace flag `is_eval` in BasePolicy with `is_within_training_step` (negating usages)
    and set it appropriately in BaseTrainer
2024-05-03 15:18:39 +02:00

80 lines
2.6 KiB
Python

import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Distribution, Independent, Normal
from tianshou.policy import BasePolicy, PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor
obs_shape = (5,)
def _to_hashable(x: np.ndarray | int) -> int | tuple[list]:
return x if isinstance(x, int) else tuple(x.tolist())
@pytest.fixture(params=["continuous", "discrete"])
def policy(request: pytest.FixtureRequest) -> PPOPolicy:
action_type = request.param
action_space: gym.spaces.Box | gym.spaces.Discrete
actor: Actor | ActorProb
if action_type == "continuous":
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
actor = ActorProb(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
action_shape=action_space.n,
)
dist_fn = Categorical
else:
raise ValueError(f"Unknown action type: {action_type}")
critic = Critic(
Net(obs_shape, hidden_sizes=[64, 64]),
)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
policy: BasePolicy
policy = PPOPolicy(
actor=actor,
critic=critic,
dist_fn=dist_fn,
optim=optim,
action_space=action_space,
action_scaling=False,
)
policy.eval()
return policy
class TestPolicyBasics:
def test_get_action(self, policy: PPOPolicy) -> None:
policy.is_within_training_step = False
sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)]
assert all(policy.action_space.contains(a) for a in actions)
# check that the actions are different in non-deterministic mode
assert len(set(map(_to_hashable, actions))) > 1
policy.deterministic_eval = True
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1