Upgrade gym (#613)
fixes some deprecation warnings due to new changes in gym version 0.23: - use `env.np_random.integers` instead of `env.np_random.randint` - support `seed` and `return_info` arguments for reset (addresses https://github.com/thu-ml/tianshou/issues/605)
This commit is contained in:
parent
aba2d01d25
commit
43792bf5ab
@ -32,7 +32,10 @@ class NoopResetEnv(gym.Wrapper):
|
||||
|
||||
def reset(self):
|
||||
self.env.reset()
|
||||
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
|
||||
if hasattr(self.unwrapped.np_random, "integers"):
|
||||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
|
||||
else:
|
||||
noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
|
||||
for _ in range(noops):
|
||||
obs, _, done, _ = self.env.step(self.noop_action)
|
||||
if done:
|
||||
|
2
setup.py
2
setup.py
@ -15,7 +15,7 @@ def get_version() -> str:
|
||||
|
||||
def get_install_requires() -> str:
|
||||
return [
|
||||
"gym>=0.15.4",
|
||||
"gym>=0.23.1",
|
||||
"tqdm",
|
||||
"numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793
|
||||
"tensorboard>=2.5.0",
|
||||
|
@ -79,11 +79,16 @@ class MyTestEnv(gym.Env):
|
||||
self.rng = np.random.RandomState(seed)
|
||||
return [seed]
|
||||
|
||||
def reset(self, state=0):
|
||||
def reset(self, state=0, seed=None, return_info=False):
|
||||
if seed is not None:
|
||||
self.rng = np.random.RandomState(seed)
|
||||
self.done = False
|
||||
self.do_sleep()
|
||||
self.index = state
|
||||
return self._get_state()
|
||||
if return_info:
|
||||
return self._get_state(), {'key': 1, 'env': self}
|
||||
else:
|
||||
return self._get_state()
|
||||
|
||||
def _get_reward(self):
|
||||
"""Generate a non-scalar reward if ma_rew is True."""
|
||||
|
@ -15,6 +15,11 @@ from tianshou.data import (
|
||||
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
|
||||
from tianshou.policy import BasePolicy
|
||||
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool = None
|
||||
|
||||
if __name__ == '__main__':
|
||||
from env import MyTestEnv, NXEnv
|
||||
else: # pytest
|
||||
@ -23,7 +28,7 @@ else: # pytest
|
||||
|
||||
class MyPolicy(BasePolicy):
|
||||
|
||||
def __init__(self, dict_state=False, need_state=True):
|
||||
def __init__(self, dict_state=False, need_state=True, action_shape=None):
|
||||
"""
|
||||
:param bool dict_state: if the observation of the environment is a dict
|
||||
:param bool need_state: if the policy needs the hidden state (for RNN)
|
||||
@ -31,6 +36,7 @@ class MyPolicy(BasePolicy):
|
||||
super().__init__()
|
||||
self.dict_state = dict_state
|
||||
self.need_state = need_state
|
||||
self.action_shape = action_shape
|
||||
|
||||
def forward(self, batch, state=None):
|
||||
if self.need_state:
|
||||
@ -39,8 +45,12 @@ class MyPolicy(BasePolicy):
|
||||
else:
|
||||
state += 1
|
||||
if self.dict_state:
|
||||
return Batch(act=np.ones(len(batch.obs['index'])), state=state)
|
||||
return Batch(act=np.ones(len(batch.obs)), state=state)
|
||||
action_shape = self.action_shape if self.action_shape else len(
|
||||
batch.obs['index']
|
||||
)
|
||||
return Batch(act=np.ones(action_shape), state=state)
|
||||
action_shape = self.action_shape if self.action_shape else len(batch.obs)
|
||||
return Batch(act=np.ones(action_shape), state=state)
|
||||
|
||||
def learn(self):
|
||||
pass
|
||||
@ -77,7 +87,8 @@ class Logger:
|
||||
return Batch()
|
||||
|
||||
|
||||
def test_collector():
|
||||
@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
|
||||
def test_collector(gym_reset_kwargs):
|
||||
writer = SummaryWriter('log/collector')
|
||||
logger = Logger(writer)
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||
@ -86,52 +97,102 @@ def test_collector():
|
||||
dum = DummyVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn)
|
||||
c0.collect(n_step=3)
|
||||
c0 = Collector(
|
||||
policy,
|
||||
env,
|
||||
ReplayBuffer(size=100),
|
||||
logger.preprocess_fn,
|
||||
)
|
||||
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert len(c0.buffer) == 3
|
||||
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
|
||||
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
|
||||
c0.collect(n_episode=3)
|
||||
keys = np.zeros(100)
|
||||
keys[:3] = 1
|
||||
assert np.allclose(c0.buffer.info["key"], keys)
|
||||
for e in c0.buffer.info["env"][:3]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
assert np.allclose(c0.buffer.info["env_id"], 0)
|
||||
rews = np.zeros(100)
|
||||
rews[:3] = [0, 1, 0]
|
||||
assert np.allclose(c0.buffer.info["rew"], rews)
|
||||
c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert len(c0.buffer) == 8
|
||||
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
|
||||
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c0.collect(n_step=3, random=True)
|
||||
assert np.allclose(c0.buffer.info["key"][:8], 1)
|
||||
for e in c0.buffer.info["env"][:8]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
assert np.allclose(c0.buffer.info["env_id"][:8], 0)
|
||||
assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
c1 = Collector(
|
||||
policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
logger.preprocess_fn
|
||||
)
|
||||
c1.collect(n_step=8)
|
||||
c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
|
||||
obs = np.zeros(100)
|
||||
obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1]
|
||||
|
||||
valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
|
||||
obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
|
||||
assert np.allclose(c1.buffer.obs[:, 0], obs)
|
||||
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
c1.collect(n_episode=4)
|
||||
keys = np.zeros(100)
|
||||
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert np.allclose(c1.buffer.info["key"], keys)
|
||||
for e in c1.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
env_ids = np.zeros(100)
|
||||
env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
|
||||
assert np.allclose(c1.buffer.info["env_id"], env_ids)
|
||||
rews = np.zeros(100)
|
||||
rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
|
||||
assert np.allclose(c1.buffer.info["rew"], rews)
|
||||
c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert len(c1.buffer) == 16
|
||||
valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
|
||||
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
|
||||
assert np.allclose(c1.buffer.obs[:, 0], obs)
|
||||
assert np.allclose(
|
||||
c1.buffer[:].obs_next[..., 0],
|
||||
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]
|
||||
)
|
||||
c1.collect(n_episode=4, random=True)
|
||||
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert np.allclose(c1.buffer.info["key"], keys)
|
||||
for e in c1.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
|
||||
assert np.allclose(c1.buffer.info["env_id"], env_ids)
|
||||
rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
|
||||
assert np.allclose(c1.buffer.info["rew"], rews)
|
||||
c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
c2 = Collector(
|
||||
policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
logger.preprocess_fn
|
||||
)
|
||||
c2.collect(n_episode=7)
|
||||
c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
|
||||
obs1 = obs.copy()
|
||||
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
|
||||
obs2 = obs.copy()
|
||||
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
|
||||
c2obs = c2.buffer.obs[:, 0]
|
||||
assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
|
||||
c2.reset_env()
|
||||
c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
||||
c2.reset_buffer()
|
||||
assert c2.collect(n_episode=8)['n/ep'] == 8
|
||||
obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||
assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8
|
||||
valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
|
||||
obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||
assert np.all(c2.buffer.obs[:, 0] == obs)
|
||||
c2.collect(n_episode=4, random=True)
|
||||
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert np.allclose(c2.buffer.info["key"], keys)
|
||||
for e in c2.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
|
||||
assert np.allclose(c2.buffer.info["env_id"], env_ids)
|
||||
rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
|
||||
assert np.allclose(c2.buffer.info["rew"], rews)
|
||||
c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
|
||||
|
||||
# test corner case
|
||||
with pytest.raises(TypeError):
|
||||
@ -147,11 +208,12 @@ def test_collector():
|
||||
[lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]]
|
||||
)
|
||||
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
|
||||
c3.collect(n_step=6)
|
||||
c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert c3.buffer.obs.dtype == object
|
||||
|
||||
|
||||
def test_collector_with_async():
|
||||
@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)])
|
||||
def test_collector_with_async(gym_reset_kwargs):
|
||||
env_lens = [2, 3, 4, 5]
|
||||
writer = SummaryWriter('log/async_collector')
|
||||
logger = Logger(writer)
|
||||
@ -163,12 +225,14 @@ def test_collector_with_async():
|
||||
policy = MyPolicy()
|
||||
bufsize = 60
|
||||
c1 = AsyncCollector(
|
||||
policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
|
||||
logger.preprocess_fn
|
||||
policy,
|
||||
venv,
|
||||
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
|
||||
logger.preprocess_fn,
|
||||
)
|
||||
ptr = [0, 0, 0, 0]
|
||||
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
|
||||
result = c1.collect(n_episode=n_episode)
|
||||
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert result["n/ep"] >= n_episode
|
||||
# check buffer data, obs and obs_next, env_id
|
||||
for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]):
|
||||
@ -183,7 +247,7 @@ def test_collector_with_async():
|
||||
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
|
||||
# test async n_step, for now the buffer should be full of data
|
||||
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
|
||||
result = c1.collect(n_step=n_step)
|
||||
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert result["n/st"] >= n_step
|
||||
for i in range(4):
|
||||
env_len = i + 2
|
||||
@ -618,9 +682,29 @@ def test_collector_with_atari_setting():
|
||||
assert np.allclose(result2[key], result[key])
|
||||
|
||||
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_collector_envpool_gym_reset_return_info():
|
||||
envs = envpool.make_gym("Pendulum-v0", num_envs=4, gym_reset_return_info=True)
|
||||
policy = MyPolicy(action_shape=(len(envs), 1))
|
||||
|
||||
c0 = Collector(
|
||||
policy,
|
||||
envs,
|
||||
VectorReplayBuffer(len(envs) * 10, len(envs)),
|
||||
exploration_noise=True
|
||||
)
|
||||
c0.collect(n_step=8)
|
||||
env_ids = np.zeros(len(envs) * 10)
|
||||
env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3]
|
||||
assert np.allclose(c0.buffer.info["env_id"], env_ids)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_collector()
|
||||
test_collector(gym_reset_kwargs=None)
|
||||
test_collector(gym_reset_kwargs=dict(return_info=True))
|
||||
test_collector_with_dict_state()
|
||||
test_collector_with_ma()
|
||||
test_collector_with_atari_setting()
|
||||
test_collector_with_async()
|
||||
test_collector_with_async(gym_reset_kwargs=None)
|
||||
test_collector_with_async(gym_reset_kwargs=dict(return_info=True))
|
||||
test_collector_envpool_gym_reset_return_info()
|
||||
|
@ -222,6 +222,18 @@ def test_env_obs_dtype():
|
||||
assert obs.dtype == object
|
||||
|
||||
|
||||
def test_env_reset_optional_kwargs(size=10000, num=8):
|
||||
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
|
||||
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
|
||||
if has_ray():
|
||||
test_cls += [RayVectorEnv]
|
||||
for cls in test_cls:
|
||||
v = cls(env_fns, wait_num=num // 2, timeout=1e-3)
|
||||
_, info = v.reset(seed=1, return_info=True)
|
||||
assert len(info) == len(env_fns)
|
||||
assert isinstance(info[0], dict)
|
||||
|
||||
|
||||
def run_align_norm_obs(raw_env, train_env, test_env, action_list):
|
||||
eps = np.finfo(np.float32).eps.item()
|
||||
raw_obs, train_obs = [raw_env.reset()], [train_env.reset()]
|
||||
@ -319,11 +331,25 @@ def test_venv_wrapper_envpool():
|
||||
run_align_norm_obs(raw, train, test, actions)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_venv_wrapper_envpool_gym_reset_return_info():
|
||||
num_envs = 4
|
||||
env = VectorEnvNormObs(
|
||||
envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True)
|
||||
)
|
||||
obs, info = env.reset()
|
||||
assert obs.shape[0] == num_envs
|
||||
for _, v in info.items():
|
||||
if not isinstance(v, dict):
|
||||
assert v.shape[0] == num_envs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_venv_norm_obs()
|
||||
test_venv_wrapper_envpool()
|
||||
test_env_obs_dtype()
|
||||
test_vecenv()
|
||||
test_async_env()
|
||||
test_async_check_id()
|
||||
test_env_reset_optional_kwargs()
|
||||
test_gym_wrappers()
|
||||
|
@ -100,18 +100,24 @@ class Collector(object):
|
||||
)
|
||||
self.buffer = buffer
|
||||
|
||||
def reset(self, reset_buffer: bool = True) -> None:
|
||||
def reset(
|
||||
self,
|
||||
reset_buffer: bool = True,
|
||||
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Reset the environment, statistics, current data and possibly replay memory.
|
||||
|
||||
:param bool reset_buffer: if true, reset the replay buffer that is attached
|
||||
to the collector.
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
"""
|
||||
# use empty Batch for "state" so that self.data supports slicing
|
||||
# convert empty Batch to None when passing data to policy
|
||||
self.data = Batch(
|
||||
obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={}
|
||||
)
|
||||
self.reset_env()
|
||||
self.reset_env(gym_reset_kwargs)
|
||||
if reset_buffer:
|
||||
self.reset_buffer()
|
||||
self.reset_stat()
|
||||
@ -124,12 +130,27 @@ class Collector(object):
|
||||
"""Reset the data buffer."""
|
||||
self.buffer.reset(keep_statistics=keep_statistics)
|
||||
|
||||
def reset_env(self) -> None:
|
||||
def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Reset all of the environments."""
|
||||
obs = self.env.reset()
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs,
|
||||
env_id=np.arange(self.env_num)).get("obs", obs)
|
||||
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
||||
rval = self.env.reset(**gym_reset_kwargs)
|
||||
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
|
||||
)
|
||||
if returns_info:
|
||||
obs, info = rval
|
||||
if self.preprocess_fn:
|
||||
processed_data = self.preprocess_fn(
|
||||
obs=obs, info=info, env_id=np.arange(self.env_num)
|
||||
)
|
||||
obs = processed_data.get("obs", obs)
|
||||
info = processed_data.get("info", info)
|
||||
self.data.info = info
|
||||
else:
|
||||
obs = rval
|
||||
if self.preprocess_fn:
|
||||
obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num
|
||||
)).get("obs", obs)
|
||||
self.data.obs = obs
|
||||
|
||||
def _reset_state(self, id: Union[int, List[int]]) -> None:
|
||||
@ -143,6 +164,33 @@ class Collector(object):
|
||||
elif isinstance(state, Batch):
|
||||
state.empty_(id)
|
||||
|
||||
def _reset_env_with_ids(
|
||||
self,
|
||||
local_ids: Union[List[int], np.ndarray],
|
||||
global_ids: Union[List[int], np.ndarray],
|
||||
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
|
||||
rval = self.env.reset(global_ids, **gym_reset_kwargs)
|
||||
returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and (
|
||||
isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore
|
||||
)
|
||||
if returns_info:
|
||||
obs_reset, info = rval
|
||||
if self.preprocess_fn:
|
||||
processed_data = self.preprocess_fn(
|
||||
obs=obs_reset, info=info, env_id=global_ids
|
||||
)
|
||||
obs_reset = processed_data.get("obs", obs_reset)
|
||||
info = processed_data.get("info", info)
|
||||
self.data.info[local_ids] = info
|
||||
else:
|
||||
obs_reset = rval
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids
|
||||
).get("obs", obs_reset)
|
||||
self.data.obs_next[local_ids] = obs_reset
|
||||
|
||||
def collect(
|
||||
self,
|
||||
n_step: Optional[int] = None,
|
||||
@ -150,6 +198,7 @@ class Collector(object):
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
no_grad: bool = True,
|
||||
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Collect a specified number of step or episode.
|
||||
|
||||
@ -165,6 +214,8 @@ class Collector(object):
|
||||
Default to None (no rendering).
|
||||
:param bool no_grad: whether to retain gradient in policy.forward(). Default to
|
||||
True (no gradient retaining).
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
|
||||
.. note::
|
||||
|
||||
@ -288,12 +339,9 @@ class Collector(object):
|
||||
episode_start_indices.append(ep_idx[env_ind_local])
|
||||
# now we copy obs_next to obs, but since there might be
|
||||
# finished episodes, we have to reset finished envs first.
|
||||
obs_reset = self.env.reset(env_ind_global)
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(
|
||||
obs=obs_reset, env_id=env_ind_global
|
||||
).get("obs", obs_reset)
|
||||
self.data.obs_next[env_ind_local] = obs_reset
|
||||
self._reset_env_with_ids(
|
||||
env_ind_local, env_ind_global, gym_reset_kwargs
|
||||
)
|
||||
for i in env_ind_local:
|
||||
self._reset_state(i)
|
||||
|
||||
@ -367,10 +415,16 @@ class AsyncCollector(Collector):
|
||||
) -> None:
|
||||
# assert env.is_async
|
||||
warnings.warn("Using async setting may collect extra transitions into buffer.")
|
||||
super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
|
||||
super().__init__(
|
||||
policy,
|
||||
env,
|
||||
buffer,
|
||||
preprocess_fn,
|
||||
exploration_noise,
|
||||
)
|
||||
|
||||
def reset_env(self) -> None:
|
||||
super().reset_env()
|
||||
def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
|
||||
super().reset_env(gym_reset_kwargs)
|
||||
self._ready_env_ids = np.arange(self.env_num)
|
||||
|
||||
def collect(
|
||||
@ -380,6 +434,7 @@ class AsyncCollector(Collector):
|
||||
random: bool = False,
|
||||
render: Optional[float] = None,
|
||||
no_grad: bool = True,
|
||||
gym_reset_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Collect a specified number of step or episode with async env setting.
|
||||
|
||||
@ -395,6 +450,8 @@ class AsyncCollector(Collector):
|
||||
Default to None (no rendering).
|
||||
:param bool no_grad: whether to retain gradient in policy.forward(). Default to
|
||||
True (no gradient retaining).
|
||||
:param gym_reset_kwargs: extra keyword arguments to pass into the environment's
|
||||
reset function. Defaults to None (extra keyword arguments)
|
||||
|
||||
.. note::
|
||||
|
||||
@ -528,12 +585,9 @@ class AsyncCollector(Collector):
|
||||
episode_start_indices.append(ep_idx[env_ind_local])
|
||||
# now we copy obs_next to obs, but since there might be
|
||||
# finished episodes, we have to reset finished envs first.
|
||||
obs_reset = self.env.reset(env_ind_global)
|
||||
if self.preprocess_fn:
|
||||
obs_reset = self.preprocess_fn(
|
||||
obs=obs_reset, env_id=env_ind_global
|
||||
).get("obs", obs_reset)
|
||||
self.data.obs_next[env_ind_local] = obs_reset
|
||||
self._reset_env_with_ids(
|
||||
env_ind_local, env_ind_global, gym_reset_kwargs
|
||||
)
|
||||
for i in env_ind_local:
|
||||
self._reset_state(i)
|
||||
|
||||
|
20
tianshou/env/pettingzoo_env.py
vendored
20
tianshou/env/pettingzoo_env.py
vendored
@ -1,5 +1,5 @@
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import gym.spaces
|
||||
from pettingzoo.utils.env import AECEnv
|
||||
@ -55,11 +55,11 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self, *args: Any, **kwargs: Any) -> dict:
|
||||
def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
|
||||
self.env.reset(*args, **kwargs)
|
||||
observation = self.env.observe(self.env.agent_selection)
|
||||
observation, _, _, info = self.env.last(self)
|
||||
if isinstance(observation, dict) and 'action_mask' in observation:
|
||||
return {
|
||||
observation_dict = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation['observation'],
|
||||
'mask':
|
||||
@ -67,13 +67,21 @@ class PettingZooEnv(AECEnv, ABC):
|
||||
}
|
||||
else:
|
||||
if isinstance(self.action_space, gym.spaces.Discrete):
|
||||
return {
|
||||
observation_dict = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation,
|
||||
'mask': [True] * self.env.action_space(self.env.agent_selection).n
|
||||
}
|
||||
else:
|
||||
return {'agent_id': self.env.agent_selection, 'obs': observation}
|
||||
observation_dict = {
|
||||
'agent_id': self.env.agent_selection,
|
||||
'obs': observation,
|
||||
}
|
||||
|
||||
if "return_info" in kwargs and kwargs["return_info"]:
|
||||
return observation_dict, info
|
||||
else:
|
||||
return observation_dict
|
||||
|
||||
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
|
||||
self.env.step(action)
|
||||
|
38
tianshou/env/venv_wrappers.py
vendored
38
tianshou/env/venv_wrappers.py
vendored
@ -37,11 +37,12 @@ class VectorEnvWrapper(BaseVectorEnv):
|
||||
) -> None:
|
||||
return self.venv.set_env_attr(key, value, id)
|
||||
|
||||
# TODO: compatible issue with reset -> (obs, info)
|
||||
def reset(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
return self.venv.reset(id)
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
||||
return self.venv.reset(id, **kwargs)
|
||||
|
||||
def step(
|
||||
self,
|
||||
@ -86,14 +87,33 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
||||
self.clip_max = clip_obs
|
||||
self.eps = epsilon
|
||||
|
||||
# TODO: compatible issue with reset -> (obs, info)
|
||||
def reset(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
obs = self.venv.reset(id)
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
||||
retval = self.venv.reset(id, **kwargs)
|
||||
reset_returns_info = isinstance(
|
||||
retval, (tuple, list)
|
||||
) and len(retval) == 2 and isinstance(retval[1], dict)
|
||||
if reset_returns_info:
|
||||
obs, info = retval
|
||||
else:
|
||||
obs = retval
|
||||
|
||||
if isinstance(obs, tuple):
|
||||
raise TypeError(
|
||||
"Tuple observation space is not supported. ",
|
||||
"Please change it to array or dict space",
|
||||
)
|
||||
|
||||
if self.obs_rms and self.update_obs_rms:
|
||||
self.obs_rms.update(obs)
|
||||
return self._norm_obs(obs)
|
||||
obs = self._norm_obs(obs)
|
||||
if reset_returns_info:
|
||||
return obs, info
|
||||
else:
|
||||
return obs
|
||||
|
||||
def step(
|
||||
self,
|
||||
|
37
tianshou/env/venvs.py
vendored
37
tianshou/env/venvs.py
vendored
@ -181,10 +181,11 @@ class BaseVectorEnv(object):
|
||||
assert i in self.ready_id, \
|
||||
f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
# TODO: compatible issue with reset -> (obs, info)
|
||||
def reset(
|
||||
self, id: Optional[Union[int, List[int], np.ndarray]] = None
|
||||
) -> np.ndarray:
|
||||
self,
|
||||
id: Optional[Union[int, List[int], np.ndarray]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
|
||||
If id is None, reset the state of all the environments and return
|
||||
@ -195,15 +196,35 @@ class BaseVectorEnv(object):
|
||||
id = self._wrap_id(id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
|
||||
# send(None) == reset() in worker
|
||||
for i in id:
|
||||
self.workers[i].send(None)
|
||||
obs_list = [self.workers[i].recv() for i in id]
|
||||
self.workers[i].send(None, **kwargs)
|
||||
ret_list = [self.workers[i].recv() for i in id]
|
||||
|
||||
reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len(
|
||||
ret_list[0]
|
||||
) == 2 and isinstance(ret_list[0][1], dict)
|
||||
if reset_returns_info:
|
||||
obs_list = [r[0] for r in ret_list]
|
||||
else:
|
||||
obs_list = ret_list
|
||||
|
||||
if isinstance(obs_list[0], tuple):
|
||||
raise TypeError(
|
||||
"Tuple observation space is not supported. ",
|
||||
"Please change it to array or dict space",
|
||||
)
|
||||
try:
|
||||
obs = np.stack(obs_list)
|
||||
except ValueError: # different len(obs)
|
||||
obs = np.array(obs_list, dtype=object)
|
||||
return obs
|
||||
|
||||
if reset_returns_info:
|
||||
infos = [r[1] for r in ret_list]
|
||||
return obs, infos # type: ignore
|
||||
else:
|
||||
return obs
|
||||
|
||||
def step(
|
||||
self,
|
||||
@ -248,7 +269,7 @@ class BaseVectorEnv(object):
|
||||
self.workers[j].send(action[i])
|
||||
result = []
|
||||
for j in id:
|
||||
obs, rew, done, info = self.workers[j].recv()
|
||||
obs, rew, done, info = self.workers[j].recv() # type: ignore
|
||||
info["env_id"] = j
|
||||
result.append((obs, rew, done, info))
|
||||
else:
|
||||
@ -270,7 +291,7 @@ class BaseVectorEnv(object):
|
||||
waiting_index = self.waiting_conn.index(conn)
|
||||
self.waiting_conn.pop(waiting_index)
|
||||
env_id = self.waiting_id.pop(waiting_index)
|
||||
obs, rew, done, info = conn.recv()
|
||||
obs, rew, done, info = conn.recv() # type: ignore
|
||||
info["env_id"] = env_id
|
||||
result.append((obs, rew, done, info))
|
||||
self.ready_id.append(env_id)
|
||||
|
11
tianshou/env/worker/base.py
vendored
11
tianshou/env/worker/base.py
vendored
@ -14,7 +14,7 @@ class EnvWorker(ABC):
|
||||
self._env_fn = env_fn
|
||||
self.is_closed = False
|
||||
self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
|
||||
np.ndarray]
|
||||
Tuple[np.ndarray, dict], np.ndarray]
|
||||
self.action_space = self.get_env_attr("action_space") # noqa: B009
|
||||
self.is_reset = False
|
||||
|
||||
@ -47,7 +47,8 @@ class EnvWorker(ABC):
|
||||
|
||||
def recv(
|
||||
self
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[
|
||||
np.ndarray, dict], np.ndarray]: # noqa:E125
|
||||
"""Receive result from low-level worker.
|
||||
|
||||
If the last "send" function sends a NULL action, it only returns a
|
||||
@ -63,9 +64,9 @@ class EnvWorker(ABC):
|
||||
self.result = self.get_result() # type: ignore
|
||||
return self.result
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
self.send(None)
|
||||
return self.recv() # type: ignore
|
||||
@abstractmethod
|
||||
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
|
||||
pass
|
||||
|
||||
def step(
|
||||
self, action: np.ndarray
|
||||
|
20
tianshou/env/worker/dummy.py
vendored
20
tianshou/env/worker/dummy.py
vendored
@ -1,4 +1,4 @@
|
||||
from typing import Any, Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import gym
|
||||
import numpy as np
|
||||
@ -19,8 +19,10 @@ class DummyEnvWorker(EnvWorker):
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
setattr(self.env, key, value)
|
||||
|
||||
def reset(self) -> Any:
|
||||
return self.env.reset()
|
||||
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
|
||||
if "seed" in kwargs:
|
||||
super().seed(kwargs["seed"])
|
||||
return self.env.reset(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
@ -29,15 +31,19 @@ class DummyEnvWorker(EnvWorker):
|
||||
# Sequential EnvWorker objects are always ready
|
||||
return workers
|
||||
|
||||
def send(self, action: Optional[np.ndarray]) -> None:
|
||||
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
|
||||
if action is None:
|
||||
self.result = self.env.reset() # type: ignore
|
||||
self.result = self.env.reset(**kwargs)
|
||||
else:
|
||||
self.result = self.env.step(action) # type: ignore
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
super().seed(seed)
|
||||
return self.env.seed(seed)
|
||||
try:
|
||||
return self.env.seed(seed)
|
||||
except NotImplementedError:
|
||||
self.env.reset(seed=seed)
|
||||
return [seed] # type: ignore
|
||||
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
return self.env.render(**kwargs)
|
||||
|
20
tianshou/env/worker/ray.py
vendored
20
tianshou/env/worker/ray.py
vendored
@ -35,8 +35,10 @@ class RayEnvWorker(EnvWorker):
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
ray.get(self.env.set_env_attr.remote(key, value))
|
||||
|
||||
def reset(self) -> Any:
|
||||
return ray.get(self.env.reset.remote())
|
||||
def reset(self, **kwargs: Any) -> Any:
|
||||
if "seed" in kwargs:
|
||||
super().seed(kwargs["seed"])
|
||||
return ray.get(self.env.reset.remote(**kwargs))
|
||||
|
||||
@staticmethod
|
||||
def wait( # type: ignore
|
||||
@ -46,10 +48,10 @@ class RayEnvWorker(EnvWorker):
|
||||
ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout)
|
||||
return [workers[results.index(result)] for result in ready_results]
|
||||
|
||||
def send(self, action: Optional[np.ndarray]) -> None:
|
||||
# self.action is actually a handle
|
||||
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
|
||||
# self.result is actually a handle
|
||||
if action is None:
|
||||
self.result = self.env.reset.remote()
|
||||
self.result = self.env.reset.remote(**kwargs)
|
||||
else:
|
||||
self.result = self.env.step.remote(action)
|
||||
|
||||
@ -58,9 +60,13 @@ class RayEnvWorker(EnvWorker):
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||
return ray.get(self.result) # type: ignore
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> List[int]:
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
super().seed(seed)
|
||||
return ray.get(self.env.seed.remote(seed))
|
||||
try:
|
||||
return ray.get(self.env.seed.remote(seed))
|
||||
except NotImplementedError:
|
||||
self.env.reset.remote(seed=seed)
|
||||
return None
|
||||
|
||||
def render(self, **kwargs: Any) -> Any:
|
||||
return ray.get(self.env.render.remote(**kwargs))
|
||||
|
65
tianshou/env/worker/subproc.py
vendored
65
tianshou/env/worker/subproc.py
vendored
@ -86,17 +86,27 @@ def _worker(
|
||||
p.close()
|
||||
break
|
||||
if cmd == "step":
|
||||
if data is None: # reset
|
||||
obs = env.reset()
|
||||
else:
|
||||
obs, reward, done, info = env.step(data)
|
||||
obs, reward, done, info = env.step(data)
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
if data is None:
|
||||
p.send(obs)
|
||||
p.send((obs, reward, done, info))
|
||||
elif cmd == "reset":
|
||||
retval = env.reset(**data)
|
||||
reset_returns_info = isinstance(
|
||||
retval, (tuple, list)
|
||||
) and len(retval) == 2 and isinstance(retval[1], dict)
|
||||
if reset_returns_info:
|
||||
obs, info = retval
|
||||
else:
|
||||
p.send((obs, reward, done, info))
|
||||
obs = retval
|
||||
if obs_bufs is not None:
|
||||
_encode_obs(obs, obs_bufs)
|
||||
obs = None
|
||||
if reset_returns_info:
|
||||
p.send((obs, info))
|
||||
else:
|
||||
p.send(obs)
|
||||
elif cmd == "close":
|
||||
p.send(env.close())
|
||||
p.close()
|
||||
@ -104,7 +114,11 @@ def _worker(
|
||||
elif cmd == "render":
|
||||
p.send(env.render(**data) if hasattr(env, "render") else None)
|
||||
elif cmd == "seed":
|
||||
p.send(env.seed(data) if hasattr(env, "seed") else None)
|
||||
if hasattr(env, "seed"):
|
||||
p.send(env.seed(data))
|
||||
else:
|
||||
env.reset(seed=data)
|
||||
p.send(None)
|
||||
elif cmd == "getattr":
|
||||
p.send(getattr(env, data) if hasattr(env, data) else None)
|
||||
elif cmd == "setattr":
|
||||
@ -140,7 +154,6 @@ class SubprocEnvWorker(EnvWorker):
|
||||
self.process = Process(target=_worker, args=args, daemon=True)
|
||||
self.process.start()
|
||||
self.child_remote.close()
|
||||
self.is_reset = False
|
||||
super().__init__(env_fn)
|
||||
|
||||
def get_env_attr(self, key: str) -> Any:
|
||||
@ -186,14 +199,25 @@ class SubprocEnvWorker(EnvWorker):
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
|
||||
def send(self, action: Optional[np.ndarray]) -> None:
|
||||
self.parent_remote.send(["step", action])
|
||||
def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None:
|
||||
if action is None:
|
||||
if "seed" in kwargs:
|
||||
super().seed(kwargs["seed"])
|
||||
self.parent_remote.send(["reset", kwargs])
|
||||
else:
|
||||
self.parent_remote.send(["step", action])
|
||||
|
||||
def recv(
|
||||
self
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]:
|
||||
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[
|
||||
np.ndarray, dict], np.ndarray]: # noqa:E125
|
||||
result = self.parent_remote.recv()
|
||||
if isinstance(result, tuple):
|
||||
if len(result) == 2:
|
||||
obs, info = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs, info
|
||||
obs, rew, done, info = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
@ -204,6 +228,23 @@ class SubprocEnvWorker(EnvWorker):
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
|
||||
if "seed" in kwargs:
|
||||
super().seed(kwargs["seed"])
|
||||
self.parent_remote.send(["reset", kwargs])
|
||||
|
||||
result = self.parent_remote.recv()
|
||||
if isinstance(result, tuple):
|
||||
obs, info = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs, info
|
||||
else:
|
||||
obs = result
|
||||
if self.share_memory:
|
||||
obs = self._decode_obs()
|
||||
return obs
|
||||
|
||||
def seed(self, seed: Optional[int] = None) -> Optional[List[int]]:
|
||||
super().seed(seed)
|
||||
self.parent_remote.send(["seed", seed])
|
||||
|
Loading…
x
Reference in New Issue
Block a user