Tianshou/test/base/test_policy.py
Erni bf0d632108
Naming and typing improvements in Actor/Critic/Policy forwards (#1032)
Closes #917 

### Internal Improvements
- Better variable names related to model outputs (logits, dist input
etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like
`Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the
presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see
associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the
distribution type. #1032

### Breaking Changes

- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to
take a single argument in both
continuous and discrete cases. #1032

---------

Co-authored-by: Arnau Jimenez <arnau.jimenez@zeiss.com>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2024-04-01 17:14:17 +02:00

76 lines
2.4 KiB
Python

import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Distribution, Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor
obs_shape = (5,)
def _to_hashable(x: np.ndarray | int):
return x if isinstance(x, int) else tuple(x.tolist())
@pytest.fixture(params=["continuous", "discrete"])
def policy(request):
action_type = request.param
if action_type == "continuous":
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
actor = ActorProb(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
action_shape=action_space.n,
)
dist_fn = lambda logits: Categorical(logits=logits)
else:
raise ValueError(f"Unknown action type: {action_type}")
critic = Critic(
Net(obs_shape, hidden_sizes=[64, 64]),
)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
policy: PPOPolicy = PPOPolicy(
actor=actor,
critic=critic,
dist_fn=dist_fn,
optim=optim,
action_space=action_space,
action_scaling=False,
)
policy.eval()
return policy
class TestPolicyBasics:
def test_get_action(self, policy) -> None:
sample_obs = torch.randn(obs_shape)
policy.deterministic_eval = False
actions = [policy.compute_action(sample_obs) for _ in range(10)]
assert all(policy.action_space.contains(a) for a in actions)
# check that the actions are different in non-deterministic mode
assert len(set(map(_to_hashable, actions))) > 1
policy.deterministic_eval = True
actions = [policy.compute_action(sample_obs) for _ in range(10)]
# check that the actions are the same in deterministic mode
assert len(set(map(_to_hashable, actions))) == 1