bugfixes: gym->gymnasium; render() update (#769)

Credits (names from the Farama Discord):

- @nrwahl2
- @APN-Pucky
- chattershuts
This commit is contained in:
Will Dudley 2022-11-11 20:25:35 +00:00 committed by GitHub
parent 06aaad460e
commit b9a6d8b5f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 28 additions and 22 deletions

View File

@ -1,6 +1,12 @@
Multi-Agent RL Multi-Agent RL
============== ==============
Tianshou use `PettingZoo` environment for multi-agent RL training. Here are some helpful tutorial links:
* https://pettingzoo.farama.org/tutorials/tianshou/beginner/
* https://pettingzoo.farama.org/tutorials/tianshou/intermediate/
* https://pettingzoo.farama.org/tutorials/tianshou/advanced/
In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ (see the image below) against a random opponent. In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ (see the image below) against a random opponent.
.. image:: ../_static/images/tic-tac-toe.png .. image:: ../_static/images/tic-tac-toe.png
@ -10,7 +16,7 @@ In this section, we describe how to use Tianshou to implement multi-agent reinfo
Tic-Tac-Toe Environment Tic-Tac-Toe Environment
----------------------- -----------------------
The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any [PettingZoo](https://www.pettingzoo.ml/) environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it. The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any `PettingZoo <https://www.pettingzoo.ml/>`_ environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it.
:: ::
>>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments >>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments
@ -19,7 +25,7 @@ The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~ti
>>> # Players place 'x' and 'o' in turn on the board >>> # Players place 'x' and 'o' in turn on the board
>>> # The player who first gets 3 consecutive 'x's or 'o's wins >>> # The player who first gets 3 consecutive 'x's or 'o's wins
>>> >>>
>>> env = PettingZooEnv(tictactoe_v3.env()) >>> env = PettingZooEnv(tictactoe_v3.env(render_mode="human"))
>>> obs = env.reset() >>> obs = env.reset()
>>> env.render() # render the empty board >>> env.render() # render the empty board
board (step 0): board (step 0):
@ -333,8 +339,8 @@ With the above preparation, we are close to the first learned agent. The followi
:: ::
def get_env(): def get_env(render_mode=None):
return PettingZooEnv(tictactoe_v3.env()) return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
def train_agent( def train_agent(
@ -427,7 +433,7 @@ With the above preparation, we are close to the first learned agent. The followi
agent_learn: Optional[BasePolicy] = None, agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None, agent_opponent: Optional[BasePolicy] = None,
) -> None: ) -> None:
env = get_env() env = get_env(render_mode="human")
env = DummyVectorEnv([lambda: env]) env = DummyVectorEnv([lambda: env])
policy, optim, agents = get_agents( policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent args, agent_learn=agent_learn, agent_opponent=agent_opponent

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# isort: skip_file
import argparse import argparse
import datetime import datetime

View File

@ -1,6 +1,7 @@
import argparse import argparse
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial
from typing import Optional, Tuple from typing import Optional, Tuple
import gym import gym
@ -24,8 +25,8 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
def get_env(): def get_env(render_mode: Optional[str] = None):
return PettingZooEnv(tictactoe_v3.env()) return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
def get_parser() -> argparse.ArgumentParser: def get_parser() -> argparse.ArgumentParser:
@ -230,7 +231,7 @@ def watch(
agent_learn: Optional[BasePolicy] = None, agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None, agent_opponent: Optional[BasePolicy] = None,
) -> None: ) -> None:
env = DummyVectorEnv([get_env]) env = DummyVectorEnv([partial(get_env, render_mode="human")])
policy, optim, agents = get_agents( policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent args, agent_learn=agent_learn, agent_opponent=agent_opponent
) )

View File

@ -73,12 +73,10 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
# Note: object is used as a proxy for objects that can be pickled # Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently # Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[ # type: ignore Hdf5ConvertibleValues = Union[int, float, Batch, np.ndarray, torch.Tensor, object,
int, float, Batch, np.ndarray, torch.Tensor, object, "Hdf5ConvertibleType"]
'Hdf5ConvertibleType', # type: ignore
]
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues]
def to_hdf5( def to_hdf5(

View File

@ -2,8 +2,8 @@ import warnings
from abc import ABC from abc import ABC
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
import gym.spaces
import pettingzoo import pettingzoo
from gymnasium import spaces
from packaging import version from packaging import version
from pettingzoo.utils.env import AECEnv from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper from pettingzoo.utils.wrappers import BaseWrapper
@ -27,7 +27,7 @@ class PettingZooEnv(AECEnv, ABC):
# obs is a dict containing obs, agent_id, and mask # obs is a dict containing obs, agent_id, and mask
obs = env.reset() obs = env.reset()
action = policy(obs) action = policy(obs)
obs, rew, done, info = env.step(action) obs, rew, trunc, term, info = env.step(action)
env.close() env.close()
The available action's mask is set to True, otherwise it is set to False. The available action's mask is set to True, otherwise it is set to False.
@ -89,7 +89,7 @@ class PettingZooEnv(AECEnv, ABC):
[True if obm == 1 else False for obm in observation['action_mask']] [True if obm == 1 else False for obm in observation['action_mask']]
} }
else: else:
if isinstance(self.action_space, gym.spaces.Discrete): if isinstance(self.action_space, spaces.Discrete):
observation_dict = { observation_dict = {
'agent_id': self.env.agent_selection, 'agent_id': self.env.agent_selection,
'obs': observation, 'obs': observation,
@ -124,7 +124,7 @@ class PettingZooEnv(AECEnv, ABC):
[True if obm == 1 else False for obm in observation['action_mask']] [True if obm == 1 else False for obm in observation['action_mask']]
} }
else: else:
if isinstance(self.action_space, gym.spaces.Discrete): if isinstance(self.action_space, spaces.Discrete):
obs = { obs = {
'agent_id': self.env.agent_selection, 'agent_id': self.env.agent_selection,
'obs': observation, 'obs': observation,
@ -146,5 +146,5 @@ class PettingZooEnv(AECEnv, ABC):
except (NotImplementedError, AttributeError): except (NotImplementedError, AttributeError):
self.env.reset(seed=seed) self.env.reset(seed=seed)
def render(self, mode: str = "human") -> Any: def render(self) -> Any:
return self.env.render(mode) return self.env.render()

View File

@ -44,7 +44,7 @@ class EnvWorker(ABC):
self.result = self.reset() self.result = self.reset()
else: else:
self.is_reset = False self.is_reset = False
self.send_action(action) # type: ignore self.send_action(action)
def recv( def recv(
self self
@ -63,7 +63,7 @@ class EnvWorker(ABC):
"Please use send and recv for your own EnvWorker." "Please use send and recv for your own EnvWorker."
) )
if not self.is_reset: if not self.is_reset:
self.result = self.get_result() # type: ignore self.result = self.get_result()
return self.result return self.result
@abstractmethod @abstractmethod

View File

@ -139,4 +139,4 @@ class LazyLogger(BaseLogger):
pass pass
def restore_data(self) -> Tuple[int, int, int]: def restore_data(self) -> Tuple[int, int, int]:
pass return 0, 0, 0