2023-11-16 18:27:53 +01:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import gymnasium as gym
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import numpy as np
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import pytest
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								import torch
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from torch.distributions import Categorical, Independent, Normal
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from tianshou.policy import PPOPolicy
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from tianshou.utils.net.common import ActorCritic, Net
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from tianshou.utils.net.continuous import ActorProb, Critic
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from tianshou.utils.net.discrete import Actor
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								obs_shape = (5,)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								def _to_hashable(x: np.ndarray | int):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    return x if isinstance(x, int) else tuple(x.tolist())
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								@pytest.fixture(params=["continuous", "discrete"])
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								def policy(request):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    action_type = request.param
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    if action_type == "continuous":
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        action_space = gym.spaces.Box(low=-1, high=1, shape=(3,))
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        actor = ActorProb(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            action_shape=action_space.shape,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        dist_fn = lambda *logits: Independent(Normal(*logits), 1)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    elif action_type == "discrete":
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        action_space = gym.spaces.Discrete(3)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        actor = Actor(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.n),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            action_shape=action_space.n,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        dist_fn = lambda logits: Categorical(logits=logits)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    else:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        raise ValueError(f"Unknown action type: {action_type}")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    critic = Critic(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        Net(obs_shape, hidden_sizes=[64, 64]),
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    actor_critic = ActorCritic(actor, critic)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    optim = torch.optim.Adam(actor_critic.parameters(), lr=1e-3)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    policy: PPOPolicy = PPOPolicy(
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-16 18:27:53 +01:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        actor=actor,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        critic=critic,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        dist_fn=dist_fn,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        optim=optim,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        action_space=action_space,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        action_scaling=False,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    policy.eval()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    return policy
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								class TestPolicyBasics:
							 | 
						
					
						
							
								
									
										
										
										
											2024-02-06 14:24:30 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def test_get_action(self, policy) -> None:
							 | 
						
					
						
							
								
									
										
										
										
											2023-11-16 18:27:53 +01:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        sample_obs = torch.randn(obs_shape)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        policy.deterministic_eval = False
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        actions = [policy.compute_action(sample_obs) for _ in range(10)]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        assert all(policy.action_space.contains(a) for a in actions)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # check that the actions are different in non-deterministic mode
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        assert len(set(map(_to_hashable, actions))) > 1
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        policy.deterministic_eval = True
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        actions = [policy.compute_action(sample_obs) for _ in range(10)]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # check that the actions are the same in deterministic mode
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        assert len(set(map(_to_hashable, actions))) == 1
							 |