Tianshou/tianshou/env/gym_wrappers.py

29 lines
1.0 KiB
Python

import gym
import numpy as np
class ContinuousToDiscrete(gym.ActionWrapper):
"""Gym environment wrapper to take discrete action in a continuous environment.
:param gym.Env env: gym environment with continuous action space.
:param int action_per_branch: number of discrete actions in each dimension
of the action space.
"""
def __init__(self, env: gym.Env, action_per_branch: int) -> None:
super().__init__(env)
assert isinstance(env.action_space, gym.spaces.Box)
low, high = env.action_space.low, env.action_space.high
num_branches = env.action_space.shape[0]
self.action_space = gym.spaces.MultiDiscrete(
[action_per_branch] * num_branches
)
mesh = []
for lo, hi in zip(low, high):
mesh.append(np.linspace(lo, hi, action_per_branch))
self.mesh = np.array(mesh)
def action(self, act: np.ndarray) -> np.ndarray:
# modify act
return np.array([self.mesh[i][a] for i, a in enumerate(act)])