Added support for new PettingZoo API (#751)

This commit is contained in:
Markus Krimmel 2022-10-02 18:33:12 +02:00 committed by GitHub
parent b0c8d28a7d
commit 128feb677f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 9 deletions

View File

@ -1,10 +1,8 @@
import pprint
import pytest
from pistonball import get_args, train_agent, watch
@pytest.mark.skip(reason="TODO(Markus28): fix later")
def test_piston_ball(args=get_args()):
if args.watch:
watch(args)

View File

@ -1,10 +1,8 @@
import pprint
import pytest
from tic_tac_toe import get_args, train_agent, watch
@pytest.mark.skip(reason="TODO(Markus28): fix later")
def test_tic_tac_toe(args=get_args()):
if args.watch:
watch(args)

View File

@ -1,10 +1,20 @@
import warnings
from abc import ABC
from typing import Any, Dict, List, Tuple, Union
import gym.spaces
import pettingzoo
from packaging import version
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper
if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
warnings.warn(
f"You are using PettingZoo {pettingzoo.__version__}. "
f"Future tianshou versions may not support PettingZoo<1.21.0. "
f"Consider upgrading your PettingZoo version.", DeprecationWarning
)
class PettingZooEnv(AECEnv, ABC):
"""The interface for petting zoo environments.
@ -57,7 +67,20 @@ class PettingZooEnv(AECEnv, ABC):
def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
self.env.reset(*args, **kwargs)
observation, _, _, info = self.env.last(self)
# Here, we do not label the return values explicitly to keep compatibility with
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
last_return = self.env.last(self)
if len(last_return) == 4:
warnings.warn(
"The PettingZoo environment is using the old step API. "
"This API may not be supported in future versions of tianshou. "
"We recommend that you update the environment code or apply a "
"compatibility wrapper.", DeprecationWarning
)
observation, info = last_return[0], last_return[-1]
if isinstance(observation, dict) and 'action_mask' in observation:
observation_dict = {
'agent_id': self.env.agent_selection,
@ -83,9 +106,16 @@ class PettingZooEnv(AECEnv, ABC):
else:
return observation_dict
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
def step(
self, action: Any
) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool,
Dict]]:
self.env.step(action)
observation, rew, done, info = self.env.last()
# Here, we do not label the return values explicitly to keep compatibility with
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
last_return = self.env.last()
observation = last_return[0]
if isinstance(observation, dict) and 'action_mask' in observation:
obs = {
'agent_id': self.env.agent_selection,
@ -105,7 +135,7 @@ class PettingZooEnv(AECEnv, ABC):
for agent_id, reward in self.env.rewards.items():
self.rewards[self.agent_idx[agent_id]] = reward
return obs, self.rewards, done, info
return (obs, self.rewards, *last_return[2:]) # type: ignore
def close(self) -> None:
self.env.close()
@ -113,7 +143,7 @@ class PettingZooEnv(AECEnv, ABC):
def seed(self, seed: Any = None) -> None:
try:
self.env.seed(seed)
except NotImplementedError:
except (NotImplementedError, AttributeError):
self.env.reset(seed=seed)
def render(self, mode: str = "human") -> Any: