bugfixes: gym->gymnasium; render() update (#769)

Credits (names from the Farama Discord):

- @nrwahl2
- @APN-Pucky
- chattershuts
This commit is contained in:
Will Dudley 2022-11-11 20:25:35 +00:00 committed by GitHub
parent 06aaad460e
commit b9a6d8b5f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 28 additions and 22 deletions

View File

@ -1,6 +1,12 @@
Multi-Agent RL
==============
Tianshou use `PettingZoo` environment for multi-agent RL training. Here are some helpful tutorial links:
* https://pettingzoo.farama.org/tutorials/tianshou/beginner/
* https://pettingzoo.farama.org/tutorials/tianshou/intermediate/
* https://pettingzoo.farama.org/tutorials/tianshou/advanced/
In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ (see the image below) against a random opponent.
.. image:: ../_static/images/tic-tac-toe.png
@ -10,7 +16,7 @@ In this section, we describe how to use Tianshou to implement multi-agent reinfo
Tic-Tac-Toe Environment
-----------------------
The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any [PettingZoo](https://www.pettingzoo.ml/) environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it.
The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any `PettingZoo <https://www.pettingzoo.ml/>`_ environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it.
::
>>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments
@ -19,7 +25,7 @@ The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~ti
>>> # Players place 'x' and 'o' in turn on the board
>>> # The player who first gets 3 consecutive 'x's or 'o's wins
>>>
>>> env = PettingZooEnv(tictactoe_v3.env())
>>> env = PettingZooEnv(tictactoe_v3.env(render_mode="human"))
>>> obs = env.reset()
>>> env.render() # render the empty board
board (step 0):
@ -333,8 +339,8 @@ With the above preparation, we are close to the first learned agent. The followi
::
def get_env():
return PettingZooEnv(tictactoe_v3.env())
def get_env(render_mode=None):
return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
def train_agent(
@ -427,7 +433,7 @@ With the above preparation, we are close to the first learned agent. The followi
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = get_env()
env = get_env(render_mode="human")
env = DummyVectorEnv([lambda: env])
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python3
# isort: skip_file
import argparse
import datetime

View File

@ -1,6 +1,7 @@
import argparse
import os
from copy import deepcopy
from functools import partial
from typing import Optional, Tuple
import gym
@ -24,8 +25,8 @@ from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
def get_env():
return PettingZooEnv(tictactoe_v3.env())
def get_env(render_mode: Optional[str] = None):
return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode))
def get_parser() -> argparse.ArgumentParser:
@ -230,7 +231,7 @@ def watch(
agent_learn: Optional[BasePolicy] = None,
agent_opponent: Optional[BasePolicy] = None,
) -> None:
env = DummyVectorEnv([get_env])
env = DummyVectorEnv([partial(get_env, render_mode="human")])
policy, optim, agents = get_agents(
args, agent_learn=agent_learn, agent_opponent=agent_opponent
)

View File

@ -73,12 +73,10 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]:
# Note: object is used as a proxy for objects that can be pickled
# Note: mypy does not support cyclic definition currently
Hdf5ConvertibleValues = Union[ # type: ignore
int, float, Batch, np.ndarray, torch.Tensor, object,
'Hdf5ConvertibleType', # type: ignore
]
Hdf5ConvertibleValues = Union[int, float, Batch, np.ndarray, torch.Tensor, object,
"Hdf5ConvertibleType"]
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore
Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues]
def to_hdf5(

View File

@ -2,8 +2,8 @@ import warnings
from abc import ABC
from typing import Any, Dict, List, Tuple, Union
import gym.spaces
import pettingzoo
from gymnasium import spaces
from packaging import version
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import BaseWrapper
@ -27,7 +27,7 @@ class PettingZooEnv(AECEnv, ABC):
# obs is a dict containing obs, agent_id, and mask
obs = env.reset()
action = policy(obs)
obs, rew, done, info = env.step(action)
obs, rew, trunc, term, info = env.step(action)
env.close()
The available action's mask is set to True, otherwise it is set to False.
@ -89,7 +89,7 @@ class PettingZooEnv(AECEnv, ABC):
[True if obm == 1 else False for obm in observation['action_mask']]
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
if isinstance(self.action_space, spaces.Discrete):
observation_dict = {
'agent_id': self.env.agent_selection,
'obs': observation,
@ -124,7 +124,7 @@ class PettingZooEnv(AECEnv, ABC):
[True if obm == 1 else False for obm in observation['action_mask']]
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
if isinstance(self.action_space, spaces.Discrete):
obs = {
'agent_id': self.env.agent_selection,
'obs': observation,
@ -146,5 +146,5 @@ class PettingZooEnv(AECEnv, ABC):
except (NotImplementedError, AttributeError):
self.env.reset(seed=seed)
def render(self, mode: str = "human") -> Any:
return self.env.render(mode)
def render(self) -> Any:
return self.env.render()

View File

@ -44,7 +44,7 @@ class EnvWorker(ABC):
self.result = self.reset()
else:
self.is_reset = False
self.send_action(action) # type: ignore
self.send_action(action)
def recv(
self
@ -63,7 +63,7 @@ class EnvWorker(ABC):
"Please use send and recv for your own EnvWorker."
)
if not self.is_reset:
self.result = self.get_result() # type: ignore
self.result = self.get_result()
return self.result
@abstractmethod

View File

@ -139,4 +139,4 @@ class LazyLogger(BaseLogger):
pass
def restore_data(self) -> Tuple[int, int, int]:
pass
return 0, 0, 0