diff --git a/docs/tutorials/tictactoe.rst b/docs/tutorials/tictactoe.rst index 84460bd..bf03300 100644 --- a/docs/tutorials/tictactoe.rst +++ b/docs/tutorials/tictactoe.rst @@ -1,6 +1,12 @@ Multi-Agent RL ============== +Tianshou use `PettingZoo` environment for multi-agent RL training. Here are some helpful tutorial links: + +* https://pettingzoo.farama.org/tutorials/tianshou/beginner/ +* https://pettingzoo.farama.org/tutorials/tianshou/intermediate/ +* https://pettingzoo.farama.org/tutorials/tianshou/advanced/ + In this section, we describe how to use Tianshou to implement multi-agent reinforcement learning. Specifically, we will design an algorithm to learn how to play `Tic Tac Toe `_ (see the image below) against a random opponent. .. image:: ../_static/images/tic-tac-toe.png @@ -10,7 +16,7 @@ In this section, we describe how to use Tianshou to implement multi-agent reinfo Tic-Tac-Toe Environment ----------------------- -The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any [PettingZoo](https://www.pettingzoo.ml/) environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it. +The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~tianshou.env.PettingZooEnv` which can wrap any `PettingZoo `_ environment. PettingZoo offers a 3x3 Tic-Tac-Toe environment, let's first explore it. :: >>> from tianshou.env import PettingZooEnv # wrapper for PettingZoo environments @@ -19,7 +25,7 @@ The scripts are located at ``test/pettingzoo/``. We have implemented :class:`~ti >>> # Players place 'x' and 'o' in turn on the board >>> # The player who first gets 3 consecutive 'x's or 'o's wins >>> - >>> env = PettingZooEnv(tictactoe_v3.env()) + >>> env = PettingZooEnv(tictactoe_v3.env(render_mode="human")) >>> obs = env.reset() >>> env.render() # render the empty board board (step 0): @@ -333,8 +339,8 @@ With the above preparation, we are close to the first learned agent. The followi :: - def get_env(): - return PettingZooEnv(tictactoe_v3.env()) + def get_env(render_mode=None): + return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode)) def train_agent( @@ -427,7 +433,7 @@ With the above preparation, we are close to the first learned agent. The followi agent_learn: Optional[BasePolicy] = None, agent_opponent: Optional[BasePolicy] = None, ) -> None: - env = get_env() + env = get_env(render_mode="human") env = DummyVectorEnv([lambda: env]) policy, optim, agents = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent diff --git a/examples/mujoco/fetch_her_ddpg.py b/examples/mujoco/fetch_her_ddpg.py index 8939125..2890085 100644 --- a/examples/mujoco/fetch_her_ddpg.py +++ b/examples/mujoco/fetch_her_ddpg.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# isort: skip_file import argparse import datetime diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 62d9a09..7610cee 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -1,6 +1,7 @@ import argparse import os from copy import deepcopy +from functools import partial from typing import Optional, Tuple import gym @@ -24,8 +25,8 @@ from tianshou.utils import TensorboardLogger from tianshou.utils.net.common import Net -def get_env(): - return PettingZooEnv(tictactoe_v3.env()) +def get_env(render_mode: Optional[str] = None): + return PettingZooEnv(tictactoe_v3.env(render_mode=render_mode)) def get_parser() -> argparse.ArgumentParser: @@ -230,7 +231,7 @@ def watch( agent_learn: Optional[BasePolicy] = None, agent_opponent: Optional[BasePolicy] = None, ) -> None: - env = DummyVectorEnv([get_env]) + env = DummyVectorEnv([partial(get_env, render_mode="human")]) policy, optim, agents = get_agents( args, agent_learn=agent_learn, agent_opponent=agent_opponent ) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index ed705e2..68e7508 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -73,12 +73,10 @@ def to_torch_as(x: Any, y: torch.Tensor) -> Union[Batch, torch.Tensor]: # Note: object is used as a proxy for objects that can be pickled # Note: mypy does not support cyclic definition currently -Hdf5ConvertibleValues = Union[ # type: ignore - int, float, Batch, np.ndarray, torch.Tensor, object, - 'Hdf5ConvertibleType', # type: ignore -] +Hdf5ConvertibleValues = Union[int, float, Batch, np.ndarray, torch.Tensor, object, + "Hdf5ConvertibleType"] -Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] # type: ignore +Hdf5ConvertibleType = Dict[str, Hdf5ConvertibleValues] def to_hdf5( diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index d1ab131..2817333 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -2,8 +2,8 @@ import warnings from abc import ABC from typing import Any, Dict, List, Tuple, Union -import gym.spaces import pettingzoo +from gymnasium import spaces from packaging import version from pettingzoo.utils.env import AECEnv from pettingzoo.utils.wrappers import BaseWrapper @@ -27,7 +27,7 @@ class PettingZooEnv(AECEnv, ABC): # obs is a dict containing obs, agent_id, and mask obs = env.reset() action = policy(obs) - obs, rew, done, info = env.step(action) + obs, rew, trunc, term, info = env.step(action) env.close() The available action's mask is set to True, otherwise it is set to False. @@ -89,7 +89,7 @@ class PettingZooEnv(AECEnv, ABC): [True if obm == 1 else False for obm in observation['action_mask']] } else: - if isinstance(self.action_space, gym.spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): observation_dict = { 'agent_id': self.env.agent_selection, 'obs': observation, @@ -124,7 +124,7 @@ class PettingZooEnv(AECEnv, ABC): [True if obm == 1 else False for obm in observation['action_mask']] } else: - if isinstance(self.action_space, gym.spaces.Discrete): + if isinstance(self.action_space, spaces.Discrete): obs = { 'agent_id': self.env.agent_selection, 'obs': observation, @@ -146,5 +146,5 @@ class PettingZooEnv(AECEnv, ABC): except (NotImplementedError, AttributeError): self.env.reset(seed=seed) - def render(self, mode: str = "human") -> Any: - return self.env.render(mode) + def render(self) -> Any: + return self.env.render() diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index 056a6f0..779a822 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -44,7 +44,7 @@ class EnvWorker(ABC): self.result = self.reset() else: self.is_reset = False - self.send_action(action) # type: ignore + self.send_action(action) def recv( self @@ -63,7 +63,7 @@ class EnvWorker(ABC): "Please use send and recv for your own EnvWorker." ) if not self.is_reset: - self.result = self.get_result() # type: ignore + self.result = self.get_result() return self.result @abstractmethod diff --git a/tianshou/utils/logger/base.py b/tianshou/utils/logger/base.py index 2dc47ed..e4e6cc3 100644 --- a/tianshou/utils/logger/base.py +++ b/tianshou/utils/logger/base.py @@ -139,4 +139,4 @@ class LazyLogger(BaseLogger): pass def restore_data(self) -> Tuple[int, int, int]: - pass + return 0, 0, 0