Added support for new PettingZoo API (#751)
This commit is contained in:
parent
b0c8d28a7d
commit
128feb677f
@ -1,10 +1,8 @@
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
from pistonball import get_args, train_agent, watch
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO(Markus28): fix later")
|
||||
def test_piston_ball(args=get_args()):
|
||||
if args.watch:
|
||||
watch(args)
|
||||
|
@ -1,10 +1,8 @@
|
||||
import pprint
|
||||
|
||||
import pytest
|
||||
from tic_tac_toe import get_args, train_agent, watch
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="TODO(Markus28): fix later")
|
||||
def test_tic_tac_toe(args=get_args()):
|
||||
if args.watch:
|
||||
watch(args)
|
||||
|
40
tianshou/env/pettingzoo_env.py
vendored
40
tianshou/env/pettingzoo_env.py
vendored
@ -1,10 +1,20 @@
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import gym.spaces
|
||||
import pettingzoo
|
||||
from packaging import version
|
||||
from pettingzoo.utils.env import AECEnv
|
||||
from pettingzoo.utils.wrappers import BaseWrapper
|
||||
|
||||
if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
|
||||
warnings.warn(
|
||||
f"You are using PettingZoo {pettingzoo.__version__}. "
|
||||
f"Future tianshou versions may not support PettingZoo<1.21.0. "
|
||||
f"Consider upgrading your PettingZoo version.", DeprecationWarning
|
||||
)
|
||||
|
||||
|
||||
class PettingZooEnv(AECEnv, ABC):
|
||||
"""The interface for petting zoo environments.
|
||||
@ -57,7 +67,20 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
|
||||
def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
|
||||
self.env.reset(*args, **kwargs)
|
||||
observation, _, _, info = self.env.last(self)
|
||||
|
||||
# Here, we do not label the return values explicitly to keep compatibility with
|
||||
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
|
||||
last_return = self.env.last(self)
|
||||
|
||||
if len(last_return) == 4:
|
||||
warnings.warn(
|
||||
"The PettingZoo environment is using the old step API. "
|
||||
"This API may not be supported in future versions of tianshou. "
|
||||
"We recommend that you update the environment code or apply a "
|
||||
"compatibility wrapper.", DeprecationWarning
|
||||
)
|
||||
|
||||
observation, info = last_return[0], last_return[-1]
|
||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||
observation_dict = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
@ -83,9 +106,16 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
else:
|
||||
return observation_dict
|
||||
|
||||
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
|
||||
def step(
|
||||
self, action: Any
|
||||
) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool,
|
||||
Dict]]:
|
||||
self.env.step(action)
|
||||
observation, rew, done, info = self.env.last()
|
||||
|
||||
# Here, we do not label the return values explicitly to keep compatibility with
|
||||
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
|
||||
last_return = self.env.last()
|
||||
observation = last_return[0]
|
||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||
obs = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
@ -105,7 +135,7 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
|
||||
for agent_id, reward in self.env.rewards.items():
|
||||
self.rewards[self.agent_idx[agent_id]] = reward
|
||||
return obs, self.rewards, done, info
|
||||
return (obs, self.rewards, *last_return[2:]) # type: ignore
|
||||
|
||||
def close(self) -> None:
|
||||
self.env.close()
|
||||
@ -113,7 +143,7 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
def seed(self, seed: Any = None) -> None:
|
||||
try:
|
||||
self.env.seed(seed)
|
||||
except NotImplementedError:
|
||||
except (NotImplementedError, AttributeError):
|
||||
self.env.reset(seed=seed)
|
||||
|
||||
def render(self, mode: str = "human") -> Any:
|
||||
|
Loading…
x
Reference in New Issue
Block a user