2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								import warnings
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from abc import ABC
							 | 
						
					
						
							
								
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from typing import Any, Dict, List, Tuple
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								import pettingzoo
							 | 
						
					
						
							
								
									
										
										
										
											2022-11-11 20:25:35 +00:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from gymnasium import spaces
							 | 
						
					
						
							
								
									
										
										
										
											2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								from packaging import version
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from pettingzoo.utils.env import AECEnv
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								from pettingzoo.utils.wrappers import BaseWrapper
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    warnings.warn(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        f"You are using PettingZoo {pettingzoo.__version__}. "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        f"Future tianshou versions may not support PettingZoo<1.21.0. "
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        f"Consider upgrading your PettingZoo version.", DeprecationWarning
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-03-16 14:38:51 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								class PettingZooEnv(AECEnv, ABC):
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """The interface for petting zoo environments.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    Multi-agent environments must be wrapped as
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    :class:`~tianshou.env.PettingZooEnv`. Here is the usage:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    ::
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        env = PettingZooEnv(...)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # obs is a dict containing obs, agent_id, and mask
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        obs = env.reset()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        action = policy(obs)
							 | 
						
					
						
							
								
									
										
										
										
											2022-11-11 20:25:35 +00:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        obs, rew, trunc, term, info = env.step(action)
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        env.close()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    The available action's mask is set to True, otherwise it is set to False.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    Further usage can be found at :ref:`marl_example`.
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    """
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def __init__(self, env: BaseWrapper):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        super().__init__()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.env = env
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # agent idx list
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.agents = self.env.possible_agents
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.agent_idx = {}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        for i, agent_id in enumerate(self.agents):
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            self.agent_idx[agent_id] = i
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.rewards = [0] * len(self.agents)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Get first observation space, assuming all agents have equal space
							 | 
						
					
						
							
								
									
										
										
										
											2022-04-16 11:17:53 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        self.observation_space: Any = self.env.observation_space(self.agents[0])
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        # Get first action space, assuming all agents have equal space
							 | 
						
					
						
							
								
									
										
										
										
											2022-04-16 11:17:53 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        self.action_space: Any = self.env.action_space(self.agents[0])
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        assert all(self.env.observation_space(agent) == self.observation_space
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                   for agent in self.agents), \
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "Observation spaces for all agents must be identical. Perhaps " \
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "SuperSuit's pad_observations wrapper can help (useage: " \
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "`supersuit.aec_wrappers.pad_observations(env)`"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        assert all(self.env.action_space(agent) == self.action_space
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                   for agent in self.agents), \
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "Action spaces for all agents must be identical. Perhaps " \
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "SuperSuit's pad_action_space wrapper can help (useage: " \
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            "`supersuit.aec_wrappers.pad_action_space(env)`"
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.reset()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def reset(self, *args: Any, **kwargs: Any) -> Tuple[dict, dict]:
							 | 
						
					
						
							
								
									
										
										
										
											2022-04-30 09:06:00 -07:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        self.env.reset(*args, **kwargs)
							 | 
						
					
						
							
								
									
										
										
										
											2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        observation, reward, terminated, truncated, info = self.env.last(self)
							 | 
						
					
						
							
								
									
										
										
										
											2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if isinstance(observation, dict) and 'action_mask' in observation:
							 | 
						
					
						
							
								
									
										
										
										
											2022-06-27 18:52:21 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            observation_dict = {
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                'agent_id': self.env.agent_selection,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                'obs': observation['observation'],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                'mask':
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                [True if obm == 1 else False for obm in observation['action_mask']]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        else:
							 | 
						
					
						
							
								
									
										
										
										
											2022-11-11 20:25:35 +00:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            if isinstance(self.action_space, spaces.Discrete):
							 | 
						
					
						
							
								
									
										
										
										
											2022-06-27 18:52:21 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                observation_dict = {
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'agent_id': self.env.agent_selection,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'obs': observation,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'mask': [True] * self.env.action_space(self.env.agent_selection).n
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            else:
							 | 
						
					
						
							
								
									
										
										
										
											2022-06-27 18:52:21 -04:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								                observation_dict = {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'agent_id': self.env.agent_selection,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'obs': observation,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return observation_dict, info
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def step(self, action: Any) -> Tuple[Dict, List[int], bool, bool, Dict]:
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.env.step(action)
							 | 
						
					
						
							
								
									
										
										
										
											2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        observation, rew, term, trunc, info = self.env.last()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        if isinstance(observation, dict) and 'action_mask' in observation:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            obs = {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                'agent_id': self.env.agent_selection,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                'obs': observation['observation'],
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                'mask':
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                [True if obm == 1 else False for obm in observation['action_mask']]
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        else:
							 | 
						
					
						
							
								
									
										
										
										
											2022-11-11 20:25:35 +00:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            if isinstance(self.action_space, spaces.Discrete):
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                obs = {
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'agent_id': self.env.agent_selection,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'obs': observation,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                    'mask': [True] * self.env.action_space(self.env.agent_selection).n
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            else:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								                obs = {'agent_id': self.env.agent_selection, 'obs': observation}
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        for agent_id, reward in self.env.rewards.items():
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            self.rewards[self.agent_idx[agent_id]] = reward
							 | 
						
					
						
							
								
									
										
										
										
											2023-02-03 20:57:27 +01:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        return obs, self.rewards, term, trunc, info
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def close(self) -> None:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        self.env.close()
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    def seed(self, seed: Any = None) -> None:
							 | 
						
					
						
							
								
									
										
										
										
											2022-04-30 09:06:00 -07:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        try:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            self.env.seed(seed)
							 | 
						
					
						
							
								
									
										
										
										
											2022-10-02 18:33:12 +02:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        except (NotImplementedError, AttributeError):
							 | 
						
					
						
							
								
									
										
										
										
											2022-04-30 09:06:00 -07:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								            self.env.reset(seed=seed)
							 | 
						
					
						
							
								
									
										
										
										
											2022-02-15 17:56:45 +03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2022-11-11 20:25:35 +00:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								    def render(self) -> Any:
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return self.env.render()
							 |