MultiDiscrete to discrete gym action space wrapper (#664)
Has been tested to work with DQN and a custom MultiDiscrete gym env.
This commit is contained in:
parent
21b15803ac
commit
aba2d01d25
@ -1,13 +1,16 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
from gym.spaces.discrete import Discrete
|
||||
|
||||
from tianshou.data import Batch
|
||||
from tianshou.env import (
|
||||
ContinuousToDiscrete,
|
||||
DummyVectorEnv,
|
||||
MultiDiscreteToDiscrete,
|
||||
RayVectorEnv,
|
||||
ShmemVectorEnv,
|
||||
SubprocVectorEnv,
|
||||
@ -265,6 +268,43 @@ def test_venv_norm_obs():
|
||||
run_align_norm_obs(raw, train_env, test_env, action_list)
|
||||
|
||||
|
||||
def test_gym_wrappers():
|
||||
|
||||
class DummyEnv(gym.Env):
|
||||
|
||||
def __init__(self):
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=-1.0, high=2.0, shape=(4, ), dtype=np.float32
|
||||
)
|
||||
|
||||
bsz = 10
|
||||
action_per_branch = [4, 6, 10, 7]
|
||||
env = DummyEnv()
|
||||
original_act = env.action_space.high
|
||||
# convert continous to multidiscrete action space
|
||||
# with different action number per dimension
|
||||
env_m = ContinuousToDiscrete(env, action_per_branch)
|
||||
# check conversion is working properly for one action
|
||||
np.testing.assert_allclose(env_m.action(env_m.action_space.nvec - 1), original_act)
|
||||
# check conversion is working properly for a batch of actions
|
||||
np.testing.assert_allclose(
|
||||
env_m.action(np.array([env_m.action_space.nvec - 1] * bsz)),
|
||||
np.array([original_act] * bsz)
|
||||
)
|
||||
# convert multidiscrete with different action number per
|
||||
# dimension to discrete action space
|
||||
env_d = MultiDiscreteToDiscrete(env_m)
|
||||
# check conversion is working properly for one action
|
||||
np.testing.assert_allclose(
|
||||
env_d.action(env_d.action_space.n - 1), env_m.action_space.nvec - 1
|
||||
)
|
||||
# check conversion is working properly for a batch of actions
|
||||
np.testing.assert_allclose(
|
||||
env_d.action(np.array([env_d.action_space.n - 1] * bsz)),
|
||||
np.array([env_m.action_space.nvec - 1] * bsz)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_venv_wrapper_envpool():
|
||||
raw = envpool.make_gym("Ant-v3", num_envs=4)
|
||||
@ -279,10 +319,11 @@ def test_venv_wrapper_envpool():
|
||||
run_align_norm_obs(raw, train, test, actions)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_venv_norm_obs()
|
||||
test_venv_wrapper_envpool()
|
||||
test_env_obs_dtype()
|
||||
test_vecenv()
|
||||
test_async_env()
|
||||
test_async_check_id()
|
||||
test_gym_wrappers()
|
||||
|
3
tianshou/env/__init__.py
vendored
3
tianshou/env/__init__.py
vendored
@ -1,6 +1,6 @@
|
||||
"""Env package."""
|
||||
|
||||
from tianshou.env.gym_wrappers import ContinuousToDiscrete
|
||||
from tianshou.env.gym_wrappers import ContinuousToDiscrete, MultiDiscreteToDiscrete
|
||||
from tianshou.env.venv_wrappers import VectorEnvNormObs, VectorEnvWrapper
|
||||
from tianshou.env.venvs import (
|
||||
BaseVectorEnv,
|
||||
@ -25,4 +25,5 @@ __all__ = [
|
||||
"VectorEnvNormObs",
|
||||
"PettingZooEnv",
|
||||
"ContinuousToDiscrete",
|
||||
"MultiDiscreteToDiscrete",
|
||||
]
|
||||
|
47
tianshou/env/gym_wrappers.py
vendored
47
tianshou/env/gym_wrappers.py
vendored
@ -1,3 +1,5 @@
|
||||
from typing import List, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
@ -6,23 +8,50 @@ class ContinuousToDiscrete(gym.ActionWrapper):
|
||||
"""Gym environment wrapper to take discrete action in a continuous environment.
|
||||
|
||||
:param gym.Env env: gym environment with continuous action space.
|
||||
:param int action_per_branch: number of discrete actions in each dimension
|
||||
:param int action_per_dim: number of discrete actions in each dimension
|
||||
of the action space.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env, action_per_branch: int) -> None:
|
||||
def __init__(self, env: gym.Env, action_per_dim: Union[int, List[int]]) -> None:
|
||||
super().__init__(env)
|
||||
assert isinstance(env.action_space, gym.spaces.Box)
|
||||
low, high = env.action_space.low, env.action_space.high
|
||||
num_branches = env.action_space.shape[0]
|
||||
self.action_space = gym.spaces.MultiDiscrete(
|
||||
[action_per_branch] * num_branches
|
||||
if isinstance(action_per_dim, int):
|
||||
action_per_dim = [action_per_dim] * env.action_space.shape[0]
|
||||
assert len(action_per_dim) == env.action_space.shape[0]
|
||||
self.action_space = gym.spaces.MultiDiscrete(action_per_dim)
|
||||
self.mesh = np.array(
|
||||
[np.linspace(lo, hi, a) for lo, hi, a in zip(low, high, action_per_dim)],
|
||||
dtype=object
|
||||
)
|
||||
mesh = []
|
||||
for lo, hi in zip(low, high):
|
||||
mesh.append(np.linspace(lo, hi, action_per_branch))
|
||||
self.mesh = np.array(mesh)
|
||||
|
||||
def action(self, act: np.ndarray) -> np.ndarray:
|
||||
# modify act
|
||||
assert len(act.shape) <= 2, f"Unknown action format with shape {act.shape}."
|
||||
if len(act.shape) == 1:
|
||||
return np.array([self.mesh[i][a] for i, a in enumerate(act)])
|
||||
return np.array([[self.mesh[i][a] for i, a in enumerate(a_)] for a_ in act])
|
||||
|
||||
|
||||
class MultiDiscreteToDiscrete(gym.ActionWrapper):
|
||||
"""Gym environment wrapper to take discrete action in multidiscrete environment.
|
||||
|
||||
:param gym.Env env: gym environment with multidiscrete action space.
|
||||
"""
|
||||
|
||||
def __init__(self, env: gym.Env) -> None:
|
||||
super().__init__(env)
|
||||
assert isinstance(env.action_space, gym.spaces.MultiDiscrete)
|
||||
nvec = env.action_space.nvec
|
||||
assert nvec.ndim == 1
|
||||
self.bases = np.ones_like(nvec)
|
||||
for i in range(1, len(self.bases)):
|
||||
self.bases[i] = self.bases[i - 1] * nvec[-i]
|
||||
self.action_space = gym.spaces.Discrete(np.prod(nvec))
|
||||
|
||||
def action(self, act: np.ndarray) -> np.ndarray:
|
||||
converted_act = []
|
||||
for b in np.flip(self.bases):
|
||||
converted_act.append(act // b)
|
||||
act = act % b
|
||||
return np.array(converted_act).transpose()
|
||||
|
Loading…
x
Reference in New Issue
Block a user