Upgrade gym (#613)

fixes some deprecation warnings due to new changes in gym version 0.23:
- use `env.np_random.integers` instead of `env.np_random.randint`
- support `seed` and `return_info` arguments for reset (addresses https://github.com/thu-ml/tianshou/issues/605)
This commit is contained in:
Yifei Cheng 2022-06-27 18:52:21 -04:00 committed by GitHub
parent aba2d01d25
commit 43792bf5ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 382 additions and 107 deletions

View File

@ -32,7 +32,10 @@ class NoopResetEnv(gym.Wrapper):
def reset(self):
self.env.reset()
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
if hasattr(self.unwrapped.np_random, "integers"):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action)
if done:

View File

@ -15,7 +15,7 @@ def get_version() -> str:
def get_install_requires() -> str:
return [
"gym>=0.15.4",
"gym>=0.23.1",
"tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0",

View File

@ -79,11 +79,16 @@ class MyTestEnv(gym.Env):
self.rng = np.random.RandomState(seed)
return [seed]
def reset(self, state=0):
def reset(self, state=0, seed=None, return_info=False):
if seed is not None:
self.rng = np.random.RandomState(seed)
self.done = False
self.do_sleep()
self.index = state
return self._get_state()
if return_info:
return self._get_state(), {'key': 1, 'env': self}
else:
return self._get_state()
def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""

View File

@ -15,6 +15,11 @@ from tianshou.data import (
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy
try:
import envpool
except ImportError:
envpool = None
if __name__ == '__main__':
from env import MyTestEnv, NXEnv
else: # pytest
@ -23,7 +28,7 @@ else: # pytest
class MyPolicy(BasePolicy):
def __init__(self, dict_state=False, need_state=True):
def __init__(self, dict_state=False, need_state=True, action_shape=None):
"""
:param bool dict_state: if the observation of the environment is a dict
:param bool need_state: if the policy needs the hidden state (for RNN)
@ -31,6 +36,7 @@ class MyPolicy(BasePolicy):
super().__init__()
self.dict_state = dict_state
self.need_state = need_state
self.action_shape = action_shape
def forward(self, batch, state=None):
if self.need_state:
@ -39,8 +45,12 @@ class MyPolicy(BasePolicy):
else:
state += 1
if self.dict_state:
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
return Batch(act=np.ones(len(batch.obs)), state=state)
action_shape = self.action_shape if self.action_shape else len(
batch.obs['index']
)
return Batch(act=np.ones(action_shape), state=state)
action_shape = self.action_shape if self.action_shape else len(batch.obs)
return Batch(act=np.ones(action_shape), state=state)
def learn(self):
pass
@ -77,7 +87,8 @@ class Logger:
return Batch()
def test_collector():
@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
def test_collector(gym_reset_kwargs):
writer = SummaryWriter('log/collector')
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
@ -86,52 +97,102 @@ def test_collector():
dum = DummyVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
c0.collect(n_step=3)
c0 = Collector(
policy,
env,
ReplayBuffer(size=100),
logger.preprocess_fn,
)
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 3
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
c0.collect(n_episode=3)
keys = np.zeros(100)
keys[:3] = 1
assert np.allclose(c0.buffer.info["key"], keys)
for e in c0.buffer.info["env"][:3]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"], 0)
rews = np.zeros(100)
rews[:3] = [0, 1, 0]
assert np.allclose(c0.buffer.info["rew"], rews)
c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 8
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True)
assert np.allclose(c0.buffer.info["key"][:8], 1)
for e in c0.buffer.info["env"][:8]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"][:8], 0)
assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)
c1 = Collector(
policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn
)
c1.collect(n_step=8)
c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
obs = np.zeros(100)
obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]
valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c1.collect(n_episode=4)
keys = np.zeros(100)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids = np.zeros(100)
env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
rews = np.zeros(100)
rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
assert len(c1.buffer) == 16
valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(
c1.buffer[:].obs_next[..., 0],
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
)
c1.collect(n_episode=4, random=True)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
c2 = Collector(
policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn
)
c2.collect(n_episode=7)
c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
obs1 = obs.copy()
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
obs2 = obs.copy()
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
c2obs = c2.buffer.obs[:, 0]
assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
c2.reset_env()
c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
c2.reset_buffer()
assert c2.collect(n_episode=8)['n/ep'] == 8
obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8
valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert np.all(c2.buffer.obs[:, 0] == obs)
c2.collect(n_episode=4, random=True)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c2.buffer.info["key"], keys)
for e in c2.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
assert np.allclose(c2.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
assert np.allclose(c2.buffer.info["rew"], rews)
c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
# test corner case
with pytest.raises(TypeError):
@ -147,11 +208,12 @@ def test_collector():
[lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]
)
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=6)
c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
assert c3.buffer.obs.dtype == object
def test_collector_with_async():
@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
def test_collector_with_async(gym_reset_kwargs):
env_lens = [2, 3, 4, 5]
writer = SummaryWriter('log/async_collector')
logger = Logger(writer)
@ -163,12 +225,14 @@ def test_collector_with_async():
policy = MyPolicy()
bufsize = 60
c1 = AsyncCollector(
policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn
policy,
venv,
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn,
)
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
@ -183,7 +247,7 @@ def test_collector_with_async():
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step)
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
assert result["n/st"] >= n_step
for i in range(4):
env_len = i + 2
@ -618,9 +682,29 @@ def test_collector_with_atari_setting():
assert np.allclose(result2[key], result[key])
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_collector_envpool_gym_reset_return_info():
envs = envpool.make_gym("Pendulum-v0", num_envs=4, gym_reset_return_info=True)
policy = MyPolicy(action_shape=(len(envs), 1))
c0 = Collector(
policy,
envs,
VectorReplayBuffer(len(envs) * 10, len(envs)),
exploration_noise=True
)
c0.collect(n_step=8)
env_ids = np.zeros(len(envs) * 10)
env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c0.buffer.info["env_id"], env_ids)
if __name__ == '__main__':
test_collector()
test_collector(gym_reset_kwargs=None)
test_collector(gym_reset_kwargs=dict(return_info=True))
test_collector_with_dict_state()
test_collector_with_ma()
test_collector_with_atari_setting()
test_collector_with_async()
test_collector_with_async(gym_reset_kwargs=None)
test_collector_with_async(gym_reset_kwargs=dict(return_info=True))
test_collector_envpool_gym_reset_return_info()

View File

@ -222,6 +222,18 @@ def test_env_obs_dtype():
assert obs.dtype == object
def test_env_reset_optional_kwargs(size=10000, num=8):
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
for cls in test_cls:
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
_, info = v.reset(seed=1, return_info=True)
assert len(info) == len(env_fns)
assert isinstance(info[0], dict)
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
eps = np.finfo(np.float32).eps.item()
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
@ -319,11 +331,25 @@ def test_venv_wrapper_envpool():
run_align_norm_obs(raw, train, test, actions)
if __name__ == "__main__":
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info():
num_envs = 4
env = VectorEnvNormObs(
envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True)
)
obs, info = env.reset()
assert obs.shape[0] == num_envs
for _, v in info.items():
if not isinstance(v, dict):
assert v.shape[0] == num_envs
if __name__ == '__main__':
test_venv_norm_obs()
test_venv_wrapper_envpool()
test_env_obs_dtype()
test_vecenv()
test_async_env()
test_async_check_id()
test_env_reset_optional_kwargs()
test_gym_wrappers()

View File

@ -100,18 +100,24 @@ class Collector(object):
)
self.buffer = buffer
def reset(self, reset_buffer: bool = True) -> None:
def reset(
self,
reset_buffer: bool = True,
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Reset the environment, statistics, current data and possibly replay memory.
:param bool reset_buffer: if true, reset the replay buffer that is attached
to the collector.
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
"""
# use empty Batch for "state" so that self.data supports slicing
# convert empty Batch to None when passing data to policy
self.data = Batch(
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
)
self.reset_env()
self.reset_env(gym_reset_kwargs)
if reset_buffer:
self.reset_buffer()
self.reset_stat()
@ -124,12 +130,27 @@ class Collector(object):
"""Reset the data buffer."""
self.buffer.reset(keep_statistics=keep_statistics)
def reset_env(self) -> None:
def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
"""Reset all of the environments."""
obs = self.env.reset()
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs,
env_id=np.arange(self.env_num)).get("obs", obs)
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(**gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
)
if returns_info:
obs, info = rval
if self.preprocess_fn:
processed_data = self.preprocess_fn(
obs=obs, info=info, env_id=np.arange(self.env_num)
)
obs = processed_data.get("obs", obs)
info = processed_data.get("info", info)
self.data.info = info
else:
obs = rval
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num
)).get("obs", obs)
self.data.obs = obs
def _reset_state(self, id: Union[int, List[int]]) -> None:
@ -143,6 +164,33 @@ class Collector(object):
elif isinstance(state, Batch):
state.empty_(id)
def _reset_env_with_ids(
self,
local_ids: Union[List[int], np.ndarray],
global_ids: Union[List[int], np.ndarray],
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(global_ids, **gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
)
if returns_info:
obs_reset, info = rval
if self.preprocess_fn:
processed_data = self.preprocess_fn(
obs=obs_reset, info=info, env_id=global_ids
)
obs_reset = processed_data.get("obs", obs_reset)
info = processed_data.get("info", info)
self.data.info[local_ids] = info
else:
obs_reset = rval
if self.preprocess_fn:
obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids
).get("obs", obs_reset)
self.data.obs_next[local_ids] = obs_reset
def collect(
self,
n_step: Optional[int] = None,
@ -150,6 +198,7 @@ class Collector(object):
random: bool = False,
render: Optional[float] = None,
no_grad: bool = True,
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Collect a specified number of step or episode.
@ -165,6 +214,8 @@ class Collector(object):
Default to None (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note::
@ -288,12 +339,9 @@ class Collector(object):
episode_start_indices.append(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global
).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
self._reset_env_with_ids(
env_ind_local, env_ind_global, gym_reset_kwargs
)
for i in env_ind_local:
self._reset_state(i)
@ -367,10 +415,16 @@ class AsyncCollector(Collector):
) -> None:
# assert env.is_async
warnings.warn("Using async setting may collect extra transitions into buffer.")
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
super().__init__(
policy,
env,
buffer,
preprocess_fn,
exploration_noise,
)
def reset_env(self) -> None:
super().reset_env()
def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
super().reset_env(gym_reset_kwargs)
self._ready_env_ids = np.arange(self.env_num)
def collect(
@ -380,6 +434,7 @@ class AsyncCollector(Collector):
random: bool = False,
render: Optional[float] = None,
no_grad: bool = True,
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Collect a specified number of step or episode with async env setting.
@ -395,6 +450,8 @@ class AsyncCollector(Collector):
Default to None (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note::
@ -528,12 +585,9 @@ class AsyncCollector(Collector):
episode_start_indices.append(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global)
if self.preprocess_fn:
obs_reset = self.preprocess_fn(
obs=obs_reset, env_id=env_ind_global
).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
self._reset_env_with_ids(
env_ind_local, env_ind_global, gym_reset_kwargs
)
for i in env_ind_local:
self._reset_state(i)

View File

@ -1,5 +1,5 @@
from abc import ABC
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union
import gym.spaces
from pettingzoo.utils.env import AECEnv
@ -55,11 +55,11 @@ class PettingZooEnv(AECEnv, ABC):
self.reset()
def reset(self, *args: Any, **kwargs: Any) -> dict:
def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
self.env.reset(*args, **kwargs)
observation = self.env.observe(self.env.agent_selection)
observation, _, _, info = self.env.last(self)
if isinstance(observation, dict) and 'action_mask' in observation:
return {
observation_dict = {
'agent_id': self.env.agent_selection,
'obs': observation['observation'],
'mask':
@ -67,13 +67,21 @@ class PettingZooEnv(AECEnv, ABC):
}
else:
if isinstance(self.action_space, gym.spaces.Discrete):
return {
observation_dict = {
'agent_id': self.env.agent_selection,
'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n
}
else:
return {'agent_id': self.env.agent_selection, 'obs': observation}
observation_dict = {
'agent_id': self.env.agent_selection,
'obs': observation,
}
if "return_info" in kwargs and kwargs["return_info"]:
return observation_dict, info
else:
return observation_dict
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
self.env.step(action)

View File

@ -37,11 +37,12 @@ class VectorEnvWrapper(BaseVectorEnv):
) -> None:
return self.venv.set_env_attr(key, value, id)
# TODO: compatible issue with reset -> (obs, info)
def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> np.ndarray:
return self.venv.reset(id)
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
return self.venv.reset(id, **kwargs)
def step(
self,
@ -86,14 +87,33 @@ class VectorEnvNormObs(VectorEnvWrapper):
self.clip_max = clip_obs
self.eps = epsilon
# TODO: compatible issue with reset -> (obs, info)
def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> np.ndarray:
obs = self.venv.reset(id)
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
retval = self.venv.reset(id, **kwargs)
reset_returns_info = isinstance(
retval, (tuple, list)
) and len(retval) == 2 and isinstance(retval[1], dict)
if reset_returns_info:
obs, info = retval
else:
obs = retval
if isinstance(obs, tuple):
raise TypeError(
"Tuple observation space is not supported. ",
"Please change it to array or dict space",
)
if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs)
return self._norm_obs(obs)
obs = self._norm_obs(obs)
if reset_returns_info:
return obs, info
else:
return obs
def step(
self,

37
tianshou/env/venvs.py vendored
View File

@ -181,10 +181,11 @@ class BaseVectorEnv(object):
assert i in self.ready_id, \
f"Can only interact with ready environments {self.ready_id}."
# TODO: compatible issue with reset -> (obs, info)
def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None
) -> np.ndarray:
self,
id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
"""Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return
@ -195,15 +196,35 @@ class BaseVectorEnv(object):
id = self._wrap_id(id)
if self.is_async:
self._assert_id(id)
# send(None) == reset() in worker
for i in id:
self.workers[i].send(None)
obs_list = [self.workers[i].recv() for i in id]
self.workers[i].send(None, **kwargs)
ret_list = [self.workers[i].recv() for i in id]
reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len(
ret_list[0]
) == 2 and isinstance(ret_list[0][1], dict)
if reset_returns_info:
obs_list = [r[0] for r in ret_list]
else:
obs_list = ret_list
if isinstance(obs_list[0], tuple):
raise TypeError(
"Tuple observation space is not supported. ",
"Please change it to array or dict space",
)
try:
obs = np.stack(obs_list)
except ValueError: # different len(obs)
obs = np.array(obs_list, dtype=object)
return obs
if reset_returns_info:
infos = [r[1] for r in ret_list]
return obs, infos # type: ignore
else:
return obs
def step(
self,
@ -248,7 +269,7 @@ class BaseVectorEnv(object):
self.workers[j].send(action[i])
result = []
for j in id:
obs, rew, done, info = self.workers[j].recv()
obs, rew, done, info = self.workers[j].recv() # type: ignore
info["env_id"] = j
result.append((obs, rew, done, info))
else:
@ -270,7 +291,7 @@ class BaseVectorEnv(object):
waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index)
obs, rew, done, info = conn.recv()
obs, rew, done, info = conn.recv() # type: ignore
info["env_id"] = env_id
result.append((obs, rew, done, info))
self.ready_id.append(env_id)

View File

@ -14,7 +14,7 @@ class EnvWorker(ABC):
self._env_fn = env_fn
self.is_closed = False
self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
np.ndarray]
Tuple[np.ndarray, dict], np.ndarray]
self.action_space = self.get_env_attr("action_space") # noqa: B009
self.is_reset = False
@ -47,7 +47,8 @@ class EnvWorker(ABC):
def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[
np.ndarray, dict], np.ndarray]: # noqa:E125
"""Receive result from low-level worker.
If the last "send" function sends a NULL action, it only returns a
@ -63,9 +64,9 @@ class EnvWorker(ABC):
self.result = self.get_result() # type: ignore
return self.result
def reset(self) -> np.ndarray:
self.send(None)
return self.recv() # type: ignore
@abstractmethod
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
pass
def step(
self, action: np.ndarray

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional
from typing import Any, Callable, List, Optional, Tuple, Union
import gym
import numpy as np
@ -19,8 +19,10 @@ class DummyEnvWorker(EnvWorker):
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value)
def reset(self) -> Any:
return self.env.reset()
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
if "seed" in kwargs:
super().seed(kwargs["seed"])
return self.env.reset(**kwargs)
@staticmethod
def wait( # type: ignore
@ -29,15 +31,19 @@ class DummyEnvWorker(EnvWorker):
# Sequential EnvWorker objects are always ready
return workers
def send(self, action: Optional[np.ndarray]) -> None:
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
if action is None:
self.result = self.env.reset() # type: ignore
self.result = self.env.reset(**kwargs)
else:
self.result = self.env.step(action) # type: ignore
def seed(self, seed: Optional[int] = None) -> List[int]:
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed)
return self.env.seed(seed)
try:
return self.env.seed(seed)
except NotImplementedError:
self.env.reset(seed=seed)
return [seed] # type: ignore
def render(self, **kwargs: Any) -> Any:
return self.env.render(**kwargs)

View File

@ -35,8 +35,10 @@ class RayEnvWorker(EnvWorker):
def set_env_attr(self, key: str, value: Any) -> None:
ray.get(self.env.set_env_attr.remote(key, value))
def reset(self) -> Any:
return ray.get(self.env.reset.remote())
def reset(self, **kwargs: Any) -> Any:
if "seed" in kwargs:
super().seed(kwargs["seed"])
return ray.get(self.env.reset.remote(**kwargs))
@staticmethod
def wait( # type: ignore
@ -46,10 +48,10 @@ class RayEnvWorker(EnvWorker):
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
return [workers[results.index(result)] for result in ready_results]
def send(self, action: Optional[np.ndarray]) -> None:
# self.action is actually a handle
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
# self.result is actually a handle
if action is None:
self.result = self.env.reset.remote()
self.result = self.env.reset.remote(**kwargs)
else:
self.result = self.env.step.remote(action)
@ -58,9 +60,13 @@ class RayEnvWorker(EnvWorker):
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
return ray.get(self.result) # type: ignore
def seed(self, seed: Optional[int] = None) -> List[int]:
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed)
return ray.get(self.env.seed.remote(seed))
try:
return ray.get(self.env.seed.remote(seed))
except NotImplementedError:
self.env.reset.remote(seed=seed)
return None
def render(self, **kwargs: Any) -> Any:
return ray.get(self.env.render.remote(**kwargs))

View File

@ -86,17 +86,27 @@ def _worker(
p.close()
break
if cmd == "step":
if data is None: # reset
obs = env.reset()
else:
obs, reward, done, info = env.step(data)
obs, reward, done, info = env.step(data)
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
if data is None:
p.send(obs)
p.send((obs, reward, done, info))
elif cmd == "reset":
retval = env.reset(**data)
reset_returns_info = isinstance(
retval, (tuple, list)
) and len(retval) == 2 and isinstance(retval[1], dict)
if reset_returns_info:
obs, info = retval
else:
p.send((obs, reward, done, info))
obs = retval
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
if reset_returns_info:
p.send((obs, info))
else:
p.send(obs)
elif cmd == "close":
p.send(env.close())
p.close()
@ -104,7 +114,11 @@ def _worker(
elif cmd == "render":
p.send(env.render(**data) if hasattr(env, "render") else None)
elif cmd == "seed":
p.send(env.seed(data) if hasattr(env, "seed") else None)
if hasattr(env, "seed"):
p.send(env.seed(data))
else:
env.reset(seed=data)
p.send(None)
elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr":
@ -140,7 +154,6 @@ class SubprocEnvWorker(EnvWorker):
self.process = Process(target=_worker, args=args, daemon=True)
self.process.start()
self.child_remote.close()
self.is_reset = False
super().__init__(env_fn)
def get_env_attr(self, key: str) -> Any:
@ -186,14 +199,25 @@ class SubprocEnvWorker(EnvWorker):
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
def send(self, action: Optional[np.ndarray]) -> None:
self.parent_remote.send(["step", action])
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
if action is None:
if "seed" in kwargs:
super().seed(kwargs["seed"])
self.parent_remote.send(["reset", kwargs])
else:
self.parent_remote.send(["step", action])
def recv(
self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[
np.ndarray, dict], np.ndarray]: # noqa:E125
result = self.parent_remote.recv()
if isinstance(result, tuple):
if len(result) == 2:
obs, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, info
obs, rew, done, info = result
if self.share_memory:
obs = self._decode_obs()
@ -204,6 +228,23 @@ class SubprocEnvWorker(EnvWorker):
obs = self._decode_obs()
return obs
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
if "seed" in kwargs:
super().seed(kwargs["seed"])
self.parent_remote.send(["reset", kwargs])
result = self.parent_remote.recv()
if isinstance(result, tuple):
obs, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, info
else:
obs = result
if self.share_memory:
obs = self._decode_obs()
return obs
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed)
self.parent_remote.send(["seed", seed])