bugfixes: gym->gymnasium; render() update (#769)
Credits (names from the Farama Discord): - @nrwahl2 - @APN-Pucky - chattershuts
This commit is contained in:
parent
06aaad460e
commit
b9a6d8b5f0
@ -1,6 +1,12 @@
|
|||||||
Multi-Agent RL
|
Multi-Agent RL
|
||||||
==============
|
==============
|
||||||
|
|
||||||
|
Tianshou use `PettingZoo` environment for multi-agent RL training. Here are some helpful tutorial links:
|
||||||
|
|
||||||
|
* https://pettingzoo.farama.org/tutorials/tianshou/beginner/
|
||||||
|
* https://pettingzoo.farama.org/tutorials/tianshou/intermediate/
|
||||||
|
* https://pettingzoo.farama.org/tutorials/tianshou/advanced/
|
||||||
|
|
||||||
In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ (see the image below) against a random opponent.
|
In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ (see the image below) against a random opponent.
|
||||||
|
|
||||||
.. image:: ../_static/images/tic-tac-toe.png
|
.. image:: ../_static/images/tic-tac-toe.png
|
||||||
@ -10,7 +16,7 @@ In this section, we describe how to use Tianshou to implement multi-agent reinfo
|
|||||||
Tic-Tac-Toe Environment
|
Tic-Tac-Toe Environment
|
||||||
-----------------------
|
-----------------------
|
||||||
|
|
||||||
The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any [PettingZoo](https://www.pettingzoo.ml/) environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it.
|
The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any `PettingZoo <https://www.pettingzoo.ml/>`_ environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it.
|
||||||
::
|
::
|
||||||
|
|
||||||
>>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments
|
>>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments
|
||||||
@ -19,7 +25,7 @@ The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~ti
|
|||||||
>>> # Players place 'x' and 'o' in turn on the board
|
>>> # Players place 'x' and 'o' in turn on the board
|
||||||
>>> # The player who first gets 3 consecutive 'x's or 'o's wins
|
>>> # The player who first gets 3 consecutive 'x's or 'o's wins
|
||||||
>>>
|
>>>
|
||||||
>>> env = PettingZooEnv(tictactoe_v3.env())
|
>>> env = PettingZooEnv(tictactoe_v3.env(render_mode="human"))
|
||||||
>>> obs = env.reset()
|
>>> obs = env.reset()
|
||||||
>>> env.render() # render the empty board
|
>>> env.render() # render the empty board
|
||||||
board (step 0):
|
board (step 0):
|
||||||
@ -333,8 +339,8 @@ With the above preparation, we are close to the first learned agent. The followi
|
|||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
def get_env():
|
def get_env(render_mode=None):
|
||||||
return PettingZooEnv(tictactoe_v3.env())
|
return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
|
||||||
|
|
||||||
|
|
||||||
def train_agent(
|
def train_agent(
|
||||||
@ -427,7 +433,7 @@ With the above preparation, we are close to the first learned agent. The followi
|
|||||||
agent_learn: Optional[BasePolicy] = None,
|
agent_learn: Optional[BasePolicy] = None,
|
||||||
agent_opponent: Optional[BasePolicy] = None,
|
agent_opponent: Optional[BasePolicy] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
env = get_env()
|
env = get_env(render_mode="human")
|
||||||
env = DummyVectorEnv([lambda: env])
|
env = DummyVectorEnv([lambda: env])
|
||||||
policy, optim, agents = get_agents(
|
policy, optim, agents = get_agents(
|
||||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# isort: skip_file
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import gym
|
import gym
|
||||||
@ -24,8 +25,8 @@ from tianshou.utils import TensorboardLogger
|
|||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
|
||||||
|
|
||||||
def get_env():
|
def get_env(render_mode: Optional[str] = None):
|
||||||
return PettingZooEnv(tictactoe_v3.env())
|
return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
|
||||||
|
|
||||||
|
|
||||||
def get_parser() -> argparse.ArgumentParser:
|
def get_parser() -> argparse.ArgumentParser:
|
||||||
@ -230,7 +231,7 @@ def watch(
|
|||||||
agent_learn: Optional[BasePolicy] = None,
|
agent_learn: Optional[BasePolicy] = None,
|
||||||
agent_opponent: Optional[BasePolicy] = None,
|
agent_opponent: Optional[BasePolicy] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
env = DummyVectorEnv([get_env])
|
env = DummyVectorEnv([partial(get_env, render_mode="human")])
|
||||||
policy, optim, agents = get_agents(
|
policy, optim, agents = get_agents(
|
||||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||||
)
|
)
|
||||||
|
@ -73,12 +73,10 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
|
|||||||
|
|
||||||
# Note: object is used as a proxy for objects that can be pickled
|
# Note: object is used as a proxy for objects that can be pickled
|
||||||
# Note: mypy does not support cyclic definition currently
|
# Note: mypy does not support cyclic definition currently
|
||||||
Hdf5ConvertibleValues = Union[ # type: ignore
|
Hdf5ConvertibleValues = Union[int, float, Batch, np.ndarray, torch.Tensor, object,
|
||||||
int, float, Batch, np.ndarray, torch.Tensor, object,
|
"Hdf5ConvertibleType"]
|
||||||
'Hdf5ConvertibleType', # type: ignore
|
|
||||||
]
|
|
||||||
|
|
||||||
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
|
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues]
|
||||||
|
|
||||||
|
|
||||||
def to_hdf5(
|
def to_hdf5(
|
||||||
|
12
tianshou/env/pettingzoo_env.py
vendored
12
tianshou/env/pettingzoo_env.py
vendored
@ -2,8 +2,8 @@ import warnings
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
import gym.spaces
|
|
||||||
import pettingzoo
|
import pettingzoo
|
||||||
|
from gymnasium import spaces
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pettingzoo.utils.env import AECEnv
|
from pettingzoo.utils.env import AECEnv
|
||||||
from pettingzoo.utils.wrappers import BaseWrapper
|
from pettingzoo.utils.wrappers import BaseWrapper
|
||||||
@ -27,7 +27,7 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
# obs is a dict containing obs, agent_id, and mask
|
# obs is a dict containing obs, agent_id, and mask
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
action = policy(obs)
|
action = policy(obs)
|
||||||
obs, rew, done, info = env.step(action)
|
obs, rew, trunc, term, info = env.step(action)
|
||||||
env.close()
|
env.close()
|
||||||
|
|
||||||
The available action's mask is set to True, otherwise it is set to False.
|
The available action's mask is set to True, otherwise it is set to False.
|
||||||
@ -89,7 +89,7 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
[True if obm == 1 else False for obm in observation['action_mask']]
|
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
observation_dict = {
|
observation_dict = {
|
||||||
'agent_id': self.env.agent_selection,
|
'agent_id': self.env.agent_selection,
|
||||||
'obs': observation,
|
'obs': observation,
|
||||||
@ -124,7 +124,7 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
[True if obm == 1 else False for obm in observation['action_mask']]
|
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
if isinstance(self.action_space, spaces.Discrete):
|
||||||
obs = {
|
obs = {
|
||||||
'agent_id': self.env.agent_selection,
|
'agent_id': self.env.agent_selection,
|
||||||
'obs': observation,
|
'obs': observation,
|
||||||
@ -146,5 +146,5 @@ class PettingZooEnv(AECEnv, ABC):
|
|||||||
except (NotImplementedError, AttributeError):
|
except (NotImplementedError, AttributeError):
|
||||||
self.env.reset(seed=seed)
|
self.env.reset(seed=seed)
|
||||||
|
|
||||||
def render(self, mode: str = "human") -> Any:
|
def render(self) -> Any:
|
||||||
return self.env.render(mode)
|
return self.env.render()
|
||||||
|
4
tianshou/env/worker/base.py
vendored
4
tianshou/env/worker/base.py
vendored
@ -44,7 +44,7 @@ class EnvWorker(ABC):
|
|||||||
self.result = self.reset()
|
self.result = self.reset()
|
||||||
else:
|
else:
|
||||||
self.is_reset = False
|
self.is_reset = False
|
||||||
self.send_action(action) # type: ignore
|
self.send_action(action)
|
||||||
|
|
||||||
def recv(
|
def recv(
|
||||||
self
|
self
|
||||||
@ -63,7 +63,7 @@ class EnvWorker(ABC):
|
|||||||
"Please use send and recv for your own EnvWorker."
|
"Please use send and recv for your own EnvWorker."
|
||||||
)
|
)
|
||||||
if not self.is_reset:
|
if not self.is_reset:
|
||||||
self.result = self.get_result() # type: ignore
|
self.result = self.get_result()
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -139,4 +139,4 @@ class LazyLogger(BaseLogger):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def restore_data(self) -> Tuple[int, int, int]:
|
def restore_data(self) -> Tuple[int, int, int]:
|
||||||
pass
|
return 0, 0, 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user