2022-10-02 18:33:12 +02:00
|
|
|
import warnings
|
2022-02-15 17:56:45 +03:00
|
|
|
from abc import ABC
|
2023-02-03 20:57:27 +01:00
|
|
|
from typing import Any, Dict, List, Tuple
|
2022-02-15 17:56:45 +03:00
|
|
|
|
2022-10-02 18:33:12 +02:00
|
|
|
import pettingzoo
|
2022-11-11 20:25:35 +00:00
|
|
|
from gymnasium import spaces
|
2022-10-02 18:33:12 +02:00
|
|
|
from packaging import version
|
2022-02-15 17:56:45 +03:00
|
|
|
from pettingzoo.utils.env import AECEnv
|
|
|
|
from pettingzoo.utils.wrappers import BaseWrapper
|
|
|
|
|
2022-10-02 18:33:12 +02:00
|
|
|
if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
|
|
|
|
warnings.warn(
|
|
|
|
f"You are using PettingZoo {pettingzoo.__version__}. "
|
|
|
|
f"Future tianshou versions may not support PettingZoo<1.21.0. "
|
|
|
|
f"Consider upgrading your PettingZoo version.", DeprecationWarning
|
|
|
|
)
|
|
|
|
|
2022-02-15 17:56:45 +03:00
|
|
|
|
2022-03-16 14:38:51 +01:00
|
|
|
class PettingZooEnv(AECEnv, ABC):
|
2022-02-15 17:56:45 +03:00
|
|
|
"""The interface for petting zoo environments.
|
|
|
|
|
|
|
|
Multi-agent environments must be wrapped as
|
|
|
|
:class:`~tianshou.env.PettingZooEnv`. Here is the usage:
|
|
|
|
::
|
|
|
|
|
|
|
|
env = PettingZooEnv(...)
|
|
|
|
# obs is a dict containing obs, agent_id, and mask
|
|
|
|
obs = env.reset()
|
|
|
|
action = policy(obs)
|
2022-11-11 20:25:35 +00:00
|
|
|
obs, rew, trunc, term, info = env.step(action)
|
2022-02-15 17:56:45 +03:00
|
|
|
env.close()
|
|
|
|
|
|
|
|
The available action's mask is set to True, otherwise it is set to False.
|
|
|
|
Further usage can be found at :ref:`marl_example`.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, env: BaseWrapper):
|
|
|
|
super().__init__()
|
|
|
|
self.env = env
|
|
|
|
# agent idx list
|
|
|
|
self.agents = self.env.possible_agents
|
|
|
|
self.agent_idx = {}
|
|
|
|
for i, agent_id in enumerate(self.agents):
|
|
|
|
self.agent_idx[agent_id] = i
|
|
|
|
|
|
|
|
self.rewards = [0] * len(self.agents)
|
|
|
|
|
|
|
|
# Get first observation space, assuming all agents have equal space
|
2022-04-16 11:17:53 -04:00
|
|
|
self.observation_space: Any = self.env.observation_space(self.agents[0])
|
2022-02-15 17:56:45 +03:00
|
|
|
|
|
|
|
# Get first action space, assuming all agents have equal space
|
2022-04-16 11:17:53 -04:00
|
|
|
self.action_space: Any = self.env.action_space(self.agents[0])
|
2022-02-15 17:56:45 +03:00
|
|
|
|
|
|
|
assert all(self.env.observation_space(agent) == self.observation_space
|
|
|
|
for agent in self.agents), \
|
|
|
|
"Observation spaces for all agents must be identical. Perhaps " \
|
|
|
|
"SuperSuit's pad_observations wrapper can help (useage: " \
|
|
|
|
"`supersuit.aec_wrappers.pad_observations(env)`"
|
|
|
|
|
|
|
|
assert all(self.env.action_space(agent) == self.action_space
|
|
|
|
for agent in self.agents), \
|
|
|
|
"Action spaces for all agents must be identical. Perhaps " \
|
|
|
|
"SuperSuit's pad_action_space wrapper can help (useage: " \
|
|
|
|
"`supersuit.aec_wrappers.pad_action_space(env)`"
|
|
|
|
|
|
|
|
self.reset()
|
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
def reset(self, *args: Any, **kwargs: Any) -> Tuple[dict, dict]:
|
2022-04-30 09:06:00 -07:00
|
|
|
self.env.reset(*args, **kwargs)
|
2022-10-02 18:33:12 +02:00
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
observation, reward, terminated, truncated, info = self.env.last(self)
|
2022-10-02 18:33:12 +02:00
|
|
|
|
2022-02-15 17:56:45 +03:00
|
|
|
if isinstance(observation, dict) and 'action_mask' in observation:
|
2022-06-27 18:52:21 -04:00
|
|
|
observation_dict = {
|
2022-02-15 17:56:45 +03:00
|
|
|
'agent_id': self.env.agent_selection,
|
|
|
|
'obs': observation['observation'],
|
|
|
|
'mask':
|
|
|
|
[True if obm == 1 else False for obm in observation['action_mask']]
|
|
|
|
}
|
|
|
|
else:
|
2022-11-11 20:25:35 +00:00
|
|
|
if isinstance(self.action_space, spaces.Discrete):
|
2022-06-27 18:52:21 -04:00
|
|
|
observation_dict = {
|
2022-02-15 17:56:45 +03:00
|
|
|
'agent_id': self.env.agent_selection,
|
|
|
|
'obs': observation,
|
|
|
|
'mask': [True] * self.env.action_space(self.env.agent_selection).n
|
|
|
|
}
|
|
|
|
else:
|
2022-06-27 18:52:21 -04:00
|
|
|
observation_dict = {
|
|
|
|
'agent_id': self.env.agent_selection,
|
|
|
|
'obs': observation,
|
|
|
|
}
|
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
return observation_dict, info
|
2022-02-15 17:56:45 +03:00
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
def step(self, action: Any) -> Tuple[Dict, List[int], bool, bool, Dict]:
|
2022-02-15 17:56:45 +03:00
|
|
|
self.env.step(action)
|
2022-10-02 18:33:12 +02:00
|
|
|
|
2023-02-03 20:57:27 +01:00
|
|
|
observation, rew, term, trunc, info = self.env.last()
|
|
|
|
|
2022-02-15 17:56:45 +03:00
|
|
|
if isinstance(observation, dict) and 'action_mask' in observation:
|
|
|
|
obs = {
|
|
|
|
'agent_id': self.env.agent_selection,
|
|
|
|
'obs': observation['observation'],
|
|
|
|
'mask':
|
|
|
|
[True if obm == 1 else False for obm in observation['action_mask']]
|
|
|
|
}
|
|
|
|
else:
|
2022-11-11 20:25:35 +00:00
|
|
|
if isinstance(self.action_space, spaces.Discrete):
|
2022-02-15 17:56:45 +03:00
|
|
|
obs = {
|
|
|
|
'agent_id': self.env.agent_selection,
|
|
|
|
'obs': observation,
|
|
|
|
'mask': [True] * self.env.action_space(self.env.agent_selection).n
|
|
|
|
}
|
|
|
|
else:
|
|
|
|
obs = {'agent_id': self.env.agent_selection, 'obs': observation}
|
|
|
|
|
|
|
|
for agent_id, reward in self.env.rewards.items():
|
|
|
|
self.rewards[self.agent_idx[agent_id]] = reward
|
2023-02-03 20:57:27 +01:00
|
|
|
return obs, self.rewards, term, trunc, info
|
2022-02-15 17:56:45 +03:00
|
|
|
|
|
|
|
def close(self) -> None:
|
|
|
|
self.env.close()
|
|
|
|
|
|
|
|
def seed(self, seed: Any = None) -> None:
|
2022-04-30 09:06:00 -07:00
|
|
|
try:
|
|
|
|
self.env.seed(seed)
|
2022-10-02 18:33:12 +02:00
|
|
|
except (NotImplementedError, AttributeError):
|
2022-04-30 09:06:00 -07:00
|
|
|
self.env.reset(seed=seed)
|
2022-02-15 17:56:45 +03:00
|
|
|
|
2022-11-11 20:25:35 +00:00
|
|
|
def render(self) -> Any:
|
|
|
|
return self.env.render()
|