This PR adds a new method for getting actions from an env's observation and info. This is useful for standard inference and stands in contrast to batch-based methods that are currently used in training and evaluation. Without this, users have to do some kind of gymnastics to actually perform inference with a trained policy. I have also added a test for the new method. In future PRs, this method should be included in the examples (in the the "watch" section). To add this required improving multiple typing things and, importantly, _simplifying the signature of `forward` in many policies!_ This is a **breaking change**, but it will likely affect no users. The `input` parameter of forward was a rather hacky mechanism, I believe it is good that it's gone now. It will also help with #948 . The main functional change is the addition of `compute_action` to `BasePolicy`. Other minor changes: - improvements in typing - updated PR and Issue templates - Improved handling of `max_action_num` Closes #981
72 lines
2.3 KiB
Python
72 lines
2.3 KiB
Python
import gymnasium as gym
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from torch.distributions import Categorical, Independent, Normal
|
|
|
|
from tianshou.policy import PPOPolicy
|
|
from tianshou.utils.net.common import ActorCritic, Net
|
|
from tianshou.utils.net.continuous import ActorProb, Critic
|
|
from tianshou.utils.net.discrete import Actor
|
|
|
|
obs_shape = (5,)
|
|
|
|
|
|
def _to_hashable(x: np.ndarray | int):
|
|
return x if isinstance(x, int) else tuple(x.tolist())
|
|
|
|
|
|
@pytest.fixture(params=["continuous", "discrete"])
|
|
def policy(request):
|
|
action_type = request.param
|
|
if action_type == "continuous":
|
|
action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
|
|
actor = ActorProb(
|
|
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
|
|
action_shape=action_space.shape,
|
|
)
|
|
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
|
|
elif action_type == "discrete":
|
|
action_space = gym.spaces.Discrete(3)
|
|
actor = Actor(
|
|
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
|
|
action_shape=action_space.n,
|
|
)
|
|
dist_fn = lambda logits: Categorical(logits=logits)
|
|
else:
|
|
raise ValueError(f"Unknown action type: {action_type}")
|
|
|
|
critic = Critic(
|
|
Net(obs_shape, hidden_sizes=[64, 64]),
|
|
)
|
|
|
|
actor_critic = ActorCritic(actor, critic)
|
|
optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
|
|
|
|
policy = PPOPolicy(
|
|
actor=actor,
|
|
critic=critic,
|
|
dist_fn=dist_fn,
|
|
optim=optim,
|
|
action_space=action_space,
|
|
action_scaling=False,
|
|
)
|
|
policy.eval()
|
|
return policy
|
|
|
|
|
|
class TestPolicyBasics:
|
|
def test_get_action(self, policy):
|
|
sample_obs = torch.randn(obs_shape)
|
|
policy.deterministic_eval = False
|
|
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
|
assert all(policy.action_space.contains(a) for a in actions)
|
|
|
|
# check that the actions are different in non-deterministic mode
|
|
assert len(set(map(_to_hashable, actions))) > 1
|
|
|
|
policy.deterministic_eval = True
|
|
actions = [policy.compute_action(sample_obs) for _ in range(10)]
|
|
# check that the actions are the same in deterministic mode
|
|
assert len(set(map(_to_hashable, actions))) == 1
|