Preparation for #914 and #920 Changes formatting to ruff and black. Remove python 3.8 ## Additional Changes - Removed flake8 dependencies - Adjusted pre-commit. Now CI and Make use pre-commit, reducing the duplication of linting calls - Removed check-docstyle option (ruff is doing that) - Merged format and lint. In CI the format-lint step fails if any changes are done, so it fulfills the lint functionality. --------- Co-authored-by: Jiayi Weng <jiayi@openai.com>
		
			
				
	
	
		
			133 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			133 lines
		
	
	
		
			4.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import warnings
 | |
| from abc import ABC
 | |
| from typing import Any
 | |
| 
 | |
| import pettingzoo
 | |
| from gymnasium import spaces
 | |
| from packaging import version
 | |
| from pettingzoo.utils.env import AECEnv
 | |
| from pettingzoo.utils.wrappers import BaseWrapper
 | |
| 
 | |
| if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
 | |
|     warnings.warn(
 | |
|         f"You are using PettingZoo {pettingzoo.__version__}. "
 | |
|         f"Future tianshou versions may not support PettingZoo<1.21.0. "
 | |
|         f"Consider upgrading your PettingZoo version.",
 | |
|         DeprecationWarning,
 | |
|     )
 | |
| 
 | |
| 
 | |
| class PettingZooEnv(AECEnv, ABC):
 | |
|     """The interface for petting zoo environments.
 | |
| 
 | |
|     Multi-agent environments must be wrapped as
 | |
|     :class:`~tianshou.env.PettingZooEnv`. Here is the usage:
 | |
|     ::
 | |
| 
 | |
|         env = PettingZooEnv(...)
 | |
|         # obs is a dict containing obs, agent_id, and mask
 | |
|         obs = env.reset()
 | |
|         action = policy(obs)
 | |
|         obs, rew, trunc, term, info = env.step(action)
 | |
|         env.close()
 | |
| 
 | |
|     The available action's mask is set to True, otherwise it is set to False.
 | |
|     Further usage can be found at :ref:`marl_example`.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, env: BaseWrapper):
 | |
|         super().__init__()
 | |
|         self.env = env
 | |
|         # agent idx list
 | |
|         self.agents = self.env.possible_agents
 | |
|         self.agent_idx = {}
 | |
|         for i, agent_id in enumerate(self.agents):
 | |
|             self.agent_idx[agent_id] = i
 | |
| 
 | |
|         self.rewards = [0] * len(self.agents)
 | |
| 
 | |
|         # Get first observation space, assuming all agents have equal space
 | |
|         self.observation_space: Any = self.env.observation_space(self.agents[0])
 | |
| 
 | |
|         # Get first action space, assuming all agents have equal space
 | |
|         self.action_space: Any = self.env.action_space(self.agents[0])
 | |
| 
 | |
|         assert all(
 | |
|             self.env.observation_space(agent) == self.observation_space for agent in self.agents
 | |
|         ), (
 | |
|             "Observation spaces for all agents must be identical. Perhaps "
 | |
|             "SuperSuit's pad_observations wrapper can help (useage: "
 | |
|             "`supersuit.aec_wrappers.pad_observations(env)`"
 | |
|         )
 | |
| 
 | |
|         assert all(self.env.action_space(agent) == self.action_space for agent in self.agents), (
 | |
|             "Action spaces for all agents must be identical. Perhaps "
 | |
|             "SuperSuit's pad_action_space wrapper can help (useage: "
 | |
|             "`supersuit.aec_wrappers.pad_action_space(env)`"
 | |
|         )
 | |
| 
 | |
|         self.reset()
 | |
| 
 | |
|     def reset(self, *args: Any, **kwargs: Any) -> tuple[dict, dict]:
 | |
|         self.env.reset(*args, **kwargs)
 | |
| 
 | |
|         observation, reward, terminated, truncated, info = self.env.last(self)
 | |
| 
 | |
|         if isinstance(observation, dict) and "action_mask" in observation:
 | |
|             observation_dict = {
 | |
|                 "agent_id": self.env.agent_selection,
 | |
|                 "obs": observation["observation"],
 | |
|                 "mask": [obm == 1 for obm in observation["action_mask"]],
 | |
|             }
 | |
|         else:
 | |
|             if isinstance(self.action_space, spaces.Discrete):
 | |
|                 observation_dict = {
 | |
|                     "agent_id": self.env.agent_selection,
 | |
|                     "obs": observation,
 | |
|                     "mask": [True] * self.env.action_space(self.env.agent_selection).n,
 | |
|                 }
 | |
|             else:
 | |
|                 observation_dict = {
 | |
|                     "agent_id": self.env.agent_selection,
 | |
|                     "obs": observation,
 | |
|                 }
 | |
| 
 | |
|         return observation_dict, info
 | |
| 
 | |
|     def step(self, action: Any) -> tuple[dict, list[int], bool, bool, dict]:
 | |
|         self.env.step(action)
 | |
| 
 | |
|         observation, rew, term, trunc, info = self.env.last()
 | |
| 
 | |
|         if isinstance(observation, dict) and "action_mask" in observation:
 | |
|             obs = {
 | |
|                 "agent_id": self.env.agent_selection,
 | |
|                 "obs": observation["observation"],
 | |
|                 "mask": [obm == 1 for obm in observation["action_mask"]],
 | |
|             }
 | |
|         else:
 | |
|             if isinstance(self.action_space, spaces.Discrete):
 | |
|                 obs = {
 | |
|                     "agent_id": self.env.agent_selection,
 | |
|                     "obs": observation,
 | |
|                     "mask": [True] * self.env.action_space(self.env.agent_selection).n,
 | |
|                 }
 | |
|             else:
 | |
|                 obs = {"agent_id": self.env.agent_selection, "obs": observation}
 | |
| 
 | |
|         for agent_id, reward in self.env.rewards.items():
 | |
|             self.rewards[self.agent_idx[agent_id]] = reward
 | |
|         return obs, self.rewards, term, trunc, info
 | |
| 
 | |
|     def close(self) -> None:
 | |
|         self.env.close()
 | |
| 
 | |
|     def seed(self, seed: Any = None) -> None:
 | |
|         try:
 | |
|             self.env.seed(seed)
 | |
|         except (NotImplementedError, AttributeError):
 | |
|             self.env.reset(seed=seed)
 | |
| 
 | |
|     def render(self) -> Any:
 | |
|         return self.env.render()
 |