Added support for new PettingZoo API (#751)
This commit is contained in:
parent
b0c8d28a7d
commit
128feb677f
@ -1,10 +1,8 @@
|
|||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pistonball import get_args, train_agent, watch
|
from pistonball import get_args, train_agent, watch
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO(Markus28): fix later")
|
|
||||||
def test_piston_ball(args=get_args()):
|
def test_piston_ball(args=get_args()):
|
||||||
if args.watch:
|
if args.watch:
|
||||||
watch(args)
|
watch(args)
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
import pprint
|
import pprint
|
||||||
|
|
||||||
import pytest
|
|
||||||
from tic_tac_toe import get_args, train_agent, watch
|
from tic_tac_toe import get_args, train_agent, watch
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO(Markus28): fix later")
|
|
||||||
def test_tic_tac_toe(args=get_args()):
|
def test_tic_tac_toe(args=get_args()):
|
||||||
if args.watch:
|
if args.watch:
|
||||||
watch(args)
|
watch(args)
|
||||||
|
40
tianshou/env/pettingzoo_env.py
vendored
40
tianshou/env/pettingzoo_env.py
vendored
@ -1,10 +1,20 @@
|
|||||||
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import gym.spaces
|
import gym.spaces
|
||||||
|
import pettingzoo
|
||||||
|
from packaging import version
|
||||||
from pettingzoo.utils.env import AECEnv
|
from pettingzoo.utils.env import AECEnv
|
||||||
from pettingzoo.utils.wrappers import BaseWrapper
|
from pettingzoo.utils.wrappers import BaseWrapper
|
||||||
|
|
||||||
|
if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
|
||||||
|
warnings.warn(
|
||||||
|
f"You are using PettingZoo {pettingzoo.__version__}. "
|
||||||
|
f"Future tianshou versions may not support PettingZoo<1.21.0. "
|
||||||
|
f"Consider upgrading your PettingZoo version.", DeprecationWarning
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PettingZooEnv(AECEnv, ABC):
|
class PettingZooEnv(AECEnv, ABC):
|
||||||
"""The interface for petting zoo environments.
|
"""The interface for petting zoo environments.
|
||||||
@ -57,7 +67,20 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
|
|
||||||
def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
|
def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
|
||||||
self.env.reset(*args, **kwargs)
|
self.env.reset(*args, **kwargs)
|
||||||
observation, _, _, info = self.env.last(self)
|
|
||||||
|
# Here, we do not label the return values explicitly to keep compatibility with
|
||||||
|
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
|
||||||
|
last_return = self.env.last(self)
|
||||||
|
|
||||||
|
if len(last_return) == 4:
|
||||||
|
warnings.warn(
|
||||||
|
"The PettingZoo environment is using the old step API. "
|
||||||
|
"This API may not be supported in future versions of tianshou. "
|
||||||
|
"We recommend that you update the environment code or apply a "
|
||||||
|
"compatibility wrapper.", DeprecationWarning
|
||||||
|
)
|
||||||
|
|
||||||
|
observation, info = last_return[0], last_return[-1]
|
||||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||||
observation_dict = {
|
observation_dict = {
|
||||||
'agent_id': self.env.agent_selection,
|
'agent_id': self.env.agent_selection,
|
||||||
@ -83,9 +106,16 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
else:
|
else:
|
||||||
return observation_dict
|
return observation_dict
|
||||||
|
|
||||||
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
|
def step(
|
||||||
|
self, action: Any
|
||||||
|
) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool,
|
||||||
|
Dict]]:
|
||||||
self.env.step(action)
|
self.env.step(action)
|
||||||
observation, rew, done, info = self.env.last()
|
|
||||||
|
# Here, we do not label the return values explicitly to keep compatibility with
|
||||||
|
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
|
||||||
|
last_return = self.env.last()
|
||||||
|
observation = last_return[0]
|
||||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||||
obs = {
|
obs = {
|
||||||
'agent_id': self.env.agent_selection,
|
'agent_id': self.env.agent_selection,
|
||||||
@ -105,7 +135,7 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
|
|
||||||
for agent_id, reward in self.env.rewards.items():
|
for agent_id, reward in self.env.rewards.items():
|
||||||
self.rewards[self.agent_idx[agent_id]] = reward
|
self.rewards[self.agent_idx[agent_id]] = reward
|
||||||
return obs, self.rewards, done, info
|
return (obs, self.rewards, *last_return[2:]) # type: ignore
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.env.close()
|
self.env.close()
|
||||||
@ -113,7 +143,7 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
def seed(self, seed: Any = None) -> None:
|
def seed(self, seed: Any = None) -> None:
|
||||||
try:
|
try:
|
||||||
self.env.seed(seed)
|
self.env.seed(seed)
|
||||||
except NotImplementedError:
|
except (NotImplementedError, AttributeError):
|
||||||
self.env.reset(seed=seed)
|
self.env.reset(seed=seed)
|
||||||
|
|
||||||
def render(self, mode: str = "human") -> Any:
|
def render(self, mode: str = "human") -> Any:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user