From aba2d01d251ead1d33a14a22a400c5b586737781 Mon Sep 17 00:00:00 2001 From: Anas BELFADIL <56280198+BFAnas@users.noreply.github.com> Date: Mon, 13 Jun 2022 00:18:22 +0200 Subject: [PATCH] MultiDiscrete to discrete gym action space wrapper (#664) Has been tested to work with DQN and a custom MultiDiscrete gym env. --- test/base/test_env.py | 43 ++++++++++++++++++++++++++++++- tianshou/env/__init__.py | 3 ++- tianshou/env/gym_wrappers.py | 49 ++++++++++++++++++++++++++++-------- 3 files changed, 83 insertions(+), 12 deletions(-) diff --git a/test/base/test_env.py b/test/base/test_env.py index 3b3a74c..002799b 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -1,13 +1,16 @@ import sys import time +import gym import numpy as np import pytest from gym.spaces.discrete import Discrete from tianshou.data import Batch from tianshou.env import ( + ContinuousToDiscrete, DummyVectorEnv, + MultiDiscreteToDiscrete, RayVectorEnv, ShmemVectorEnv, SubprocVectorEnv, @@ -265,6 +268,43 @@ def test_venv_norm_obs(): run_align_norm_obs(raw, train_env, test_env, action_list) +def test_gym_wrappers(): + + class DummyEnv(gym.Env): + + def __init__(self): + self.action_space = gym.spaces.Box( + low=-1.0, high=2.0, shape=(4, ), dtype=np.float32 + ) + + bsz = 10 + action_per_branch = [4, 6, 10, 7] + env = DummyEnv() + original_act = env.action_space.high + # convert continous to multidiscrete action space + # with different action number per dimension + env_m = ContinuousToDiscrete(env, action_per_branch) + # check conversion is working properly for one action + np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act) + # check conversion is working properly for a batch of actions + np.testing.assert_allclose( + env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)), + np.array([original_act] * bsz) + ) + # convert multidiscrete with different action number per + # dimension to discrete action space + env_d = MultiDiscreteToDiscrete(env_m) + # check conversion is working properly for one action + np.testing.assert_allclose( + env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1 + ) + # check conversion is working properly for a batch of actions + np.testing.assert_allclose( + env_d.action(np.array([env_d.action_space.n - 1] * bsz)), + np.array([env_m.action_space.nvec - 1] * bsz) + ) + + @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_venv_wrapper_envpool(): raw = envpool.make_gym("Ant-v3", num_envs=4) @@ -279,10 +319,11 @@ def test_venv_wrapper_envpool(): run_align_norm_obs(raw, train, test, actions) -if __name__ == '__main__': +if __name__ == "__main__": test_venv_norm_obs() test_venv_wrapper_envpool() test_env_obs_dtype() test_vecenv() test_async_env() test_async_check_id() + test_gym_wrappers() diff --git a/tianshou/env/__init__.py b/tianshou/env/__init__.py index 8b1c71a..6abea32 100644 --- a/tianshou/env/__init__.py +++ b/tianshou/env/__init__.py @@ -1,6 +1,6 @@ """Env package.""" -from tianshou.env.gym_wrappers import ContinuousToDiscrete +from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper from tianshou.env.venvs import ( BaseVectorEnv, @@ -25,4 +25,5 @@ __all__ = [ "VectorEnvNormObs", "PettingZooEnv", "ContinuousToDiscrete", + "MultiDiscreteToDiscrete", ] diff --git a/tianshou/env/gym_wrappers.py b/tianshou/env/gym_wrappers.py index f63bc9e..5b98e77 100644 --- a/tianshou/env/gym_wrappers.py +++ b/tianshou/env/gym_wrappers.py @@ -1,3 +1,5 @@ +from typing import List, Union + import gym import numpy as np @@ -6,23 +8,50 @@ class ContinuousToDiscrete(gym.ActionWrapper): """Gym environment wrapper to take discrete action in a continuous environment. :param gym.Env env: gym environment with continuous action space. - :param int action_per_branch: number of discrete actions in each dimension + :param int action_per_dim: number of discrete actions in each dimension of the action space. """ - def __init__(self, env: gym.Env, action_per_branch: int) -> None: + def __init__(self, env: gym.Env, action_per_dim: Union[int, List[int]]) -> None: super().__init__(env) assert isinstance(env.action_space, gym.spaces.Box) low, high = env.action_space.low, env.action_space.high - num_branches = env.action_space.shape[0] - self.action_space = gym.spaces.MultiDiscrete( - [action_per_branch] * num_branches + if isinstance(action_per_dim, int): + action_per_dim = [action_per_dim] * env.action_space.shape[0] + assert len(action_per_dim) == env.action_space.shape[0] + self.action_space = gym.spaces.MultiDiscrete(action_per_dim) + self.mesh = np.array( + [np.linspace(lo, hi, a) for lo, hi, a in zip(low, high, action_per_dim)], + dtype=object ) - mesh = [] - for lo, hi in zip(low, high): - mesh.append(np.linspace(lo, hi, action_per_branch)) - self.mesh = np.array(mesh) def action(self, act: np.ndarray) -> np.ndarray: # modify act - return np.array([self.mesh[i][a] for i, a in enumerate(act)]) + assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}." + if len(act.shape) == 1: + return np.array([self.mesh[i][a] for i, a in enumerate(act)]) + return np.array([[self.mesh[i][a] for i, a in enumerate(a_)] for a_ in act]) + + +class MultiDiscreteToDiscrete(gym.ActionWrapper): + """Gym environment wrapper to take discrete action in multidiscrete environment. + + :param gym.Env env: gym environment with multidiscrete action space. + """ + + def __init__(self, env: gym.Env) -> None: + super().__init__(env) + assert isinstance(env.action_space, gym.spaces.MultiDiscrete) + nvec = env.action_space.nvec + assert nvec.ndim == 1 + self.bases = np.ones_like(nvec) + for i in range(1, len(self.bases)): + self.bases[i] = self.bases[i - 1] * nvec[-i] + self.action_space = gym.spaces.Discrete(np.prod(nvec)) + + def action(self, act: np.ndarray) -> np.ndarray: + converted_act = [] + for b in np.flip(self.bases): + converted_act.append(act // b) + act = act % b + return np.array(converted_act).transpose()