From 4f65b131aa7a2fdd0cc099d8dcfccf29d10e74f3 Mon Sep 17 00:00:00 2001 From: bordeauxred <2robert.mueller@gmail.com> Date: Thu, 28 Mar 2024 18:02:31 +0100 Subject: [PATCH 1/5] Feat/refactor collector (#1063) Closes: #1058 ### Api Extensions - Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 - `Collector`s can now be closed, and their reset is more granular. #1063 - Trainers can control whether collectors should be reset prior to training. #1063 - Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 ### Internal Improvements - `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 ### Breaking Changes - Removed `.data` attribute from `Collector` and its child classes. #1063 - Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 --------- Co-authored-by: Michael Panchenko --- CHANGELOG.md | 23 + docs/02_notebooks/L0_overview.ipynb | 2 +- docs/02_notebooks/L5_Collector.ipynb | 5 +- pyproject.toml | 1 + test/base/env.py | 24 +- test/base/test_buffer.py | 12 +- test/base/test_collector.py | 509 +++++++----- test/base/test_env.py | 22 +- test/base/test_env_finite.py | 19 +- test/continuous/test_redq.py | 1 + test/continuous/test_td3.py | 1 + test/discrete/test_a2c_with_il.py | 2 + test/discrete/test_bdq.py | 2 +- test/discrete/test_c51.py | 2 +- test/discrete/test_dqn.py | 2 +- test/discrete/test_drqn.py | 2 +- test/discrete/test_fqf.py | 2 +- test/discrete/test_iqn.py | 2 +- test/discrete/test_qrdqn.py | 2 +- test/discrete/test_rainbow.py | 2 +- test/modelbased/test_dqn_icm.py | 2 +- test/modelbased/test_psrl.py | 3 +- test/offline/gather_cartpole_data.py | 5 +- test/offline/test_discrete_bcq.py | 1 + test/pettingzoo/pistonball.py | 6 +- test/pettingzoo/pistonball_continuous.py | 2 +- test/pettingzoo/tic_tac_toe.py | 6 +- tianshou/data/batch.py | 34 +- tianshou/data/buffer/manager.py | 2 +- tianshou/data/collector.py | 909 ++++++++++++++-------- tianshou/data/utils/converter.py | 1 + tianshou/env/venv_wrappers.py | 16 +- tianshou/env/venvs.py | 26 +- tianshou/highlevel/agent.py | 11 + tianshou/highlevel/experiment.py | 2 +- tianshou/policy/base.py | 12 +- tianshou/policy/modelbased/icm.py | 10 +- tianshou/policy/modelfree/bdq.py | 9 +- tianshou/policy/modelfree/ddpg.py | 9 +- tianshou/policy/modelfree/discrete_sac.py | 11 +- tianshou/policy/modelfree/dqn.py | 9 +- tianshou/policy/multiagent/mapolicy.py | 22 +- tianshou/trainer/base.py | 28 +- tianshou/trainer/utils.py | 3 +- 44 files changed, 1143 insertions(+), 633 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e0ac65..5a37acb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,27 @@ # Changelog +## Release 1.1.0 + +### Api Extensions +- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063 +- `Collector`s can now be closed, and their reset is more granular. #1063 +- Trainers can control whether collectors should be reset prior to training. #1063 +- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063 + +### Internal Improvements +- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063 +- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 +- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 +- Improved typing for `exploration_noise` and within Collector. #1063 + +### Breaking Changes + +- Removed `.data` attribute from `Collector` and its child classes. #1063 +- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset` +expicitly or pass `reset_before_collect=True` . #1063 +- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 +- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 + + Started after v1.0.0 diff --git a/docs/02_notebooks/L0_overview.ipynb b/docs/02_notebooks/L0_overview.ipynb index ac0514b..37cba0b 100644 --- a/docs/02_notebooks/L0_overview.ipynb +++ b/docs/02_notebooks/L0_overview.ipynb @@ -164,7 +164,7 @@ "source": [ "# Let's watch its performance!\n", "policy.eval()\n", - "eval_result = test_collector.collect(n_episode=1, render=False)\n", + "eval_result = test_collector.collect(n_episode=3, render=False)\n", "print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")" ] }, diff --git a/docs/02_notebooks/L5_Collector.ipynb b/docs/02_notebooks/L5_Collector.ipynb index 1053e15..3e91e0f 100644 --- a/docs/02_notebooks/L5_Collector.ipynb +++ b/docs/02_notebooks/L5_Collector.ipynb @@ -119,7 +119,7 @@ }, "outputs": [], "source": [ - "collect_result = test_collector.collect(n_episode=9)\n", + "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n", "\n", "collect_result.pprint_asdict()" ] @@ -146,8 +146,7 @@ "outputs": [], "source": [ "# Reset the collector\n", - "test_collector.reset()\n", - "collect_result = test_collector.collect(n_episode=9, random=True)\n", + "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n", "\n", "collect_result.pprint_asdict()" ] diff --git a/pyproject.toml b/pyproject.toml index d4679f9..7a79500 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -166,6 +166,7 @@ ignore = [ "RET505", "D106", # undocumented public nested class "D205", # blank line after summary (prevents summary-only docstrings, which makes no sense) + "PLW2901", # overwrite vars in loop ] unfixable = [ "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all diff --git a/test/base/env.py b/test/base/env.py index 8a2de26..c05c987 100644 --- a/test/base/env.py +++ b/test/base/env.py @@ -9,13 +9,24 @@ import numpy as np from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple -class MyTestEnv(gym.Env): - """A task for "going right". The task is to go right ``size`` steps.""" +class MoveToRightEnv(gym.Env): + """A task for "going right". The task is to go right ``size`` steps. + + The observation is the current index, and the action is to go left or right. + Action 0 is to go left, and action 1 is to go right. + Taking action 0 at index 0 will keep the index at 0. + Arriving at index ``size`` means the task is done. + In the current implementation, stepping after the task is done is possible, which will + lead the index to be larger than ``size``. + + Index 0 is the starting point. If reset is called with default options, the index will + be reset to 0. + """ def __init__( self, size: int, - sleep: int = 0, + sleep: float = 0.0, dict_state: bool = False, recurse_state: bool = False, ma_rew: int = 0, @@ -74,8 +85,13 @@ class MyTestEnv(gym.Env): def reset( self, seed: int | None = None, + # TODO: passing a dict here doesn't make any sense options: dict[str, Any] | None = None, ) -> tuple[dict[str, Any] | np.ndarray, dict]: + """:param seed: + :param options: the start index is provided in options["state"] + :return: + """ if options is None: options = {"state": 0} super().reset(seed=seed) @@ -188,7 +204,7 @@ class NXEnv(gym.Env): return self._encode_obs(), 1.0, False, False, {} -class MyGoalEnv(MyTestEnv): +class MyGoalEnv(MoveToRightEnv): def __init__(self, *args: Any, **kwargs: Any) -> None: assert ( kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0 diff --git a/test/base/test_buffer.py b/test/base/test_buffer.py index 99154bb..0806a75 100644 --- a/test/base/test_buffer.py +++ b/test/base/test_buffer.py @@ -22,13 +22,13 @@ from tianshou.data import ( from tianshou.data.utils.converter import to_hdf5 if __name__ == "__main__": - from env import MyGoalEnv, MyTestEnv + from env import MoveToRightEnv, MyGoalEnv else: # pytest - from test.base.env import MyGoalEnv, MyTestEnv + from test.base.env import MoveToRightEnv, MyGoalEnv def test_replaybuffer(size=10, bufsize=20) -> None: - env = MyTestEnv(size) + env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize) buf.update(buf) assert str(buf) == buf.__class__.__name__ + "()" @@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None: def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: - env = MyTestEnv(size) + env = MoveToRightEnv(size) buf = ReplayBuffer(bufsize, stack_num=stack_num) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) @@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: def test_priortized_replaybuffer(size=32, bufsize=15) -> None: - env = MyTestEnv(size) + env = MoveToRightEnv(size) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) obs, info = env.reset() @@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None: bufsize = 9 stack_num = 4 cached_num = 3 - env = MyTestEnv(size) + env = MoveToRightEnv(size) # test if CachedReplayBuffer can handle stack_num + ignore_obs_next buf4 = CachedReplayBuffer( ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), diff --git a/test/base/test_collector.py b/test/base/test_collector.py index f7a24a8..6bc1703 100644 --- a/test/base/test_collector.py +++ b/test/base/test_collector.py @@ -2,7 +2,6 @@ import gymnasium as gym import numpy as np import pytest import tqdm -from torch.utils.tensorboard import SummaryWriter from tianshou.data import ( AsyncCollector, @@ -22,12 +21,12 @@ except ImportError: envpool = None if __name__ == "__main__": - from env import MyTestEnv, NXEnv + from env import MoveToRightEnv, NXEnv else: # pytest - from test.base.env import MyTestEnv, NXEnv + from test.base.env import MoveToRightEnv, NXEnv -class MyPolicy(BasePolicy): +class MaxActionPolicy(BasePolicy): def __init__( self, action_space: gym.spaces.Space | None = None, @@ -35,7 +34,9 @@ class MyPolicy(BasePolicy): need_state=True, action_shape=None, ) -> None: - """Mock policy for testing. + """Mock policy for testing, will always return an array of ones of the shape of the action space. + Note that this doesn't make much sense for discrete action space (the output is then intepreted as + logits, meaning all actions would be equally likely). :param action_space: the action space of the environment. If None, a dummy Box space will be used. :param bool dict_state: if the observation of the environment is a dict @@ -63,215 +64,290 @@ class MyPolicy(BasePolicy): pass -class Logger: - def __init__(self, writer) -> None: - self.cnt = 0 - self.writer = writer +def test_collector() -> None: + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] - def preprocess_fn(self, **kwargs): - # modify info before adding into the buffer, and recorded into tfb - # if obs && env_id exist -> reset - # if obs_next/rew/done/info/env_id exist -> normal step - if "rew" in kwargs: - info = kwargs["info"] - info.rew = kwargs["rew"] - if "key" in info: - self.writer.add_scalar("key", np.mean(info.key), global_step=self.cnt) - self.cnt += 1 - return Batch(info=info) - return Batch() - - @staticmethod - def single_preprocess_fn(**kwargs): - # same as above, without tfb - if "rew" in kwargs: - info = kwargs["info"] - info.rew = kwargs["rew"] - return Batch(info=info) - return Batch() - - -@pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector(gym_reset_kwargs) -> None: - writer = SummaryWriter("log/collector") - logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]] - - venv = SubprocVectorEnv(env_fns) - dum = DummyVectorEnv(env_fns) - policy = MyPolicy() - env = env_fns[0]() - c0 = Collector( + subproc_venv_4_envs = SubprocVectorEnv(env_fns) + dummy_venv_4_envs = DummyVectorEnv(env_fns) + policy = MaxActionPolicy() + single_env = env_fns[0]() + c_single_env = Collector( policy, - env, + single_env, ReplayBuffer(size=100), - logger.preprocess_fn, ) - c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs) - assert len(c0.buffer) == 3 - assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) + c_single_env.reset() + c_single_env.collect(n_step=3) + assert len(c_single_env.buffer) == 3 + # TODO: direct attr access is an arcane way of using the buffer, it should be never done + # The placeholders for entries are all zeros, so buffer.obs is an array filled with 3 + # observations, and 97 zeros. + # However, buffer[:] will have all attributes with length three... The non-filled entries are removed there + + # See above. For the single env, we start with obs=0, obs_next=1. + # We move to obs=1, obs_next=2, + # then the env is reset and we move to obs=0 + # Making one more step results in obs_next=1 + # The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access + assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0]) + assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1]) keys = np.zeros(100) keys[:3] = 1 - assert np.allclose(c0.buffer.info["key"], keys) - for e in c0.buffer.info["env"][:3]: - assert isinstance(e, MyTestEnv) - assert np.allclose(c0.buffer.info["env_id"], 0) + assert np.allclose(c_single_env.buffer.info["key"], keys) + for e in c_single_env.buffer.info["env"][:3]: + assert isinstance(e, MoveToRightEnv) + assert np.allclose(c_single_env.buffer.info["env_id"], 0) rews = np.zeros(100) rews[:3] = [0, 1, 0] - assert np.allclose(c0.buffer.info["rew"], rews) - c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs) - assert len(c0.buffer) == 8 - assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) - assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) - assert np.allclose(c0.buffer.info["key"][:8], 1) - for e in c0.buffer.info["env"][:8]: - assert isinstance(e, MyTestEnv) - assert np.allclose(c0.buffer.info["env_id"][:8], 0) - assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1]) - c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs) + assert np.allclose(c_single_env.buffer.rew, rews) + # At this point, the buffer contains obs 0 -> 1 -> 0 - c1 = Collector( + # At start we have 3 entries in the buffer + # We collect 3 episodes, in addition to the transitions we have collected before + # 0 -> 1 -> 0 -> 0 (reset at collection start) -> 1 -> done (0) -> 1 -> done(0) + # obs_next: 1 -> 2 -> 1 -> 1 (reset at collection start) -> 2 -> 1 -> 2 -> 1 -> 2 + # In total, we will have 3 + 6 = 9 entries in the buffer + c_single_env.collect(n_episode=3) + assert len(c_single_env.buffer) == 8 + assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0]) + assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c_single_env.buffer.info["key"][:8], 1) + for e in c_single_env.buffer.info["env"][:8]: + assert isinstance(e, MoveToRightEnv) + assert np.allclose(c_single_env.buffer.info["env_id"][:8], 0) + assert np.allclose(c_single_env.buffer.rew[:8], [0, 1, 0, 1, 0, 1, 0, 1]) + c_single_env.collect(n_step=3, random=True) + + c_subproc_venv_4_envs = Collector( policy, - venv, + subproc_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn, ) - c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs) + c_subproc_venv_4_envs.reset() + + # Collect some steps + c_subproc_venv_4_envs.collect(n_step=8) obs = np.zeros(100) valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] - assert np.allclose(c1.buffer.obs[:, 0], obs) - assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) + assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) + assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) keys = np.zeros(100) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] - assert np.allclose(c1.buffer.info["key"], keys) - for e in c1.buffer.info["env"][valid_indices]: - assert isinstance(e, MyTestEnv) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) + for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]: + assert isinstance(e, MoveToRightEnv) env_ids = np.zeros(100) env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3] - assert np.allclose(c1.buffer.info["env_id"], env_ids) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids) rews = np.zeros(100) rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0] - assert np.allclose(c1.buffer.info["rew"], rews) - c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs) - assert len(c1.buffer) == 16 + assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews) + + # we previously collected 8 steps, 2 from each env, now we collect 4 episodes + # each env will contribute an episode, which will be of lens 2 (first env was reset), 1, 2, 3 + # So we get 8 + 2+1+2+3 = 16 steps + c_subproc_venv_4_envs.collect(n_episode=4) + assert len(c_subproc_venv_4_envs.buffer) == 16 + valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] - obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] - assert np.allclose(c1.buffer.obs[:, 0], obs) + obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4] + assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs) assert np.allclose( - c1.buffer[:].obs_next[..., 0], + c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], ) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] - assert np.allclose(c1.buffer.info["key"], keys) - for e in c1.buffer.info["env"][valid_indices]: - assert isinstance(e, MyTestEnv) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys) + for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]: + assert isinstance(e, MoveToRightEnv) env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3] - assert np.allclose(c1.buffer.info["env_id"], env_ids) + assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids) rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1] - assert np.allclose(c1.buffer.info["rew"], rews) - c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) + assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews) + c_subproc_venv_4_envs.collect(n_episode=4, random=True) - c2 = Collector( + c_dummy_venv_4_envs = Collector( policy, - dum, + dummy_venv_4_envs, VectorReplayBuffer(total_size=100, buffer_num=4), - logger.preprocess_fn, ) - c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs) + c_dummy_venv_4_envs.reset() + c_dummy_venv_4_envs.collect(n_episode=7) obs1 = obs.copy() obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] obs2 = obs.copy() obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] - c2obs = c2.buffer.obs[:, 0] + c2obs = c_dummy_venv_4_envs.buffer.obs[:, 0] assert np.all(c2obs == obs1) or np.all(c2obs == obs2) - c2.reset_env(gym_reset_kwargs=gym_reset_kwargs) - c2.reset_buffer() - assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs).n_collected_episodes == 8 + c_dummy_venv_4_envs.reset_env() + c_dummy_venv_4_envs.reset_buffer() + assert c_dummy_venv_4_envs.collect(n_episode=8).n_collected_episodes == 8 valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] - assert np.all(c2.buffer.obs[:, 0] == obs) + assert np.all(c_dummy_venv_4_envs.buffer.obs[:, 0] == obs) keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1] - assert np.allclose(c2.buffer.info["key"], keys) - for e in c2.buffer.info["env"][valid_indices]: - assert isinstance(e, MyTestEnv) + assert np.allclose(c_dummy_venv_4_envs.buffer.info["key"], keys) + for e in c_dummy_venv_4_envs.buffer.info["env"][valid_indices]: + assert isinstance(e, MoveToRightEnv) env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2] - assert np.allclose(c2.buffer.info["env_id"], env_ids) + assert np.allclose(c_dummy_venv_4_envs.buffer.info["env_id"], env_ids) rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1] - assert np.allclose(c2.buffer.info["rew"], rews) - c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) + assert np.allclose(c_dummy_venv_4_envs.buffer.rew, rews) + c_dummy_venv_4_envs.collect(n_episode=4, random=True) # test corner case with pytest.raises(TypeError): - Collector(policy, dum, ReplayBuffer(10)) + Collector(policy, dummy_venv_4_envs, ReplayBuffer(10)) with pytest.raises(TypeError): - Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5)) + Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5)) with pytest.raises(TypeError): - c2.collect() + c_dummy_venv_4_envs.collect() # test NXEnv for obs_type in ["array", "object"]: envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) - c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) - c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs) - assert c3.buffer.obs.dtype == object + c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) + c_suproc_new.reset() + c_suproc_new.collect(n_step=6) + assert c_suproc_new.buffer.obs.dtype == object -@pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) -def test_collector_with_async(gym_reset_kwargs) -> None: +@pytest.fixture() +def get_AsyncCollector(): env_lens = [2, 3, 4, 5] - writer = SummaryWriter("log/async_collector") - logger = Logger(writer) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens] venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) - policy = MyPolicy() + policy = MaxActionPolicy() bufsize = 60 c1 = AsyncCollector( policy, venv, VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), - logger.preprocess_fn, ) - ptr = [0, 0, 0, 0] - for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): - result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) - assert result.n_collected_episodes >= n_episode - # check buffer data, obs and obs_next, env_id - for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): - env_len = i + 2 - total = env_len * count - indices = np.arange(ptr[i], ptr[i] + total) % bufsize - ptr[i] = (ptr[i] + total) % bufsize - seq = np.arange(env_len) - buf = c1.buffer.buffers[i] - assert np.all(buf.info.env_id[indices] == i) - assert np.all(buf.obs[indices].reshape(count, env_len) == seq) - assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) - # test async n_step, for now the buffer should be full of data - for n_step in tqdm.trange(1, 15, desc="test async n_step"): - result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) - assert result.n_collected_steps >= n_step - for i in range(4): - env_len = i + 2 - seq = np.arange(env_len) - buf = c1.buffer.buffers[i] - assert np.all(buf.info.env_id == i) - assert np.all(buf.obs.reshape(-1, env_len) == seq) - assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) - with pytest.raises(TypeError): - c1.collect() + c1.reset() + return c1, env_lens + + +class TestAsyncCollector: + def test_collect_without_argument_gives_error(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector + with pytest.raises(TypeError): + c1.collect() + + def test_collect_one_episode_async(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector + result = c1.collect(n_episode=1) + assert result.n_collected_episodes >= 1 + + def test_enough_episodes_two_collection_cycles_n_episode_without_reset( + self, + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector + n_episode = 2 + result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False) + assert result_c1.n_collected_episodes >= n_episode + result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False) + assert result_c2.n_collected_episodes >= n_episode + + def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector): + c1, env_lens = get_AsyncCollector + n_episode = 2 + result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True) + assert result_c1.n_collected_episodes >= n_episode + result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=True) + assert result_c2.n_collected_episodes >= n_episode + + def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode( + self, + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector + ptr = [0, 0, 0, 0] + bufsize = 60 + for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode) + assert result.n_collected_episodes >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) + + def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step( + self, + get_AsyncCollector, + ): + c1, env_lens = get_AsyncCollector + bufsize = 60 + ptr = [0, 0, 0, 0] + for n_step in tqdm.trange(1, 15, desc="test async n_step"): + result = c1.collect(n_step=n_step) + assert result.n_collected_steps >= n_step + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) + + @pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) + def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step( + self, + get_AsyncCollector, + gym_reset_kwargs, + ): + c1, env_lens = get_AsyncCollector + bufsize = 60 + ptr = [0, 0, 0, 0] + for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): + result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) + assert result.n_collected_episodes >= n_episode + # check buffer data, obs and obs_next, env_id + for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]): + env_len = i + 2 + total = env_len * count + indices = np.arange(ptr[i], ptr[i] + total) % bufsize + ptr[i] = (ptr[i] + total) % bufsize + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id[indices] == i) + assert np.all(buf.obs[indices].reshape(count, env_len) == seq) + assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) + # test async n_step, for now the buffer should be full of data, thus no bincount stuff as above + for n_step in tqdm.trange(1, 15, desc="test async n_step"): + result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) + assert result.n_collected_steps >= n_step + for i in range(4): + env_len = i + 2 + seq = np.arange(env_len) + buf = c1.buffer.buffers[i] + assert np.all(buf.info.env_id == i) + assert np.all(buf.obs.reshape(-1, env_len) == seq) + assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) def test_collector_with_dict_state() -> None: - env = MyTestEnv(size=5, sleep=0, dict_state=True) - policy = MyPolicy(dict_state=True) - c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) + env = MoveToRightEnv(size=5, sleep=0, dict_state=True) + policy = MaxActionPolicy(dict_state=True) + c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0.reset() c0.collect(n_step=3) c0.collect(n_episode=2) - assert len(c0.buffer) == 10 - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] + assert len(c0.buffer) == 10 # 3 + two episodes with 5 steps each + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) envs.seed(666) obs, info = envs.reset() @@ -280,8 +356,8 @@ def test_collector_with_dict_state() -> None: policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn, ) + c1.reset() c1.collect(n_step=12) result = c1.collect(n_episode=8) assert result.n_collected_episodes == 8 @@ -396,41 +472,47 @@ def test_collector_with_dict_state() -> None: policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn, ) + c2.reset() c2.collect(n_episode=10) batch, _ = c2.buffer.sample(10) -def test_collector_with_ma() -> None: - env = MyTestEnv(size=5, sleep=0, ma_rew=4) - policy = MyPolicy() - c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) - # n_step=3 will collect a full episode - rew = c0.collect(n_step=3).returns - assert len(rew) == 0 - rew = c0.collect(n_episode=2).returns - assert rew.shape == (2, 4) - assert np.all(rew == 1) - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] +def test_collector_with_multi_agent() -> None: + multi_agent_env = MoveToRightEnv(size=5, sleep=0, ma_rew=4) + policy = MaxActionPolicy() + c_single_env = Collector(policy, multi_agent_env, ReplayBuffer(size=100)) + c_single_env.reset() + multi_env_returns = c_single_env.collect(n_step=3).returns + # c_single_env has length 3 + # We have no full episodes, so no returns yet + assert len(multi_env_returns) == 0 + + single_env_returns = c_single_env.collect(n_episode=2).returns + # now two episodes. Since we have 4 a agents, the returns have shape (2, 4) + assert single_env_returns.shape == (2, 4) + assert np.all(single_env_returns == 1) + + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) - c1 = Collector( + c_multi_env_ma = Collector( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4), - Logger.single_preprocess_fn, ) - rew = c1.collect(n_step=12).returns - assert rew.shape == (2, 4) and np.all(rew == 1), rew - rew = c1.collect(n_episode=8).returns - assert rew.shape == (8, 4) - assert np.all(rew == 1) - batch, _ = c1.buffer.sample(10) + c_multi_env_ma.reset() + multi_env_returns = c_multi_env_ma.collect(n_step=12).returns + # each env makes 3 steps, the first two envs are done and result in two finished episodes + assert multi_env_returns.shape == (2, 4) and np.all(multi_env_returns == 1), multi_env_returns + multi_env_returns = c_multi_env_ma.collect(n_episode=8).returns + assert multi_env_returns.shape == (8, 4) + assert np.all(multi_env_returns == 1) + batch, _ = c_multi_env_ma.buffer.sample(10) print(batch) - c0.buffer.update(c1.buffer) - assert len(c0.buffer) in [42, 43] - if len(c0.buffer) == 42: - rew = [ + c_single_env.buffer.update(c_multi_env_ma.buffer) + assert len(c_single_env.buffer) in [42, 43] + if len(c_single_env.buffer) == 42: + multi_env_returns = [ 0, 0, 0, @@ -475,7 +557,7 @@ def test_collector_with_ma() -> None: 1, ] else: - rew = [ + multi_env_returns = [ 0, 0, 0, @@ -520,17 +602,17 @@ def test_collector_with_ma() -> None: 0, 1, ] - assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) - assert np.all(c0.buffer[:].done == rew) + assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns]) + assert np.all(c_single_env.buffer[:].done == multi_env_returns) c2 = Collector( policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), - Logger.single_preprocess_fn, ) - rew = c2.collect(n_episode=10).returns - assert rew.shape == (10, 4) - assert np.all(rew == 1) + c2.reset() + multi_env_returns = c2.collect(n_episode=10).returns + assert multi_env_returns.shape == (10, 4) + assert np.all(multi_env_returns == 1) batch, _ = c2.buffer.sample(10) @@ -543,20 +625,21 @@ def test_collector_with_atari_setting() -> None: reference_obs[i, 0] = i # atari single buffer - env = MyTestEnv(size=5, sleep=0, array_state=True) - policy = MyPolicy() + env = MoveToRightEnv(size=5, sleep=0, array_state=True) + policy = MaxActionPolicy() c0 = Collector(policy, env, ReplayBuffer(size=100)) + c0.reset() c0.collect(n_step=6) c0.collect(n_episode=2) assert c0.buffer.obs.shape == (100, 4, 84, 84) assert c0.buffer.obs_next.shape == (100, 4, 84, 84) - assert len(c0.buffer) == 15 + assert len(c0.buffer) == 15 # 6 + 2 episodes with 5 steps each obs = np.zeros_like(c0.buffer.obs) obs[np.arange(15)] = reference_obs[np.arange(15) % 5] assert np.all(obs == c0.buffer.obs) c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) - c1.collect(n_episode=3) + c1.collect(n_episode=3, reset_before_collect=True) assert np.allclose(c0.buffer.obs, c1.buffer.obs) with pytest.raises(AttributeError): c1.buffer.obs_next # noqa: B018 @@ -567,6 +650,7 @@ def test_collector_with_atari_setting() -> None: env, ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True), ) + c2.reset() c2.collect(n_step=8) assert c2.buffer.obs.shape == (100, 84, 84) obs = np.zeros_like(c2.buffer.obs) @@ -575,9 +659,10 @@ def test_collector_with_atari_setting() -> None: assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) # atari multi buffer - env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] envs = DummyVectorEnv(env_fns) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) + c3.reset() c3.collect(n_step=12) result = c3.collect(n_episode=9) assert result.n_collected_episodes == 9 @@ -606,6 +691,7 @@ def test_collector_with_atari_setting() -> None: save_only_last_obs=True, ), ) + c4.reset() c4.collect(n_step=12) result = c4.collect(n_episode=9) assert result.n_collected_episodes == 9 @@ -672,6 +758,7 @@ def test_collector_with_atari_setting() -> None: buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) + c5.reset() result_ = c5.collect(n_step=12) assert len(buf) == 5 assert len(c5.buffer) == 12 @@ -767,6 +854,7 @@ def test_collector_with_atari_setting() -> None: # test buffer=None c6 = Collector(policy, envs) + c6.reset() result1 = c6.collect(n_step=12) for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: assert np.allclose(getattr(result1, key), getattr(result_, key)) @@ -778,7 +866,7 @@ def test_collector_with_atari_setting() -> None: @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") def test_collector_envpool_gym_reset_return_info() -> None: envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True) - policy = MyPolicy(action_shape=(len(envs), 1)) + policy = MaxActionPolicy(action_shape=(len(envs), 1)) c0 = Collector( policy, @@ -786,18 +874,59 @@ def test_collector_envpool_gym_reset_return_info() -> None: VectorReplayBuffer(len(envs) * 10, len(envs)), exploration_noise=True, ) + c0.reset() c0.collect(n_step=8) env_ids = np.zeros(len(envs) * 10) env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3] assert np.allclose(c0.buffer.info["env_id"], env_ids) +def test_collector_with_vector_env(): + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] + + dum = DummyVectorEnv(env_fns) + policy = MaxActionPolicy() + + c2 = Collector( + policy, + dum, + VectorReplayBuffer(total_size=100, buffer_num=4), + ) + + c2.reset() + + c1r = c2.collect(n_episode=2) + assert np.array_equal(np.array([1, 8]), c1r.lens) + c2r = c2.collect(n_episode=10) + assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 8, 9, 10]), c2r.lens) + c3r = c2.collect(n_step=20) + assert np.array_equal(np.array([1, 1, 1, 1, 1]), c3r.lens) + c4r = c2.collect(n_step=20) + assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens) + + +def test_async_collector_with_vector_env(): + env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]] + + dum = DummyVectorEnv(env_fns) + policy = MaxActionPolicy() + c1 = AsyncCollector( + policy, + dum, + VectorReplayBuffer(total_size=100, buffer_num=4), + ) + + c1r = c1.collect(n_episode=10, reset_before_collect=True) + assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens) + c2r = c1.collect(n_step=20) + assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens) + + if __name__ == "__main__": - test_collector(gym_reset_kwargs=None) - test_collector(gym_reset_kwargs={}) + test_collector() test_collector_with_dict_state() - test_collector_with_ma() + test_collector_with_multi_agent() test_collector_with_atari_setting() - test_collector_with_async(gym_reset_kwargs=None) - test_collector_with_async(gym_reset_kwargs={"return_info": True}) test_collector_envpool_gym_reset_return_info() + test_collector_with_vector_env() + test_async_collector_with_vector_env() diff --git a/test/base/test_env.py b/test/base/test_env.py index edeb3f3..f1571ca 100644 --- a/test/base/test_env.py +++ b/test/base/test_env.py @@ -20,9 +20,9 @@ from tianshou.env.gym_wrappers import TruncatedAsTerminated from tianshou.utils import RunningMeanStd if __name__ == "__main__": - from env import MyTestEnv, NXEnv + from env import MoveToRightEnv, NXEnv else: # pytest - from test.base.env import MyTestEnv, NXEnv + from test.base.env import MoveToRightEnv, NXEnv try: import envpool @@ -56,7 +56,7 @@ def recurse_comp(a, b): def test_async_env(size=10000, num=8, sleep=0.1) -> None: # simplify the test case, just keep stepping env_fns = [ - lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True) + lambda i=i: MoveToRightEnv(size=i, sleep=sleep, random_sleep=True) for i in range(size, size + num) ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] @@ -108,10 +108,10 @@ def test_async_env(size=10000, num=8, sleep=0.1) -> None: def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: env_fns = [ - lambda: MyTestEnv(size=size, sleep=sleep * 2), - lambda: MyTestEnv(size=size, sleep=sleep * 3), - lambda: MyTestEnv(size=size, sleep=sleep * 5), - lambda: MyTestEnv(size=size, sleep=sleep * 7), + lambda: MoveToRightEnv(size=size, sleep=sleep * 2), + lambda: MoveToRightEnv(size=size, sleep=sleep * 3), + lambda: MoveToRightEnv(size=size, sleep=sleep * 5), + lambda: MoveToRightEnv(size=size, sleep=sleep * 7), ] test_cls = [SubprocVectorEnv, ShmemVectorEnv] if has_ray(): @@ -156,7 +156,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: def test_vecenv(size=10, num=8, sleep=0.001) -> None: env_fns = [ - lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True) + lambda i=i: MoveToRightEnv(size=i, sleep=sleep, recurse_state=True) for i in range(size, size + num) ] venv = [ @@ -237,7 +237,7 @@ def test_env_obs_dtype() -> None: def test_env_reset_optional_kwargs(size=10000, num=8) -> None: - env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)] + env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] if has_ray(): test_cls += [RayVectorEnv] @@ -257,7 +257,7 @@ def test_venv_wrapper_gym(num_envs: int = 4) -> None: except ValueError: obs, info = envs.reset(return_info=True) assert isinstance(obs, np.ndarray) - assert isinstance(info, list) + assert isinstance(info, np.ndarray) assert isinstance(info[0], dict) assert obs.shape[0] == len(info) == num_envs @@ -334,7 +334,7 @@ def test_venv_norm_obs() -> None: action = np.array([1, 1, 1, 1]) total_step = 30 action_list = [action] * total_step - env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes] + env_fns = [lambda i=x: MoveToRightEnv(size=i, array_state=True) for x in sizes] raw = DummyVectorEnv(env_fns) train_env = VectorEnvNormObs(DummyVectorEnv(env_fns)) print(train_env.observation_space) diff --git a/test/base/test_env_finite.py b/test/base/test_env_finite.py index d1e7802..651e770 100644 --- a/test/base/test_env_finite.py +++ b/test/base/test_env_finite.py @@ -90,20 +90,20 @@ class FiniteVectorEnv(BaseVectorEnv): # END - def reset(self, id=None): - id = self._wrap_id(id) + def reset(self, env_id=None): + env_id = self._wrap_id(env_id) self._reset_alive_envs() # ask super to reset alive envs and remap to current index - request_id = list(filter(lambda i: i in self._alive_env_ids, id)) - obs = [None] * len(id) - infos = [None] * len(id) - id2idx = {i: k for k, i in enumerate(id)} + request_id = list(filter(lambda i: i in self._alive_env_ids, env_id)) + obs = [None] * len(env_id) + infos = [None] * len(env_id) + id2idx = {i: k for k, i in enumerate(env_id)} if request_id: for k, o, info in zip(request_id, *super().reset(request_id), strict=True): obs[id2idx[k]] = o infos[id2idx[k]] = info - for i, o in zip(id, obs, strict=True): + for i, o in zip(env_id, obs, strict=True): if o is None and i in self._alive_env_ids: self._alive_env_ids.remove(i) @@ -121,7 +121,7 @@ class FiniteVectorEnv(BaseVectorEnv): self.reset() raise StopIteration - return np.stack(obs), infos + return np.stack(obs), np.array(infos) def step(self, action, id=None): id = self._wrap_id(id) @@ -204,10 +204,12 @@ def test_finite_dummy_vector_env() -> None: envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) + test_collector.reset() for _ in range(3): envs.tracker = MetricTracker() try: + # TODO: why on earth 10**18? test_collector.collect(n_step=10**18) except StopIteration: envs.tracker.validate() @@ -218,6 +220,7 @@ def test_finite_subproc_vector_env() -> None: envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) policy = AnyPolicy() test_collector = Collector(policy, envs, exploration_noise=True) + test_collector.reset() for _ in range(3): envs.tracker = MetricTracker() diff --git a/test/continuous/test_redq.py b/test/continuous/test_redq.py index 20177f4..a180e44 100644 --- a/test/continuous/test_redq.py +++ b/test/continuous/test_redq.py @@ -136,6 +136,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None: exploration_noise=True, ) test_collector = Collector(policy, test_envs) + train_collector.reset() train_collector.collect(n_step=args.start_timesteps, random=True) # log log_path = os.path.join(args.logdir, args.task, "redq") diff --git a/test/continuous/test_td3.py b/test/continuous/test_td3.py index 7b8690f..fb1e28a 100644 --- a/test/continuous/test_td3.py +++ b/test/continuous/test_td3.py @@ -162,6 +162,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) policy.eval() collector = Collector(policy, env) + collector.reset() collector_stats = collector.collect(n_episode=1, render=args.render) print(collector_stats) diff --git a/test/discrete/test_a2c_with_il.py b/test/discrete/test_a2c_with_il.py index 3ca7ce6..f51a8d7 100644 --- a/test/discrete/test_a2c_with_il.py +++ b/test/discrete/test_a2c_with_il.py @@ -109,7 +109,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None: train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)), ) + train_collector.reset() test_collector = Collector(policy, test_envs) + test_collector.reset() # log log_path = os.path.join(args.logdir, args.task, "a2c") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index a45b021..295c6b3 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -108,7 +108,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: ) test_collector = Collector(policy, test_envs, exploration_noise=False) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) def train_fn(epoch: int, env_step: int) -> None: # exp decay eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) diff --git a/test/discrete/test_c51.py b/test/discrete/test_c51.py index 013e2c4..c9731ed 100644 --- a/test/discrete/test_c51.py +++ b/test/discrete/test_c51.py @@ -120,7 +120,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "c51") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_dqn.py b/test/discrete/test_dqn.py index d0aba10..de598e1 100644 --- a/test/discrete/test_dqn.py +++ b/test/discrete/test_dqn.py @@ -111,7 +111,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "dqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_drqn.py b/test/discrete/test_drqn.py index 3ed1f4f..03ece9b 100644 --- a/test/discrete/test_drqn.py +++ b/test/discrete/test_drqn.py @@ -95,7 +95,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None: # the stack_num is for RNN training: sample framestack obs test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "drqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_fqf.py b/test/discrete/test_fqf.py index 7de0901..fa7a4ca 100644 --- a/test/discrete/test_fqf.py +++ b/test/discrete/test_fqf.py @@ -128,7 +128,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "fqf") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_iqn.py b/test/discrete/test_iqn.py index 5ff71f5..1f75ab5 100644 --- a/test/discrete/test_iqn.py +++ b/test/discrete/test_iqn.py @@ -124,7 +124,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "iqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_qrdqn.py b/test/discrete/test_qrdqn.py index c1bbcc3..b3e42d5 100644 --- a/test/discrete/test_qrdqn.py +++ b/test/discrete/test_qrdqn.py @@ -113,7 +113,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) diff --git a/test/discrete/test_rainbow.py b/test/discrete/test_rainbow.py index fafa1e0..a38433b 100644 --- a/test/discrete/test_rainbow.py +++ b/test/discrete/test_rainbow.py @@ -128,7 +128,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "rainbow") writer = SummaryWriter(log_path) diff --git a/test/modelbased/test_dqn_icm.py b/test/modelbased/test_dqn_icm.py index 0f957c7..2aa5b4e 100644 --- a/test/modelbased/test_dqn_icm.py +++ b/test/modelbased/test_dqn_icm.py @@ -154,7 +154,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None: train_collector = Collector(policy, train_envs, buf, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "dqn_icm") writer = SummaryWriter(log_path) diff --git a/test/modelbased/test_psrl.py b/test/modelbased/test_psrl.py index 55719f4..2994b11 100644 --- a/test/modelbased/test_psrl.py +++ b/test/modelbased/test_psrl.py @@ -81,7 +81,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: VectorReplayBuffer(args.buffer_size, len(train_envs)), exploration_noise=True, ) + train_collector.reset() test_collector = Collector(policy, test_envs) + test_collector.reset() # Logger log_path = os.path.join(args.logdir, args.task, "psrl") writer = SummaryWriter(log_path) @@ -120,7 +122,6 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None: # Let's watch its performance! policy.eval() test_envs.seed(args.seed) - test_collector.reset() result = test_collector.collect(n_episode=args.test_num, render=args.render) print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") elif env.spec.reward_threshold: diff --git a/test/offline/gather_cartpole_data.py b/test/offline/gather_cartpole_data.py index 7f1a312..61450a9 100644 --- a/test/offline/gather_cartpole_data.py +++ b/test/offline/gather_cartpole_data.py @@ -115,9 +115,11 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) # collector train_collector = Collector(policy, train_envs, buf, exploration_noise=True) + train_collector.reset() test_collector = Collector(policy, test_envs, exploration_noise=True) + test_collector.reset() # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, args.task, "qrdqn") writer = SummaryWriter(log_path) @@ -165,6 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer: buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) policy.set_eps(0.2) collector = Collector(policy, test_envs, buf, exploration_noise=True) + collector.reset() collector_stats = collector.collect(n_step=args.buffer_size) if args.save_buffer_name.endswith(".hdf5"): buf.save_hdf5(args.save_buffer_name) diff --git a/test/offline/test_discrete_bcq.py b/test/offline/test_discrete_bcq.py index 32e6d56..81fc899 100644 --- a/test/offline/test_discrete_bcq.py +++ b/test/offline/test_discrete_bcq.py @@ -178,6 +178,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None: def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: + test_discrete_bcq() args.resume = True test_discrete_bcq(args) diff --git a/test/pettingzoo/pistonball.py b/test/pettingzoo/pistonball.py index 0dd750b..990bf46 100644 --- a/test/pettingzoo/pistonball.py +++ b/test/pettingzoo/pistonball.py @@ -83,8 +83,8 @@ def get_agents( if isinstance(env.observation_space, gym.spaces.Dict) else env.observation_space ) - args.state_shape = observation_space.shape or observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = observation_space.shape or int(observation_space.n) + args.action_shape = env.action_space.shape or int(env.action_space.n) if agents is None: agents = [] optims = [] @@ -135,7 +135,7 @@ def train_agent( exploration_noise=True, ) test_collector = Collector(policy, test_envs, exploration_noise=True) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 8bbb20c..0897d73 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -234,7 +234,7 @@ def train_agent( exploration_noise=False, # True ) test_collector = Collector(policy, test_envs) - # train_collector.collect(n_step=args.batch_size * args.training_num) + # train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "pistonball", "dqn") writer = SummaryWriter(log_path) diff --git a/test/pettingzoo/tic_tac_toe.py b/test/pettingzoo/tic_tac_toe.py index 62b66df..e1559b1 100644 --- a/test/pettingzoo/tic_tac_toe.py +++ b/test/pettingzoo/tic_tac_toe.py @@ -102,8 +102,8 @@ def get_agents( if isinstance(env.observation_space, gymnasium.spaces.Dict) else env.observation_space ) - args.state_shape = observation_space.shape or observation_space.n - args.action_shape = env.action_space.shape or env.action_space.n + args.state_shape = observation_space.shape or int(observation_space.n) + args.action_shape = env.action_space.shape or int(env.action_space.n) if agent_learn is None: # model net = Net( @@ -170,7 +170,7 @@ def train_agent( ) test_collector = Collector(policy, test_envs, exploration_noise=True) # policy.set_eps(1) - train_collector.collect(n_step=args.batch_size * args.training_num) + train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True) # log log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") writer = SummaryWriter(log_path) diff --git a/tianshou/data/batch.py b/tianshou/data/batch.py index 7002b55..b9b7024 100644 --- a/tianshou/data/batch.py +++ b/tianshou/data/batch.py @@ -263,6 +263,9 @@ class BatchProtocol(Protocol): def __repr__(self) -> str: ... + def __iter__(self) -> Iterator[Self]: + ... + def to_numpy(self) -> None: """Change all torch.Tensor to numpy.ndarray in-place.""" ... @@ -391,6 +394,12 @@ class BatchProtocol(Protocol): """ ... + def to_dict(self) -> dict[str, Any]: + ... + + def to_list_of_dicts(self) -> list[dict[str, Any]]: + ... + class Batch(BatchProtocol): """See :class:`~tianshou.data.batch.BatchProtocol`.""" @@ -422,6 +431,17 @@ class Batch(BatchProtocol): # Feels like kwargs could be just merged into batch_dict in the beginning self.__init__(kwargs, copy=copy) # type: ignore + def to_dict(self) -> dict[str, Any]: + result = {} + for k, v in self.__dict__.items(): + if isinstance(v, Batch): + v = v.to_dict() + result[k] = v + return result + + def to_list_of_dicts(self) -> list[dict[str, Any]]: + return [entry.to_dict() for entry in self] + def __setattr__(self, key: str, value: Any) -> None: """Set self.key = value.""" self.__dict__[key] = _parse_value(value) @@ -478,6 +498,14 @@ class Batch(BatchProtocol): return new_batch raise IndexError("Cannot access item from empty Batch object.") + def __iter__(self) -> Iterator[Self]: + # TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea + if len(self.__dict__) == 0: + yield from [] + else: + for i in range(len(self)): + yield self[i] + def __setitem__(self, index: str | IndexType, value: Any) -> None: """Assign value to self[index].""" value = _parse_value(value) @@ -601,10 +629,10 @@ class Batch(BatchProtocol): else: # ndarray or scalar if not isinstance(obj, np.ndarray): - obj = np.asanyarray(obj) # noqa: PLW2901 - obj = torch.from_numpy(obj).to(device) # noqa: PLW2901 + obj = np.asanyarray(obj) + obj = torch.from_numpy(obj).to(device) if dtype is not None: - obj = obj.type(dtype) # noqa: PLW2901 + obj = obj.type(dtype) self.__dict__[batch_key] = obj def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: diff --git a/tianshou/data/buffer/manager.py b/tianshou/data/buffer/manager.py index e09b696..a495b0a 100644 --- a/tianshou/data/buffer/manager.py +++ b/tianshou/data/buffer/manager.py @@ -200,7 +200,7 @@ class ReplayBufferManager(ReplayBuffer): return np.concatenate( [ - buf.sample_indices(bsz) + offset + buf.sample_indices(int(bsz)) + offset for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True) ], ) diff --git a/tianshou/data/collector.py b/tianshou/data/collector.py index 09290c7..751fedf 100644 --- a/tianshou/data/collector.py +++ b/tianshou/data/collector.py @@ -1,8 +1,8 @@ import time import warnings -from collections.abc import Callable +from copy import copy from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Self, TypeVar, cast import gymnasium as gym import numpy as np @@ -18,8 +18,10 @@ from tianshou.data import ( VectorReplayBuffer, to_numpy, ) -from tianshou.data.batch import alloc_by_keys_diff -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ( + ObsBatchProtocol, + RolloutBatchProtocol, +) from tianshou.env import BaseVectorEnv, DummyVectorEnv from tianshou.policy import BasePolicy from tianshou.utils.print import DataclassPPrintMixin @@ -45,13 +47,80 @@ class CollectStats(CollectStatsBase): """The speed of collecting (env_step per second).""" returns: np.ndarray """The collected episode returns.""" - returns_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + returns_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step """Stats of the collected returns.""" lens: np.ndarray """The collected episode lengths.""" - lens_stat: SequenceSummaryStats | None # can be None if no episode ends during collect step + lens_stat: SequenceSummaryStats | None # can be None if no episode ends during the collect step """Stats of the collected episode lengths.""" + @classmethod + def with_autogenerated_stats( + cls, + returns: np.ndarray, + lens: np.ndarray, + n_collected_episodes: int = 0, + n_collected_steps: int = 0, + collect_time: float = 0.0, + collect_speed: float = 0.0, + ) -> Self: + """Return a new instance with the stats autogenerated from the given lists.""" + returns_stat = SequenceSummaryStats.from_sequence(returns) if returns.size > 0 else None + lens_stat = SequenceSummaryStats.from_sequence(lens) if lens.size > 0 else None + return cls( + n_collected_episodes=n_collected_episodes, + n_collected_steps=n_collected_steps, + collect_time=collect_time, + collect_speed=collect_speed, + returns=returns, + returns_stat=returns_stat, + lens=np.array(lens, int), + lens_stat=lens_stat, + ) + + +_TArrLike = TypeVar("_TArrLike", bound="np.ndarray | torch.Tensor | Batch | None") + + +def _nullable_slice(obj: _TArrLike, indices: np.ndarray) -> _TArrLike: + """Return None, or the values at the given indices if the object is not None.""" + if obj is not None: + return obj[indices] # type: ignore[index, return-value] + return None # type: ignore[unreachable] + + +def _dict_of_arr_to_arr_of_dicts(dict_of_arr: dict[str, np.ndarray | dict]) -> np.ndarray: + return np.array(Batch(dict_of_arr).to_list_of_dicts()) + + +def _HACKY_create_info_batch(info_array: np.ndarray) -> Batch: + """TODO: this exists because of multiple bugs in Batch and to restore backwards compatibility. + Batch should be fixed and this function should be removed asap!. + """ + if info_array.dtype != np.dtype("O"): + raise ValueError( + f"Expected info_array to have dtype=object, but got {info_array.dtype}.", + ) + + truthy_info_indices = info_array.nonzero()[0] + falsy_info_indices = set(range(len(info_array))) - set(truthy_info_indices) + falsy_info_indices = np.array(list(falsy_info_indices), dtype=int) + + if len(falsy_info_indices) == len(info_array): + return Batch() + + some_nonempty_info = None + for info in info_array: + if info: + some_nonempty_info = info + break + + info_array = copy(info_array) + info_array[falsy_info_indices] = some_nonempty_info + result_batch_parent = Batch(info=info_array) + result_batch_parent.info[falsy_info_indices] = {} + return result_batch_parent.info + class Collector: """Collector enables the policy to interact with different types of envs with exact number of steps or episodes. @@ -60,23 +129,13 @@ class Collector: :param env: a ``gym.Env`` environment or an instance of the :class:`~tianshou.env.BaseVectorEnv` class. :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class. - If set to None, it will not store the data. Default to None. - :param function preprocess_fn: a function called before the data has been added to - the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None. + If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer` + as the default buffer. :param exploration_noise: determine whether the action needs to be modified - with corresponding policy's exploration noise. If so, "policy. + with the corresponding policy's exploration noise. If so, "policy. exploration_noise(act, batch)" will be called automatically to add the exploration noise into action. Default to False. - The "preprocess_fn" is a function called before the data has been added to the - buffer with batch format. It will receive only "obs" and "env_id" when the - collector resets the environment, and will receive the keys "obs_next", "rew", - "terminated", "truncated, "info", "policy" and "env_id" in a normal env step. - Alternatively, it may also accept the keys "obs_next", "rew", "done", "info", - "policy" and "env_id". - It returns either a dict or a :class:`~tianshou.data.Batch` with the modified - keys and values. Examples are in "test/base/test_collector.py". - .. note:: Please make sure the given environment has a time limitation if using n_episode @@ -84,7 +143,7 @@ class Collector: .. note:: - In past versions of Tianshou, the replay buffer that was passed to `__init__` + In past versions of Tianshou, the replay buffer passed to `__init__` was automatically reset. This is not done in the current implementation. """ @@ -93,7 +152,6 @@ class Collector: policy: BasePolicy, env: gym.Env | BaseVectorEnv, buffer: ReplayBuffer | None = None, - preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None, exploration_noise: bool = False, ) -> None: super().__init__() @@ -105,16 +163,30 @@ class Collector: self.env = env # type: ignore self.env_num = len(self.env) self.exploration_noise = exploration_noise - self.buffer: ReplayBuffer - self._assign_buffer(buffer) + self.buffer = self._assign_buffer(buffer) self.policy = policy - self.preprocess_fn = preprocess_fn self._action_space = self.env.action_space - self.data: RolloutBatchProtocol - # avoid creating attribute outside __init__ - self.reset(False) - def _assign_buffer(self, buffer: ReplayBuffer | None) -> None: + self._pre_collect_obs_RO: np.ndarray | None = None + self._pre_collect_info_R: np.ndarray | None = None + self._pre_collect_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None + + self._is_closed = False + self.collect_step, self.collect_episode, self.collect_time = 0, 0, 0.0 + + def close(self) -> None: + """Close the collector and the environment.""" + self.env.close() + self._pre_collect_obs_RO = None + self._pre_collect_info_R = None + self._is_closed = True + + @property + def is_closed(self) -> bool: + """Return True if the collector is closed.""" + return self._is_closed + + def _assign_buffer(self, buffer: ReplayBuffer | None) -> ReplayBuffer: """Check if the buffer matches the constraint.""" if buffer is None: buffer = VectorReplayBuffer(self.env_num, self.env_num) @@ -136,38 +208,28 @@ class Collector: f"{self.env_num} envs,\n\tplease use {vector_type}(total_size=" f"{buffer.maxsize}, buffer_num={self.env_num}, ...) instead.", ) - self.buffer = buffer + return buffer def reset( self, reset_buffer: bool = True, + reset_stats: bool = True, gym_reset_kwargs: dict[str, Any] | None = None, ) -> None: - """Reset the environment, statistics, current data and possibly replay memory. + """Reset the environment, statistics, and data needed to start the collection. - :param reset_buffer: if true, reset the replay buffer that is attached + :param reset_buffer: if true, reset the replay buffer attached to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) """ - # use empty Batch for "state" so that self.data supports slicing - # convert empty Batch to None when passing data to policy - data = Batch( - obs={}, - act={}, - rew={}, - terminated={}, - truncated={}, - done={}, - obs_next={}, - info={}, - policy={}, - ) - self.data = cast(RolloutBatchProtocol, data) - self.reset_env(gym_reset_kwargs) + self.reset_env(gym_reset_kwargs=gym_reset_kwargs) if reset_buffer: self.reset_buffer() - self.reset_stat() + if reset_stats: + self.reset_stat() + self._is_closed = False def reset_stat(self) -> None: """Reset the statistic variables.""" @@ -177,44 +239,76 @@ class Collector: """Reset the data buffer.""" self.buffer.reset(keep_statistics=keep_statistics) - def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None: - """Reset all of the environments.""" - gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - obs, info = self.env.reset(**gym_reset_kwargs) - if self.preprocess_fn: - processed_data = self.preprocess_fn(obs=obs, info=info, env_id=np.arange(self.env_num)) - obs = processed_data.get("obs", obs) - info = processed_data.get("info", info) - self.data.info = info # type: ignore - self.data.obs = obs - - def _reset_state(self, id: int | list[int]) -> None: - """Reset the hidden state: self.data.state[id].""" - if hasattr(self.data.policy, "hidden_state"): - state = self.data.policy.hidden_state # it is a reference - if isinstance(state, torch.Tensor): - state[id].zero_() - elif isinstance(state, np.ndarray): - state[id] = None if state.dtype == object else 0 - elif isinstance(state, Batch): - state.empty_(id) - - def _reset_env_with_ids( + def reset_env( self, - local_ids: list[int] | np.ndarray, - global_ids: list[int] | np.ndarray, gym_reset_kwargs: dict[str, Any] | None = None, ) -> None: - gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {} - obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs) - if self.preprocess_fn: - processed_data = self.preprocess_fn(obs=obs_reset, info=info, env_id=global_ids) - obs_reset = processed_data.get("obs", obs_reset) - info = processed_data.get("info", info) - self.data.info[local_ids] = info # type: ignore + """Reset the environments and the initial obs, info, and hidden state of the collector.""" + gym_reset_kwargs = gym_reset_kwargs or {} + self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs) + # TODO: hack, wrap envpool envs such that they don't return a dict + if isinstance(self._pre_collect_info_R, dict): # type: ignore[unreachable] + # this can happen if the env is an envpool env. Then the thing returned by reset is a dict + # with array entries instead of an array of dicts + # We use Batch to turn it into an array of dicts + self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R) # type: ignore[unreachable] - self.data.obs_next[local_ids] = obs_reset # type: ignore + self._pre_collect_hidden_state_RH = None + def _compute_action_policy_hidden( + self, + random: bool, + ready_env_ids_R: np.ndarray, + use_grad: bool, + last_obs_RO: np.ndarray, + last_info_R: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None = None, + ) -> tuple[np.ndarray, np.ndarray, Batch, np.ndarray | torch.Tensor | Batch | None]: + """Returns the action, the normalized action, a "policy" entry, and the hidden state.""" + if random: + try: + act_normalized_RA = np.array( + [self._action_space[i].sample() for i in ready_env_ids_R], + ) + # TODO: test whether envpool env explicitly + except TypeError: # envpool's action space is not for per-env + act_normalized_RA = np.array([self._action_space.sample() for _ in ready_env_ids_R]) + act_RA = self.policy.map_action_inverse(np.array(act_normalized_RA)) + policy_R = Batch() + hidden_state_RH = None + + else: + info_batch = _HACKY_create_info_batch(last_info_R) + obs_batch_R = cast(ObsBatchProtocol, Batch(obs=last_obs_RO, info=info_batch)) + + with torch.set_grad_enabled(use_grad): + act_batch_RA = self.policy( + obs_batch_R, + last_hidden_state_RH, + ) + + act_RA = to_numpy(act_batch_RA.act) + if self.exploration_noise: + act_RA = self.policy.exploration_noise(act_RA, obs_batch_R) + act_normalized_RA = self.policy.map_action(act_RA) + + # TODO: cleanup the whole policy in batch thing + # todo policy_R can also be none, check + policy_R = act_batch_RA.get("policy", Batch()) + if not isinstance(policy_R, Batch): + raise RuntimeError( + f"The policy result should be a {Batch}, but got {type(policy_R)}", + ) + + hidden_state_RH = act_batch_RA.get("state", None) + # TODO: do we need the conditional? Would be better to just add hidden_state which could be None + if hidden_state_RH is not None: + policy_R.hidden_state = ( + hidden_state_RH # save state into buffer through policy attr + ) + return act_RA, act_normalized_RA, policy_R, hidden_state_RH + + # TODO: reduce complexity, remove the noqa def collect( self, n_step: int | None = None, @@ -222,49 +316,74 @@ class Collector: random: bool = False, render: float | None = None, no_grad: bool = True, + reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - """Collect a specified number of step or episode. + """Collect a specified number of steps or episodes. - To ensure unbiased sampling result with n_episode option, this function will + To ensure an unbiased sampling result with the n_episode option, this function will first collect ``n_episode - env_num`` episodes, then for the last ``env_num`` episodes, they will be collected evenly from each env. :param n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. Default - to False. + :param random: whether to use random policy for collecting data. :param render: the sleep time between rendering consecutive frames. - Default to None (no rendering). - :param no_grad: whether to retain gradient in policy.forward(). Default to - True (no gradient retaining). + :param no_grad: whether to retain gradient in policy.forward(). + :param reset_before_collect: whether to reset the environment before + collecting data. + It has only an effect if n_episode is not None, i.e. + if one wants to collect a fixed number of episodes. + (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's - reset function. Defaults to None (extra keyword arguments) + reset function. Only used if reset_before_collect is True. .. note:: One and only one collection number specification is permitted, either ``n_step`` or ``n_episode``. - :return: A dataclass object + :return: The collected stats """ + # NAMING CONVENTION (mostly suffixes): + # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed, + # the corresponding env is either reset or removed from the ready envs. + # R - number ready env ids. Note that this might change when envs get idle. + # This can only happen in n_episode case, see explanation in the corresponding block. + # For n_step, we always use all envs to collect the data, while for n_episode, + # R will be at most n_episode at the beginning, but can decrease during the collection. + # O - dimension(s) of observations + # A - dimension(s) of actions + # H - dimension(s) of hidden state + # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case. + # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration. + # Only used in n_episode case. Then, R becomes R-S. + + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + + # Input validation assert not self.env.is_async, "Please use AsyncCollector if using async venv." if n_step is not None: assert n_episode is None, ( f"Only one of n_step or n_episode is allowed in Collector." - f"collect, got n_step={n_step}, n_episode={n_episode}." + f"collect, got {n_step=}, {n_episode=}." ) assert n_step > 0 if n_step % self.env_num != 0: warnings.warn( - f"n_step={n_step} is not a multiple of #env ({self.env_num}), " - "which may cause extra transitions collected into the buffer.", + f"{n_step=} is not a multiple of ({self.env_num=}), " + "which may cause extra transitions being collected into the buffer.", ) - ready_env_ids = np.arange(self.env_num) + ready_env_ids_R = np.arange(self.env_num) elif n_episode is not None: assert n_episode > 0 - ready_env_ids = np.arange(min(self.env_num, n_episode)) - self.data = self.data[: min(self.env_num, n_episode)] + if self.env_num > n_episode: + warnings.warn( + f"{n_episode=} should be larger than {self.env_num=} to " + f"collect at least one trajectory in each environment.", + ) + ready_env_ids_R = np.arange(min(self.env_num, n_episode)) else: raise TypeError( "Please specify at least one (either n_step or n_episode) " @@ -273,149 +392,209 @@ class Collector: start_time = time.time() + if reset_before_collect: + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) + + if self._pre_collect_obs_RO is None or self._pre_collect_info_R is None: + raise ValueError( + "Initial obs and info should not be None. " + "Either reset the collector (using reset or reset_env) or pass reset_before_collect=True to collect.", + ) + + # get the first obs to be the current obs in the n_step case as + # episodes as a new call to collect does not restart trajectories + # (which we also really don't want) step_count = 0 - episode_count = 0 + num_collected_episodes = 0 episode_returns: list[float] = [] episode_lens: list[int] = [] episode_start_indices: list[int] = [] + # in case we select fewer episodes than envs, we run only some of them + last_obs_RO = _nullable_slice(self._pre_collect_obs_RO, ready_env_ids_R) + last_info_R = _nullable_slice(self._pre_collect_info_R, ready_env_ids_R) + last_hidden_state_RH = _nullable_slice( + self._pre_collect_hidden_state_RH, + ready_env_ids_R, + ) + while True: - assert len(self.data) == len(ready_env_ids) + # todo check if we need this when using cur_rollout_batch + # if len(cur_rollout_batch) != len(ready_env_ids): + # raise RuntimeError( + # f"The length of the collected_rollout_batch {len(cur_rollout_batch)}) is not equal to the length of ready_env_ids" + # f"{len(ready_env_ids)}. This should not happen and could be a bug!", + # ) # restore the state: if the last state is None, it won't store - last_state = self.data.policy.pop("hidden_state", None) # get the next action - if random: - try: - act_sample = [self._action_space[i].sample() for i in ready_env_ids] - except TypeError: # envpool's action space is not for per-env - act_sample = [self._action_space.sample() for _ in ready_env_ids] - act_sample = self.policy.map_action_inverse(act_sample) # type: ignore - self.data.update(act=act_sample) - else: - if no_grad: - with torch.no_grad(): # faster than retain_grad version - # self.data.obs will be used by agent to get result - result = self.policy(self.data, last_state) - else: - result = self.policy(self.data, last_state) - # update state / act / policy into self.data - policy = result.get("policy", Batch()) - assert isinstance(policy, Batch) - state = result.get("state", None) - if state is not None: - policy.hidden_state = state # save state into buffer - act = to_numpy(result.act) - if self.exploration_noise: - act = self.policy.exploration_noise(act, self.data) - self.data.update(policy=policy, act=act) - - # get bounded and remapped actions first (not saved into buffer) - action_remap = self.policy.map_action(self.data.act) - # step in env - - obs_next, rew, terminated, truncated, info = self.env.step( - action_remap, - ready_env_ids, + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, ) - done = np.logical_or(terminated, truncated) - self.data.update( - obs_next=obs_next, - rew=rew, - terminated=terminated, - truncated=truncated, - done=done, - info=info, + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, ) - if self.preprocess_fn: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - policy=self.data.policy, - env_id=ready_env_ids, - act=self.data.act, - ), - ) + if isinstance(info_R, dict): # type: ignore[unreachable] + # This can happen if the env is an envpool env. Then the info returned by step is a dict + info_R = _dict_of_arr_to_arr_of_dicts(info_R) # type: ignore[unreachable] + done_R = np.logical_or(terminated_R, truncated_R) + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=last_obs_RO, + act=act_RA, + policy=policy_R, + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), + ) + + # TODO: only makes sense if render_mode is human. + # Also, doubtful whether it makes sense at all for true vectorized envs if render: self.env.render() - if render > 0 and not np.isclose(render, 0): + if not np.isclose(render, 0): time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) # collect statistics - step_count += len(ready_env_ids) + num_episodes_done_this_iter = np.sum(done_R) + num_collected_episodes += num_episodes_done_this_iter + step_count += len(ready_env_ids_R) - if np.any(done): - env_ind_local = np.where(done)[0] - env_ind_global = ready_env_ids[env_ind_local] - episode_count += len(env_ind_local) - episode_lens.extend(ep_len[env_ind_local]) - episode_returns.extend(ep_rew[env_ind_local]) - episode_start_indices.extend(ep_idx[env_ind_local]) + # preparing for the next iteration + # obs_next, info and hidden_state will be modified inplace in the code below, so we copy to not affect the data in the buffer + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(hidden_state_RH) + + # Preparing last_obs_RO, last_info_R, last_hidden_state_RH for the next while-loop iteration + # Resetting envs that reached done, or removing some of them from the collection if needed (see below) + if num_episodes_done_this_iter > 0: + # TODO: adjust the whole index story, don't use np.where, just slice with boolean arrays + # D - number of envs that reached done in the rollout above + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) # now we copy obs_next to obs, but since there might be # finished episodes, we have to reset finished envs first. - self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) - for i in env_ind_local: - self._reset_state(i) - # remove surplus env id from ready_env_ids - # to avoid bias in selecting environments + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + + # Set the hidden state to zero or None for the envs that reached done + # TODO: does it have to be so complicated? We should have a single clear type for hidden_state instead of + # this complex logic + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) + + # preparing for the next iteration + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D + + # Handling the case when we have more ready envs than desired and are not done yet + # + # This can only happen if we are collecting a fixed number of episodes + # If we have more ready envs than there are remaining episodes to collect, + # we will remove some of them for the next rollout + # One effect of this is the following: only envs that have completed an episode + # in the last step can ever be removed from the ready envs. + # Thus, this guarantees that each env will contribute at least one episode to the + # collected data (the buffer). This effect was previous called "avoiding bias in selecting environments" + # However, it is not at all clear whether this is actually useful or necessary. + # Additional naming convention: + # S - number of surplus envs + # TODO: can the whole block be removed? If we have too many episodes, we could just strip the last ones. + # Changing R to R-S highly increases the complexity of the code. if n_episode: - surplus_env_num = len(ready_env_ids) - (n_episode - episode_count) + remaining_episodes_to_collect = n_episode - num_collected_episodes + surplus_env_num = len(ready_env_ids_R) - remaining_episodes_to_collect if surplus_env_num > 0: - mask = np.ones_like(ready_env_ids, dtype=bool) - mask[env_ind_local[:surplus_env_num]] = False - ready_env_ids = ready_env_ids[mask] - self.data = self.data[mask] + # R becomes R-S here, preparing for the next iteration in while loop + # Everything that was of length R needs to be filtered and become of length R-S. + # Note that this won't be the last iteration, as one iteration equals one + # step and we still need to collect the remaining episodes to reach the breaking condition. - self.data.obs = self.data.obs_next + # creating the mask + env_to_be_ignored_ind_local_S = env_ind_local_D[:surplus_env_num] + env_should_remain_R = np.ones_like(ready_env_ids_R, dtype=bool) + env_should_remain_R[env_to_be_ignored_ind_local_S] = False + # stripping the "idle" indices, shortening the relevant quantities from R to R-S + ready_env_ids_R = ready_env_ids_R[env_should_remain_R] + last_obs_RO = last_obs_RO[env_should_remain_R] + last_info_R = last_info_R[env_should_remain_R] + if hidden_state_RH is not None: + last_hidden_state_RH = last_hidden_state_RH[env_should_remain_R] # type: ignore[index] - if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): break # generate statistics self.collect_step += step_count - self.collect_episode += episode_count + self.collect_episode += num_collected_episodes collect_time = max(time.time() - start_time, 1e-9) self.collect_time += collect_time - if n_episode: - data = Batch( - obs={}, - act={}, - rew={}, - terminated={}, - truncated={}, - done={}, - obs_next={}, - info={}, - policy={}, - ) - self.data = cast(RolloutBatchProtocol, data) - self.reset_env() + if n_step: + # persist for future collect iterations + self._pre_collect_obs_RO = last_obs_RO + self._pre_collect_info_R = last_info_R + self._pre_collect_hidden_state_RH = last_hidden_state_RH + elif n_episode: + # reset envs and the _pre_collect fields + self.reset_env(gym_reset_kwargs) # todo still necessary? - return CollectStats( - n_collected_episodes=episode_count, + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, n_collected_steps=step_count, collect_time=collect_time, collect_speed=step_count / collect_time, - returns=np.array(episode_returns), - returns_stat=SequenceSummaryStats.from_sequence(episode_returns) - if len(episode_returns) > 0 - else None, - lens=np.array(episode_lens, int), - lens_stat=SequenceSummaryStats.from_sequence(episode_lens) - if len(episode_lens) > 0 - else None, ) + def _reset_hidden_state_based_on_type( + self, + env_ind_local_D: np.ndarray, + last_hidden_state_RH: np.ndarray | torch.Tensor | Batch | None, + ) -> None: + if isinstance(last_hidden_state_RH, torch.Tensor): + last_hidden_state_RH[env_ind_local_D].zero_() # type: ignore[index] + elif isinstance(last_hidden_state_RH, np.ndarray): + last_hidden_state_RH[env_ind_local_D] = ( + None if last_hidden_state_RH.dtype == object else 0 + ) + elif isinstance(last_hidden_state_RH, Batch): + last_hidden_state_RH.empty_(env_ind_local_D) + # todo is this inplace magic and just working? + class AsyncCollector(Collector): """Async Collector handles async vector environment. @@ -429,7 +608,6 @@ class AsyncCollector(Collector): policy: BasePolicy, env: BaseVectorEnv, buffer: ReplayBuffer | None = None, - preprocess_fn: Callable[..., RolloutBatchProtocol] | None = None, exploration_noise: bool = False, ) -> None: # assert env.is_async @@ -438,13 +616,48 @@ class AsyncCollector(Collector): policy, env, buffer, - preprocess_fn, exploration_noise, ) + # E denotes the number of parallel environments: self.env_num + # At init, E=R but during collection R <= E + # Keep in sync with reset! + self._ready_env_ids_R: np.ndarray = np.arange(self.env_num) + self._current_obs_in_all_envs_EO: np.ndarray | None = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E: np.ndarray | None = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH: np.ndarray | torch.Tensor | Batch | None = copy( + self._pre_collect_hidden_state_RH, + ) + self._current_action_in_all_envs_EA: np.ndarray = np.empty(self.env_num) + self._current_policy_in_all_envs_E: Batch | None = None - def reset_env(self, gym_reset_kwargs: dict[str, Any] | None = None) -> None: - super().reset_env(gym_reset_kwargs) - self._ready_env_ids = np.arange(self.env_num) + def reset( + self, + reset_buffer: bool = True, + reset_stats: bool = True, + gym_reset_kwargs: dict[str, Any] | None = None, + ) -> None: + """Reset the environment, statistics, and data needed to start the collection. + + :param reset_buffer: if true, reset the replay buffer attached + to the collector. + :param reset_stats: if true, reset the statistics attached to the collector. + :param gym_reset_kwargs: extra keyword arguments to pass into the environment's + reset function. Defaults to None (extra keyword arguments) + """ + # This sets the _pre_collect attrs + super().reset( + reset_buffer=reset_buffer, + reset_stats=reset_stats, + gym_reset_kwargs=gym_reset_kwargs, + ) + # Keep in sync with init! + self._ready_env_ids_R = np.arange(self.env_num) + # E denotes the number of parallel environments self.env_num + self._current_obs_in_all_envs_EO = copy(self._pre_collect_obs_RO) + self._current_info_in_all_envs_E = copy(self._pre_collect_info_R) + self._current_hidden_state_in_all_envs_EH = copy(self._pre_collect_hidden_state_RH) + self._current_action_in_all_envs_EA = np.empty(self.env_num) + self._current_policy_in_all_envs_E = None def collect( self, @@ -453,22 +666,27 @@ class AsyncCollector(Collector): random: bool = False, render: float | None = None, no_grad: bool = True, + reset_before_collect: bool = False, gym_reset_kwargs: dict[str, Any] | None = None, ) -> CollectStats: - """Collect a specified number of step or episode with async env setting. + """Collect a specified number of steps or episodes with async env setting. - This function doesn't collect exactly n_step or n_episode number of - transitions. Instead, in order to support async setting, it may collect more - than given n_step or n_episode transitions and save into buffer. + This function does not collect an exact number of transitions specified by n_step or + n_episode. Instead, to support the asynchronous setting, it may collect more transitions + than requested by n_step or n_episode and save them into the buffer. :param n_step: how many steps you want to collect. :param n_episode: how many episodes you want to collect. - :param random: whether to use random policy for collecting data. Default + :param random: whether to use random policy_R for collecting data. Default to False. :param render: the sleep time between rendering consecutive frames. Default to None (no rendering). - :param no_grad: whether to retain gradient in policy.forward(). Default to + :param no_grad: whether to retain gradient in policy_R.forward(). Default to True (no gradient retaining). + :param reset_before_collect: whether to reset the environment before + collecting data. It has only an effect if n_episode is not None, i.e. + if one wants to collect a fixed number of episodes. + (The collector needs the initial obs and info to function properly.) :param gym_reset_kwargs: extra keyword arguments to pass into the environment's reset function. Defaults to None (extra keyword arguments) @@ -479,6 +697,9 @@ class AsyncCollector(Collector): :return: A dataclass object """ + use_grad = not no_grad + gym_reset_kwargs = gym_reset_kwargs or {} + # collect at least n_step or n_episode if n_step is not None: assert n_episode is None, ( @@ -494,104 +715,123 @@ class AsyncCollector(Collector): "in AsyncCollector.collect().", ) - ready_env_ids = self._ready_env_ids + if reset_before_collect: + # first we need to step all envs to be able to interact with them + if self.env.waiting_id: + self.env.step(None, id=self.env.waiting_id) + self.reset(reset_buffer=False, gym_reset_kwargs=gym_reset_kwargs) start_time = time.time() step_count = 0 - episode_count = 0 + num_collected_episodes = 0 episode_returns: list[float] = [] episode_lens: list[int] = [] episode_start_indices: list[int] = [] + ready_env_ids_R = self._ready_env_ids_R + # last_obs_RO= self._current_obs_in_all_envs_EO[ready_env_ids_R] # type: ignore[index] + # last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] # type: ignore[index] + # last_hidden_state_RH = self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] # type: ignore[index] + # last_obs_RO = self._pre_collect_obs_RO + # last_info_R = self._pre_collect_info_R + # last_hidden_state_RH = self._pre_collect_hidden_state_RH + if self._current_obs_in_all_envs_EO is None or self._current_info_in_all_envs_E is None: + raise RuntimeError( + "Current obs or info array is None, did you call reset or pass reset_at_collect=True?", + ) + + last_obs_RO = self._current_obs_in_all_envs_EO[ready_env_ids_R] + last_info_R = self._current_info_in_all_envs_E[ready_env_ids_R] + last_hidden_state_RH = _nullable_slice( + self._current_hidden_state_in_all_envs_EH, + ready_env_ids_R, + ) + # Each iteration of the AsyncCollector is only stepping a subset of the + # envs. The last observation/ hidden state of the ones not included in + # the current iteration has to be retained. while True: - whole_data = self.data - self.data = self.data[ready_env_ids] - assert len(whole_data) == self.env_num # major difference - # restore the state: if the last state is None, it won't store - last_state = self.data.policy.pop("hidden_state", None) + # todo do we need this? + # todo extend to all current attributes but some could be None at init + if self._current_obs_in_all_envs_EO is None: + raise RuntimeError( + "Current obs is None, did you call reset or pass reset_at_collect=True?", + ) + if ( + not len(self._current_obs_in_all_envs_EO) + == len(self._current_action_in_all_envs_EA) + == self.env_num + ): # major difference + raise RuntimeError( + f"{len(self._current_obs_in_all_envs_EO)=} and" + f"{len(self._current_action_in_all_envs_EA)=} have to equal" + f" {self.env_num=} as it tracks the current transition" + f"in all envs", + ) # get the next action - if random: - try: - act_sample = [self._action_space[i].sample() for i in ready_env_ids] - except TypeError: # envpool's action space is not for per-env - act_sample = [self._action_space.sample() for _ in ready_env_ids] - act_sample = self.policy.map_action_inverse(act_sample) # type: ignore - self.data.update(act=act_sample) + ( + act_RA, + act_normalized_RA, + policy_R, + hidden_state_RH, + ) = self._compute_action_policy_hidden( + random=random, + ready_env_ids_R=ready_env_ids_R, + use_grad=use_grad, + last_obs_RO=last_obs_RO, + last_info_R=last_info_R, + last_hidden_state_RH=last_hidden_state_RH, + ) + + # save act_RA/policy_R/ hidden_state_RH before env.step + self._current_action_in_all_envs_EA[ready_env_ids_R] = act_RA + if self._current_policy_in_all_envs_E: + self._current_policy_in_all_envs_E[ready_env_ids_R] = policy_R else: - if no_grad: - with torch.no_grad(): # faster than retain_grad version - # self.data.obs will be used by agent to get result - result = self.policy(self.data, last_state) + self._current_policy_in_all_envs_E = policy_R # first iteration + if hidden_state_RH is not None: + if self._current_hidden_state_in_all_envs_EH is not None: + # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not + # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat + # and hope that if one of the two is a tensor, the other one is as well. + self._current_hidden_state_in_all_envs_EH = cast( + np.ndarray | Batch, + self._current_hidden_state_in_all_envs_EH, + ) + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = hidden_state_RH else: - result = self.policy(self.data, last_state) - # update state / act / policy into self.data - policy = result.get("policy", Batch()) - assert isinstance(policy, Batch) - state = result.get("state", None) - if state is not None: - policy.hidden_state = state # save state into buffer - act = to_numpy(result.act) - if self.exploration_noise: - act = self.policy.exploration_noise(act, self.data) - self.data.update(policy=policy, act=act) + self._current_hidden_state_in_all_envs_EH = hidden_state_RH - # save act/policy before env.step - try: - whole_data.act[ready_env_ids] = self.data.act # type: ignore - whole_data.policy[ready_env_ids] = self.data.policy - except ValueError: - alloc_by_keys_diff(whole_data, self.data, self.env_num, False) - whole_data[ready_env_ids] = self.data # lots of overhead - - # get bounded and remapped actions first (not saved into buffer) - action_remap = self.policy.map_action(self.data.act) # step in env - obs_next, rew, terminated, truncated, info = self.env.step( - action_remap, - ready_env_ids, + obs_next_RO, rew_R, terminated_R, truncated_R, info_R = self.env.step( + act_normalized_RA, + ready_env_ids_R, ) - done = np.logical_or(terminated, truncated) - - # change self.data here because ready_env_ids has changed + done_R = np.logical_or(terminated_R, truncated_R) + # Not all environments of the AsyncCollector might have performed a step in this iteration. + # Change batch_of_envs_with_step_in_this_iteration here to reflect that ready_env_ids_R has changed. + # This means especially that R is potentially changing every iteration try: - ready_env_ids = info["env_id"] + ready_env_ids_R = cast(np.ndarray, info_R["env_id"]) + # TODO: don't use bare Exception! except Exception: - ready_env_ids = np.array([i["env_id"] for i in info]) - self.data = whole_data[ready_env_ids] + ready_env_ids_R = np.array([i["env_id"] for i in info_R]) - self.data.update( - obs_next=obs_next, - rew=rew, - terminated=terminated, - truncated=truncated, - info=info, + current_iteration_batch = cast( + RolloutBatchProtocol, + Batch( + obs=self._current_obs_in_all_envs_EO[ready_env_ids_R], + act=self._current_action_in_all_envs_EA[ready_env_ids_R], + policy=self._current_policy_in_all_envs_E[ready_env_ids_R], + obs_next=obs_next_RO, + rew=rew_R, + terminated=terminated_R, + truncated=truncated_R, + done=done_R, + info=info_R, + ), ) - if self.preprocess_fn: - try: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - terminated=self.data.terminated, - truncated=self.data.truncated, - info=self.data.info, - env_id=ready_env_ids, - act=self.data.act, - ), - ) - except TypeError: - self.data.update( - self.preprocess_fn( - obs_next=self.data.obs_next, - rew=self.data.rew, - done=self.data.done, - info=self.data.info, - env_id=ready_env_ids, - act=self.data.act, - ), - ) if render: self.env.render() @@ -599,60 +839,77 @@ class AsyncCollector(Collector): time.sleep(render) # add data into the buffer - ptr, ep_rew, ep_len, ep_idx = self.buffer.add(self.data, buffer_ids=ready_env_ids) + ptr_R, ep_rew_R, ep_len_R, ep_idx_R = self.buffer.add( + current_iteration_batch, + buffer_ids=ready_env_ids_R, + ) # collect statistics - step_count += len(ready_env_ids) + num_episodes_done_this_iter = np.sum(done_R) + step_count += len(ready_env_ids_R) + num_collected_episodes += num_episodes_done_this_iter - if np.any(done): - env_ind_local = np.where(done)[0] - env_ind_global = ready_env_ids[env_ind_local] - episode_count += len(env_ind_local) - episode_lens.extend(ep_len[env_ind_local]) - episode_returns.extend(ep_rew[env_ind_local]) - episode_start_indices.extend(ep_idx[env_ind_local]) - # now we copy obs_next to obs, but since there might be + # preparing for the next iteration + # todo do we need the copy stuff (tests pass also without) + # todo seem we can get rid of this last_sth stuff altogether + last_obs_RO = copy(obs_next_RO) + last_info_R = copy(info_R) + last_hidden_state_RH = copy(self._current_hidden_state_in_all_envs_EH[ready_env_ids_R]) # type: ignore[index] + + if num_episodes_done_this_iter: + env_ind_local_D = np.where(done_R)[0] + env_ind_global_D = ready_env_ids_R[env_ind_local_D] + episode_lens.extend(ep_len_R[env_ind_local_D]) + episode_returns.extend(ep_rew_R[env_ind_local_D]) + episode_start_indices.extend(ep_idx_R[env_ind_local_D]) + + # now we copy obs_next_RO to obs, but since there might be # finished episodes, we have to reset finished envs first. - self._reset_env_with_ids(env_ind_local, env_ind_global, gym_reset_kwargs) - for i in env_ind_local: - self._reset_state(i) + obs_reset_DO, info_reset_D = self.env.reset( + env_id=env_ind_global_D, + **gym_reset_kwargs, + ) + last_obs_RO[env_ind_local_D] = obs_reset_DO + last_info_R[env_ind_local_D] = info_reset_D - try: - # Need to ignore types b/c according to mypy Tensors cannot be indexed - # by arrays (which they can...) - whole_data.obs[ready_env_ids] = self.data.obs_next # type: ignore - whole_data.rew[ready_env_ids] = self.data.rew - whole_data.done[ready_env_ids] = self.data.done - whole_data.info[ready_env_ids] = self.data.info # type: ignore - except ValueError: - alloc_by_keys_diff(whole_data, self.data, self.env_num, False) - self.data.obs = self.data.obs_next - # lots of overhead - whole_data[ready_env_ids] = self.data - self.data = whole_data + self._reset_hidden_state_based_on_type(env_ind_local_D, last_hidden_state_RH) - if (n_step and step_count >= n_step) or (n_episode and episode_count >= n_episode): + # update based on the current transition in all envs + self._current_obs_in_all_envs_EO[ready_env_ids_R] = last_obs_RO + # this is a list, so loop over + for idx, ready_env_id in enumerate(ready_env_ids_R): + self._current_info_in_all_envs_E[ready_env_id] = last_info_R[idx] + if self._current_hidden_state_in_all_envs_EH is not None: + # Need to cast since if it's a Tensor, the assignment might in fact fail if hidden_state_RH is not + # a tensor as well. This is hard to express with proper typing, even using @overload, so we cheat + # and hope that if one of the two is a tensor, the other one is as well. + self._current_hidden_state_in_all_envs_EH = cast( + np.ndarray | Batch, + self._current_hidden_state_in_all_envs_EH, + ) + self._current_hidden_state_in_all_envs_EH[ready_env_ids_R] = last_hidden_state_RH + else: + self._current_hidden_state_in_all_envs_EH = last_hidden_state_RH + + if (n_step and step_count >= n_step) or ( + n_episode and num_collected_episodes >= n_episode + ): break - self._ready_env_ids = ready_env_ids - # generate statistics self.collect_step += step_count - self.collect_episode += episode_count + self.collect_episode += num_collected_episodes collect_time = max(time.time() - start_time, 1e-9) self.collect_time += collect_time - return CollectStats( - n_collected_episodes=episode_count, + # persist for future collect iterations + self._ready_env_ids_R = ready_env_ids_R + + return CollectStats.with_autogenerated_stats( + returns=np.array(episode_returns), + lens=np.array(episode_lens), + n_collected_episodes=num_collected_episodes, n_collected_steps=step_count, collect_time=collect_time, collect_speed=step_count / collect_time, - returns=np.array(episode_returns), - returns_stat=SequenceSummaryStats.from_sequence(episode_returns) - if len(episode_returns) > 0 - else None, - lens=np.array(episode_lens, int), - lens_stat=SequenceSummaryStats.from_sequence(episode_lens) - if len(episode_lens) > 0 - else None, ) diff --git a/tianshou/data/utils/converter.py b/tianshou/data/utils/converter.py index 205c2d5..2df462d 100644 --- a/tianshou/data/utils/converter.py +++ b/tianshou/data/utils/converter.py @@ -12,6 +12,7 @@ from tianshou.data.batch import Batch, _parse_value # TODO: confusing name, could actually return a batch... # Overrides and generic types should be added +# todo check for ActBatchProtocol @no_type_check def to_numpy(x: Any) -> Batch | np.ndarray: """Return an object without torch.Tensor.""" diff --git a/tianshou/env/venv_wrappers.py b/tianshou/env/venv_wrappers.py index 4e30e71..9297dde 100644 --- a/tianshou/env/venv_wrappers.py +++ b/tianshou/env/venv_wrappers.py @@ -44,14 +44,14 @@ class VectorEnvWrapper(BaseVectorEnv): def reset( self, - id: int | list[int] | np.ndarray | None = None, + env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, - ) -> tuple[np.ndarray, dict | list[dict]]: - return self.venv.reset(id, **kwargs) + ) -> tuple[np.ndarray, np.ndarray]: + return self.venv.reset(env_id, **kwargs) def step( self, - action: np.ndarray | torch.Tensor, + action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: return self.venv.step(action, id) @@ -80,10 +80,10 @@ class VectorEnvNormObs(VectorEnvWrapper): def reset( self, - id: int | list[int] | np.ndarray | None = None, + env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, - ) -> tuple[np.ndarray, dict | list[dict]]: - obs, info = self.venv.reset(id, **kwargs) + ) -> tuple[np.ndarray, np.ndarray]: + obs, info = self.venv.reset(env_id, **kwargs) if isinstance(obs, tuple): # type: ignore raise TypeError( @@ -98,7 +98,7 @@ class VectorEnvNormObs(VectorEnvWrapper): def step( self, - action: np.ndarray | torch.Tensor, + action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: step_results = self.venv.step(action, id) diff --git a/tianshou/env/venvs.py b/tianshou/env/venvs.py index dfcd12e..e9309f9 100644 --- a/tianshou/env/venvs.py +++ b/tianshou/env/venvs.py @@ -190,11 +190,13 @@ class BaseVectorEnv: ), f"Cannot interact with environment {i} which is stepping now." assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}." + # TODO: for now, has to be kept in sync with reset in EnvPoolMixin + # In particular, can't rename env_id to env_ids def reset( self, - id: int | list[int] | np.ndarray | None = None, + env_id: int | list[int] | np.ndarray | None = None, **kwargs: Any, - ) -> tuple[np.ndarray, dict | list[dict]]: + ) -> tuple[np.ndarray, np.ndarray]: """Reset the state of some envs and return initial observations. If id is None, reset the state of all the environments and return @@ -202,14 +204,14 @@ class BaseVectorEnv: the given id, either an int or a list. """ self._assert_is_not_closed() - id = self._wrap_id(id) + env_id = self._wrap_id(env_id) if self.is_async: - self._assert_id(id) + self._assert_id(env_id) # send(None) == reset() in worker - for i in id: - self.workers[i].send(None, **kwargs) - ret_list = [self.workers[i].recv() for i in id] + for id in env_id: + self.workers[id].send(None, **kwargs) + ret_list = [self.workers[id].recv() for id in env_id] assert ( isinstance(ret_list[0], tuple | list) @@ -229,12 +231,12 @@ class BaseVectorEnv: except ValueError: # different len(obs) obs = np.array(obs_list, dtype=object) - infos = [r[1] for r in ret_list] - return obs, infos # type: ignore + infos = np.array([r[1] for r in ret_list]) + return obs, infos def step( self, - action: np.ndarray | torch.Tensor, + action: np.ndarray | torch.Tensor | None, id: int | list[int] | np.ndarray | None = None, ) -> gym_new_venv_step_type: """Run one timestep of some environments' dynamics. @@ -248,6 +250,8 @@ class BaseVectorEnv: batch_done, batch_info) in numpy format. :param numpy.ndarray action: a batch of action provided by the agent. + If the venv is async, the action can be None, which will result + in all arrays in the returned tuple being empty. :return: A tuple consisting of either: @@ -271,6 +275,8 @@ class BaseVectorEnv: self._assert_is_not_closed() id = self._wrap_id(id) if not self.is_async: + if action is None: + raise ValueError("action must be not-None for non-async") assert len(action) == len(id) for i, j in enumerate(id): self.workers[j].send(action[i]) diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 39d367e..f71a7f9 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -93,7 +93,14 @@ class AgentFactory(ABC, ToStringMixin): self, policy: BasePolicy, envs: Environments, + reset_collectors: bool = True, ) -> tuple[Collector, Collector]: + """:param policy: + :param envs: + :param reset_collectors: Whether to reset the collectors before returning them. + Setting to True means that the envs will be reset as well. + :return: + """ buffer_size = self.sampling_config.buffer_size train_envs = envs.train_envs buffer: ReplayBuffer @@ -114,6 +121,10 @@ class AgentFactory(ABC, ToStringMixin): ) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) test_collector = Collector(policy, envs.test_envs) + if reset_collectors: + train_collector.reset() + test_collector.reset() + if self.sampling_config.start_timesteps > 0: log.info( f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index 5b7d388..17f0550 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -311,7 +311,7 @@ class Experiment(ToStringMixin): ) -> None: policy.eval() collector = Collector(policy, env) - result = collector.collect(n_episode=num_episodes, render=render) + result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True) assert result.returns_stat is not None # for mypy assert result.lens_stat is not None # for mypy log.info( diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 087e3cb..7df7ebd 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -18,6 +18,7 @@ from tianshou.data.batch import Batch, BatchProtocol, arr_type from tianshou.data.buffer.base import TBuffer from tianshou.data.types import ( ActBatchProtocol, + ActStateBatchProtocol, BatchWithReturnsProtocol, ObsBatchProtocol, RolloutBatchProtocol, @@ -233,11 +234,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): # have a method to add noise to action. # So we add the default behavior here. It's a little messy, maybe one can # find a better way to do this. + + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: """Modify the action from policy.forward with exploration noise. NOTE: currently does not add any noise! Needs to be overridden by subclasses @@ -287,7 +291,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): batch: ObsBatchProtocol, state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, - ) -> ActBatchProtocol: + ) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing """Compute action over the given batch data. :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: diff --git a/tianshou/policy/modelbased/icm.py b/tianshou/policy/modelbased/icm.py index 54d560b..9a603b7 100644 --- a/tianshou/policy/modelbased/icm.py +++ b/tianshou/policy/modelbased/icm.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Self +from typing import Any, Literal, Self, TypeVar import gymnasium as gym import numpy as np @@ -105,11 +105,13 @@ class ICMPolicy(BasePolicy[ICMTrainingStats]): """ return self.policy.forward(batch, state, **kwargs) + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: return self.policy.exploration_noise(act, batch) def set_eps(self, eps: float) -> None: diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index a91ea00..ba37477 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -8,6 +8,7 @@ import torch from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( + ActBatchProtocol, BatchWithReturnsProtocol, ModelOutputBatchProtocol, ObsBatchProtocol, @@ -182,11 +183,13 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index 1b371d4..b54860e 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -10,6 +10,7 @@ import torch from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( + ActBatchProtocol, ActStateBatchProtocol, BatchWithReturnsProtocol, ObsBatchProtocol, @@ -208,11 +209,13 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: if self._exploration_noise is None: return act if isinstance(act, np.ndarray): diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index b271cbd..d1054f9 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -8,8 +8,7 @@ from overrides import override from torch.distributions import Categorical from tianshou.data import Batch, ReplayBuffer, to_torch -from tianshou.data.batch import BatchProtocol -from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.sac import SACTrainingStats @@ -184,9 +183,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), ) + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: return act diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index 5b90510..ad5f7dd 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -9,6 +9,7 @@ import torch from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data.batch import BatchProtocol from tianshou.data.types import ( + ActBatchProtocol, BatchWithReturnsProtocol, ModelOutputBatchProtocol, ObsBatchProtocol, @@ -232,11 +233,13 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): bsz = len(act) rand_mask = np.random.rand(bsz) < self.eps diff --git a/tianshou/policy/multiagent/mapolicy.py b/tianshou/policy/multiagent/mapolicy.py index e41e069..7a7be58 100644 --- a/tianshou/policy/multiagent/mapolicy.py +++ b/tianshou/policy/multiagent/mapolicy.py @@ -1,11 +1,11 @@ -from typing import Any, Literal, Protocol, Self, cast, overload +from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload import numpy as np from overrides import override from tianshou.data import Batch, ReplayBuffer from tianshou.data.batch import BatchProtocol, IndexType -from tianshou.data.types import RolloutBatchProtocol +from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats @@ -160,16 +160,18 @@ class MultiAgentPolicyManager(BasePolicy): buffer._meta.rew = save_rew return Batch(results) + _TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol") + def exploration_noise( self, - act: np.ndarray | BatchProtocol, - batch: RolloutBatchProtocol, - ) -> np.ndarray | BatchProtocol: + act: _TArrOrActBatch, + batch: ObsBatchProtocol, + ) -> _TArrOrActBatch: """Add exploration noise from sub-policy onto act.""" - assert isinstance( - batch.obs, - BatchProtocol, - ), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" + if not isinstance(batch.obs, Batch): + raise TypeError( + f"here only observations of type Batch are permitted, but got {type(batch.obs)}", + ) for agent_id, policy in self.policies.items(): agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] if len(agent_index) == 0: @@ -223,7 +225,7 @@ class MultiAgentPolicyManager(BasePolicy): results.append((False, np.array([-1]), Batch(), Batch(), Batch())) continue tmp_batch = batch[agent_index] - if isinstance(tmp_batch.rew, np.ndarray): + if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray): # reward can be empty Batch (after initial reset) or nparray. tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] if not hasattr(tmp_batch.obs, "mask"): diff --git a/tianshou/trainer/base.py b/tianshou/trainer/base.py index c6d87d1..675112f 100644 --- a/tianshou/trainer/base.py +++ b/tianshou/trainer/base.py @@ -237,7 +237,13 @@ class BaseTrainer(ABC): self.stop_fn_flag = False self.iter_num = 0 - def reset(self) -> None: + def _reset_collectors(self, reset_buffer: bool = False) -> None: + if self.train_collector is not None: + self.train_collector.reset(reset_buffer=reset_buffer) + if self.test_collector is not None: + self.test_collector.reset(reset_buffer=reset_buffer) + + def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None: """Initialize or reset the instance to yield a new iterator from zero.""" self.is_run = False self.env_step = 0 @@ -250,16 +256,18 @@ class BaseTrainer(ABC): self.last_rew, self.last_len = 0.0, 0.0 self.start_time = time.time() - if self.train_collector is not None: - self.train_collector.reset_stat() - if self.train_collector.policy != self.policy or self.test_collector is None: - self.test_in_train = False + if reset_collectors: + self._reset_collectors(reset_buffer=reset_buffer) + + if self.train_collector is not None and ( + self.train_collector.policy != self.policy or self.test_collector is None + ): + self.test_in_train = False if self.test_collector is not None: assert self.episode_per_test is not None assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 - self.test_collector.reset_stat() test_result = test_episode( self.policy, self.test_collector, @@ -284,7 +292,7 @@ class BaseTrainer(ABC): self.iter_num = 0 def __iter__(self): # type: ignore - self.reset() + self.reset(reset_collectors=True, reset_buffer=False) return self def __next__(self) -> EpochStats: @@ -308,8 +316,8 @@ class BaseTrainer(ABC): # perform n step_per_epoch with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: + train_stat: CollectStatsBase while t.n < t.total and not self.stop_fn_flag: - train_stat: CollectStatsBase if self.train_collector is not None: train_stat, self.stop_fn_flag = self.train_step() pbar_data_dict = { @@ -515,12 +523,14 @@ class BaseTrainer(ABC): stats of the whole dataset """ - def run(self) -> InfoStats: + def run(self, reset_prior_to_run: bool = True) -> InfoStats: """Consume iterator. See itertools - recipes. Use functions that consume iterators at C speed (feed the entire iterator into a zero-length deque). """ + if reset_prior_to_run: + self.reset() try: self.is_run = True deque(self, maxlen=0) # feed the entire iterator into a zero-length deque diff --git a/tianshou/trainer/utils.py b/tianshou/trainer/utils.py index 300c7c4..7a96ea0 100644 --- a/tianshou/trainer/utils.py +++ b/tianshou/trainer/utils.py @@ -26,8 +26,7 @@ def test_episode( reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, ) -> CollectStats: """A simple wrapper of testing policy in collector.""" - collector.reset_env() - collector.reset_buffer() + collector.reset(reset_stats=False) policy.eval() if test_fn: test_fn(epoch, global_step) From ecb272c61b2d40a4064c58ec616d4963a84f45e4 Mon Sep 17 00:00:00 2001 From: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com> Date: Thu, 28 Mar 2024 18:06:00 +0100 Subject: [PATCH 2/5] Update CHANGELOG.md [skip ci] --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a37acb..dd50651 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 +### Tests +- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 + Started after v1.0.0 From 23a33a9aa37dd358c25b6c044313f7a5f101f487 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 28 Mar 2024 18:13:15 +0100 Subject: [PATCH 3/5] Removed link to Chinese docs [skip ci] --- docs/04_contributing/04_contributing.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/04_contributing/04_contributing.rst b/docs/04_contributing/04_contributing.rst index b24acbf..01764fc 100644 --- a/docs/04_contributing/04_contributing.rst +++ b/docs/04_contributing/04_contributing.rst @@ -92,8 +92,6 @@ To compile documentation into webpage, run The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/). -Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/. - Documentation Generation Test ----------------------------- From 5bf923c9bd2a08746dd599010f92e1365bcfb844 Mon Sep 17 00:00:00 2001 From: Michael Panchenko Date: Thu, 28 Mar 2024 18:17:25 +0100 Subject: [PATCH 4/5] Removed more references to Chinese docs [skip ci] --- README.md | 3 --- docs/index.rst | 3 --- 2 files changed, 6 deletions(-) diff --git a/README.md b/README.md index 2ddefd5..0238d9e 100644 --- a/README.md +++ b/README.md @@ -147,9 +147,6 @@ Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https:// Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders. -中文文档位于 [https://tianshou.readthedocs.io/zh/master/](https://tianshou.readthedocs.io/zh/master/)。 - - ## Why Tianshou? diff --git a/docs/index.rst b/docs/index.rst index c41c7e6..c7c2177 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -57,9 +57,6 @@ Here is Tianshou's other features: * Support multi-GPU training :ref:`multi_gpu` * Comprehensive `unit tests `_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking -中文文档位于 `https://tianshou.readthedocs.io/zh/master/ `_ - - Installation ------------ From bf0d6321084a7ca6f765ab2564630e5c564fdcc3 Mon Sep 17 00:00:00 2001 From: Erni <38285979+arnaujc91@users.noreply.github.com> Date: Mon, 1 Apr 2024 17:14:17 +0200 Subject: [PATCH 5/5] Naming and typing improvements in Actor/Critic/Policy forwards (#1032) Closes #917 ### Internal Improvements - Better variable names related to model outputs (logits, dist input etc.). #1032 - Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc., instead of just `nn.Module`. #1032 - Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032 - Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 - Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 ### Breaking Changes - Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both continuous and discrete cases. #1032 --------- Co-authored-by: Arnau Jimenez Co-authored-by: Michael Panchenko --- CHANGELOG.md | 8 ++ docs/02_notebooks/L4_Policy.ipynb | 4 +- docs/spelling_wordlist.txt | 5 + examples/inverse/irl_gail.py | 5 +- examples/mujoco/mujoco_a2c.py | 5 +- examples/mujoco/mujoco_npg.py | 5 +- examples/mujoco/mujoco_ppo.py | 5 +- examples/mujoco/mujoco_reinforce.py | 5 +- examples/mujoco/mujoco_trpo.py | 5 +- test/base/test_policy.py | 8 +- test/continuous/test_npg.py | 5 +- test/continuous/test_ppo.py | 5 +- test/continuous/test_trpo.py | 5 +- test/offline/test_gail.py | 5 +- test/pettingzoo/pistonball_continuous.py | 5 +- tianshou/highlevel/params/dist_fn.py | 23 +++-- tianshou/highlevel/params/policy_params.py | 4 +- tianshou/policy/base.py | 9 +- tianshou/policy/imitation/base.py | 23 ++++- tianshou/policy/imitation/discrete_bcq.py | 3 +- tianshou/policy/imitation/discrete_cql.py | 3 +- tianshou/policy/imitation/discrete_crr.py | 10 +- tianshou/policy/imitation/gail.py | 15 ++- tianshou/policy/imitation/td3_bc.py | 2 +- tianshou/policy/modelfree/a2c.py | 15 ++- tianshou/policy/modelfree/bdq.py | 10 +- tianshou/policy/modelfree/c51.py | 3 +- tianshou/policy/modelfree/ddpg.py | 8 +- tianshou/policy/modelfree/discrete_sac.py | 20 ++-- tianshou/policy/modelfree/dqn.py | 14 +-- tianshou/policy/modelfree/fqf.py | 3 +- tianshou/policy/modelfree/iqn.py | 3 +- tianshou/policy/modelfree/npg.py | 15 ++- tianshou/policy/modelfree/pg.py | 50 +++++---- tianshou/policy/modelfree/ppo.py | 15 ++- tianshou/policy/modelfree/qrdqn.py | 3 +- tianshou/policy/modelfree/redq.py | 24 +++-- tianshou/policy/modelfree/sac.py | 23 ++--- tianshou/policy/modelfree/td3.py | 2 +- tianshou/policy/modelfree/trpo.py | 15 ++- tianshou/utils/net/common.py | 11 ++ tianshou/utils/net/continuous.py | 114 ++++++++++----------- tianshou/utils/net/discrete.py | 67 ++++++------ 43 files changed, 342 insertions(+), 245 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd50651..fb78698 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,12 @@ - Introduced a first iteration of a naming convention for vars in `Collector`s. #1063 - Generally improved readability of Collector code and associated tests (still quite some way to go). #1063 - Improved typing for `exploration_noise` and within Collector. #1063 +- Better variable names related to model outputs (logits, dist input etc.). #1032 +- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc., +instead of just `nn.Module`. #1032 +- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032 +- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032 +- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032 ### Breaking Changes @@ -21,6 +27,8 @@ expicitly or pass `reset_before_collect=True` . #1063 - VectorEnvs now return an array of info-dicts on reset instead of a list. #1063 - Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063 +- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both +continuous and discrete cases. #1032 ### Tests - Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081 diff --git a/docs/02_notebooks/L4_Policy.ipynb b/docs/02_notebooks/L4_Policy.ipynb index 37a1f93..00f7f27 100644 --- a/docs/02_notebooks/L4_Policy.ipynb +++ b/docs/02_notebooks/L4_Policy.ipynb @@ -69,7 +69,7 @@ "from tianshou.policy import BasePolicy\n", "from tianshou.policy.modelfree.pg import (\n", " PGTrainingStats,\n", - " TDistributionFunction,\n", + " TDistFnDiscrOrCont,\n", " TPGTrainingStats,\n", ")\n", "from tianshou.utils import RunningMeanStd\n", @@ -339,7 +339,7 @@ " *,\n", " actor: torch.nn.Module,\n", " optim: torch.optim.Optimizer,\n", - " dist_fn: TDistributionFunction,\n", + " dist_fn: TDistFnDiscrOrCont,\n", " action_space: gym.Space,\n", " discount_factor: float = 0.99,\n", " observation_space: gym.Space | None = None,\n", diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 63ee791..be730ff 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -257,3 +257,8 @@ macOS joblib master Panchenko +BA +BH +BO +BD + diff --git a/examples/inverse/irl_gail.py b/examples/inverse/irl_gail.py index 705acaa..3ee3709 100644 --- a/examples/inverse/irl_gail.py +++ b/examples/inverse/irl_gail.py @@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) # expert replay buffer dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) diff --git a/examples/mujoco/mujoco_a2c.py b/examples/mujoco/mujoco_a2c.py index 95b645d..6caac98 100755 --- a/examples/mujoco/mujoco_a2c.py +++ b/examples/mujoco/mujoco_a2c.py @@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: A2CPolicy = A2CPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_npg.py b/examples/mujoco/mujoco_npg.py index 454565a..e8ee97c 100755 --- a/examples/mujoco/mujoco_npg.py +++ b/examples/mujoco/mujoco_npg.py @@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: NPGPolicy = NPGPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_ppo.py b/examples/mujoco/mujoco_ppo.py index c0d868c..218b95d 100755 --- a/examples/mujoco/mujoco_ppo.py +++ b/examples/mujoco/mujoco_ppo.py @@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: PPOPolicy = PPOPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_reinforce.py b/examples/mujoco/mujoco_reinforce.py index f4a8693..06e2bc1 100755 --- a/examples/mujoco/mujoco_reinforce.py +++ b/examples/mujoco/mujoco_reinforce.py @@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: PGPolicy = PGPolicy( actor=actor, diff --git a/examples/mujoco/mujoco_trpo.py b/examples/mujoco/mujoco_trpo.py index b001fd0..c17ba6c 100755 --- a/examples/mujoco/mujoco_trpo.py +++ b/examples/mujoco/mujoco_trpo.py @@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: TRPOPolicy = TRPOPolicy( actor=actor, diff --git a/test/base/test_policy.py b/test/base/test_policy.py index 9fe6f8c..0c51f84 100644 --- a/test/base/test_policy.py +++ b/test/base/test_policy.py @@ -2,7 +2,7 @@ import gymnasium as gym import numpy as np import pytest import torch -from torch.distributions import Categorical, Independent, Normal +from torch.distributions import Categorical, Distribution, Independent, Normal from tianshou.policy import PPOPolicy from tianshou.utils.net.common import ActorCritic, Net @@ -25,7 +25,11 @@ def policy(request): Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), action_shape=action_space.shape, ) - dist_fn = lambda *logits: Independent(Normal(*logits), 1) + + def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) + elif action_type == "discrete": action_space = gym.spaces.Discrete(3) actor = Actor( diff --git a/test/continuous/test_npg.py b/test/continuous/test_npg.py index bcfe6b0..8e0a50d 100644 --- a/test/continuous/test_npg.py +++ b/test/continuous/test_npg.py @@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: NPGPolicy[NPGTrainingStats] = NPGPolicy( actor=actor, diff --git a/test/continuous/test_ppo.py b/test/continuous/test_ppo.py index d092bc6..38ddbe8 100644 --- a/test/continuous/test_ppo.py +++ b/test/continuous/test_ppo.py @@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( actor=actor, diff --git a/test/continuous/test_trpo.py b/test/continuous/test_trpo.py index 8070612..9de8128 100644 --- a/test/continuous/test_trpo.py +++ b/test/continuous/test_trpo.py @@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: BasePolicy = TRPOPolicy( actor=actor, diff --git a/test/offline/test_gail.py b/test/offline/test_gail.py index 37eb335..68fab72 100644 --- a/test/offline/test_gail.py +++ b/test/offline/test_gail.py @@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None: # replace DiagGuassian with Independent(Normal) which is equivalent # pass *logits to be consistent with policy.forward - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) policy: BasePolicy = GAILPolicy( actor=actor, diff --git a/test/pettingzoo/pistonball_continuous.py b/test/pettingzoo/pistonball_continuous.py index 0897d73..14b5aac 100644 --- a/test/pettingzoo/pistonball_continuous.py +++ b/test/pettingzoo/pistonball_continuous.py @@ -181,8 +181,9 @@ def get_agents( torch.nn.init.zeros_(m.bias) optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr) - def dist(*logits: torch.Tensor) -> Distribution: - return Independent(Normal(*logits), 1) + def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution: + loc, scale = loc_scale + return Independent(Normal(loc, scale), 1) agent: PPOPolicy = PPOPolicy( actor, diff --git a/tianshou/highlevel/params/dist_fn.py b/tianshou/highlevel/params/dist_fn.py index 9e9c266..c8d2aca 100644 --- a/tianshou/highlevel/params/dist_fn.py +++ b/tianshou/highlevel/params/dist_fn.py @@ -1,40 +1,47 @@ from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any import torch from tianshou.highlevel.env import Environments, EnvType -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont from tianshou.utils.string import ToStringMixin class DistributionFunctionFactory(ToStringMixin, ABC): + # True return type defined in subclasses @abstractmethod - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn( + self, + envs: Environments, + ) -> Callable[[Any], torch.distributions.Distribution]: pass class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete: envs.get_type().assert_discrete(self) return self._dist_fn @staticmethod - def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution: + def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical: return torch.distributions.Categorical(logits=p) class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: envs.get_type().assert_continuous(self) return self._dist_fn @staticmethod - def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution: - return torch.distributions.Independent(torch.distributions.Normal(*p), 1) + def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution: + loc, scale = loc_scale + return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1) class DistributionFunctionFactoryDefault(DistributionFunctionFactory): - def create_dist_fn(self, envs: Environments) -> TDistributionFunction: + def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont: match envs.get_type(): case EnvType.DISCRETE: return DistributionFunctionFactoryCategorical().create_dist_fn(envs) diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 373a413..24674bc 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -19,7 +19,7 @@ from tianshou.highlevel.params.dist_fn import ( from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils import MultipleLRSchedulers from tianshou.utils.string import ToStringMixin @@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. Does not affect training. """ - dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default" + dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default" """ This can either be a function which maps the model output to a torch distribution or a factory for the creation of such a function. diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 7df7ebd..77602a0 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -213,10 +213,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): super().__init__() self.observation_space = observation_space self.action_space = action_space + self._action_type: Literal["discrete", "continuous"] if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): - self.action_type = "discrete" + self._action_type = "discrete" elif isinstance(action_space, Box): - self.action_type = "continuous" + self._action_type = "continuous" else: raise ValueError(f"Unsupported action space: {action_space}.") self.agent_id = 0 @@ -226,6 +227,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): self.lr_scheduler = lr_scheduler self._compile() + @property + def action_type(self) -> Literal["discrete", "continuous"]: + return self._action_type + def set_agent_id(self, agent_id: int) -> None: """Set self.agent_id = agent_id, for MARL.""" self.agent_id = agent_id diff --git a/tianshou/policy/imitation/base.py b/tianshou/policy/imitation/base.py index 1daa9ae..6e21016 100644 --- a/tianshou/policy/imitation/base.py +++ b/tianshou/policy/imitation/base.py @@ -16,6 +16,12 @@ from tianshou.data.types import ( from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + @dataclass(kw_only=True) class ImitationTrainingStats(TrainingStats): @@ -72,9 +78,20 @@ class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTra state: dict | BatchProtocol | np.ndarray | None = None, **kwargs: Any, ) -> ModelOutputBatchProtocol: - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits - result = Batch(logits=logits, act=act, state=hidden) + # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced + if self.action_type == "discrete": + # If it's discrete, the "actor" is usually a critic that maps obs to action_values + # which then could be turned into logits or a Categorigal + action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + act_B = action_values_BA.argmax(dim=1) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) + elif self.action_type == "continuous": + # If it's continuous, the actor would usually deliver something like loc, scale determining a + # Gaussian dist + dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH) + else: + raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!") return cast(ModelOutputBatchProtocol, result) def learn( diff --git a/tianshou/policy/imitation/discrete_bcq.py b/tianshou/policy/imitation/discrete_bcq.py index 8412e0a..b5258c1 100644 --- a/tianshou/policy/imitation/discrete_bcq.py +++ b/tianshou/policy/imitation/discrete_bcq.py @@ -34,8 +34,7 @@ TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteB class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): """Implementation of discrete BCQ algorithm. arXiv:1910.01708. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> q_value) + :param model: a model following the rules (s_B -> action_values_BA) :param imitator: a model following the rules in :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) :param optim: a torch.optim for optimizing the model. diff --git a/tianshou/policy/imitation/discrete_cql.py b/tianshou/policy/imitation/discrete_cql.py index dc23cb7..b63f83e 100644 --- a/tianshou/policy/imitation/discrete_cql.py +++ b/tianshou/policy/imitation/discrete_cql.py @@ -25,8 +25,7 @@ TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteC class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param action_space: Env's action space. :param min_q_weight: the weight for the cql loss. diff --git a/tianshou/policy/imitation/discrete_crr.py b/tianshou/policy/imitation/discrete_crr.py index 9a3c2db..9c54129 100644 --- a/tianshou/policy/imitation/discrete_crr.py +++ b/tianshou/policy/imitation/discrete_crr.py @@ -11,6 +11,7 @@ from tianshou.data import to_torch, to_torch_as from tianshou.data.types import RolloutBatchProtocol from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats +from tianshou.utils.net.discrete import Actor, Critic @dataclass @@ -26,8 +27,9 @@ TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteC class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the action-value critic (i.e., Q function) network. (s -> Q(s, \*)) :param optim: a torch.optim for optimizing the model. @@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | Actor, + critic: torch.nn.Module | Critic, optim: torch.optim.Optimizer, action_space: gym.spaces.Discrete, discount_factor: float = 0.99, diff --git a/tianshou/policy/imitation/gail.py b/tianshou/policy/imitation/gail.py index c98f7af..524f040 100644 --- a/tianshou/policy/imitation/gail.py +++ b/tianshou/policy/imitation/gail.py @@ -15,8 +15,11 @@ from tianshou.data import ( from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import PPOPolicy from tianshou.policy.base import TLearningRateScheduler -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.policy.modelfree.ppo import PPOTrainingStats +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -32,7 +35,9 @@ TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats) class GAILPolicy(PPOPolicy[TGailTrainingStats]): r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, expert_buffer: ReplayBuffer, disc_net: torch.nn.Module, diff --git a/tianshou/policy/imitation/td3_bc.py b/tianshou/policy/imitation/td3_bc.py index 7ef700b..f4b2bfe 100644 --- a/tianshou/policy/imitation/td3_bc.py +++ b/tianshou/policy/imitation/td3_bc.py @@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]): """Implementation of TD3+BC. arXiv:2106.06860. :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :class:`~tianshou.policy.BasePolicy`. (s -> actions) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. diff --git a/tianshou/policy/modelfree/a2c.py b/tianshou/policy/modelfree/a2c.py index 2aad187..d41ccb4 100644 --- a/tianshou/policy/modelfree/a2c.py +++ b/tianshou/policy/modelfree/a2c.py @@ -11,8 +11,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import PGPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -30,7 +33,9 @@ TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats) class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, vf_coef: float = 0.5, ent_coef: float = 0.01, diff --git a/tianshou/policy/modelfree/bdq.py b/tianshou/policy/modelfree/bdq.py index ba37477..d7196a9 100644 --- a/tianshou/policy/modelfree/bdq.py +++ b/tianshou/policy/modelfree/bdq.py @@ -31,7 +31,7 @@ TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats) class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): """Implementation of the Branching dual Q network arXiv:1711.08946. - :param model: BranchingNet mapping (obs, state, info) -> logits. + :param model: BranchingNet mapping (obs, state, info) -> action_values_BA. :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param estimation_step: the number of steps to look ahead. @@ -156,10 +156,10 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): model = getattr(self, model) obs = batch.obs # TODO: this is very contrived, see also iqn.py - obs_next = obs.obs if hasattr(obs, "obs") else obs - logits, hidden = model(obs_next, state=state, info=batch.info) - act = to_numpy(logits.max(dim=-1)[1]) - result = Batch(logits=logits, act=act, state=hidden) + obs_next_BO = obs.obs if hasattr(obs, "obs") else obs + action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info) + act_B = to_numpy(action_values_BA.argmax(dim=-1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: diff --git a/tianshou/policy/modelfree/c51.py b/tianshou/policy/modelfree/c51.py index bd44914..5bfdba0 100644 --- a/tianshou/policy/modelfree/c51.py +++ b/tianshou/policy/modelfree/c51.py @@ -23,8 +23,7 @@ TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats) class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): """Implementation of Categorical Deep Q-Network. arXiv:1707.06887. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param num_atoms: the number of atoms in the support set of the diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py index b54860e..f21744f 100644 --- a/tianshou/policy/modelfree/ddpg.py +++ b/tianshou/policy/modelfree/ddpg.py @@ -19,6 +19,7 @@ from tianshou.data.types import ( from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.net.continuous import Actor, Critic @dataclass(kw_only=True) @@ -33,8 +34,7 @@ TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats) class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. - :param actor: The actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> model_output) + :param actor: The actor network following the rules (s -> actions) :param actor_optim: The optimizer for actor network. :param critic: The critic network. (s, a -> Q(s, a)) :param critic_optim: The optimizer for critic network. @@ -60,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | Actor, actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, + critic: torch.nn.Module | Critic, critic_optim: torch.optim.Optimizer, action_space: gym.Space, tau: float = 0.005, diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d1054f9..e9f9b3b 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -12,6 +12,7 @@ from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatch from tianshou.policy import SACPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.sac import SACTrainingStats +from tianshou.utils.net.discrete import Actor, Critic @dataclass @@ -25,8 +26,7 @@ TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteS class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param actor: the actor network following the rules (s_B -> dist_input_BD) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. @@ -54,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | Actor, actor_optim: torch.optim.Optimizer, - critic: torch.nn.Module, + critic: torch.nn.Module | Critic, critic_optim: torch.optim.Optimizer, action_space: gym.spaces.Discrete, - critic2: torch.nn.Module | None = None, + critic2: torch.nn.Module | Critic | None = None, critic2_optim: torch.optim.Optimizer | None = None, tau: float = 0.005, gamma: float = 0.99, @@ -105,13 +105,13 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> Batch: - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - dist = Categorical(logits=logits) + logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Categorical(logits=logits_BA) if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.sample() - return Batch(logits=logits, act=act, state=hidden, dist=dist) + act_B = dist.sample() + return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py index ad5f7dd..e0ada07 100644 --- a/tianshou/policy/modelfree/dqn.py +++ b/tianshou/policy/modelfree/dqn.py @@ -17,6 +17,7 @@ from tianshou.data.types import ( ) from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats +from tianshou.utils.net.common import Net @dataclass(kw_only=True) @@ -35,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is implemented in the network side, not here). - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param estimation_step: the number of steps to look ahead. @@ -60,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): def __init__( self, *, - model: torch.nn.Module, + model: torch.nn.Module | Net, optim: torch.optim.Optimizer, # TODO: type violates Liskov substitution principle action_space: gym.spaces.Discrete, @@ -201,12 +201,12 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]): obs = batch.obs # TODO: this is convoluted! See also other places where this is done. obs_next = obs.obs if hasattr(obs, "obs") else obs - logits, hidden = model(obs_next, state=state, info=batch.info) - q = self.compute_q_value(logits, getattr(obs, "mask", None)) + action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info) + q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None)) if self.max_action_num is None: self.max_action_num = q.shape[1] - act = to_numpy(q.max(dim=1)[1]) - result = Batch(logits=logits, act=act, state=hidden) + act_B = to_numpy(q.argmax(dim=1)) + result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH) return cast(ModelOutputBatchProtocol, result) def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: diff --git a/tianshou/policy/modelfree/fqf.py b/tianshou/policy/modelfree/fqf.py index 9f1b083..9c87f9c 100644 --- a/tianshou/policy/modelfree/fqf.py +++ b/tianshou/policy/modelfree/fqf.py @@ -27,8 +27,7 @@ TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats) class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param fraction_model: a FractionProposalNetwork for proposing fractions/quantiles given state. diff --git a/tianshou/policy/modelfree/iqn.py b/tianshou/policy/modelfree/iqn.py index f242c14..75d76a2 100644 --- a/tianshou/policy/modelfree/iqn.py +++ b/tianshou/policy/modelfree/iqn.py @@ -29,8 +29,7 @@ TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats) class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): """Implementation of Implicit Quantile Network. arXiv:1806.06923. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s_B -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param discount_factor: in [0, 1]. :param sample_size: the number of samples for policy evaluation. diff --git a/tianshou/policy/modelfree/npg.py b/tianshou/policy/modelfree/npg.py index f293945..9e04d3f 100644 --- a/tianshou/policy/modelfree/npg.py +++ b/tianshou/policy/modelfree/npg.py @@ -12,7 +12,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, optim_critic_iters: int = 5, actor_step_size: float = 0.5, diff --git a/tianshou/policy/modelfree/pg.py b/tianshou/policy/modelfree/pg.py index eb6cb59..9a148fe 100644 --- a/tianshou/policy/modelfree/pg.py +++ b/tianshou/policy/modelfree/pg.py @@ -1,7 +1,7 @@ import warnings from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import Any, Generic, Literal, TypeVar, cast import gymnasium as gym import numpy as np @@ -24,9 +24,22 @@ from tianshou.data.types import ( from tianshou.policy import BasePolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils import RunningMeanStd +from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.discrete import Actor -# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution] -TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution] +# Dimension Naming Convention +# B - Batch Size +# A - Action +# D - Dist input (usually 2, loc and scale) +# H - Dimension of hidden, can be None + +TDistFnContinuous = Callable[ + [tuple[torch.Tensor, torch.Tensor]], + torch.distributions.Distribution, +] +TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical] + +TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete @dataclass(kw_only=True) @@ -40,8 +53,9 @@ TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats) class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): """Implementation of REINFORCE algorithm. - :param actor: mapping (s->model_output), should follow the rules in - :class:`~tianshou.policy.BasePolicy`. + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param optim: optimizer for actor network. :param dist_fn: distribution class for computing the action. Maps model_output -> distribution. Typically a Gaussian distribution @@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | ActorProb | Actor, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, discount_factor: float = 0.99, # TODO: rename to return_normalization? @@ -175,20 +189,20 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for more detailed explanation. """ - # TODO: rename? It's not really logits and there are particular - # assumptions about the order of the output and on distribution type - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - if isinstance(logits, tuple): - dist = self.dist_fn(*logits) - else: - dist = self.dist_fn(logits) + # TODO - ALGO: marked for algorithm refactoring + action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A + # therefore action_dist_input_BD is equivalent to logits_BA + # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian) + # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked + dist = self.dist_fn(action_dist_input_BD) - # in this case, the dist is unused! if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.sample() - result = Batch(logits=logits, act=act, state=hidden, dist=dist) + act_B = dist.sample() + # act is of dimension BA in continuous case and of dimension B in discrete + result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist) return cast(DistBatchProtocol, result) # TODO: why does mypy complain? diff --git a/tianshou/policy/modelfree/ppo.py b/tianshou/policy/modelfree/ppo.py index fde9e7c..196cd72 100644 --- a/tianshou/policy/modelfree/ppo.py +++ b/tianshou/policy/modelfree/ppo.py @@ -10,8 +10,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.policy import A2CPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont from tianshou.utils.net.common import ActorCritic +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -29,7 +32,9 @@ TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats) class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, eps_clip: float = 0.2, dual_clip: float | None = None, diff --git a/tianshou/policy/modelfree/qrdqn.py b/tianshou/policy/modelfree/qrdqn.py index b2f5d1e..71c36de 100644 --- a/tianshou/policy/modelfree/qrdqn.py +++ b/tianshou/policy/modelfree/qrdqn.py @@ -25,8 +25,7 @@ TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats) class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. - :param model: a model following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param model: a model following the rules (s -> action_values_BA) :param optim: a torch.optim for optimizing the model. :param action_space: Env's action space. :param discount_factor: in [0, 1]. diff --git a/tianshou/policy/modelfree/redq.py b/tianshou/policy/modelfree/redq.py index 100a361..f9793f4 100644 --- a/tianshou/policy/modelfree/redq.py +++ b/tianshou/policy/modelfree/redq.py @@ -12,6 +12,7 @@ from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.ddpg import DDPGTrainingStats +from tianshou.utils.net.continuous import ActorProb @dataclass @@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | ActorProb, actor_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, @@ -150,23 +151,28 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]): state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> Batch: - loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) - loc, scale = loc_scale - dist = Independent(Normal(loc, scale), 1) + (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc_B, scale_B), 1) if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.rsample() - log_prob = dist.log_prob(act).unsqueeze(-1) + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. - squashed_action = torch.tanh(act) + squashed_action = torch.tanh(act_B) log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( -1, keepdim=True, ) - return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob) + return Batch( + logits=(loc_B, scale_B), + act=squashed_action, + state=h_BH, + dist=dist, + log_prob=log_prob, + ) def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: obs_next_batch = Batch( diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py index a433624..3b39754 100644 --- a/tianshou/policy/modelfree/sac.py +++ b/tianshou/policy/modelfree/sac.py @@ -17,6 +17,7 @@ from tianshou.exploration import BaseNoise from tianshou.policy import DDPGPolicy from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.utils.conversion import to_optional_float +from tianshou.utils.net.continuous import ActorProb from tianshou.utils.optim import clone_optimizer @@ -36,8 +37,7 @@ TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats) class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] """Implementation of Soft Actor-Critic. arXiv:1812.05905. - :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :param actor: the actor network following the rules (s -> dist_input_BD) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. @@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t def __init__( self, *, - actor: torch.nn.Module, + actor: torch.nn.Module | ActorProb, actor_optim: torch.optim.Optimizer, critic: torch.nn.Module, critic_optim: torch.optim.Optimizer, @@ -173,26 +173,25 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t state: dict | Batch | np.ndarray | None = None, **kwargs: Any, ) -> DistLogProbBatchProtocol: - logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - assert isinstance(logits, tuple) - dist = Independent(Normal(*logits), 1) + (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info) + dist = Independent(Normal(loc=loc_B, scale=scale_B), 1) if self.deterministic_eval and not self.training: - act = dist.mode + act_B = dist.mode else: - act = dist.rsample() - log_prob = dist.log_prob(act).unsqueeze(-1) + act_B = dist.rsample() + log_prob = dist.log_prob(act_B).unsqueeze(-1) # apply correction for Tanh squashing when computing logprob from Gaussian # You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # in appendix C to get some understanding of this equation. - squashed_action = torch.tanh(act) + squashed_action = torch.tanh(act_B) log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( -1, keepdim=True, ) result = Batch( - logits=logits, + logits=(loc_B, scale_B), act=squashed_action, - state=hidden, + state=hidden_BH, dist=dist, log_prob=log_prob, ) diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py index dbf7b65..8c2ae8c 100644 --- a/tianshou/policy/modelfree/td3.py +++ b/tianshou/policy/modelfree/td3.py @@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t """Implementation of TD3, arXiv:1802.09477. :param actor: the actor network following the rules in - :class:`~tianshou.policy.BasePolicy`. (s -> logits) + :class:`~tianshou.policy.BasePolicy`. (s -> actions) :param actor_optim: the optimizer for actor network. :param critic: the first critic network. (s, a -> Q(s, a)) :param critic_optim: the optimizer for the first critic network. diff --git a/tianshou/policy/modelfree/trpo.py b/tianshou/policy/modelfree/trpo.py index babc23b..e7aa5cf 100644 --- a/tianshou/policy/modelfree/trpo.py +++ b/tianshou/policy/modelfree/trpo.py @@ -11,7 +11,10 @@ from tianshou.data import Batch, SequenceSummaryStats from tianshou.policy import NPGPolicy from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.modelfree.npg import NPGTrainingStats -from tianshou.policy.modelfree.pg import TDistributionFunction +from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont +from tianshou.utils.net.continuous import ActorProb, Critic +from tianshou.utils.net.discrete import Actor as DiscreteActor +from tianshou.utils.net.discrete import Critic as DiscreteCritic @dataclass(kw_only=True) @@ -25,7 +28,9 @@ TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats) class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): """Implementation of Trust Region Policy Optimization. arXiv:1502.05477. - :param actor: the actor network following the rules in BasePolicy. (s -> logits) + :param actor: the actor network following the rules: + If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`). + If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`). :param critic: the critic network. (s -> V(s)) :param optim: the optimizer for actor and critic network. :param dist_fn: distribution class for computing the action. @@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): def __init__( self, *, - actor: torch.nn.Module, - critic: torch.nn.Module, + actor: torch.nn.Module | ActorProb | DiscreteActor, + critic: torch.nn.Module | Critic | DiscreteCritic, optim: torch.optim.Optimizer, - dist_fn: TDistributionFunction, + dist_fn: TDistFnDiscrOrCont, action_space: gym.Space, max_kl: float = 0.01, backtrack_coeff: float = 0.8, diff --git a/tianshou/utils/net/common.py b/tianshou/utils/net/common.py index 14ec54a..dabe24e 100644 --- a/tianshou/utils/net/common.py +++ b/tianshou/utils/net/common.py @@ -610,6 +610,17 @@ class BaseActor(nn.Module, ABC): def get_output_dim(self) -> int: pass + @abstractmethod + def forward( + self, + obs: np.ndarray | torch.Tensor, + state: Any = None, + info: dict[str, Any] | None = None, + ) -> tuple[Any, Any]: + # TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction. + # Return type needs to be more specific + pass + def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: """Gets the given attribute from the given object or takes the alternative value if it is not present. diff --git a/tianshou/utils/net/continuous.py b/tianshou/utils/net/continuous.py index f257f8a..6cd4a0f 100644 --- a/tianshou/utils/net/continuous.py +++ b/tianshou/utils/net/continuous.py @@ -1,4 +1,5 @@ import warnings +from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any @@ -9,6 +10,7 @@ from torch import nn from tianshou.utils.net.common import ( MLP, BaseActor, + Net, TActionShape, TLinearLayer, get_output_dim, @@ -19,33 +21,27 @@ SIGMA_MAX = 2 class Actor(BaseActor): - """Simple actor network. + """Simple actor network that directly outputs actions for continuous action space. + Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param max_action: the scale for the final action logits. Default to - 1. - :param preprocess_net_output_dim: the output dimension of preprocess_net. + :param max_action: the scale for the final action. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, @@ -77,42 +73,50 @@ class Actor(BaseActor): state: Any = None, info: dict[str, Any] | None = None, ) -> tuple[torch.Tensor, Any]: - """Mapping: obs -> logits -> action.""" - if info is None: - info = {} - logits, hidden = self.preprocess(obs, state) - logits = self.max_action * torch.tanh(self.last(logits)) - return logits, hidden + """Mapping: s_B -> action_values_BA, hidden_state_BH | None. + + Returns a tensor representing the actions directly, i.e, of shape + `(n_actions, )`, and a hidden state (which may be None). + The hidden state is only not None if a recurrent net is used as part of the + learning algorithm (support for RNNs is currently experimental). + """ + action_BA, hidden_BH = self.preprocess(obs, state) + action_BA = self.max_action * torch.tanh(self.last(action_BA)) + return action_BA, hidden_BH -class Critic(nn.Module): +class CriticBase(nn.Module, ABC): + @abstractmethod + def forward( + self, + obs: np.ndarray | torch.Tensor, + act: np.ndarray | torch.Tensor | None = None, + info: dict[str, Any] | None = None, + ) -> torch.Tensor: + """Mapping: (s_B, a_B) -> Q(s, a)_B.""" + + +class Critic(CriticBase): """Simple critic network. It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param preprocess_net_output_dim: the output dimension of preprocess_net. - :param linear_layer: use this module as linear layer. Default to nn.Linear. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. + :param linear_layer: use this module as linear layer. :param flatten_input: whether to flatten input data for the last layer. - Default to True. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, hidden_sizes: Sequence[int] = (), device: str | int | torch.device = "cpu", preprocess_net_output_dim: int | None = None, @@ -139,9 +143,7 @@ class Critic(nn.Module): act: np.ndarray | torch.Tensor | None = None, info: dict[str, Any] | None = None, ) -> torch.Tensor: - """Mapping: (s, a) -> logits -> Q(s, a).""" - if info is None: - info = {} + """Mapping: (s_B, a_B) -> Q(s, a)_B.""" obs = torch.as_tensor( obs, device=self.device, @@ -154,41 +156,35 @@ class Critic(nn.Module): dtype=torch.float32, ).flatten(1) obs = torch.cat([obs, act], dim=1) - logits, hidden = self.preprocess(obs) - return self.last(logits) + values_B, hidden_BH = self.preprocess(obs) + return self.last(values_B) class ActorProb(BaseActor): - """Simple actor network (output with a Gauss distribution). + """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian). - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`. + + :param preprocess_net: a self-defined preprocess_net, see usage. + Typically, an instance of :class:`~tianshou.utils.net.common.Net`. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after - preprocess_net. Default to empty sequence (where the MLP now contains - only a single linear layer). - :param max_action: the scale for the final action logits. Default to - 1. - :param unbounded: whether to apply tanh activation on final logits. - Default to False. - :param conditioned_sigma: True when sigma is calculated from the - input, False when sigma is an independent parameter. Default to False. - :param preprocess_net_output_dim: the output dimension of preprocess_net. + :param max_action: the scale for the final action logits. + :param unbounded: whether to apply tanh activation on final logits. + :param conditioned_sigma: True when sigma is calculated from the + input, False when sigma is an independent parameter. + :param preprocess_net_output_dim: the output dimension of + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ # TODO: force kwargs, adjust downstream code def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), max_action: float = 1.0, @@ -402,8 +398,7 @@ class Perturbation(nn.Module): flattened hidden state. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. - Default to cpu. - :param phi: max perturbation parameter for BCQ. Default to 0.05. + :param phi: max perturbation parameter for BCQ. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. @@ -449,7 +444,6 @@ class VAE(nn.Module): :param latent_dim: the size of latent layer. :param max_action: the maximum value of each dimension of action. :param device: which device to create this model on. - Default to "cpu". For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. diff --git a/tianshou/utils/net/discrete.py b/tianshou/utils/net/discrete.py index 8a54a07..ab90698 100644 --- a/tianshou/utils/net/discrete.py +++ b/tianshou/utils/net/discrete.py @@ -7,17 +7,14 @@ import torch.nn.functional as F from torch import nn from tianshou.data import Batch, to_torch -from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim +from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim class Actor(BaseActor): - """Simple actor network. + """Simple actor network for discrete action spaces. - Will create an actor operated in discrete action space with structure of - preprocess_net ---> action_shape. - - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net. Typically, an instance of + :class:`~tianshou.utils.net.common.Net`. :param action_shape: a sequence of int for the shape of action. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains @@ -25,20 +22,15 @@ class Actor(BaseActor): :param softmax_output: whether to apply a softmax layer over the last layer's output. :param preprocess_net_output_dim: the output dimension of - preprocess_net. + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, action_shape: TActionShape, hidden_sizes: Sequence[int] = (), softmax_output: bool = True, @@ -71,43 +63,44 @@ class Actor(BaseActor): obs: np.ndarray | torch.Tensor, state: Any = None, info: dict[str, Any] | None = None, - ) -> tuple[torch.Tensor, Any]: - r"""Mapping: s -> Q(s, \*).""" - if info is None: - info = {} - logits, hidden = self.preprocess(obs, state) - logits = self.last(logits) + ) -> tuple[torch.Tensor, torch.Tensor | None]: + r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None. + + Returns a tensor representing the values of each action, i.e, of shape + `(n_actions, )`, and + a hidden state (which may be None). If `self.softmax_output` is True, they are the + probabilities for taking each action. Otherwise, they will be action values. + The hidden state is only + not None if a recurrent net is used as part of the learning algorithm. + """ + x, hidden_BH = self.preprocess(obs, state) + x = self.last(x) if self.softmax_output: - logits = F.softmax(logits, dim=-1) - return logits, hidden + x = F.softmax(x, dim=-1) + # If we computed softmax, output is probabilities, otherwise it's the non-normalized action values + output_BA = x + return output_BA, hidden_BH class Critic(nn.Module): - """Simple critic network. + """Simple critic network for discrete action spaces. - It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value). - - :param preprocess_net: a self-defined preprocess_net which output a - flattened hidden state. + :param preprocess_net: a self-defined preprocess_net. Typically, an instance of + :class:`~tianshou.utils.net.common.Net`. :param hidden_sizes: a sequence of int for constructing the MLP after preprocess_net. Default to empty sequence (where the MLP now contains only a single linear layer). :param last_size: the output dimension of Critic network. Default to 1. :param preprocess_net_output_dim: the output dimension of - preprocess_net. + `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`. For advanced usage (how to customize the network), please refer to - :ref:`build_the_network`. - - .. seealso:: - - Please refer to :class:`~tianshou.utils.net.common.Net` as an instance - of how preprocess_net is suggested to be defined. + :ref:`build_the_network`.. """ def __init__( self, - preprocess_net: nn.Module, + preprocess_net: nn.Module | Net, hidden_sizes: Sequence[int] = (), last_size: int = 1, preprocess_net_output_dim: int | None = None, @@ -120,8 +113,10 @@ class Critic(nn.Module): input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) + # TODO: make a proper interface! def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: - """Mapping: s -> V(s).""" + """Mapping: s_B -> V(s)_B.""" + # TODO: don't use this mechanism for passing state logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) return self.last(logits)