diff --git a/examples/atari/atari_wrapper.py b/examples/atari/atari_wrapper.py index a9a3f82..293c1ce 100644 --- a/examples/atari/atari_wrapper.py +++ b/examples/atari/atari_wrapper.py @@ -32,7 +32,10 @@ class NoopResetEnv(gym.Wrapper): def reset(self): self.env.reset() - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + if hasattr(self.unwrapped.np_random, "integers"): + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) + else: + noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) for _ in range(noops): obs, _, done, _ = self.env.step(self.noop_action) if done: diff --git a/setup.py b/setup.py index 6a1367e..6125ce2 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ def get_version() -> str: def get_install_requires() -> str: return [ - "gym>=0.15.4", + "gym>=0.23.1", "tqdm", "numpy>1.16.0", # https://github.com/numpy/numpy/issues/12793 "tensorboard>=2.5.0", diff --git a/test/base/env.py b/test/base/env.py index 0a649e5..e29f7ff 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -79,11 +79,16 @@ class MyTestEnv(gym.Env): self.rng = np.random.RandomState(seed) return [seed] - def reset(self, state=0): + def reset(self, state=0, seed=None, return_info=False): + if seed is not None: + self.rng = np.random.RandomState(seed) self.done = False self.do_sleep() self.index = state - return self._get_state() + if return_info: + return self._get_state(), {'key': 1, 'env': self} + else: + return self._get_state() def _get_reward(self): """Generate a non-scalar reward if ma_rew is True.""" diff --git a/test/base/test_collector.py b/test/base/test_collector.py index 9a8d749..56c5b11 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -15,6 +15,11 @@ from tianshou.data import ( from tianshou.env import DummyVectorEnv, SubprocVectorEnv from tianshou.policy import BasePolicy +try: + import envpool +except ImportError: + envpool = None + if __name__ == '__main__': from env import MyTestEnv, NXEnv else: # pytest @@ -23,7 +28,7 @@ else: # pytest class MyPolicy(BasePolicy): - def __init__(self, dict_state=False, need_state=True): + def __init__(self, dict_state=False, need_state=True, action_shape=None): """ :param bool dict_state: if the observation of the environment is a dict :param bool need_state: if the policy needs the hidden state (for RNN) @@ -31,6 +36,7 @@ class MyPolicy(BasePolicy): super().__init__() self.dict_state = dict_state self.need_state = need_state + self.action_shape = action_shape def forward(self, batch, state=None): if self.need_state: @@ -39,8 +45,12 @@ class MyPolicy(BasePolicy): else: state += 1 if self.dict_state: - return Batch(act=np.ones(len(batch.obs['index'])), state=state) - return Batch(act=np.ones(len(batch.obs)), state=state) + action_shape = self.action_shape if self.action_shape else len( + batch.obs['index'] + ) + return Batch(act=np.ones(action_shape), state=state) + action_shape = self.action_shape if self.action_shape else len(batch.obs) + return Batch(act=np.ones(action_shape), state=state) def learn(self): pass @@ -77,7 +87,8 @@ class Logger: return Batch() -def test_collector(): +@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)]) +def test_collector(gym_reset_kwargs): writer = SummaryWriter('log/collector') logger = Logger(writer) env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] @@ -86,52 +97,102 @@ def test_collector(): dum = DummyVectorEnv(env_fns) policy = MyPolicy() env = env_fns[0]() - c0 = Collector(policy, env, ReplayBuffer(size=100), logger.preprocess_fn) - c0.collect(n_step=3) + c0 = Collector( + policy, + env, + ReplayBuffer(size=100), + logger.preprocess_fn, + ) + c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs) assert len(c0.buffer) == 3 assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) - c0.collect(n_episode=3) + keys = np.zeros(100) + keys[:3] = 1 + assert np.allclose(c0.buffer.info["key"], keys) + for e in c0.buffer.info["env"][:3]: + assert isinstance(e, MyTestEnv) + assert np.allclose(c0.buffer.info["env_id"], 0) + rews = np.zeros(100) + rews[:3] = [0, 1, 0] + assert np.allclose(c0.buffer.info["rew"], rews) + c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs) assert len(c0.buffer) == 8 assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) - c0.collect(n_step=3, random=True) + assert np.allclose(c0.buffer.info["key"][:8], 1) + for e in c0.buffer.info["env"][:8]: + assert isinstance(e, MyTestEnv) + assert np.allclose(c0.buffer.info["env_id"][:8], 0) + assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1]) + c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs) + c1 = Collector( policy, venv, VectorReplayBuffer(total_size=100, buffer_num=4), logger.preprocess_fn ) - c1.collect(n_step=8) + c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs) obs = np.zeros(100) - obs[[0, 1, 25, 26, 50, 51, 75, 76]] = [0, 1, 0, 1, 0, 1, 0, 1] - + valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] + obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) - c1.collect(n_episode=4) + keys = np.zeros(100) + keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] + assert np.allclose(c1.buffer.info["key"], keys) + for e in c1.buffer.info["env"][valid_indices]: + assert isinstance(e, MyTestEnv) + env_ids = np.zeros(100) + env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3] + assert np.allclose(c1.buffer.info["env_id"], env_ids) + rews = np.zeros(100) + rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0] + assert np.allclose(c1.buffer.info["rew"], rews) + c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs) assert len(c1.buffer) == 16 + valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose( c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5] ) - c1.collect(n_episode=4, random=True) + keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] + assert np.allclose(c1.buffer.info["key"], keys) + for e in c1.buffer.info["env"][valid_indices]: + assert isinstance(e, MyTestEnv) + env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3] + assert np.allclose(c1.buffer.info["env_id"], env_ids) + rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1] + assert np.allclose(c1.buffer.info["rew"], rews) + c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) + c2 = Collector( policy, dum, VectorReplayBuffer(total_size=100, buffer_num=4), logger.preprocess_fn ) - c2.collect(n_episode=7) + c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs) obs1 = obs.copy() obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] obs2 = obs.copy() obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] c2obs = c2.buffer.obs[:, 0] assert np.all(c2obs == obs1) or np.all(c2obs == obs2) - c2.reset_env() + c2.reset_env(gym_reset_kwargs=gym_reset_kwargs) c2.reset_buffer() - assert c2.collect(n_episode=8)['n/ep'] == 8 - obs[[4, 5, 28, 29, 30, 54, 55, 56, 57]] = [0, 1, 0, 1, 2, 0, 1, 2, 3] + assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs)['n/ep'] == 8 + valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] + obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] assert np.all(c2.buffer.obs[:, 0] == obs) - c2.collect(n_episode=4, random=True) + keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1] + assert np.allclose(c2.buffer.info["key"], keys) + for e in c2.buffer.info["env"][valid_indices]: + assert isinstance(e, MyTestEnv) + env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2] + assert np.allclose(c2.buffer.info["env_id"], env_ids) + rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1] + assert np.allclose(c2.buffer.info["rew"], rews) + c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) # test corner case with pytest.raises(TypeError): @@ -147,11 +208,12 @@ def test_collector(): [lambda i=x: NXEnv(i, obs_type) for x in [5, 10, 15, 20]] ) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) - c3.collect(n_step=6) + c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs) assert c3.buffer.obs.dtype == object -def test_collector_with_async(): +@pytest.mark.parametrize("gym_reset_kwargs", [None, dict(return_info=True)]) +def test_collector_with_async(gym_reset_kwargs): env_lens = [2, 3, 4, 5] writer = SummaryWriter('log/async_collector') logger = Logger(writer) @@ -163,12 +225,14 @@ def test_collector_with_async(): policy = MyPolicy() bufsize = 60 c1 = AsyncCollector( - policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), - logger.preprocess_fn + policy, + venv, + VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), + logger.preprocess_fn, ) ptr = [0, 0, 0, 0] for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): - result = c1.collect(n_episode=n_episode) + result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) assert result["n/ep"] >= n_episode # check buffer data, obs and obs_next, env_id for i, count in enumerate(np.bincount(result["lens"], minlength=6)[2:]): @@ -183,7 +247,7 @@ def test_collector_with_async(): assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) # test async n_step, for now the buffer should be full of data for n_step in tqdm.trange(1, 15, desc="test async n_step"): - result = c1.collect(n_step=n_step) + result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) assert result["n/st"] >= n_step for i in range(4): env_len = i + 2 @@ -618,9 +682,29 @@ def test_collector_with_atari_setting(): assert np.allclose(result2[key], result[key]) +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") +def test_collector_envpool_gym_reset_return_info(): + envs = envpool.make_gym("Pendulum-v0", num_envs=4, gym_reset_return_info=True) + policy = MyPolicy(action_shape=(len(envs), 1)) + + c0 = Collector( + policy, + envs, + VectorReplayBuffer(len(envs) * 10, len(envs)), + exploration_noise=True + ) + c0.collect(n_step=8) + env_ids = np.zeros(len(envs) * 10) + env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3] + assert np.allclose(c0.buffer.info["env_id"], env_ids) + + if __name__ == '__main__': - test_collector() + test_collector(gym_reset_kwargs=None) + test_collector(gym_reset_kwargs=dict(return_info=True)) test_collector_with_dict_state() test_collector_with_ma() test_collector_with_atari_setting() - test_collector_with_async() + test_collector_with_async(gym_reset_kwargs=None) + test_collector_with_async(gym_reset_kwargs=dict(return_info=True)) + test_collector_envpool_gym_reset_return_info() diff --git a/test/base/test_env.py b/test/base/test_env.py index 002799b..87284c7 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -222,6 +222,18 @@ def test_env_obs_dtype(): assert obs.dtype == object +def test_env_reset_optional_kwargs(size=10000, num=8): + env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)] + test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] + if has_ray(): + test_cls += [RayVectorEnv] + for cls in test_cls: + v = cls(env_fns, wait_num=num // 2, timeout=1e-3) + _, info = v.reset(seed=1, return_info=True) + assert len(info) == len(env_fns) + assert isinstance(info[0], dict) + + def run_align_norm_obs(raw_env, train_env, test_env, action_list): eps = np.finfo(np.float32).eps.item() raw_obs, train_obs = [raw_env.reset()], [train_env.reset()] @@ -319,11 +331,25 @@ def test_venv_wrapper_envpool(): run_align_norm_obs(raw, train, test, actions) -if __name__ == "__main__": +@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") +def test_venv_wrapper_envpool_gym_reset_return_info(): + num_envs = 4 + env = VectorEnvNormObs( + envpool.make_gym("Ant-v3", num_envs=num_envs, gym_reset_return_info=True) + ) + obs, info = env.reset() + assert obs.shape[0] == num_envs + for _, v in info.items(): + if not isinstance(v, dict): + assert v.shape[0] == num_envs + + +if __name__ == '__main__': test_venv_norm_obs() test_venv_wrapper_envpool() test_env_obs_dtype() test_vecenv() test_async_env() test_async_check_id() + test_env_reset_optional_kwargs() test_gym_wrappers() diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 537986c..ce96bcf 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -100,18 +100,24 @@ class Collector(object): ) self.buffer = buffer - def reset(self, reset_buffer: bool = True) -> None: + def reset( + self, + reset_buffer: bool = True, + gym_reset_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: """Reset the environment, statistics, current data and possibly replay memory. :param bool reset_buffer: if true, reset the replay buffer that is attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) """ # use empty Batch for "state" so that self.data supports slicing # convert empty Batch to None when passing data to policy self.data = Batch( obs={}, act={}, rew={}, done={}, obs_next={}, info={}, policy={} ) - self.reset_env() + self.reset_env(gym_reset_kwargs) if reset_buffer: self.reset_buffer() self.reset_stat() @@ -124,12 +130,27 @@ class Collector(object): """Reset the data buffer.""" self.buffer.reset(keep_statistics=keep_statistics) - def reset_env(self) -> None: + def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: """Reset all of the environments.""" - obs = self.env.reset() - if self.preprocess_fn: - obs = self.preprocess_fn(obs=obs, - env_id=np.arange(self.env_num)).get("obs", obs) + gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} + rval = self.env.reset(**gym_reset_kwargs) + returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( + isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore + ) + if returns_info: + obs, info = rval + if self.preprocess_fn: + processed_data = self.preprocess_fn( + obs=obs, info=info, env_id=np.arange(self.env_num) + ) + obs = processed_data.get("obs", obs) + info = processed_data.get("info", info) + self.data.info = info + else: + obs = rval + if self.preprocess_fn: + obs = self.preprocess_fn(obs=obs, env_id=np.arange(self.env_num + )).get("obs", obs) self.data.obs = obs def _reset_state(self, id: Union[int, List[int]]) -> None: @@ -143,6 +164,33 @@ class Collector(object): elif isinstance(state, Batch): state.empty_(id) + def _reset_env_with_ids( + self, + local_ids: Union[List[int], np.ndarray], + global_ids: Union[List[int], np.ndarray], + gym_reset_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} + rval = self.env.reset(global_ids, **gym_reset_kwargs) + returns_info = isinstance(rval, (tuple, list)) and len(rval) == 2 and ( + isinstance(rval[1], dict) or isinstance(rval[1][0], dict) # type: ignore + ) + if returns_info: + obs_reset, info = rval + if self.preprocess_fn: + processed_data = self.preprocess_fn( + obs=obs_reset, info=info, env_id=global_ids + ) + obs_reset = processed_data.get("obs", obs_reset) + info = processed_data.get("info", info) + self.data.info[local_ids] = info + else: + obs_reset = rval + if self.preprocess_fn: + obs_reset = self.preprocess_fn(obs=obs_reset, env_id=global_ids + ).get("obs", obs_reset) + self.data.obs_next[local_ids] = obs_reset + def collect( self, n_step: Optional[int] = None, @@ -150,6 +198,7 @@ class Collector(object): random: bool = False, render: Optional[float] = None, no_grad: bool = True, + gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Collect a specified number of step or episode. @@ -165,6 +214,8 @@ class Collector(object): Default to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) .. note:: @@ -288,12 +339,9 @@ class Collector(object): episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - obs_reset = self.env.reset(env_ind_global) - if self.preprocess_fn: - obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global - ).get("obs", obs_reset) - self.data.obs_next[env_ind_local] = obs_reset + self._reset_env_with_ids( + env_ind_local, env_ind_global, gym_reset_kwargs + ) for i in env_ind_local: self._reset_state(i) @@ -367,10 +415,16 @@ class AsyncCollector(Collector): ) -> None: # assert env.is_async warnings.warn("Using async setting may collect extra transitions into buffer.") - super().__init__(policy, env, buffer, preprocess_fn, exploration_noise) + super().__init__( + policy, + env, + buffer, + preprocess_fn, + exploration_noise, + ) - def reset_env(self) -> None: - super().reset_env() + def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None: + super().reset_env(gym_reset_kwargs) self._ready_env_ids = np.arange(self.env_num) def collect( @@ -380,6 +434,7 @@ class AsyncCollector(Collector): random: bool = False, render: Optional[float] = None, no_grad: bool = True, + gym_reset_kwargs: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: """Collect a specified number of step or episode with async env setting. @@ -395,6 +450,8 @@ class AsyncCollector(Collector): Default to None (no rendering). :param bool no_grad: whether to retain gradient in policy.forward(). Default to True (no gradient retaining). + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) .. note:: @@ -528,12 +585,9 @@ class AsyncCollector(Collector): episode_start_indices.append(ep_idx[env_ind_local]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - obs_reset = self.env.reset(env_ind_global) - if self.preprocess_fn: - obs_reset = self.preprocess_fn( - obs=obs_reset, env_id=env_ind_global - ).get("obs", obs_reset) - self.data.obs_next[env_ind_local] = obs_reset + self._reset_env_with_ids( + env_ind_local, env_ind_global, gym_reset_kwargs + ) for i in env_ind_local: self._reset_state(i) diff --git a/tianshou/env/pettingzoo_env.py b/tianshou/env/pettingzoo_env.py index c406872..1722dc5 100644 --- a/tianshou/env/pettingzoo_env.py +++ b/tianshou/env/pettingzoo_env.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union import gym.spaces from pettingzoo.utils.env import AECEnv @@ -55,11 +55,11 @@ class PettingZooEnv(AECEnv, ABC): self.reset() - def reset(self, *args: Any, **kwargs: Any) -> dict: + def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]: self.env.reset(*args, **kwargs) - observation = self.env.observe(self.env.agent_selection) + observation, _, _, info = self.env.last(self) if isinstance(observation, dict) and 'action_mask' in observation: - return { + observation_dict = { 'agent_id': self.env.agent_selection, 'obs': observation['observation'], 'mask': @@ -67,13 +67,21 @@ class PettingZooEnv(AECEnv, ABC): } else: if isinstance(self.action_space, gym.spaces.Discrete): - return { + observation_dict = { 'agent_id': self.env.agent_selection, 'obs': observation, 'mask': [True] * self.env.action_space(self.env.agent_selection).n } else: - return {'agent_id': self.env.agent_selection, 'obs': observation} + observation_dict = { + 'agent_id': self.env.agent_selection, + 'obs': observation, + } + + if "return_info" in kwargs and kwargs["return_info"]: + return observation_dict, info + else: + return observation_dict def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]: self.env.step(action) diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 860c390..bb5e294 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -37,11 +37,12 @@ class VectorEnvWrapper(BaseVectorEnv): ) -> None: return self.venv.set_env_attr(key, value, id) - # TODO: compatible issue with reset -> (obs, info) def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> np.ndarray: - return self.venv.reset(id) + self, + id: Optional[Union[int, List[int], np.ndarray]] = None, + **kwargs: Any, + ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: + return self.venv.reset(id, **kwargs) def step( self, @@ -86,14 +87,33 @@ class VectorEnvNormObs(VectorEnvWrapper): self.clip_max = clip_obs self.eps = epsilon - # TODO: compatible issue with reset -> (obs, info) def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> np.ndarray: - obs = self.venv.reset(id) + self, + id: Optional[Union[int, List[int], np.ndarray]] = None, + **kwargs: Any, + ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: + retval = self.venv.reset(id, **kwargs) + reset_returns_info = isinstance( + retval, (tuple, list) + ) and len(retval) == 2 and isinstance(retval[1], dict) + if reset_returns_info: + obs, info = retval + else: + obs = retval + + if isinstance(obs, tuple): + raise TypeError( + "Tuple observation space is not supported. ", + "Please change it to array or dict space", + ) + if self.obs_rms and self.update_obs_rms: self.obs_rms.update(obs) - return self._norm_obs(obs) + obs = self._norm_obs(obs) + if reset_returns_info: + return obs, info + else: + return obs def step( self, diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index 93558d9..1f12d3f 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -181,10 +181,11 @@ class BaseVectorEnv(object): assert i in self.ready_id, \ f"Can only interact with ready environments {self.ready_id}." - # TODO: compatible issue with reset -> (obs, info) def reset( - self, id: Optional[Union[int, List[int], np.ndarray]] = None - ) -> np.ndarray: + self, + id: Optional[Union[int, List[int], np.ndarray]] = None, + **kwargs: Any, + ) -> Union[np.ndarray, Tuple[np.ndarray, List[dict]]]: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return @@ -195,15 +196,35 @@ class BaseVectorEnv(object): id = self._wrap_id(id) if self.is_async: self._assert_id(id) + # send(None) == reset() in worker for i in id: - self.workers[i].send(None) - obs_list = [self.workers[i].recv() for i in id] + self.workers[i].send(None, **kwargs) + ret_list = [self.workers[i].recv() for i in id] + + reset_returns_info = isinstance(ret_list[0], (tuple, list)) and len( + ret_list[0] + ) == 2 and isinstance(ret_list[0][1], dict) + if reset_returns_info: + obs_list = [r[0] for r in ret_list] + else: + obs_list = ret_list + + if isinstance(obs_list[0], tuple): + raise TypeError( + "Tuple observation space is not supported. ", + "Please change it to array or dict space", + ) try: obs = np.stack(obs_list) except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - return obs + + if reset_returns_info: + infos = [r[1] for r in ret_list] + return obs, infos # type: ignore + else: + return obs def step( self, @@ -248,7 +269,7 @@ class BaseVectorEnv(object): self.workers[j].send(action[i]) result = [] for j in id: - obs, rew, done, info = self.workers[j].recv() + obs, rew, done, info = self.workers[j].recv() # type: ignore info["env_id"] = j result.append((obs, rew, done, info)) else: @@ -270,7 +291,7 @@ class BaseVectorEnv(object): waiting_index = self.waiting_conn.index(conn) self.waiting_conn.pop(waiting_index) env_id = self.waiting_id.pop(waiting_index) - obs, rew, done, info = conn.recv() + obs, rew, done, info = conn.recv() # type: ignore info["env_id"] = env_id result.append((obs, rew, done, info)) self.ready_id.append(env_id) diff --git a/tianshou/env/worker/base.py b/tianshou/env/worker/base.py index b861a15..3ea46d7 100644 --- a/tianshou/env/worker/base.py +++ b/tianshou/env/worker/base.py @@ -14,7 +14,7 @@ class EnvWorker(ABC): self._env_fn = env_fn self.is_closed = False self.result: Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], - np.ndarray] + Tuple[np.ndarray, dict], np.ndarray] self.action_space = self.get_env_attr("action_space") # noqa: B009 self.is_reset = False @@ -47,7 +47,8 @@ class EnvWorker(ABC): def recv( self - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[ + np.ndarray, dict], np.ndarray]: # noqa:E125 """Receive result from low-level worker. If the last "send" function sends a NULL action, it only returns a @@ -63,9 +64,9 @@ class EnvWorker(ABC): self.result = self.get_result() # type: ignore return self.result - def reset(self) -> np.ndarray: - self.send(None) - return self.recv() # type: ignore + @abstractmethod + def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + pass def step( self, action: np.ndarray diff --git a/tianshou/env/worker/dummy.py b/tianshou/env/worker/dummy.py index be87386..58a2fc3 100644 --- a/tianshou/env/worker/dummy.py +++ b/tianshou/env/worker/dummy.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Tuple, Union import gym import numpy as np @@ -19,8 +19,10 @@ class DummyEnvWorker(EnvWorker): def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env, key, value) - def reset(self) -> Any: - return self.env.reset() + def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + return self.env.reset(**kwargs) @staticmethod def wait( # type: ignore @@ -29,15 +31,19 @@ class DummyEnvWorker(EnvWorker): # Sequential EnvWorker objects are always ready return workers - def send(self, action: Optional[np.ndarray]) -> None: + def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: if action is None: - self.result = self.env.reset() # type: ignore + self.result = self.env.reset(**kwargs) else: self.result = self.env.step(action) # type: ignore - def seed(self, seed: Optional[int] = None) -> List[int]: + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) - return self.env.seed(seed) + try: + return self.env.seed(seed) + except NotImplementedError: + self.env.reset(seed=seed) + return [seed] # type: ignore def render(self, **kwargs: Any) -> Any: return self.env.render(**kwargs) diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index e094692..055fd7a 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -35,8 +35,10 @@ class RayEnvWorker(EnvWorker): def set_env_attr(self, key: str, value: Any) -> None: ray.get(self.env.set_env_attr.remote(key, value)) - def reset(self) -> Any: - return ray.get(self.env.reset.remote()) + def reset(self, **kwargs: Any) -> Any: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + return ray.get(self.env.reset.remote(**kwargs)) @staticmethod def wait( # type: ignore @@ -46,10 +48,10 @@ class RayEnvWorker(EnvWorker): ready_results, _ = ray.wait(results, num_returns=wait_num, timeout=timeout) return [workers[results.index(result)] for result in ready_results] - def send(self, action: Optional[np.ndarray]) -> None: - # self.action is actually a handle + def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: + # self.result is actually a handle if action is None: - self.result = self.env.reset.remote() + self.result = self.env.reset.remote(**kwargs) else: self.result = self.env.step.remote(action) @@ -58,9 +60,13 @@ class RayEnvWorker(EnvWorker): ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: return ray.get(self.result) # type: ignore - def seed(self, seed: Optional[int] = None) -> List[int]: + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) - return ray.get(self.env.seed.remote(seed)) + try: + return ray.get(self.env.seed.remote(seed)) + except NotImplementedError: + self.env.reset.remote(seed=seed) + return None def render(self, **kwargs: Any) -> Any: return ray.get(self.env.render.remote(**kwargs)) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index c2119ab..8c91b31 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -86,17 +86,27 @@ def _worker( p.close() break if cmd == "step": - if data is None: # reset - obs = env.reset() - else: - obs, reward, done, info = env.step(data) + obs, reward, done, info = env.step(data) if obs_bufs is not None: _encode_obs(obs, obs_bufs) obs = None - if data is None: - p.send(obs) + p.send((obs, reward, done, info)) + elif cmd == "reset": + retval = env.reset(**data) + reset_returns_info = isinstance( + retval, (tuple, list) + ) and len(retval) == 2 and isinstance(retval[1], dict) + if reset_returns_info: + obs, info = retval else: - p.send((obs, reward, done, info)) + obs = retval + if obs_bufs is not None: + _encode_obs(obs, obs_bufs) + obs = None + if reset_returns_info: + p.send((obs, info)) + else: + p.send(obs) elif cmd == "close": p.send(env.close()) p.close() @@ -104,7 +114,11 @@ def _worker( elif cmd == "render": p.send(env.render(**data) if hasattr(env, "render") else None) elif cmd == "seed": - p.send(env.seed(data) if hasattr(env, "seed") else None) + if hasattr(env, "seed"): + p.send(env.seed(data)) + else: + env.reset(seed=data) + p.send(None) elif cmd == "getattr": p.send(getattr(env, data) if hasattr(env, data) else None) elif cmd == "setattr": @@ -140,7 +154,6 @@ class SubprocEnvWorker(EnvWorker): self.process = Process(target=_worker, args=args, daemon=True) self.process.start() self.child_remote.close() - self.is_reset = False super().__init__(env_fn) def get_env_attr(self, key: str) -> Any: @@ -186,14 +199,25 @@ class SubprocEnvWorker(EnvWorker): remain_conns = [conn for conn in remain_conns if conn not in ready_conns] return [workers[conns.index(con)] for con in ready_conns] - def send(self, action: Optional[np.ndarray]) -> None: - self.parent_remote.send(["step", action]) + def send(self, action: Optional[np.ndarray], **kwargs: Any) -> None: + if action is None: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + self.parent_remote.send(["reset", kwargs]) + else: + self.parent_remote.send(["step", action]) def recv( self - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], np.ndarray]: + ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], Tuple[ + np.ndarray, dict], np.ndarray]: # noqa:E125 result = self.parent_remote.recv() if isinstance(result, tuple): + if len(result) == 2: + obs, info = result + if self.share_memory: + obs = self._decode_obs() + return obs, info obs, rew, done, info = result if self.share_memory: obs = self._decode_obs() @@ -204,6 +228,23 @@ class SubprocEnvWorker(EnvWorker): obs = self._decode_obs() return obs + def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]: + if "seed" in kwargs: + super().seed(kwargs["seed"]) + self.parent_remote.send(["reset", kwargs]) + + result = self.parent_remote.recv() + if isinstance(result, tuple): + obs, info = result + if self.share_memory: + obs = self._decode_obs() + return obs, info + else: + obs = result + if self.share_memory: + obs = self._decode_obs() + return obs + def seed(self, seed: Optional[int] = None) -> Optional[List[int]]: super().seed(seed) self.parent_remote.send(["seed", seed])