MultiDiscrete to discrete gym action space wrapper (#664)

Has been tested to work with DQN and a custom MultiDiscrete gym env.
This commit is contained in:
Anas BELFADIL 2022-06-13 00:18:22 +02:00 committed by GitHub
parent 21b15803ac
commit aba2d01d25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 12 deletions

View File

@ -1,13 +1,16 @@
import sys import sys
import time import time
import gym
import numpy as np import numpy as np
import pytest import pytest
from gym.spaces.discrete import Discrete from gym.spaces.discrete import Discrete
from tianshou.data import Batch from tianshou.data import Batch
from tianshou.env import ( from tianshou.env import (
ContinuousToDiscrete,
DummyVectorEnv, DummyVectorEnv,
MultiDiscreteToDiscrete,
RayVectorEnv, RayVectorEnv,
ShmemVectorEnv, ShmemVectorEnv,
SubprocVectorEnv, SubprocVectorEnv,
@ -265,6 +268,43 @@ def test_venv_norm_obs():
run_align_norm_obs(raw, train_env, test_env, action_list) run_align_norm_obs(raw, train_env, test_env, action_list)
def test_gym_wrappers():
class DummyEnv(gym.Env):
def __init__(self):
self.action_space = gym.spaces.Box(
low=-1.0, high=2.0, shape=(4, ), dtype=np.float32
)
bsz = 10
action_per_branch = [4, 6, 10, 7]
env = DummyEnv()
original_act = env.action_space.high
# convert continous to multidiscrete action space
# with different action number per dimension
env_m = ContinuousToDiscrete(env, action_per_branch)
# check conversion is working properly for one action
np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act)
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
np.array([original_act] * bsz)
)
# convert multidiscrete with different action number per
# dimension to discrete action space
env_d = MultiDiscreteToDiscrete(env_m)
# check conversion is working properly for one action
np.testing.assert_allclose(
env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1
)
# check conversion is working properly for a batch of actions
np.testing.assert_allclose(
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
np.array([env_m.action_space.nvec - 1] * bsz)
)
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool(): def test_venv_wrapper_envpool():
raw = envpool.make_gym("Ant-v3", num_envs=4) raw = envpool.make_gym("Ant-v3", num_envs=4)
@ -279,10 +319,11 @@ def test_venv_wrapper_envpool():
run_align_norm_obs(raw, train, test, actions) run_align_norm_obs(raw, train, test, actions)
if __name__ == '__main__': if __name__ == "__main__":
test_venv_norm_obs() test_venv_norm_obs()
test_venv_wrapper_envpool() test_venv_wrapper_envpool()
test_env_obs_dtype() test_env_obs_dtype()
test_vecenv() test_vecenv()
test_async_env() test_async_env()
test_async_check_id() test_async_check_id()
test_gym_wrappers()

View File

@ -1,6 +1,6 @@
"""Env package.""" """Env package."""
from tianshou.env.gym_wrappers import ContinuousToDiscrete from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete
from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper
from tianshou.env.venvs import ( from tianshou.env.venvs import (
BaseVectorEnv, BaseVectorEnv,
@ -25,4 +25,5 @@ __all__ = [
"VectorEnvNormObs", "VectorEnvNormObs",
"PettingZooEnv", "PettingZooEnv",
"ContinuousToDiscrete", "ContinuousToDiscrete",
"MultiDiscreteToDiscrete",
] ]

View File

@ -1,3 +1,5 @@
from typing import List, Union
import gym import gym
import numpy as np import numpy as np
@ -6,23 +8,50 @@ class ContinuousToDiscrete(gym.ActionWrapper):
"""Gym environment wrapper to take discrete action in a continuous environment. """Gym environment wrapper to take discrete action in a continuous environment.
:param gym.Env env: gym environment with continuous action space. :param gym.Env env: gym environment with continuous action space.
:param int action_per_branch: number of discrete actions in each dimension :param int action_per_dim: number of discrete actions in each dimension
of the action space. of the action space.
""" """
def __init__(self, env: gym.Env, action_per_branch: int) -> None: def __init__(self, env: gym.Env, action_per_dim: Union[int, List[int]]) -> None:
super().__init__(env) super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Box) assert isinstance(env.action_space, gym.spaces.Box)
low, high = env.action_space.low, env.action_space.high low, high = env.action_space.low, env.action_space.high
num_branches = env.action_space.shape[0] if isinstance(action_per_dim, int):
self.action_space = gym.spaces.MultiDiscrete( action_per_dim = [action_per_dim] * env.action_space.shape[0]
[action_per_branch] * num_branches assert len(action_per_dim) == env.action_space.shape[0]
self.action_space = gym.spaces.MultiDiscrete(action_per_dim)
self.mesh = np.array(
[np.linspace(lo, hi, a) for lo, hi, a in zip(low, high, action_per_dim)],
dtype=object
) )
mesh = []
for lo, hi in zip(low, high):
mesh.append(np.linspace(lo, hi, action_per_branch))
self.mesh = np.array(mesh)
def action(self, act: np.ndarray) -> np.ndarray: def action(self, act: np.ndarray) -> np.ndarray:
# modify act # modify act
return np.array([self.mesh[i][a] for i, a in enumerate(act)]) assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}."
if len(act.shape) == 1:
return np.array([self.mesh[i][a] for i, a in enumerate(act)])
return np.array([[self.mesh[i][a] for i, a in enumerate(a_)] for a_ in act])
class MultiDiscreteToDiscrete(gym.ActionWrapper):
"""Gym environment wrapper to take discrete action in multidiscrete environment.
:param gym.Env env: gym environment with multidiscrete action space.
"""
def __init__(self, env: gym.Env) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.MultiDiscrete)
nvec = env.action_space.nvec
assert nvec.ndim == 1
self.bases = np.ones_like(nvec)
for i in range(1, len(self.bases)):
self.bases[i] = self.bases[i - 1] * nvec[-i]
self.action_space = gym.spaces.Discrete(np.prod(nvec))
def action(self, act: np.ndarray) -> np.ndarray:
converted_act = []
for b in np.flip(self.bases):
converted_act.append(act // b)
act = act % b
return np.array(converted_act).transpose()