Upgrade gym (#613)

fixes some deprecation warnings due to new changes in gym version 0.23:
- use `env.np_random.integers` instead of `env.np_random.randint`
- support `seed` and `return_info` arguments for reset (addresses https://github.com/thu-ml/tianshou/issues/605)
This commit is contained in:
Yifei Cheng 2022-06-27 18:52:21 -04:00 committed by GitHub
parent aba2d01d25
commit 43792bf5ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 382 additions and 107 deletions

View File

@ -32,7 +32,10 @@ class NoopResetEnv(gym.Wrapper):
def reset(self): def reset(self):
self.env.reset() self.env.reset()
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) if hasattr(self.unwrapped.np_random, "integers"):
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
else:
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(noops): for _ in range(noops):
obs, _, done, _ = self.env.step(self.noop_action) obs, _, done, _ = self.env.step(self.noop_action)
if done: if done:

View File

@ -15,7 +15,7 @@ def get_version() -> str:
def get_install_requires() -> str: def get_install_requires() -> str:
return [ return [
"gym>=0.15.4", "gym>=0.23.1",
"tqdm", "tqdm",
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
"tensorboard>=2.5.0", "tensorboard>=2.5.0",

View File

@ -79,11 +79,16 @@ class MyTestEnv(gym.Env):
self.rng = np.random.RandomState(seed) self.rng = np.random.RandomState(seed)
return [seed] return [seed]
def reset(self, state=0): def reset(self, state=0, seed=None, return_info=False):
if seed is not None:
self.rng = np.random.RandomState(seed)
self.done = False self.done = False
self.do_sleep() self.do_sleep()
self.index = state self.index = state
return self._get_state() if return_info:
return self._get_state(), {'key': 1, 'env': self}
else:
return self._get_state()
def _get_reward(self): def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True.""" """Generate a non-scalar reward if ma_rew is True."""

View File

@ -15,6 +15,11 @@ from tianshou.data import (
from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
try:
import envpool
except ImportError:
envpool = None
if __name__ == '__main__': if __name__ == '__main__':
from env import MyTestEnv, NXEnv from env import MyTestEnv, NXEnv
else: # pytest else: # pytest
@ -23,7 +28,7 @@ else: # pytest
class MyPolicy(BasePolicy): class MyPolicy(BasePolicy):
def __init__(self, dict_state=False, need_state=True): def __init__(self, dict_state=False, need_state=True, action_shape=None):
""" """
:param bool dict_state: if the observation of the environment is a dict :param bool dict_state: if the observation of the environment is a dict
:param bool need_state: if the policy needs the hidden state (for RNN) :param bool need_state: if the policy needs the hidden state (for RNN)
@ -31,6 +36,7 @@ class MyPolicy(BasePolicy):
super().__init__() super().__init__()
self.dict_state = dict_state self.dict_state = dict_state
self.need_state = need_state self.need_state = need_state
self.action_shape = action_shape
def forward(self, batch, state=None): def forward(self, batch, state=None):
if self.need_state: if self.need_state:
@ -39,8 +45,12 @@ class MyPolicy(BasePolicy):
else: else:
state += 1 state += 1
if self.dict_state: if self.dict_state:
return Batch(act=np.ones(len(batch.obs['index'])), state=state) action_shape = self.action_shape if self.action_shape else len(
return Batch(act=np.ones(len(batch.obs)), state=state) batch.obs['index']
)
return Batch(act=np.ones(action_shape), state=state)
action_shape = self.action_shape if self.action_shape else len(batch.obs)
return Batch(act=np.ones(action_shape), state=state)
def learn(self): def learn(self):
pass pass
@ -77,7 +87,8 @@ class Logger:
return Batch() return Batch()
def test_collector(): @pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
def test_collector(gym_reset_kwargs):
writer = SummaryWriter('log/collector') writer = SummaryWriter('log/collector')
logger = Logger(writer) logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
@ -86,52 +97,102 @@ def test_collector():
dum = DummyVectorEnv(env_fns) dum = DummyVectorEnv(env_fns)
policy = MyPolicy() policy = MyPolicy()
env = env_fns[0]() env = env_fns[0]()
c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn) c0 = Collector(
c0.collect(n_step=3) policy,
env,
ReplayBuffer(size=100),
logger.preprocess_fn,
)
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 3 assert len(c0.buffer) == 3
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
c0.collect(n_episode=3) keys = np.zeros(100)
keys[:3] = 1
assert np.allclose(c0.buffer.info["key"], keys)
for e in c0.buffer.info["env"][:3]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"], 0)
rews = np.zeros(100)
rews[:3] = [0, 1, 0]
assert np.allclose(c0.buffer.info["rew"], rews)
c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 8 assert len(c0.buffer) == 8
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c0.collect(n_step=3, random=True) assert np.allclose(c0.buffer.info["key"][:8], 1)
for e in c0.buffer.info["env"][:8]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"][:8], 0)
assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)
c1 = Collector( c1 = Collector(
policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4), policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn logger.preprocess_fn
) )
c1.collect(n_step=8) c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
obs = np.zeros(100) obs = np.zeros(100)
obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
c1.collect(n_episode=4) keys = np.zeros(100)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids = np.zeros(100)
env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
rews = np.zeros(100)
rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
assert len(c1.buffer) == 16 assert len(c1.buffer) == 16
valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose( assert np.allclose(
c1.buffer[:].obs_next[..., 0], c1.buffer[:].obs_next[..., 0],
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5] [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
) )
c1.collect(n_episode=4, random=True) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
c2 = Collector( c2 = Collector(
policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn logger.preprocess_fn
) )
c2.collect(n_episode=7) c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
obs1 = obs.copy() obs1 = obs.copy()
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
obs2 = obs.copy() obs2 = obs.copy()
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
c2obs = c2.buffer.obs[:, 0] c2obs = c2.buffer.obs[:, 0]
assert np.all(c2obs == obs1) or np.all(c2obs == obs2) assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
c2.reset_env() c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
c2.reset_buffer() c2.reset_buffer()
assert c2.collect(n_episode=8)['n/ep'] == 8 assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8
obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3] valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert np.all(c2.buffer.obs[:, 0] == obs) assert np.all(c2.buffer.obs[:, 0] == obs)
c2.collect(n_episode=4, random=True) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c2.buffer.info["key"], keys)
for e in c2.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
assert np.allclose(c2.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
assert np.allclose(c2.buffer.info["rew"], rews)
c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
# test corner case # test corner case
with pytest.raises(TypeError): with pytest.raises(TypeError):
@ -147,11 +208,12 @@ def test_collector():
[lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]] [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]
) )
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=6) c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
assert c3.buffer.obs.dtype == object assert c3.buffer.obs.dtype == object
def test_collector_with_async(): @pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
def test_collector_with_async(gym_reset_kwargs):
env_lens = [2, 3, 4, 5] env_lens = [2, 3, 4, 5]
writer = SummaryWriter('log/async_collector') writer = SummaryWriter('log/async_collector')
logger = Logger(writer) logger = Logger(writer)
@ -163,12 +225,14 @@ def test_collector_with_async():
policy = MyPolicy() policy = MyPolicy()
bufsize = 60 bufsize = 60
c1 = AsyncCollector( c1 = AsyncCollector(
policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), policy,
logger.preprocess_fn venv,
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn,
) )
ptr = [0, 0, 0, 0] ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode) result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
assert result["n/ep"] >= n_episode assert result["n/ep"] >= n_episode
# check buffer data, obs and obs_next, env_id # check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
@ -183,7 +247,7 @@ def test_collector_with_async():
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data # test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 15, desc="test async n_step"): for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step) result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
assert result["n/st"] >= n_step assert result["n/st"] >= n_step
for i in range(4): for i in range(4):
env_len = i + 2 env_len = i + 2
@ -618,9 +682,29 @@ def test_collector_with_atari_setting():
assert np.allclose(result2[key], result[key]) assert np.allclose(result2[key], result[key])
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_collector_envpool_gym_reset_return_info():
envs = envpool.make_gym("Pendulum-v0", num_envs=4, gym_reset_return_info=True)
policy = MyPolicy(action_shape=(len(envs), 1))
c0 = Collector(
policy,
envs,
VectorReplayBuffer(len(envs) * 10, len(envs)),
exploration_noise=True
)
c0.collect(n_step=8)
env_ids = np.zeros(len(envs) * 10)
env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c0.buffer.info["env_id"], env_ids)
if __name__ == '__main__': if __name__ == '__main__':
test_collector() test_collector(gym_reset_kwargs=None)
test_collector(gym_reset_kwargs=dict(return_info=True))
test_collector_with_dict_state() test_collector_with_dict_state()
test_collector_with_ma() test_collector_with_ma()
test_collector_with_atari_setting() test_collector_with_atari_setting()
test_collector_with_async() test_collector_with_async(gym_reset_kwargs=None)
test_collector_with_async(gym_reset_kwargs=dict(return_info=True))
test_collector_envpool_gym_reset_return_info()

View File

@ -222,6 +222,18 @@ def test_env_obs_dtype():
assert obs.dtype == object assert obs.dtype == object
def test_env_reset_optional_kwargs(size=10000, num=8):
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
for cls in test_cls:
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
_, info = v.reset(seed=1, return_info=True)
assert len(info) == len(env_fns)
assert isinstance(info[0], dict)
def run_align_norm_obs(raw_env, train_env, test_env, action_list): def run_align_norm_obs(raw_env, train_env, test_env, action_list):
eps = np.finfo(np.float32).eps.item() eps = np.finfo(np.float32).eps.item()
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()] raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
@ -319,11 +331,25 @@ def test_venv_wrapper_envpool():
run_align_norm_obs(raw, train, test, actions) run_align_norm_obs(raw, train, test, actions)
if __name__ == "__main__": @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info():
num_envs = 4
env = VectorEnvNormObs(
envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True)
)
obs, info = env.reset()
assert obs.shape[0] == num_envs
for _, v in info.items():
if not isinstance(v, dict):
assert v.shape[0] == num_envs
if __name__ == '__main__':
test_venv_norm_obs() test_venv_norm_obs()
test_venv_wrapper_envpool() test_venv_wrapper_envpool()
test_env_obs_dtype() test_env_obs_dtype()
test_vecenv() test_vecenv()
test_async_env() test_async_env()
test_async_check_id() test_async_check_id()
test_env_reset_optional_kwargs()
test_gym_wrappers() test_gym_wrappers()

View File

@ -100,18 +100,24 @@ class Collector(object):
) )
self.buffer = buffer self.buffer = buffer
def reset(self, reset_buffer: bool = True) -> None: def reset(
self,
reset_buffer: bool = True,
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Reset the environment, statistics, current data and possibly replay memory. """Reset the environment, statistics, current data and possibly replay memory.
:param bool reset_buffer: if true, reset the replay buffer that is attached :param bool reset_buffer: if true, reset the replay buffer that is attached
to the collector. to the collector.
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
""" """
# use empty Batch for "state" so that self.data supports slicing # use empty Batch for "state" so that self.data supports slicing
# convert empty Batch to None when passing data to policy # convert empty Batch to None when passing data to policy
self.data = Batch( self.data = Batch(
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
) )
self.reset_env() self.reset_env(gym_reset_kwargs)
if reset_buffer: if reset_buffer:
self.reset_buffer() self.reset_buffer()
self.reset_stat() self.reset_stat()
@ -124,12 +130,27 @@ class Collector(object):
"""Reset the data buffer.""" """Reset the data buffer."""
self.buffer.reset(keep_statistics=keep_statistics) self.buffer.reset(keep_statistics=keep_statistics)
def reset_env(self) -> None: def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
"""Reset all of the environments.""" """Reset all of the environments."""
obs = self.env.reset() gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
if self.preprocess_fn: rval = self.env.reset(**gym_reset_kwargs)
obs = self.preprocess_fn(obs=obs, returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
env_id=np.arange(self.env_num)).get("obs", obs) isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
)
if returns_info:
obs, info = rval
if self.preprocess_fn:
processed_data = self.preprocess_fn(
obs=obs, info=info, env_id=np.arange(self.env_num)
)
obs = processed_data.get("obs", obs)
info = processed_data.get("info", info)
self.data.info = info
else:
obs = rval
if self.preprocess_fn:
obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num
)).get("obs", obs)
self.data.obs = obs self.data.obs = obs
def _reset_state(self, id: Union[int, List[int]]) -> None: def _reset_state(self, id: Union[int, List[int]]) -> None:
@ -143,6 +164,33 @@ class Collector(object):
elif isinstance(state, Batch): elif isinstance(state, Batch):
state.empty_(id) state.empty_(id)
def _reset_env_with_ids(
self,
local_ids: Union[List[int], np.ndarray],
global_ids: Union[List[int], np.ndarray],
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
rval = self.env.reset(global_ids, **gym_reset_kwargs)
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
)
if returns_info:
obs_reset, info = rval
if self.preprocess_fn:
processed_data = self.preprocess_fn(
obs=obs_reset, info=info, env_id=global_ids
)
obs_reset = processed_data.get("obs", obs_reset)
info = processed_data.get("info", info)
self.data.info[local_ids] = info
else:
obs_reset = rval
if self.preprocess_fn:
obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids
).get("obs", obs_reset)
self.data.obs_next[local_ids] = obs_reset
def collect( def collect(
self, self,
n_step: Optional[int] = None, n_step: Optional[int] = None,
@ -150,6 +198,7 @@ class Collector(object):
random: bool = False, random: bool = False,
render: Optional[float] = None, render: Optional[float] = None,
no_grad: bool = True, no_grad: bool = True,
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Collect a specified number of step or episode. """Collect a specified number of step or episode.
@ -165,6 +214,8 @@ class Collector(object):
Default to None (no rendering). Default to None (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward(). Default to :param bool no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining). True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note:: .. note::
@ -288,12 +339,9 @@ class Collector(object):
episode_start_indices.append(ep_idx[env_ind_local]) episode_start_indices.append(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be # now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first. # finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global) self._reset_env_with_ids(
if self.preprocess_fn: env_ind_local, env_ind_global, gym_reset_kwargs
obs_reset = self.preprocess_fn( )
obs=obs_reset, env_id=env_ind_global
).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local: for i in env_ind_local:
self._reset_state(i) self._reset_state(i)
@ -367,10 +415,16 @@ class AsyncCollector(Collector):
) -> None: ) -> None:
# assert env.is_async # assert env.is_async
warnings.warn("Using async setting may collect extra transitions into buffer.") warnings.warn("Using async setting may collect extra transitions into buffer.")
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise) super().__init__(
policy,
env,
buffer,
preprocess_fn,
exploration_noise,
)
def reset_env(self) -> None: def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
super().reset_env() super().reset_env(gym_reset_kwargs)
self._ready_env_ids = np.arange(self.env_num) self._ready_env_ids = np.arange(self.env_num)
def collect( def collect(
@ -380,6 +434,7 @@ class AsyncCollector(Collector):
random: bool = False, random: bool = False,
render: Optional[float] = None, render: Optional[float] = None,
no_grad: bool = True, no_grad: bool = True,
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Collect a specified number of step or episode with async env setting. """Collect a specified number of step or episode with async env setting.
@ -395,6 +450,8 @@ class AsyncCollector(Collector):
Default to None (no rendering). Default to None (no rendering).
:param bool no_grad: whether to retain gradient in policy.forward(). Default to :param bool no_grad: whether to retain gradient in policy.forward(). Default to
True (no gradient retaining). True (no gradient retaining).
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
reset function. Defaults to None (extra keyword arguments)
.. note:: .. note::
@ -528,12 +585,9 @@ class AsyncCollector(Collector):
episode_start_indices.append(ep_idx[env_ind_local]) episode_start_indices.append(ep_idx[env_ind_local])
# now we copy obs_next to obs, but since there might be # now we copy obs_next to obs, but since there might be
# finished episodes, we have to reset finished envs first. # finished episodes, we have to reset finished envs first.
obs_reset = self.env.reset(env_ind_global) self._reset_env_with_ids(
if self.preprocess_fn: env_ind_local, env_ind_global, gym_reset_kwargs
obs_reset = self.preprocess_fn( )
obs=obs_reset, env_id=env_ind_global
).get("obs", obs_reset)
self.data.obs_next[env_ind_local] = obs_reset
for i in env_ind_local: for i in env_ind_local:
self._reset_state(i) self._reset_state(i)

View File

@ -1,5 +1,5 @@
from abc import ABC from abc import ABC
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple, Union
import gym.spaces import gym.spaces
from pettingzoo.utils.env import AECEnv from pettingzoo.utils.env import AECEnv
@ -55,11 +55,11 @@ class PettingZooEnv(AECEnv, ABC):
self.reset() self.reset()
def reset(self, *args: Any, **kwargs: Any) -> dict: def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
self.env.reset(*args, **kwargs) self.env.reset(*args, **kwargs)
observation = self.env.observe(self.env.agent_selection) observation, _, _, info = self.env.last(self)
if isinstance(observation, dict) and 'action_mask' in observation: if isinstance(observation, dict) and 'action_mask' in observation:
return { observation_dict = {
'agent_id': self.env.agent_selection, 'agent_id': self.env.agent_selection,
'obs': observation['observation'], 'obs': observation['observation'],
'mask': 'mask':
@ -67,13 +67,21 @@ class PettingZooEnv(AECEnv, ABC):
} }
else: else:
if isinstance(self.action_space, gym.spaces.Discrete): if isinstance(self.action_space, gym.spaces.Discrete):
return { observation_dict = {
'agent_id': self.env.agent_selection, 'agent_id': self.env.agent_selection,
'obs': observation, 'obs': observation,
'mask': [True] * self.env.action_space(self.env.agent_selection).n 'mask': [True] * self.env.action_space(self.env.agent_selection).n
} }
else: else:
return {'agent_id': self.env.agent_selection, 'obs': observation} observation_dict = {
'agent_id': self.env.agent_selection,
'obs': observation,
}
if "return_info" in kwargs and kwargs["return_info"]:
return observation_dict, info
else:
return observation_dict
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
self.env.step(action) self.env.step(action)

View File

@ -37,11 +37,12 @@ class VectorEnvWrapper(BaseVectorEnv):
) -> None: ) -> None:
return self.venv.set_env_attr(key, value, id) return self.venv.set_env_attr(key, value, id)
# TODO: compatible issue with reset -> (obs, info)
def reset( def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None self,
) -> np.ndarray: id: Optional[Union[int, List[int], np.ndarray]] = None,
return self.venv.reset(id) **kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
return self.venv.reset(id, **kwargs)
def step( def step(
self, self,
@ -86,14 +87,33 @@ class VectorEnvNormObs(VectorEnvWrapper):
self.clip_max = clip_obs self.clip_max = clip_obs
self.eps = epsilon self.eps = epsilon
# TODO: compatible issue with reset -> (obs, info)
def reset( def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None self,
) -> np.ndarray: id: Optional[Union[int, List[int], np.ndarray]] = None,
obs = self.venv.reset(id) **kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
retval = self.venv.reset(id, **kwargs)
reset_returns_info = isinstance(
retval, (tuple, list)
) and len(retval) == 2 and isinstance(retval[1], dict)
if reset_returns_info:
obs, info = retval
else:
obs = retval
if isinstance(obs, tuple):
raise TypeError(
"Tuple observation space is not supported. ",
"Please change it to array or dict space",
)
if self.obs_rms and self.update_obs_rms: if self.obs_rms and self.update_obs_rms:
self.obs_rms.update(obs) self.obs_rms.update(obs)
return self._norm_obs(obs) obs = self._norm_obs(obs)
if reset_returns_info:
return obs, info
else:
return obs
def step( def step(
self, self,

37
tianshou/env/venvs.py vendored
View File

@ -181,10 +181,11 @@ class BaseVectorEnv(object):
assert i in self.ready_id, \ assert i in self.ready_id, \
f"Can only interact with ready environments {self.ready_id}." f"Can only interact with ready environments {self.ready_id}."
# TODO: compatible issue with reset -> (obs, info)
def reset( def reset(
self, id: Optional[Union[int, List[int], np.ndarray]] = None self,
) -> np.ndarray: id: Optional[Union[int, List[int], np.ndarray]] = None,
**kwargs: Any,
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
"""Reset the state of some envs and return initial observations. """Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return If id is None, reset the state of all the environments and return
@ -195,15 +196,35 @@ class BaseVectorEnv(object):
id = self._wrap_id(id) id = self._wrap_id(id)
if self.is_async: if self.is_async:
self._assert_id(id) self._assert_id(id)
# send(None) == reset() in worker # send(None) == reset() in worker
for i in id: for i in id:
self.workers[i].send(None) self.workers[i].send(None, **kwargs)
obs_list = [self.workers[i].recv() for i in id] ret_list = [self.workers[i].recv() for i in id]
reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len(
ret_list[0]
) == 2 and isinstance(ret_list[0][1], dict)
if reset_returns_info:
obs_list = [r[0] for r in ret_list]
else:
obs_list = ret_list
if isinstance(obs_list[0], tuple):
raise TypeError(
"Tuple observation space is not supported. ",
"Please change it to array or dict space",
)
try: try:
obs = np.stack(obs_list) obs = np.stack(obs_list)
except ValueError: # different len(obs) except ValueError: # different len(obs)
obs = np.array(obs_list, dtype=object) obs = np.array(obs_list, dtype=object)
return obs
if reset_returns_info:
infos = [r[1] for r in ret_list]
return obs, infos # type: ignore
else:
return obs
def step( def step(
self, self,
@ -248,7 +269,7 @@ class BaseVectorEnv(object):
self.workers[j].send(action[i]) self.workers[j].send(action[i])
result = [] result = []
for j in id: for j in id:
obs, rew, done, info = self.workers[j].recv() obs, rew, done, info = self.workers[j].recv() # type: ignore
info["env_id"] = j info["env_id"] = j
result.append((obs, rew, done, info)) result.append((obs, rew, done, info))
else: else:
@ -270,7 +291,7 @@ class BaseVectorEnv(object):
waiting_index = self.waiting_conn.index(conn) waiting_index = self.waiting_conn.index(conn)
self.waiting_conn.pop(waiting_index) self.waiting_conn.pop(waiting_index)
env_id = self.waiting_id.pop(waiting_index) env_id = self.waiting_id.pop(waiting_index)
obs, rew, done, info = conn.recv() obs, rew, done, info = conn.recv() # type: ignore
info["env_id"] = env_id info["env_id"] = env_id
result.append((obs, rew, done, info)) result.append((obs, rew, done, info))
self.ready_id.append(env_id) self.ready_id.append(env_id)

View File

@ -14,7 +14,7 @@ class EnvWorker(ABC):
self._env_fn = env_fn self._env_fn = env_fn
self.is_closed = False self.is_closed = False
self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
np.ndarray] Tuple[np.ndarray, dict], np.ndarray]
self.action_space = self.get_env_attr("action_space") # noqa: B009 self.action_space = self.get_env_attr("action_space") # noqa: B009
self.is_reset = False self.is_reset = False
@ -47,7 +47,8 @@ class EnvWorker(ABC):
def recv( def recv(
self self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[
np.ndarray, dict], np.ndarray]: # noqa:E125
"""Receive result from low-level worker. """Receive result from low-level worker.
If the last "send" function sends a NULL action, it only returns a If the last "send" function sends a NULL action, it only returns a
@ -63,9 +64,9 @@ class EnvWorker(ABC):
self.result = self.get_result() # type: ignore self.result = self.get_result() # type: ignore
return self.result return self.result
def reset(self) -> np.ndarray: @abstractmethod
self.send(None) def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
return self.recv() # type: ignore pass
def step( def step(
self, action: np.ndarray self, action: np.ndarray

View File

@ -1,4 +1,4 @@
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional, Tuple, Union
import gym import gym
import numpy as np import numpy as np
@ -19,8 +19,10 @@ class DummyEnvWorker(EnvWorker):
def set_env_attr(self, key: str, value: Any) -> None: def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env, key, value) setattr(self.env, key, value)
def reset(self) -> Any: def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
return self.env.reset() if "seed" in kwargs:
super().seed(kwargs["seed"])
return self.env.reset(**kwargs)
@staticmethod @staticmethod
def wait( # type: ignore def wait( # type: ignore
@ -29,15 +31,19 @@ class DummyEnvWorker(EnvWorker):
# Sequential EnvWorker objects are always ready # Sequential EnvWorker objects are always ready
return workers return workers
def send(self, action: Optional[np.ndarray]) -> None: def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
if action is None: if action is None:
self.result = self.env.reset() # type: ignore self.result = self.env.reset(**kwargs)
else: else:
self.result = self.env.step(action) # type: ignore self.result = self.env.step(action) # type: ignore
def seed(self, seed: Optional[int] = None) -> List[int]: def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed) super().seed(seed)
return self.env.seed(seed) try:
return self.env.seed(seed)
except NotImplementedError:
self.env.reset(seed=seed)
return [seed] # type: ignore
def render(self, **kwargs: Any) -> Any: def render(self, **kwargs: Any) -> Any:
return self.env.render(**kwargs) return self.env.render(**kwargs)

View File

@ -35,8 +35,10 @@ class RayEnvWorker(EnvWorker):
def set_env_attr(self, key: str, value: Any) -> None: def set_env_attr(self, key: str, value: Any) -> None:
ray.get(self.env.set_env_attr.remote(key, value)) ray.get(self.env.set_env_attr.remote(key, value))
def reset(self) -> Any: def reset(self, **kwargs: Any) -> Any:
return ray.get(self.env.reset.remote()) if "seed" in kwargs:
super().seed(kwargs["seed"])
return ray.get(self.env.reset.remote(**kwargs))
@staticmethod @staticmethod
def wait( # type: ignore def wait( # type: ignore
@ -46,10 +48,10 @@ class RayEnvWorker(EnvWorker):
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
return [workers[results.index(result)] for result in ready_results] return [workers[results.index(result)] for result in ready_results]
def send(self, action: Optional[np.ndarray]) -> None: def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
# self.action is actually a handle # self.result is actually a handle
if action is None: if action is None:
self.result = self.env.reset.remote() self.result = self.env.reset.remote(**kwargs)
else: else:
self.result = self.env.step.remote(action) self.result = self.env.step.remote(action)
@ -58,9 +60,13 @@ class RayEnvWorker(EnvWorker):
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
return ray.get(self.result) # type: ignore return ray.get(self.result) # type: ignore
def seed(self, seed: Optional[int] = None) -> List[int]: def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed) super().seed(seed)
return ray.get(self.env.seed.remote(seed)) try:
return ray.get(self.env.seed.remote(seed))
except NotImplementedError:
self.env.reset.remote(seed=seed)
return None
def render(self, **kwargs: Any) -> Any: def render(self, **kwargs: Any) -> Any:
return ray.get(self.env.render.remote(**kwargs)) return ray.get(self.env.render.remote(**kwargs))

View File

@ -86,17 +86,27 @@ def _worker(
p.close() p.close()
break break
if cmd == "step": if cmd == "step":
if data is None: # reset obs, reward, done, info = env.step(data)
obs = env.reset()
else:
obs, reward, done, info = env.step(data)
if obs_bufs is not None: if obs_bufs is not None:
_encode_obs(obs, obs_bufs) _encode_obs(obs, obs_bufs)
obs = None obs = None
if data is None: p.send((obs, reward, done, info))
p.send(obs) elif cmd == "reset":
retval = env.reset(**data)
reset_returns_info = isinstance(
retval, (tuple, list)
) and len(retval) == 2 and isinstance(retval[1], dict)
if reset_returns_info:
obs, info = retval
else: else:
p.send((obs, reward, done, info)) obs = retval
if obs_bufs is not None:
_encode_obs(obs, obs_bufs)
obs = None
if reset_returns_info:
p.send((obs, info))
else:
p.send(obs)
elif cmd == "close": elif cmd == "close":
p.send(env.close()) p.send(env.close())
p.close() p.close()
@ -104,7 +114,11 @@ def _worker(
elif cmd == "render": elif cmd == "render":
p.send(env.render(**data) if hasattr(env, "render") else None) p.send(env.render(**data) if hasattr(env, "render") else None)
elif cmd == "seed": elif cmd == "seed":
p.send(env.seed(data) if hasattr(env, "seed") else None) if hasattr(env, "seed"):
p.send(env.seed(data))
else:
env.reset(seed=data)
p.send(None)
elif cmd == "getattr": elif cmd == "getattr":
p.send(getattr(env, data) if hasattr(env, data) else None) p.send(getattr(env, data) if hasattr(env, data) else None)
elif cmd == "setattr": elif cmd == "setattr":
@ -140,7 +154,6 @@ class SubprocEnvWorker(EnvWorker):
self.process = Process(target=_worker, args=args, daemon=True) self.process = Process(target=_worker, args=args, daemon=True)
self.process.start() self.process.start()
self.child_remote.close() self.child_remote.close()
self.is_reset = False
super().__init__(env_fn) super().__init__(env_fn)
def get_env_attr(self, key: str) -> Any: def get_env_attr(self, key: str) -> Any:
@ -186,14 +199,25 @@ class SubprocEnvWorker(EnvWorker):
remain_conns = [conn for conn in remain_conns if conn not in ready_conns] remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns] return [workers[conns.index(con)] for con in ready_conns]
def send(self, action: Optional[np.ndarray]) -> None: def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
self.parent_remote.send(["step", action]) if action is None:
if "seed" in kwargs:
super().seed(kwargs["seed"])
self.parent_remote.send(["reset", kwargs])
else:
self.parent_remote.send(["step", action])
def recv( def recv(
self self
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[
np.ndarray, dict], np.ndarray]: # noqa:E125
result = self.parent_remote.recv() result = self.parent_remote.recv()
if isinstance(result, tuple): if isinstance(result, tuple):
if len(result) == 2:
obs, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, info
obs, rew, done, info = result obs, rew, done, info = result
if self.share_memory: if self.share_memory:
obs = self._decode_obs() obs = self._decode_obs()
@ -204,6 +228,23 @@ class SubprocEnvWorker(EnvWorker):
obs = self._decode_obs() obs = self._decode_obs()
return obs return obs
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
if "seed" in kwargs:
super().seed(kwargs["seed"])
self.parent_remote.send(["reset", kwargs])
result = self.parent_remote.recv()
if isinstance(result, tuple):
obs, info = result
if self.share_memory:
obs = self._decode_obs()
return obs, info
else:
obs = result
if self.share_memory:
obs = self._decode_obs()
return obs
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
super().seed(seed) super().seed(seed)
self.parent_remote.send(["seed", seed]) self.parent_remote.send(["seed", seed])