bugfixes: gym->gymnasium; render() update (#769)
Credits (names from the Farama Discord): - @nrwahl2 - @APN-Pucky - chattershuts
This commit is contained in:
parent
06aaad460e
commit
b9a6d8b5f0
@ -1,6 +1,12 @@
|
||||
Multi-Agent RL
|
||||
==============
|
||||
|
||||
Tianshou use `PettingZoo` environment for multi-agent RL training. Here are some helpful tutorial links:
|
||||
|
||||
* https://pettingzoo.farama.org/tutorials/tianshou/beginner/
|
||||
* https://pettingzoo.farama.org/tutorials/tianshou/intermediate/
|
||||
* https://pettingzoo.farama.org/tutorials/tianshou/advanced/
|
||||
|
||||
In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ (see the image below) against a random opponent.
|
||||
|
||||
.. image:: ../_static/images/tic-tac-toe.png
|
||||
@ -10,7 +16,7 @@ In this section, we describe how to use Tianshou to implement multi-agent reinfo
|
||||
Tic-Tac-Toe Environment
|
||||
-----------------------
|
||||
|
||||
The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any [PettingZoo](https://www.pettingzoo.ml/) environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it.
|
||||
The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any `PettingZoo <https://www.pettingzoo.ml/>`_ environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it.
|
||||
::
|
||||
|
||||
>>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments
|
||||
@ -19,7 +25,7 @@ The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~ti
|
||||
>>> # Players place 'x' and 'o' in turn on the board
|
||||
>>> # The player who first gets 3 consecutive 'x's or 'o's wins
|
||||
>>>
|
||||
>>> env = PettingZooEnv(tictactoe_v3.env())
|
||||
>>> env = PettingZooEnv(tictactoe_v3.env(render_mode="human"))
|
||||
>>> obs = env.reset()
|
||||
>>> env.render() # render the empty board
|
||||
board (step 0):
|
||||
@ -333,8 +339,8 @@ With the above preparation, we are close to the first learned agent. The followi
|
||||
|
||||
::
|
||||
|
||||
def get_env():
|
||||
return PettingZooEnv(tictactoe_v3.env())
|
||||
def get_env(render_mode=None):
|
||||
return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
|
||||
|
||||
|
||||
def train_agent(
|
||||
@ -427,7 +433,7 @@ With the above preparation, we are close to the first learned agent. The followi
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
) -> None:
|
||||
env = get_env()
|
||||
env = get_env(render_mode="human")
|
||||
env = DummyVectorEnv([lambda: env])
|
||||
policy, optim, agents = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||
|
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# isort: skip_file
|
||||
|
||||
import argparse
|
||||
import datetime
|
||||
|
@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import gym
|
||||
@ -24,8 +25,8 @@ from tianshou.utils import TensorboardLogger
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
def get_env():
|
||||
return PettingZooEnv(tictactoe_v3.env())
|
||||
def get_env(render_mode: Optional[str] = None):
|
||||
return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
@ -230,7 +231,7 @@ def watch(
|
||||
agent_learn: Optional[BasePolicy] = None,
|
||||
agent_opponent: Optional[BasePolicy] = None,
|
||||
) -> None:
|
||||
env = DummyVectorEnv([get_env])
|
||||
env = DummyVectorEnv([partial(get_env, render_mode="human")])
|
||||
policy, optim, agents = get_agents(
|
||||
args, agent_learn=agent_learn, agent_opponent=agent_opponent
|
||||
)
|
||||
|
@ -73,12 +73,10 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
|
||||
|
||||
# Note: object is used as a proxy for objects that can be pickled
|
||||
# Note: mypy does not support cyclic definition currently
|
||||
Hdf5ConvertibleValues = Union[ # type: ignore
|
||||
int, float, Batch, np.ndarray, torch.Tensor, object,
|
||||
'Hdf5ConvertibleType', # type: ignore
|
||||
]
|
||||
Hdf5ConvertibleValues = Union[int, float, Batch, np.ndarray, torch.Tensor, object,
|
||||
"Hdf5ConvertibleType"]
|
||||
|
||||
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
|
||||
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues]
|
||||
|
||||
|
||||
def to_hdf5(
|
||||
|
12
tianshou/env/pettingzoo_env.py
vendored
12
tianshou/env/pettingzoo_env.py
vendored
@ -2,8 +2,8 @@ import warnings
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import gym.spaces
|
||||
import pettingzoo
|
||||
from gymnasium import spaces
|
||||
from packaging import version
|
||||
from pettingzoo.utils.env import AECEnv
|
||||
from pettingzoo.utils.wrappers import BaseWrapper
|
||||
@ -27,7 +27,7 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
# obs is a dict containing obs, agent_id, and mask
|
||||
obs = env.reset()
|
||||
action = policy(obs)
|
||||
obs, rew, done, info = env.step(action)
|
||||
obs, rew, trunc, term, info = env.step(action)
|
||||
env.close()
|
||||
|
||||
The available action's mask is set to True, otherwise it is set to False.
|
||||
@ -89,7 +89,7 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||
}
|
||||
else:
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
observation_dict = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation,
|
||||
@ -124,7 +124,7 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
[True if obm == 1 else False for obm in observation['action_mask']]
|
||||
}
|
||||
else:
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
if isinstance(self.action_space, spaces.Discrete):
|
||||
obs = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation,
|
||||
@ -146,5 +146,5 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
except (NotImplementedError, AttributeError):
|
||||
self.env.reset(seed=seed)
|
||||
|
||||
def render(self, mode: str = "human") -> Any:
|
||||
return self.env.render(mode)
|
||||
def render(self) -> Any:
|
||||
return self.env.render()
|
||||
|
4
tianshou/env/worker/base.py
vendored
4
tianshou/env/worker/base.py
vendored
@ -44,7 +44,7 @@ class EnvWorker(ABC):
|
||||
self.result = self.reset()
|
||||
else:
|
||||
self.is_reset = False
|
||||
self.send_action(action) # type: ignore
|
||||
self.send_action(action)
|
||||
|
||||
def recv(
|
||||
self
|
||||
@ -63,7 +63,7 @@ class EnvWorker(ABC):
|
||||
"Please use send and recv for your own EnvWorker."
|
||||
)
|
||||
if not self.is_reset:
|
||||
self.result = self.get_result() # type: ignore
|
||||
self.result = self.get_result()
|
||||
return self.result
|
||||
|
||||
@abstractmethod
|
||||
|
@ -139,4 +139,4 @@ class LazyLogger(BaseLogger):
|
||||
pass
|
||||
|
||||
def restore_data(self) -> Tuple[int, int, int]:
|
||||
pass
|
||||
return 0, 0, 0
|
||||
|
Loading…
x
Reference in New Issue
Block a user