Merge branch 'thuml_master' into feature/algo-eval

This commit is contained in:
Maximilian Huettenrauch 2024-04-02 11:03:38 +02:00
commit f2e10b04bb
83 changed files with 1488 additions and 886 deletions

View File

@ -1,4 +1,38 @@
# Changelog # Changelog
## Release 1.1.0
### Api Extensions
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
- `Collector`s can now be closed, and their reset is more granular. #1063
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063
- Better variable names related to model outputs (logits, dist input etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
### Breaking Changes
- Removed `.data` attribute from `Collector` and its child classes. #1063
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
Started after v1.0.0 Started after v1.0.0

View File

@ -147,9 +147,6 @@ Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://
Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders. Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders.
中文文档位于 [https://tianshou.readthedocs.io/zh/master/](https://tianshou.readthedocs.io/zh/master/)。
<!-- 这里有一份天授平台简短的中文简介https://www.zhihu.com/question/377263715 -->
## Why Tianshou? ## Why Tianshou?

View File

@ -164,7 +164,7 @@
"source": [ "source": [
"# Let's watch its performance!\n", "# Let's watch its performance!\n",
"policy.eval()\n", "policy.eval()\n",
"eval_result = test_collector.collect(n_episode=1, render=False)\n", "eval_result = test_collector.collect(n_episode=3, render=False)\n",
"print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")" "print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")"
] ]
}, },

View File

@ -69,7 +69,7 @@
"from tianshou.policy import BasePolicy\n", "from tianshou.policy import BasePolicy\n",
"from tianshou.policy.modelfree.pg import (\n", "from tianshou.policy.modelfree.pg import (\n",
" PGTrainingStats,\n", " PGTrainingStats,\n",
" TDistributionFunction,\n", " TDistFnDiscrOrCont,\n",
" TPGTrainingStats,\n", " TPGTrainingStats,\n",
")\n", ")\n",
"from tianshou.utils import RunningMeanStd\n", "from tianshou.utils import RunningMeanStd\n",
@ -339,7 +339,7 @@
" *,\n", " *,\n",
" actor: torch.nn.Module,\n", " actor: torch.nn.Module,\n",
" optim: torch.optim.Optimizer,\n", " optim: torch.optim.Optimizer,\n",
" dist_fn: TDistributionFunction,\n", " dist_fn: TDistFnDiscrOrCont,\n",
" action_space: gym.Space,\n", " action_space: gym.Space,\n",
" discount_factor: float = 0.99,\n", " discount_factor: float = 0.99,\n",
" observation_space: gym.Space | None = None,\n", " observation_space: gym.Space | None = None,\n",

View File

@ -119,7 +119,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"collect_result = test_collector.collect(n_episode=9)\n", "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n",
"\n", "\n",
"collect_result.pprint_asdict()" "collect_result.pprint_asdict()"
] ]
@ -146,8 +146,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Reset the collector\n", "# Reset the collector\n",
"test_collector.reset()\n", "collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n",
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
"\n", "\n",
"collect_result.pprint_asdict()" "collect_result.pprint_asdict()"
] ]

View File

@ -92,8 +92,6 @@ To compile documentation into webpage, run
The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/). The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/).
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
Documentation Generation Test Documentation Generation Test
----------------------------- -----------------------------

View File

@ -57,9 +57,6 @@ Here is Tianshou's other features:
* Support multi-GPU training :ref:`multi_gpu` * Support multi-GPU training :ref:`multi_gpu`
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking * Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_
Installation Installation
------------ ------------

View File

@ -257,3 +257,8 @@ macOS
joblib joblib
master master
Panchenko Panchenko
BA
BH
BO
BD

View File

@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
# expert replay buffer # expert replay buffer
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task)) dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))

View File

@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: A2CPolicy = A2CPolicy( policy: A2CPolicy = A2CPolicy(
actor=actor, actor=actor,

View File

@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy = NPGPolicy( policy: NPGPolicy = NPGPolicy(
actor=actor, actor=actor,

View File

@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy = PPOPolicy( policy: PPOPolicy = PPOPolicy(
actor=actor, actor=actor,

View File

@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PGPolicy = PGPolicy( policy: PGPolicy = PGPolicy(
actor=actor, actor=actor,

View File

@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: TRPOPolicy = TRPOPolicy( policy: TRPOPolicy = TRPOPolicy(
actor=actor, actor=actor,

View File

@ -171,6 +171,7 @@ ignore = [
"RET505", "RET505",
"D106", # undocumented public nested class "D106", # undocumented public nested class
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense) "D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
"PLW2901", # overwrite vars in loop
] ]
unfixable = [ unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all "F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all

View File

@ -9,13 +9,24 @@ import numpy as np
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple
class MyTestEnv(gym.Env): class MoveToRightEnv(gym.Env):
"""A task for "going right". The task is to go right ``size`` steps.""" """A task for "going right". The task is to go right ``size`` steps.
The observation is the current index, and the action is to go left or right.
Action 0 is to go left, and action 1 is to go right.
Taking action 0 at index 0 will keep the index at 0.
Arriving at index ``size`` means the task is done.
In the current implementation, stepping after the task is done is possible, which will
lead the index to be larger than ``size``.
Index 0 is the starting point. If reset is called with default options, the index will
be reset to 0.
"""
def __init__( def __init__(
self, self,
size: int, size: int,
sleep: int = 0, sleep: float = 0.0,
dict_state: bool = False, dict_state: bool = False,
recurse_state: bool = False, recurse_state: bool = False,
ma_rew: int = 0, ma_rew: int = 0,
@ -74,8 +85,13 @@ class MyTestEnv(gym.Env):
def reset( def reset(
self, self,
seed: int | None = None, seed: int | None = None,
# TODO: passing a dict here doesn't make any sense
options: dict[str, Any] | None = None, options: dict[str, Any] | None = None,
) -> tuple[dict[str, Any] | np.ndarray, dict]: ) -> tuple[dict[str, Any] | np.ndarray, dict]:
""":param seed:
:param options: the start index is provided in options["state"]
:return:
"""
if options is None: if options is None:
options = {"state": 0} options = {"state": 0}
super().reset(seed=seed) super().reset(seed=seed)
@ -188,7 +204,7 @@ class NXEnv(gym.Env):
return self._encode_obs(), 1.0, False, False, {} return self._encode_obs(), 1.0, False, False, {}
class MyGoalEnv(MyTestEnv): class MyGoalEnv(MoveToRightEnv):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
assert ( assert (
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0 kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0

View File

@ -22,13 +22,13 @@ from tianshou.data import (
from tianshou.data.utils.converter import to_hdf5 from tianshou.data.utils.converter import to_hdf5
if __name__ == "__main__": if __name__ == "__main__":
from env import MyGoalEnv, MyTestEnv from env import MoveToRightEnv, MyGoalEnv
else: # pytest else: # pytest
from test.base.env import MyGoalEnv, MyTestEnv from test.base.env import MoveToRightEnv, MyGoalEnv
def test_replaybuffer(size=10, bufsize=20) -> None: def test_replaybuffer(size=10, bufsize=20) -> None:
env = MyTestEnv(size) env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize) buf = ReplayBuffer(bufsize)
buf.update(buf) buf.update(buf)
assert str(buf) == buf.__class__.__name__ + "()" assert str(buf) == buf.__class__.__name__ + "()"
@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None:
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None: def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
env = MyTestEnv(size) env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num) buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True) buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True) buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
def test_priortized_replaybuffer(size=32, bufsize=15) -> None: def test_priortized_replaybuffer(size=32, bufsize=15) -> None:
env = MyTestEnv(size) env = MoveToRightEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5) buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5) buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs, info = env.reset() obs, info = env.reset()
@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None:
bufsize = 9 bufsize = 9
stack_num = 4 stack_num = 4
cached_num = 3 cached_num = 3
env = MyTestEnv(size) env = MoveToRightEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next # test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer( buf4 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True), ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),

View File

@ -2,7 +2,6 @@ import gymnasium as gym
import numpy as np import numpy as np
import pytest import pytest
import tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import ( from tianshou.data import (
AsyncCollector, AsyncCollector,
@ -22,12 +21,12 @@ except ImportError:
envpool = None envpool = None
if __name__ == "__main__": if __name__ == "__main__":
from env import MyTestEnv, NXEnv from env import MoveToRightEnv, NXEnv
else: # pytest else: # pytest
from test.base.env import MyTestEnv, NXEnv from test.base.env import MoveToRightEnv, NXEnv
class MyPolicy(BasePolicy): class MaxActionPolicy(BasePolicy):
def __init__( def __init__(
self, self,
action_space: gym.spaces.Space | None = None, action_space: gym.spaces.Space | None = None,
@ -35,7 +34,9 @@ class MyPolicy(BasePolicy):
need_state=True, need_state=True,
action_shape=None, action_shape=None,
) -> None: ) -> None:
"""Mock policy for testing. """Mock policy for testing, will always return an array of ones of the shape of the action space.
Note that this doesn't make much sense for discrete action space (the output is then intepreted as
logits, meaning all actions would be equally likely).
:param action_space: the action space of the environment. If None, a dummy Box space will be used. :param action_space: the action space of the environment. If None, a dummy Box space will be used.
:param bool dict_state: if the observation of the environment is a dict :param bool dict_state: if the observation of the environment is a dict
@ -63,177 +64,253 @@ class MyPolicy(BasePolicy):
pass pass
class Logger: def test_collector() -> None:
def __init__(self, writer) -> None: env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
self.cnt = 0
self.writer = writer
def preprocess_fn(self, **kwargs): subproc_venv_4_envs = SubprocVectorEnv(env_fns)
# modify info before adding into the buffer, and recorded into tfb dummy_venv_4_envs = DummyVectorEnv(env_fns)
# if obs && env_id exist -> reset policy = MaxActionPolicy()
# if obs_next/rew/done/info/env_id exist -> normal step single_env = env_fns[0]()
if "rew" in kwargs: c_single_env = Collector(
info = kwargs["info"]
info.rew = kwargs["rew"]
if "key" in info:
self.writer.add_scalar("key", np.mean(info.key), global_step=self.cnt)
self.cnt += 1
return Batch(info=info)
return Batch()
@staticmethod
def single_preprocess_fn(**kwargs):
# same as above, without tfb
if "rew" in kwargs:
info = kwargs["info"]
info.rew = kwargs["rew"]
return Batch(info=info)
return Batch()
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
def test_collector(gym_reset_kwargs) -> None:
writer = SummaryWriter("log/collector")
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)
dum = DummyVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(
policy, policy,
env, single_env,
ReplayBuffer(size=100), ReplayBuffer(size=100),
logger.preprocess_fn,
) )
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs) c_single_env.reset()
assert len(c0.buffer) == 3 c_single_env.collect(n_step=3)
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0]) assert len(c_single_env.buffer) == 3
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1]) # TODO: direct attr access is an arcane way of using the buffer, it should be never done
# The placeholders for entries are all zeros, so buffer.obs is an array filled with 3
# observations, and 97 zeros.
# However, buffer[:] will have all attributes with length three... The non-filled entries are removed there
# See above. For the single env, we start with obs=0, obs_next=1.
# We move to obs=1, obs_next=2,
# then the env is reset and we move to obs=0
# Making one more step results in obs_next=1
# The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access
assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0])
assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1])
keys = np.zeros(100) keys = np.zeros(100)
keys[:3] = 1 keys[:3] = 1
assert np.allclose(c0.buffer.info["key"], keys) assert np.allclose(c_single_env.buffer.info["key"], keys)
for e in c0.buffer.info["env"][:3]: for e in c_single_env.buffer.info["env"][:3]:
assert isinstance(e, MyTestEnv) assert isinstance(e, MoveToRightEnv)
assert np.allclose(c0.buffer.info["env_id"], 0) assert np.allclose(c_single_env.buffer.info["env_id"], 0)
rews = np.zeros(100) rews = np.zeros(100)
rews[:3] = [0, 1, 0] rews[:3] = [0, 1, 0]
assert np.allclose(c0.buffer.info["rew"], rews) assert np.allclose(c_single_env.buffer.rew, rews)
c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs) # At this point, the buffer contains obs 0 -> 1 -> 0
assert len(c0.buffer) == 8
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
assert np.allclose(c0.buffer.info["key"][:8], 1)
for e in c0.buffer.info["env"][:8]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"][:8], 0)
assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)
c1 = Collector( # At start we have 3 entries in the buffer
# We collect 3 episodes, in addition to the transitions we have collected before
# 0 -> 1 -> 0 -> 0 (reset at collection start) -> 1 -> done (0) -> 1 -> done(0)
# obs_next: 1 -> 2 -> 1 -> 1 (reset at collection start) -> 2 -> 1 -> 2 -> 1 -> 2
# In total, we will have 3 + 6 = 9 entries in the buffer
c_single_env.collect(n_episode=3)
assert len(c_single_env.buffer) == 8
assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
assert np.allclose(c_single_env.buffer.info["key"][:8], 1)
for e in c_single_env.buffer.info["env"][:8]:
assert isinstance(e, MoveToRightEnv)
assert np.allclose(c_single_env.buffer.info["env_id"][:8], 0)
assert np.allclose(c_single_env.buffer.rew[:8], [0, 1, 0, 1, 0, 1, 0, 1])
c_single_env.collect(n_step=3, random=True)
c_subproc_venv_4_envs = Collector(
policy, policy,
venv, subproc_venv_4_envs,
VectorReplayBuffer(total_size=100, buffer_num=4), VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn,
) )
c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs) c_subproc_venv_4_envs.reset()
# Collect some steps
c_subproc_venv_4_envs.collect(n_step=8)
obs = np.zeros(100) obs = np.zeros(100)
valid_indices = [0, 1, 25, 26, 50, 51, 75, 76] valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1] obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs)
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2]) assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
keys = np.zeros(100) keys = np.zeros(100)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys) assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]: for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv) assert isinstance(e, MoveToRightEnv)
env_ids = np.zeros(100) env_ids = np.zeros(100)
env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3] env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids) assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids)
rews = np.zeros(100) rews = np.zeros(100)
rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0] rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
assert np.allclose(c1.buffer.info["rew"], rews) assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews)
c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
assert len(c1.buffer) == 16 # we previously collected 8 steps, 2 from each env, now we collect 4 episodes
# each env will contribute an episode, which will be of lens 2 (first env was reset), 1, 2, 3
# So we get 8 + 2+1+2+3 = 16 steps
c_subproc_venv_4_envs.collect(n_episode=4)
assert len(c_subproc_venv_4_envs.buffer) == 16
valid_indices = [2, 3, 27, 52, 53, 77, 78, 79] valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4] obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c1.buffer.obs[:, 0], obs) assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs)
assert np.allclose( assert np.allclose(
c1.buffer[:].obs_next[..., 0], c_subproc_venv_4_envs.buffer[:].obs_next[..., 0],
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5], [1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5],
) )
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1] keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys) assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]: for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv) assert isinstance(e, MoveToRightEnv)
env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3] env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids) assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1] rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
assert np.allclose(c1.buffer.info["rew"], rews) assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews)
c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) c_subproc_venv_4_envs.collect(n_episode=4, random=True)
c2 = Collector( c_dummy_venv_4_envs = Collector(
policy, policy,
dum, dummy_venv_4_envs,
VectorReplayBuffer(total_size=100, buffer_num=4), VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn,
) )
c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs) c_dummy_venv_4_envs.reset()
c_dummy_venv_4_envs.collect(n_episode=7)
obs1 = obs.copy() obs1 = obs.copy()
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2] obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
obs2 = obs.copy() obs2 = obs.copy()
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3] obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
c2obs = c2.buffer.obs[:, 0] c2obs = c_dummy_venv_4_envs.buffer.obs[:, 0]
assert np.all(c2obs == obs1) or np.all(c2obs == obs2) assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
c2.reset_env(gym_reset_kwargs=gym_reset_kwargs) c_dummy_venv_4_envs.reset_env()
c2.reset_buffer() c_dummy_venv_4_envs.reset_buffer()
assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs).n_collected_episodes == 8 assert c_dummy_venv_4_envs.collect(n_episode=8).n_collected_episodes == 8
valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57] valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3] obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert np.all(c2.buffer.obs[:, 0] == obs) assert np.all(c_dummy_venv_4_envs.buffer.obs[:, 0] == obs)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1] keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c2.buffer.info["key"], keys) assert np.allclose(c_dummy_venv_4_envs.buffer.info["key"], keys)
for e in c2.buffer.info["env"][valid_indices]: for e in c_dummy_venv_4_envs.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv) assert isinstance(e, MoveToRightEnv)
env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2] env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
assert np.allclose(c2.buffer.info["env_id"], env_ids) assert np.allclose(c_dummy_venv_4_envs.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1] rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
assert np.allclose(c2.buffer.info["rew"], rews) assert np.allclose(c_dummy_venv_4_envs.buffer.rew, rews)
c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs) c_dummy_venv_4_envs.collect(n_episode=4, random=True)
# test corner case # test corner case
with pytest.raises(TypeError): with pytest.raises(TypeError):
Collector(policy, dum, ReplayBuffer(10)) Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
with pytest.raises(TypeError): with pytest.raises(TypeError):
Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5)) Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError): with pytest.raises(TypeError):
c2.collect() c_dummy_venv_4_envs.collect()
# test NXEnv # test NXEnv
for obs_type in ["array", "object"]: for obs_type in ["array", "object"]:
envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]]) envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]])
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs) c_suproc_new.reset()
assert c3.buffer.obs.dtype == object c_suproc_new.collect(n_step=6)
assert c_suproc_new.buffer.obs.dtype == object
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}]) @pytest.fixture()
def test_collector_with_async(gym_reset_kwargs) -> None: def get_AsyncCollector():
env_lens = [2, 3, 4, 5] env_lens = [2, 3, 4, 5]
writer = SummaryWriter("log/async_collector") env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens]
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1) venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy() policy = MaxActionPolicy()
bufsize = 60 bufsize = 60
c1 = AsyncCollector( c1 = AsyncCollector(
policy, policy,
venv, venv,
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4), VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn,
) )
c1.reset()
return c1, env_lens
class TestAsyncCollector:
def test_collect_without_argument_gives_error(self, get_AsyncCollector):
c1, env_lens = get_AsyncCollector
with pytest.raises(TypeError):
c1.collect()
def test_collect_one_episode_async(self, get_AsyncCollector):
c1, env_lens = get_AsyncCollector
result = c1.collect(n_episode=1)
assert result.n_collected_episodes >= 1
def test_enough_episodes_two_collection_cycles_n_episode_without_reset(
self,
get_AsyncCollector,
):
c1, env_lens = get_AsyncCollector
n_episode = 2
result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False)
assert result_c1.n_collected_episodes >= n_episode
result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False)
assert result_c2.n_collected_episodes >= n_episode
def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector):
c1, env_lens = get_AsyncCollector
n_episode = 2
result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True)
assert result_c1.n_collected_episodes >= n_episode
result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=True)
assert result_c2.n_collected_episodes >= n_episode
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode(
self,
get_AsyncCollector,
):
c1, env_lens = get_AsyncCollector
ptr = [0, 0, 0, 0]
bufsize = 60
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result.n_collected_episodes >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step(
self,
get_AsyncCollector,
):
c1, env_lens = get_AsyncCollector
bufsize = 60
ptr = [0, 0, 0, 0]
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step)
assert result.n_collected_steps >= n_step
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step(
self,
get_AsyncCollector,
gym_reset_kwargs,
):
c1, env_lens = get_AsyncCollector
bufsize = 60
ptr = [0, 0, 0, 0] ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"): for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs) result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
@ -249,7 +326,7 @@ def test_collector_with_async(gym_reset_kwargs) -> None:
assert np.all(buf.info.env_id[indices] == i) assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq) assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1) assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data # test async n_step, for now the buffer should be full of data, thus no bincount stuff as above
for n_step in tqdm.trange(1, 15, desc="test async n_step"): for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs) result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
assert result.n_collected_steps >= n_step assert result.n_collected_steps >= n_step
@ -260,18 +337,17 @@ def test_collector_with_async(gym_reset_kwargs) -> None:
assert np.all(buf.info.env_id == i) assert np.all(buf.info.env_id == i)
assert np.all(buf.obs.reshape(-1, env_len) == seq) assert np.all(buf.obs.reshape(-1, env_len) == seq)
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1) assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
with pytest.raises(TypeError):
c1.collect()
def test_collector_with_dict_state() -> None: def test_collector_with_dict_state() -> None:
env = MyTestEnv(size=5, sleep=0, dict_state=True) env = MoveToRightEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True) policy = MaxActionPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) c0 = Collector(policy, env, ReplayBuffer(size=100))
c0.reset()
c0.collect(n_step=3) c0.collect(n_step=3)
c0.collect(n_episode=2) c0.collect(n_episode=2)
assert len(c0.buffer) == 10 assert len(c0.buffer) == 10 # 3 + two episodes with 5 steps each
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns) envs = DummyVectorEnv(env_fns)
envs.seed(666) envs.seed(666)
obs, info = envs.reset() obs, info = envs.reset()
@ -280,8 +356,8 @@ def test_collector_with_dict_state() -> None:
policy, policy,
envs, envs,
VectorReplayBuffer(total_size=100, buffer_num=4), VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn,
) )
c1.reset()
c1.collect(n_step=12) c1.collect(n_step=12)
result = c1.collect(n_episode=8) result = c1.collect(n_episode=8)
assert result.n_collected_episodes == 8 assert result.n_collected_episodes == 8
@ -396,41 +472,47 @@ def test_collector_with_dict_state() -> None:
policy, policy,
envs, envs,
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn,
) )
c2.reset()
c2.collect(n_episode=10) c2.collect(n_episode=10)
batch, _ = c2.buffer.sample(10) batch, _ = c2.buffer.sample(10)
def test_collector_with_ma() -> None: def test_collector_with_multi_agent() -> None:
env = MyTestEnv(size=5, sleep=0, ma_rew=4) multi_agent_env = MoveToRightEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy() policy = MaxActionPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn) c_single_env = Collector(policy, multi_agent_env, ReplayBuffer(size=100))
# n_step=3 will collect a full episode c_single_env.reset()
rew = c0.collect(n_step=3).returns multi_env_returns = c_single_env.collect(n_step=3).returns
assert len(rew) == 0 # c_single_env has length 3
rew = c0.collect(n_episode=2).returns # We have no full episodes, so no returns yet
assert rew.shape == (2, 4) assert len(multi_env_returns) == 0
assert np.all(rew == 1)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]] single_env_returns = c_single_env.collect(n_episode=2).returns
# now two episodes. Since we have 4 a agents, the returns have shape (2, 4)
assert single_env_returns.shape == (2, 4)
assert np.all(single_env_returns == 1)
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns) envs = DummyVectorEnv(env_fns)
c1 = Collector( c_multi_env_ma = Collector(
policy, policy,
envs, envs,
VectorReplayBuffer(total_size=100, buffer_num=4), VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn,
) )
rew = c1.collect(n_step=12).returns c_multi_env_ma.reset()
assert rew.shape == (2, 4) and np.all(rew == 1), rew multi_env_returns = c_multi_env_ma.collect(n_step=12).returns
rew = c1.collect(n_episode=8).returns # each env makes 3 steps, the first two envs are done and result in two finished episodes
assert rew.shape == (8, 4) assert multi_env_returns.shape == (2, 4) and np.all(multi_env_returns == 1), multi_env_returns
assert np.all(rew == 1) multi_env_returns = c_multi_env_ma.collect(n_episode=8).returns
batch, _ = c1.buffer.sample(10) assert multi_env_returns.shape == (8, 4)
assert np.all(multi_env_returns == 1)
batch, _ = c_multi_env_ma.buffer.sample(10)
print(batch) print(batch)
c0.buffer.update(c1.buffer) c_single_env.buffer.update(c_multi_env_ma.buffer)
assert len(c0.buffer) in [42, 43] assert len(c_single_env.buffer) in [42, 43]
if len(c0.buffer) == 42: if len(c_single_env.buffer) == 42:
rew = [ multi_env_returns = [
0, 0,
0, 0,
0, 0,
@ -475,7 +557,7 @@ def test_collector_with_ma() -> None:
1, 1,
] ]
else: else:
rew = [ multi_env_returns = [
0, 0,
0, 0,
0, 0,
@ -520,17 +602,17 @@ def test_collector_with_ma() -> None:
0, 0,
1, 1,
] ]
assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew]) assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns])
assert np.all(c0.buffer[:].done == rew) assert np.all(c_single_env.buffer[:].done == multi_env_returns)
c2 = Collector( c2 = Collector(
policy, policy,
envs, envs,
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4), VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn,
) )
rew = c2.collect(n_episode=10).returns c2.reset()
assert rew.shape == (10, 4) multi_env_returns = c2.collect(n_episode=10).returns
assert np.all(rew == 1) assert multi_env_returns.shape == (10, 4)
assert np.all(multi_env_returns == 1)
batch, _ = c2.buffer.sample(10) batch, _ = c2.buffer.sample(10)
@ -543,20 +625,21 @@ def test_collector_with_atari_setting() -> None:
reference_obs[i, 0] = i reference_obs[i, 0] = i
# atari single buffer # atari single buffer
env = MyTestEnv(size=5, sleep=0, array_state=True) env = MoveToRightEnv(size=5, sleep=0, array_state=True)
policy = MyPolicy() policy = MaxActionPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100)) c0 = Collector(policy, env, ReplayBuffer(size=100))
c0.reset()
c0.collect(n_step=6) c0.collect(n_step=6)
c0.collect(n_episode=2) c0.collect(n_episode=2)
assert c0.buffer.obs.shape == (100, 4, 84, 84) assert c0.buffer.obs.shape == (100, 4, 84, 84)
assert c0.buffer.obs_next.shape == (100, 4, 84, 84) assert c0.buffer.obs_next.shape == (100, 4, 84, 84)
assert len(c0.buffer) == 15 assert len(c0.buffer) == 15 # 6 + 2 episodes with 5 steps each
obs = np.zeros_like(c0.buffer.obs) obs = np.zeros_like(c0.buffer.obs)
obs[np.arange(15)] = reference_obs[np.arange(15) % 5] obs[np.arange(15)] = reference_obs[np.arange(15) % 5]
assert np.all(obs == c0.buffer.obs) assert np.all(obs == c0.buffer.obs)
c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True)) c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True))
c1.collect(n_episode=3) c1.collect(n_episode=3, reset_before_collect=True)
assert np.allclose(c0.buffer.obs, c1.buffer.obs) assert np.allclose(c0.buffer.obs, c1.buffer.obs)
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
c1.buffer.obs_next # noqa: B018 c1.buffer.obs_next # noqa: B018
@ -567,6 +650,7 @@ def test_collector_with_atari_setting() -> None:
env, env,
ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True), ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True),
) )
c2.reset()
c2.collect(n_step=8) c2.collect(n_step=8)
assert c2.buffer.obs.shape == (100, 84, 84) assert c2.buffer.obs.shape == (100, 84, 84)
obs = np.zeros_like(c2.buffer.obs) obs = np.zeros_like(c2.buffer.obs)
@ -575,9 +659,10 @@ def test_collector_with_atari_setting() -> None:
assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1]) assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1])
# atari multi buffer # atari multi buffer
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]] env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns) envs = DummyVectorEnv(env_fns)
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4)) c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.reset()
c3.collect(n_step=12) c3.collect(n_step=12)
result = c3.collect(n_episode=9) result = c3.collect(n_episode=9)
assert result.n_collected_episodes == 9 assert result.n_collected_episodes == 9
@ -606,6 +691,7 @@ def test_collector_with_atari_setting() -> None:
save_only_last_obs=True, save_only_last_obs=True,
), ),
) )
c4.reset()
c4.collect(n_step=12) c4.collect(n_step=12)
result = c4.collect(n_episode=9) result = c4.collect(n_episode=9)
assert result.n_collected_episodes == 9 assert result.n_collected_episodes == 9
@ -672,6 +758,7 @@ def test_collector_with_atari_setting() -> None:
buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True) buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True)
c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10)) c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
c5.reset()
result_ = c5.collect(n_step=12) result_ = c5.collect(n_step=12)
assert len(buf) == 5 assert len(buf) == 5
assert len(c5.buffer) == 12 assert len(c5.buffer) == 12
@ -767,6 +854,7 @@ def test_collector_with_atari_setting() -> None:
# test buffer=None # test buffer=None
c6 = Collector(policy, envs) c6 = Collector(policy, envs)
c6.reset()
result1 = c6.collect(n_step=12) result1 = c6.collect(n_step=12)
for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]: for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]:
assert np.allclose(getattr(result1, key), getattr(result_, key)) assert np.allclose(getattr(result1, key), getattr(result_, key))
@ -778,7 +866,7 @@ def test_collector_with_atari_setting() -> None:
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform") @pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_collector_envpool_gym_reset_return_info() -> None: def test_collector_envpool_gym_reset_return_info() -> None:
envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True) envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True)
policy = MyPolicy(action_shape=(len(envs), 1)) policy = MaxActionPolicy(action_shape=(len(envs), 1))
c0 = Collector( c0 = Collector(
policy, policy,
@ -786,18 +874,59 @@ def test_collector_envpool_gym_reset_return_info() -> None:
VectorReplayBuffer(len(envs) * 10, len(envs)), VectorReplayBuffer(len(envs) * 10, len(envs)),
exploration_noise=True, exploration_noise=True,
) )
c0.reset()
c0.collect(n_step=8) c0.collect(n_step=8)
env_ids = np.zeros(len(envs) * 10) env_ids = np.zeros(len(envs) * 10)
env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3] env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c0.buffer.info["env_id"], env_ids) assert np.allclose(c0.buffer.info["env_id"], env_ids)
def test_collector_with_vector_env():
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]]
dum = DummyVectorEnv(env_fns)
policy = MaxActionPolicy()
c2 = Collector(
policy,
dum,
VectorReplayBuffer(total_size=100, buffer_num=4),
)
c2.reset()
c1r = c2.collect(n_episode=2)
assert np.array_equal(np.array([1, 8]), c1r.lens)
c2r = c2.collect(n_episode=10)
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 8, 9, 10]), c2r.lens)
c3r = c2.collect(n_step=20)
assert np.array_equal(np.array([1, 1, 1, 1, 1]), c3r.lens)
c4r = c2.collect(n_step=20)
assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens)
def test_async_collector_with_vector_env():
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]]
dum = DummyVectorEnv(env_fns)
policy = MaxActionPolicy()
c1 = AsyncCollector(
policy,
dum,
VectorReplayBuffer(total_size=100, buffer_num=4),
)
c1r = c1.collect(n_episode=10, reset_before_collect=True)
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens)
c2r = c1.collect(n_step=20)
assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens)
if __name__ == "__main__": if __name__ == "__main__":
test_collector(gym_reset_kwargs=None) test_collector()
test_collector(gym_reset_kwargs={})
test_collector_with_dict_state() test_collector_with_dict_state()
test_collector_with_ma() test_collector_with_multi_agent()
test_collector_with_atari_setting() test_collector_with_atari_setting()
test_collector_with_async(gym_reset_kwargs=None)
test_collector_with_async(gym_reset_kwargs={"return_info": True})
test_collector_envpool_gym_reset_return_info() test_collector_envpool_gym_reset_return_info()
test_collector_with_vector_env()
test_async_collector_with_vector_env()

View File

@ -20,9 +20,9 @@ from tianshou.env.gym_wrappers import TruncatedAsTerminated
from tianshou.utils import RunningMeanStd from tianshou.utils import RunningMeanStd
if __name__ == "__main__": if __name__ == "__main__":
from env import MyTestEnv, NXEnv from env import MoveToRightEnv, NXEnv
else: # pytest else: # pytest
from test.base.env import MyTestEnv, NXEnv from test.base.env import MoveToRightEnv, NXEnv
try: try:
import envpool import envpool
@ -56,7 +56,7 @@ def recurse_comp(a, b):
def test_async_env(size=10000, num=8, sleep=0.1) -> None: def test_async_env(size=10000, num=8, sleep=0.1) -> None:
# simplify the test case, just keep stepping # simplify the test case, just keep stepping
env_fns = [ env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True) lambda i=i: MoveToRightEnv(size=i, sleep=sleep, random_sleep=True)
for i in range(size, size + num) for i in range(size, size + num)
] ]
test_cls = [SubprocVectorEnv, ShmemVectorEnv] test_cls = [SubprocVectorEnv, ShmemVectorEnv]
@ -108,10 +108,10 @@ def test_async_env(size=10000, num=8, sleep=0.1) -> None:
def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None: def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None:
env_fns = [ env_fns = [
lambda: MyTestEnv(size=size, sleep=sleep * 2), lambda: MoveToRightEnv(size=size, sleep=sleep * 2),
lambda: MyTestEnv(size=size, sleep=sleep * 3), lambda: MoveToRightEnv(size=size, sleep=sleep * 3),
lambda: MyTestEnv(size=size, sleep=sleep * 5), lambda: MoveToRightEnv(size=size, sleep=sleep * 5),
lambda: MyTestEnv(size=size, sleep=sleep * 7), lambda: MoveToRightEnv(size=size, sleep=sleep * 7),
] ]
test_cls = [SubprocVectorEnv, ShmemVectorEnv] test_cls = [SubprocVectorEnv, ShmemVectorEnv]
if has_ray(): if has_ray():
@ -156,7 +156,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None:
def test_vecenv(size=10, num=8, sleep=0.001) -> None: def test_vecenv(size=10, num=8, sleep=0.001) -> None:
env_fns = [ env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True) lambda i=i: MoveToRightEnv(size=i, sleep=sleep, recurse_state=True)
for i in range(size, size + num) for i in range(size, size + num)
] ]
venv = [ venv = [
@ -237,7 +237,7 @@ def test_env_obs_dtype() -> None:
def test_env_reset_optional_kwargs(size=10000, num=8) -> None: def test_env_reset_optional_kwargs(size=10000, num=8) -> None:
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)] env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)]
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv] test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
if has_ray(): if has_ray():
test_cls += [RayVectorEnv] test_cls += [RayVectorEnv]
@ -257,7 +257,7 @@ def test_venv_wrapper_gym(num_envs: int = 4) -> None:
except ValueError: except ValueError:
obs, info = envs.reset(return_info=True) obs, info = envs.reset(return_info=True)
assert isinstance(obs, np.ndarray) assert isinstance(obs, np.ndarray)
assert isinstance(info, list) assert isinstance(info, np.ndarray)
assert isinstance(info[0], dict) assert isinstance(info[0], dict)
assert obs.shape[0] == len(info) == num_envs assert obs.shape[0] == len(info) == num_envs
@ -334,7 +334,7 @@ def test_venv_norm_obs() -> None:
action = np.array([1, 1, 1, 1]) action = np.array([1, 1, 1, 1])
total_step = 30 total_step = 30
action_list = [action] * total_step action_list = [action] * total_step
env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes] env_fns = [lambda i=x: MoveToRightEnv(size=i, array_state=True) for x in sizes]
raw = DummyVectorEnv(env_fns) raw = DummyVectorEnv(env_fns)
train_env = VectorEnvNormObs(DummyVectorEnv(env_fns)) train_env = VectorEnvNormObs(DummyVectorEnv(env_fns))
print(train_env.observation_space) print(train_env.observation_space)

View File

@ -90,20 +90,20 @@ class FiniteVectorEnv(BaseVectorEnv):
# END # END
def reset(self, id=None): def reset(self, env_id=None):
id = self._wrap_id(id) env_id = self._wrap_id(env_id)
self._reset_alive_envs() self._reset_alive_envs()
# ask super to reset alive envs and remap to current index # ask super to reset alive envs and remap to current index
request_id = list(filter(lambda i: i in self._alive_env_ids, id)) request_id = list(filter(lambda i: i in self._alive_env_ids, env_id))
obs = [None] * len(id) obs = [None] * len(env_id)
infos = [None] * len(id) infos = [None] * len(env_id)
id2idx = {i: k for k, i in enumerate(id)} id2idx = {i: k for k, i in enumerate(env_id)}
if request_id: if request_id:
for k, o, info in zip(request_id, *super().reset(request_id), strict=True): for k, o, info in zip(request_id, *super().reset(request_id), strict=True):
obs[id2idx[k]] = o obs[id2idx[k]] = o
infos[id2idx[k]] = info infos[id2idx[k]] = info
for i, o in zip(id, obs, strict=True): for i, o in zip(env_id, obs, strict=True):
if o is None and i in self._alive_env_ids: if o is None and i in self._alive_env_ids:
self._alive_env_ids.remove(i) self._alive_env_ids.remove(i)
@ -121,7 +121,7 @@ class FiniteVectorEnv(BaseVectorEnv):
self.reset() self.reset()
raise StopIteration raise StopIteration
return np.stack(obs), infos return np.stack(obs), np.array(infos)
def step(self, action, id=None): def step(self, action, id=None):
id = self._wrap_id(id) id = self._wrap_id(id)
@ -204,10 +204,12 @@ def test_finite_dummy_vector_env() -> None:
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
policy = AnyPolicy() policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True) test_collector = Collector(policy, envs, exploration_noise=True)
test_collector.reset()
for _ in range(3): for _ in range(3):
envs.tracker = MetricTracker() envs.tracker = MetricTracker()
try: try:
# TODO: why on earth 10**18?
test_collector.collect(n_step=10**18) test_collector.collect(n_step=10**18)
except StopIteration: except StopIteration:
envs.tracker.validate() envs.tracker.validate()
@ -218,6 +220,7 @@ def test_finite_subproc_vector_env() -> None:
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)]) envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
policy = AnyPolicy() policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True) test_collector = Collector(policy, envs, exploration_noise=True)
test_collector.reset()
for _ in range(3): for _ in range(3):
envs.tracker = MetricTracker() envs.tracker = MetricTracker()

View File

@ -2,7 +2,7 @@ import gymnasium as gym
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from torch.distributions import Categorical, Independent, Normal from torch.distributions import Categorical, Distribution, Independent, Normal
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.common import ActorCritic, Net
@ -25,7 +25,11 @@ def policy(request):
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape), Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape, action_shape=action_space.shape,
) )
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
elif action_type == "discrete": elif action_type == "discrete":
action_space = gym.spaces.Discrete(3) action_space = gym.spaces.Discrete(3)
actor = Actor( actor = Actor(

View File

@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy[NPGTrainingStats] = NPGPolicy( policy: NPGPolicy[NPGTrainingStats] = NPGPolicy(
actor=actor, actor=actor,

View File

@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy[PPOTrainingStats] = PPOPolicy( policy: PPOPolicy[PPOTrainingStats] = PPOPolicy(
actor=actor, actor=actor,

View File

@ -136,6 +136,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
exploration_noise=True, exploration_noise=True,
) )
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True) train_collector.collect(n_step=args.start_timesteps, random=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "redq") log_path = os.path.join(args.logdir, args.task, "redq")

View File

@ -162,6 +162,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
env = gym.make(args.task) env = gym.make(args.task)
policy.eval() policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render) collector_stats = collector.collect(n_episode=1, render=args.render)
print(collector_stats) print(collector_stats)

View File

@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = TRPOPolicy( policy: BasePolicy = TRPOPolicy(
actor=actor, actor=actor,

View File

@ -109,7 +109,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
train_envs, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)), VectorReplayBuffer(args.buffer_size, len(train_envs)),
) )
train_collector.reset()
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
test_collector.reset()
# log # log
log_path = os.path.join(args.logdir, args.task, "a2c") log_path = os.path.join(args.logdir, args.task, "a2c")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -108,7 +108,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
) )
test_collector = Collector(policy, test_envs, exploration_noise=False) test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
def train_fn(epoch: int, env_step: int) -> None: # exp decay def train_fn(epoch: int, env_step: int) -> None: # exp decay
eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test) eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)

View File

@ -120,7 +120,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "c51") log_path = os.path.join(args.logdir, args.task, "c51")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -111,7 +111,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "dqn") log_path = os.path.join(args.logdir, args.task, "dqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -95,7 +95,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
# the stack_num is for RNN training: sample framestack obs # the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "drqn") log_path = os.path.join(args.logdir, args.task, "drqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -128,7 +128,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "fqf") log_path = os.path.join(args.logdir, args.task, "fqf")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -124,7 +124,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "iqn") log_path = os.path.join(args.logdir, args.task, "iqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -113,7 +113,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "qrdqn") log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -128,7 +128,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "rainbow") log_path = os.path.join(args.logdir, args.task, "rainbow")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -154,7 +154,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "dqn_icm") log_path = os.path.join(args.logdir, args.task, "dqn_icm")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -81,7 +81,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
VectorReplayBuffer(args.buffer_size, len(train_envs)), VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True, exploration_noise=True,
) )
train_collector.reset()
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
test_collector.reset()
# Logger # Logger
log_path = os.path.join(args.logdir, args.task, "psrl") log_path = os.path.join(args.logdir, args.task, "psrl")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
@ -120,7 +122,6 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance! # Let's watch its performance!
policy.eval() policy.eval()
test_envs.seed(args.seed) test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render) result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f"Final reward: {result.rew_mean}, length: {result.len_mean}") print(f"Final reward: {result.rew_mean}, length: {result.len_mean}")
elif env.spec.reward_threshold: elif env.spec.reward_threshold:

View File

@ -115,9 +115,11 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs)) buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector # collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True) train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
train_collector.reset()
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
test_collector.reset()
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, args.task, "qrdqn") log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)
@ -165,6 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs)) buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.2) policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True) collector = Collector(policy, test_envs, buf, exploration_noise=True)
collector.reset()
collector_stats = collector.collect(n_step=args.buffer_size) collector_stats = collector.collect(n_step=args.buffer_size)
if args.save_buffer_name.endswith(".hdf5"): if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name) buf.save_hdf5(args.save_buffer_name)

View File

@ -178,6 +178,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None: def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None:
test_discrete_bcq()
args.resume = True args.resume = True
test_discrete_bcq(args) test_discrete_bcq(args)

View File

@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent # replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward # pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = GAILPolicy( policy: BasePolicy = GAILPolicy(
actor=actor, actor=actor,

View File

@ -83,8 +83,8 @@ def get_agents(
if isinstance(env.observation_space, gym.spaces.Dict) if isinstance(env.observation_space, gym.spaces.Dict)
else env.observation_space else env.observation_space
) )
args.state_shape = observation_space.shape or observation_space.n args.state_shape = observation_space.shape or int(observation_space.n)
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or int(env.action_space.n)
if agents is None: if agents is None:
agents = [] agents = []
optims = [] optims = []
@ -135,7 +135,7 @@ def train_agent(
exploration_noise=True, exploration_noise=True,
) )
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, "pistonball", "dqn") log_path = os.path.join(args.logdir, "pistonball", "dqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -181,8 +181,9 @@ def get_agents(
torch.nn.init.zeros_(m.bias) torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr) optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr)
def dist(*logits: torch.Tensor) -> Distribution: def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
return Independent(Normal(*logits), 1) loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
agent: PPOPolicy = PPOPolicy( agent: PPOPolicy = PPOPolicy(
actor, actor,
@ -234,7 +235,7 @@ def train_agent(
exploration_noise=False, # True exploration_noise=False, # True
) )
test_collector = Collector(policy, test_envs) test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.batch_size * args.training_num) # train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, "pistonball", "dqn") log_path = os.path.join(args.logdir, "pistonball", "dqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -102,8 +102,8 @@ def get_agents(
if isinstance(env.observation_space, gymnasium.spaces.Dict) if isinstance(env.observation_space, gymnasium.spaces.Dict)
else env.observation_space else env.observation_space
) )
args.state_shape = observation_space.shape or observation_space.n args.state_shape = observation_space.shape or int(observation_space.n)
args.action_shape = env.action_space.shape or env.action_space.n args.action_shape = env.action_space.shape or int(env.action_space.n)
if agent_learn is None: if agent_learn is None:
# model # model
net = Net( net = Net(
@ -170,7 +170,7 @@ def train_agent(
) )
test_collector = Collector(policy, test_envs, exploration_noise=True) test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1) # policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num) train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log # log
log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn") log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn")
writer = SummaryWriter(log_path) writer = SummaryWriter(log_path)

View File

@ -263,6 +263,9 @@ class BatchProtocol(Protocol):
def __repr__(self) -> str: def __repr__(self) -> str:
... ...
def __iter__(self) -> Iterator[Self]:
...
def to_numpy(self) -> None: def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place.""" """Change all torch.Tensor to numpy.ndarray in-place."""
... ...
@ -391,6 +394,12 @@ class BatchProtocol(Protocol):
""" """
... ...
def to_dict(self) -> dict[str, Any]:
...
def to_list_of_dicts(self) -> list[dict[str, Any]]:
...
class Batch(BatchProtocol): class Batch(BatchProtocol):
"""See :class:`~tianshou.data.batch.BatchProtocol`.""" """See :class:`~tianshou.data.batch.BatchProtocol`."""
@ -422,6 +431,17 @@ class Batch(BatchProtocol):
# Feels like kwargs could be just merged into batch_dict in the beginning # Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore self.__init__(kwargs, copy=copy) # type: ignore
def to_dict(self) -> dict[str, Any]:
result = {}
for k, v in self.__dict__.items():
if isinstance(v, Batch):
v = v.to_dict()
result[k] = v
return result
def to_list_of_dicts(self) -> list[dict[str, Any]]:
return [entry.to_dict() for entry in self]
def __setattr__(self, key: str, value: Any) -> None: def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value.""" """Set self.key = value."""
self.__dict__[key] = _parse_value(value) self.__dict__[key] = _parse_value(value)
@ -478,6 +498,14 @@ class Batch(BatchProtocol):
return new_batch return new_batch
raise IndexError("Cannot access item from empty Batch object.") raise IndexError("Cannot access item from empty Batch object.")
def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
yield from []
else:
for i in range(len(self)):
yield self[i]
def __setitem__(self, index: str | IndexType, value: Any) -> None: def __setitem__(self, index: str | IndexType, value: Any) -> None:
"""Assign value to self[index].""" """Assign value to self[index]."""
value = _parse_value(value) value = _parse_value(value)
@ -601,10 +629,10 @@ class Batch(BatchProtocol):
else: else:
# ndarray or scalar # ndarray or scalar
if not isinstance(obj, np.ndarray): if not isinstance(obj, np.ndarray):
obj = np.asanyarray(obj) # noqa: PLW2901 obj = np.asanyarray(obj)
obj = torch.from_numpy(obj).to(device) # noqa: PLW2901 obj = torch.from_numpy(obj).to(device)
if dtype is not None: if dtype is not None:
obj = obj.type(dtype) # noqa: PLW2901 obj = obj.type(dtype)
self.__dict__[batch_key] = obj self.__dict__[batch_key] = obj
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None: def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:

View File

@ -200,7 +200,7 @@ class ReplayBufferManager(ReplayBuffer):
return np.concatenate( return np.concatenate(
[ [
buf.sample_indices(bsz) + offset buf.sample_indices(int(bsz)) + offset
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True) for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True)
], ],
) )

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ from tianshou.data.batch import Batch, _parse_value
# TODO: confusing name, could actually return a batch... # TODO: confusing name, could actually return a batch...
# Overrides and generic types should be added # Overrides and generic types should be added
# todo check for ActBatchProtocol
@no_type_check @no_type_check
def to_numpy(x: Any) -> Batch | np.ndarray: def to_numpy(x: Any) -> Batch | np.ndarray:
"""Return an object without torch.Tensor.""" """Return an object without torch.Tensor."""

View File

@ -44,14 +44,14 @@ class VectorEnvWrapper(BaseVectorEnv):
def reset( def reset(
self, self,
id: int | list[int] | np.ndarray | None = None, env_id: int | list[int] | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> tuple[np.ndarray, dict | list[dict]]: ) -> tuple[np.ndarray, np.ndarray]:
return self.venv.reset(id, **kwargs) return self.venv.reset(env_id, **kwargs)
def step( def step(
self, self,
action: np.ndarray | torch.Tensor, action: np.ndarray | torch.Tensor | None,
id: int | list[int] | np.ndarray | None = None, id: int | list[int] | np.ndarray | None = None,
) -> gym_new_venv_step_type: ) -> gym_new_venv_step_type:
return self.venv.step(action, id) return self.venv.step(action, id)
@ -80,10 +80,10 @@ class VectorEnvNormObs(VectorEnvWrapper):
def reset( def reset(
self, self,
id: int | list[int] | np.ndarray | None = None, env_id: int | list[int] | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> tuple[np.ndarray, dict | list[dict]]: ) -> tuple[np.ndarray, np.ndarray]:
obs, info = self.venv.reset(id, **kwargs) obs, info = self.venv.reset(env_id, **kwargs)
if isinstance(obs, tuple): # type: ignore if isinstance(obs, tuple): # type: ignore
raise TypeError( raise TypeError(
@ -98,7 +98,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
def step( def step(
self, self,
action: np.ndarray | torch.Tensor, action: np.ndarray | torch.Tensor | None,
id: int | list[int] | np.ndarray | None = None, id: int | list[int] | np.ndarray | None = None,
) -> gym_new_venv_step_type: ) -> gym_new_venv_step_type:
step_results = self.venv.step(action, id) step_results = self.venv.step(action, id)

26
tianshou/env/venvs.py vendored
View File

@ -190,11 +190,13 @@ class BaseVectorEnv:
), f"Cannot interact with environment {i} which is stepping now." ), f"Cannot interact with environment {i} which is stepping now."
assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}." assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}."
# TODO: for now, has to be kept in sync with reset in EnvPoolMixin
# In particular, can't rename env_id to env_ids
def reset( def reset(
self, self,
id: int | list[int] | np.ndarray | None = None, env_id: int | list[int] | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> tuple[np.ndarray, dict | list[dict]]: ) -> tuple[np.ndarray, np.ndarray]:
"""Reset the state of some envs and return initial observations. """Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return If id is None, reset the state of all the environments and return
@ -202,14 +204,14 @@ class BaseVectorEnv:
the given id, either an int or a list. the given id, either an int or a list.
""" """
self._assert_is_not_closed() self._assert_is_not_closed()
id = self._wrap_id(id) env_id = self._wrap_id(env_id)
if self.is_async: if self.is_async:
self._assert_id(id) self._assert_id(env_id)
# send(None) == reset() in worker # send(None) == reset() in worker
for i in id: for id in env_id:
self.workers[i].send(None, **kwargs) self.workers[id].send(None, **kwargs)
ret_list = [self.workers[i].recv() for i in id] ret_list = [self.workers[id].recv() for id in env_id]
assert ( assert (
isinstance(ret_list[0], tuple | list) isinstance(ret_list[0], tuple | list)
@ -229,12 +231,12 @@ class BaseVectorEnv:
except ValueError: # different len(obs) except ValueError: # different len(obs)
obs = np.array(obs_list, dtype=object) obs = np.array(obs_list, dtype=object)
infos = [r[1] for r in ret_list] infos = np.array([r[1] for r in ret_list])
return obs, infos # type: ignore return obs, infos
def step( def step(
self, self,
action: np.ndarray | torch.Tensor, action: np.ndarray | torch.Tensor | None,
id: int | list[int] | np.ndarray | None = None, id: int | list[int] | np.ndarray | None = None,
) -> gym_new_venv_step_type: ) -> gym_new_venv_step_type:
"""Run one timestep of some environments' dynamics. """Run one timestep of some environments' dynamics.
@ -248,6 +250,8 @@ class BaseVectorEnv:
batch_done, batch_info) in numpy format. batch_done, batch_info) in numpy format.
:param numpy.ndarray action: a batch of action provided by the agent. :param numpy.ndarray action: a batch of action provided by the agent.
If the venv is async, the action can be None, which will result
in all arrays in the returned tuple being empty.
:return: A tuple consisting of either: :return: A tuple consisting of either:
@ -271,6 +275,8 @@ class BaseVectorEnv:
self._assert_is_not_closed() self._assert_is_not_closed()
id = self._wrap_id(id) id = self._wrap_id(id)
if not self.is_async: if not self.is_async:
if action is None:
raise ValueError("action must be not-None for non-async")
assert len(action) == len(id) assert len(action) == len(id)
for i, j in enumerate(id): for i, j in enumerate(id):
self.workers[j].send(action[i]) self.workers[j].send(action[i])

View File

@ -93,7 +93,14 @@ class AgentFactory(ABC, ToStringMixin):
self, self,
policy: BasePolicy, policy: BasePolicy,
envs: Environments, envs: Environments,
reset_collectors: bool = True,
) -> tuple[Collector, Collector]: ) -> tuple[Collector, Collector]:
""":param policy:
:param envs:
:param reset_collectors: Whether to reset the collectors before returning them.
Setting to True means that the envs will be reset as well.
:return:
"""
buffer_size = self.sampling_config.buffer_size buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs train_envs = envs.train_envs
buffer: ReplayBuffer buffer: ReplayBuffer
@ -114,6 +121,10 @@ class AgentFactory(ABC, ToStringMixin):
) )
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs) test_collector = Collector(policy, envs.test_envs)
if reset_collectors:
train_collector.reset()
test_collector.reset()
if self.sampling_config.start_timesteps > 0: if self.sampling_config.start_timesteps > 0:
log.info( log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})", f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",

View File

@ -312,7 +312,7 @@ class Experiment(ToStringMixin):
) -> None: ) -> None:
policy.eval() policy.eval()
collector = Collector(policy, env) collector = Collector(policy, env)
result = collector.collect(n_episode=num_episodes, render=render) result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True)
assert result.returns_stat is not None # for mypy assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy assert result.lens_stat is not None # for mypy
log.info( log.info(

View File

@ -1,40 +1,47 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch import torch
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
class DistributionFunctionFactory(ToStringMixin, ABC): class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
@abstractmethod @abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass pass
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory): class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self) envs.get_type().assert_discrete(self)
return self._dist_fn return self._dist_fn
@staticmethod @staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution: def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p) return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory): class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self) envs.get_type().assert_continuous(self)
return self._dist_fn return self._dist_fn
@staticmethod @staticmethod
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution: def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1) loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory): class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type(): match envs.get_type():
case EnvType.DISCRETE: case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs) return DistributionFunctionFactoryCategorical().create_dist_fn(envs)

View File

@ -19,7 +19,7 @@ from tianshou.highlevel.params.dist_fn import (
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils import MultipleLRSchedulers from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin from tianshou.utils.string import ToStringMixin
@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation. whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training. Does not affect training.
""" """
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default" dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
""" """
This can either be a function which maps the model output to a torch distribution or a This can either be a function which maps the model output to a torch distribution or a
factory for the creation of such a function. factory for the creation of such a function.

View File

@ -18,6 +18,7 @@ from tianshou.data.batch import Batch, BatchProtocol, arr_type
from tianshou.data.buffer.base import TBuffer from tianshou.data.buffer.base import TBuffer
from tianshou.data.types import ( from tianshou.data.types import (
ActBatchProtocol, ActBatchProtocol,
ActStateBatchProtocol,
BatchWithReturnsProtocol, BatchWithReturnsProtocol,
ObsBatchProtocol, ObsBatchProtocol,
RolloutBatchProtocol, RolloutBatchProtocol,
@ -212,10 +213,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
super().__init__() super().__init__()
self.observation_space = observation_space self.observation_space = observation_space
self.action_space = action_space self.action_space = action_space
self._action_type: Literal["discrete", "continuous"]
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary): if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
self.action_type = "discrete" self._action_type = "discrete"
elif isinstance(action_space, Box): elif isinstance(action_space, Box):
self.action_type = "continuous" self._action_type = "continuous"
else: else:
raise ValueError(f"Unsupported action space: {action_space}.") raise ValueError(f"Unsupported action space: {action_space}.")
self.agent_id = 0 self.agent_id = 0
@ -225,6 +227,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.lr_scheduler = lr_scheduler self.lr_scheduler = lr_scheduler
self._compile() self._compile()
@property
def action_type(self) -> Literal["discrete", "continuous"]:
return self._action_type
def set_agent_id(self, agent_id: int) -> None: def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL.""" """Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id self.agent_id = agent_id
@ -233,11 +239,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
# have a method to add noise to action. # have a method to add noise to action.
# So we add the default behavior here. It's a little messy, maybe one can # So we add the default behavior here. It's a little messy, maybe one can
# find a better way to do this. # find a better way to do this.
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise( def exploration_noise(
self, self,
act: np.ndarray | BatchProtocol, act: _TArrOrActBatch,
batch: RolloutBatchProtocol, batch: ObsBatchProtocol,
) -> np.ndarray | BatchProtocol: ) -> _TArrOrActBatch:
"""Modify the action from policy.forward with exploration noise. """Modify the action from policy.forward with exploration noise.
NOTE: currently does not add any noise! Needs to be overridden by subclasses NOTE: currently does not add any noise! Needs to be overridden by subclasses
@ -287,7 +296,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
batch: ObsBatchProtocol, batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None, state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> ActBatchProtocol: ) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing
"""Compute action over the given batch data. """Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys: :return: A :class:`~tianshou.data.Batch` which MUST have the following keys:

View File

@ -16,6 +16,12 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
@dataclass(kw_only=True) @dataclass(kw_only=True)
class ImitationTrainingStats(TrainingStats): class ImitationTrainingStats(TrainingStats):
@ -72,9 +78,20 @@ class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTra
state: dict | BatchProtocol | np.ndarray | None = None, state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> ModelOutputBatchProtocol: ) -> ModelOutputBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) # TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits if self.action_type == "discrete":
result = Batch(logits=logits, act=act, state=hidden) # If it's discrete, the "actor" is usually a critic that maps obs to action_values
# which then could be turned into logits or a Categorigal
action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
act_B = action_values_BA.argmax(dim=1)
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
elif self.action_type == "continuous":
# If it's continuous, the actor would usually deliver something like loc, scale determining a
# Gaussian dist
dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH)
else:
raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!")
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn( def learn(

View File

@ -34,8 +34,7 @@ TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteB
class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]): class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]):
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708. """Implementation of discrete BCQ algorithm. arXiv:1910.01708.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
:param imitator: a model following the rules in :param imitator: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits) :class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.

View File

@ -25,8 +25,7 @@ TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteC
class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]): class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]):
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779. """Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space. :param action_space: Env's action space.
:param min_q_weight: the weight for the cql loss. :param min_q_weight: the weight for the cql loss.

View File

@ -11,6 +11,7 @@ from tianshou.data import to_torch, to_torch_as
from tianshou.data.types import RolloutBatchProtocol from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass @dataclass
@ -26,8 +27,9 @@ TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteC
class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]): class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134. r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
:param actor: the actor network following the rules in :param actor: the actor network following the rules:
:class:`~tianshou.policy.BasePolicy`. (s -> logits) If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the action-value critic (i.e., Q function) :param critic: the action-value critic (i.e., Q function)
network. (s -> Q(s, \*)) network. (s -> Q(s, \*))
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | Actor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete, action_space: gym.spaces.Discrete,
discount_factor: float = 0.99, discount_factor: float = 0.99,

View File

@ -15,8 +15,11 @@ from tianshou.data import (
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import PPOPolicy from tianshou.policy import PPOPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.policy.modelfree.ppo import PPOTrainingStats from tianshou.policy.modelfree.ppo import PPOTrainingStats
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -32,7 +35,9 @@ TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats)
class GAILPolicy(PPOPolicy[TGailTrainingStats]): class GAILPolicy(PPOPolicy[TGailTrainingStats]):
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476. r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
expert_buffer: ReplayBuffer, expert_buffer: ReplayBuffer,
disc_net: torch.nn.Module, disc_net: torch.nn.Module,

View File

@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]):
"""Implementation of TD3+BC. arXiv:2106.06860. """Implementation of TD3+BC. arXiv:2106.06860.
:param actor: the actor network following the rules in :param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits) :class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Self from typing import Any, Literal, Self, TypeVar
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -105,11 +105,13 @@ class ICMPolicy(BasePolicy[ICMTrainingStats]):
""" """
return self.policy.forward(batch, state, **kwargs) return self.policy.forward(batch, state, **kwargs)
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise( def exploration_noise(
self, self,
act: np.ndarray | BatchProtocol, act: _TArrOrActBatch,
batch: RolloutBatchProtocol, batch: ObsBatchProtocol,
) -> np.ndarray | BatchProtocol: ) -> _TArrOrActBatch:
return self.policy.exploration_noise(act, batch) return self.policy.exploration_noise(act, batch)
def set_eps(self, eps: float) -> None: def set_eps(self, eps: float) -> None:

View File

@ -11,8 +11,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import PGPolicy from tianshou.policy import PGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -30,7 +33,9 @@ TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats)
class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var] class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var]
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783. """Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
vf_coef: float = 0.5, vf_coef: float = 0.5,
ent_coef: float = 0.01, ent_coef: float = 0.01,

View File

@ -8,6 +8,7 @@ import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as
from tianshou.data.batch import BatchProtocol from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ( from tianshou.data.types import (
ActBatchProtocol,
BatchWithReturnsProtocol, BatchWithReturnsProtocol,
ModelOutputBatchProtocol, ModelOutputBatchProtocol,
ObsBatchProtocol, ObsBatchProtocol,
@ -30,7 +31,7 @@ TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats)
class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]): class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
"""Implementation of the Branching dual Q network arXiv:1711.08946. """Implementation of the Branching dual Q network arXiv:1711.08946.
:param model: BranchingNet mapping (obs, state, info) -> logits. :param model: BranchingNet mapping (obs, state, info) -> action_values_BA.
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead. :param estimation_step: the number of steps to look ahead.
@ -155,10 +156,10 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
model = getattr(self, model) model = getattr(self, model)
obs = batch.obs obs = batch.obs
# TODO: this is very contrived, see also iqn.py # TODO: this is very contrived, see also iqn.py
obs_next = obs.obs if hasattr(obs, "obs") else obs obs_next_BO = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info) action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info)
act = to_numpy(logits.max(dim=-1)[1]) act_B = to_numpy(action_values_BA.argmax(dim=-1))
result = Batch(logits=logits, act=act, state=hidden) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
@ -182,11 +183,13 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value] return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value]
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise( def exploration_noise(
self, self,
act: np.ndarray | BatchProtocol, act: _TArrOrActBatch,
batch: RolloutBatchProtocol, batch: ObsBatchProtocol,
) -> np.ndarray | BatchProtocol: ) -> _TArrOrActBatch:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act) bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps rand_mask = np.random.rand(bsz) < self.eps

View File

@ -23,8 +23,7 @@ TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats)
class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]): class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887. """Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param num_atoms: the number of atoms in the support set of the :param num_atoms: the number of atoms in the support set of the

View File

@ -10,6 +10,7 @@ import torch
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
from tianshou.data.batch import BatchProtocol from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ( from tianshou.data.types import (
ActBatchProtocol,
ActStateBatchProtocol, ActStateBatchProtocol,
BatchWithReturnsProtocol, BatchWithReturnsProtocol,
ObsBatchProtocol, ObsBatchProtocol,
@ -18,6 +19,7 @@ from tianshou.data.types import (
from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.continuous import Actor, Critic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -32,8 +34,7 @@ TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats)
class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]): class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971. """Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
:param actor: The actor network following the rules in :param actor: The actor network following the rules (s -> actions)
:class:`~tianshou.policy.BasePolicy`. (s -> model_output)
:param actor_optim: The optimizer for actor network. :param actor_optim: The optimizer for actor network.
:param critic: The critic network. (s, a -> Q(s, a)) :param critic: The critic network. (s, a -> Q(s, a))
:param critic_optim: The optimizer for critic network. :param critic_optim: The optimizer for critic network.
@ -59,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
action_space: gym.Space, action_space: gym.Space,
tau: float = 0.005, tau: float = 0.005,
@ -208,11 +209,13 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value] return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value]
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise( def exploration_noise(
self, self,
act: np.ndarray | BatchProtocol, act: _TArrOrActBatch,
batch: RolloutBatchProtocol, batch: ObsBatchProtocol,
) -> np.ndarray | BatchProtocol: ) -> _TArrOrActBatch:
if self._exploration_noise is None: if self._exploration_noise is None:
return act return act
if isinstance(act, np.ndarray): if isinstance(act, np.ndarray):

View File

@ -8,11 +8,11 @@ from overrides import override
from torch.distributions import Categorical from torch.distributions import Categorical
from tianshou.data import Batch, ReplayBuffer, to_torch from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.data.batch import BatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import SACPolicy from tianshou.policy import SACPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.sac import SACTrainingStats from tianshou.policy.modelfree.sac import SACTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass @dataclass
@ -26,8 +26,7 @@ TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteS
class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]): class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207. """Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
:param actor: the actor network following the rules in :param actor: the actor network following the rules (s_B -> dist_input_BD)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.
@ -55,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete, action_space: gym.spaces.Discrete,
critic2: torch.nn.Module | None = None, critic2: torch.nn.Module | Critic | None = None,
critic2_optim: torch.optim.Optimizer | None = None, critic2_optim: torch.optim.Optimizer | None = None,
tau: float = 0.005, tau: float = 0.005,
gamma: float = 0.99, gamma: float = 0.99,
@ -106,13 +105,13 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
state: dict | Batch | np.ndarray | None = None, state: dict | Batch | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits) dist = Categorical(logits=logits_BA)
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.sample() act_B = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist) return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch( obs_next_batch = Batch(
@ -184,9 +183,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(), alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(),
) )
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise( def exploration_noise(
self, self,
act: np.ndarray | BatchProtocol, act: _TArrOrActBatch,
batch: RolloutBatchProtocol, batch: ObsBatchProtocol,
) -> np.ndarray | BatchProtocol: ) -> _TArrOrActBatch:
return act return act

View File

@ -9,6 +9,7 @@ import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.batch import BatchProtocol from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ( from tianshou.data.types import (
ActBatchProtocol,
BatchWithReturnsProtocol, BatchWithReturnsProtocol,
ModelOutputBatchProtocol, ModelOutputBatchProtocol,
ObsBatchProtocol, ObsBatchProtocol,
@ -16,6 +17,7 @@ from tianshou.data.types import (
) )
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.common import Net
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -34,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here). implemented in the network side, not here).
:param model: a model following the rules in :param model: a model following the rules (s -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead. :param estimation_step: the number of steps to look ahead.
@ -59,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
model: torch.nn.Module, model: torch.nn.Module | Net,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
# TODO: type violates Liskov substitution principle # TODO: type violates Liskov substitution principle
action_space: gym.spaces.Discrete, action_space: gym.spaces.Discrete,
@ -200,12 +201,12 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
obs = batch.obs obs = batch.obs
# TODO: this is convoluted! See also other places where this is done. # TODO: this is convoluted! See also other places where this is done.
obs_next = obs.obs if hasattr(obs, "obs") else obs obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info) action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None)) q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None))
if self.max_action_num is None: if self.max_action_num is None:
self.max_action_num = q.shape[1] self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1]) act_B = to_numpy(q.argmax(dim=1))
result = Batch(logits=logits, act=act, state=hidden) result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result) return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats: def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
@ -232,11 +233,13 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value] return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value]
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise( def exploration_noise(
self, self,
act: np.ndarray | BatchProtocol, act: _TArrOrActBatch,
batch: RolloutBatchProtocol, batch: ObsBatchProtocol,
) -> np.ndarray | BatchProtocol: ) -> _TArrOrActBatch:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act) bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps rand_mask = np.random.rand(bsz) < self.eps

View File

@ -27,8 +27,7 @@ TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats)
class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]): class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140. """Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param fraction_model: a FractionProposalNetwork for :param fraction_model: a FractionProposalNetwork for
proposing fractions/quantiles given state. proposing fractions/quantiles given state.

View File

@ -29,8 +29,7 @@ TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats)
class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]): class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
"""Implementation of Implicit Quantile Network. arXiv:1806.06923. """Implementation of Implicit Quantile Network. arXiv:1806.06923.
:param model: a model following the rules in :param model: a model following the rules (s_B -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].
:param sample_size: the number of samples for policy evaluation. :param sample_size: the number of samples for policy evaluation.

View File

@ -12,7 +12,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
optim_critic_iters: int = 5, optim_critic_iters: int = 5,
actor_step_size: float = 0.5, actor_step_size: float = 0.5,

View File

@ -1,7 +1,7 @@
import warnings import warnings
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast from typing import Any, Generic, Literal, TypeVar, cast
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
@ -24,9 +24,22 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils import RunningMeanStd from tianshou.utils import RunningMeanStd
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.discrete import Actor
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution] # Dimension Naming Convention
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution] # B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
TDistFnContinuous = Callable[
[tuple[torch.Tensor, torch.Tensor]],
torch.distributions.Distribution,
]
TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical]
TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -40,8 +53,9 @@ TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats)
class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]): class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
"""Implementation of REINFORCE algorithm. """Implementation of REINFORCE algorithm.
:param actor: mapping (s->model_output), should follow the rules in :param actor: the actor network following the rules:
:class:`~tianshou.policy.BasePolicy`. If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param optim: optimizer for actor network. :param optim: optimizer for actor network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
Maps model_output -> distribution. Typically a Gaussian distribution Maps model_output -> distribution. Typically a Gaussian distribution
@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | Actor,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
discount_factor: float = 0.99, discount_factor: float = 0.99,
# TODO: rename to return_normalization? # TODO: rename to return_normalization?
@ -175,20 +189,20 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation. more detailed explanation.
""" """
# TODO: rename? It's not really logits and there are particular # TODO - ALGO: marked for algorithm refactoring
# assumptions about the order of the output and on distribution type action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) # in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
if isinstance(logits, tuple): # therefore action_dist_input_BD is equivalent to logits_BA
dist = self.dist_fn(*logits) # If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian)
else: # the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(logits) dist = self.dist_fn(action_dist_input_BD)
# in this case, the dist is unused!
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.sample() act_B = dist.sample()
result = Batch(logits=logits, act=act, state=hidden, dist=dist) # act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result) return cast(DistBatchProtocol, result)
# TODO: why does mypy complain? # TODO: why does mypy complain?

View File

@ -10,8 +10,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -29,7 +32,9 @@ TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats)
class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var] class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var]
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347. r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
eps_clip: float = 0.2, eps_clip: float = 0.2,
dual_clip: float | None = None, dual_clip: float | None = None,

View File

@ -25,8 +25,7 @@ TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats)
class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]): class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044. """Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
:param model: a model following the rules in :param model: a model following the rules (s -> action_values_BA)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param optim: a torch.optim for optimizing the model. :param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space. :param action_space: Env's action space.
:param discount_factor: in [0, 1]. :param discount_factor: in [0, 1].

View File

@ -12,6 +12,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.ddpg import DDPGTrainingStats from tianshou.policy.modelfree.ddpg import DDPGTrainingStats
from tianshou.utils.net.continuous import ActorProb
@dataclass @dataclass
@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
@ -150,23 +151,28 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
state: dict | Batch | np.ndarray | None = None, state: dict | Batch | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> Batch: ) -> Batch:
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale dist = Independent(Normal(loc_B, scale_B), 1)
dist = Independent(Normal(loc, scale), 1)
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.rsample() act_B = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1) log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian # apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation. # in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act) squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1, -1,
keepdim=True, keepdim=True,
) )
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob) return Batch(
logits=(loc_B, scale_B),
act=squashed_action,
state=h_BH,
dist=dist,
log_prob=log_prob,
)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor: def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch( obs_next_batch = Batch(

View File

@ -17,6 +17,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.conversion import to_optional_float from tianshou.utils.conversion import to_optional_float
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.optim import clone_optimizer from tianshou.utils.optim import clone_optimizer
@ -36,8 +37,7 @@ TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats)
class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var] class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var]
"""Implementation of Soft Actor-Critic. arXiv:1812.05905. """Implementation of Soft Actor-Critic. arXiv:1812.05905.
:param actor: the actor network following the rules in :param actor: the actor network following the rules (s -> dist_input_BD)
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.
@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer, actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module, critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer, critic_optim: torch.optim.Optimizer,
@ -173,26 +173,25 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
state: dict | Batch | np.ndarray | None = None, state: dict | Batch | np.ndarray | None = None,
**kwargs: Any, **kwargs: Any,
) -> DistLogProbBatchProtocol: ) -> DistLogProbBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info) (loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
assert isinstance(logits, tuple) dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
dist = Independent(Normal(*logits), 1)
if self.deterministic_eval and not self.training: if self.deterministic_eval and not self.training:
act = dist.mode act_B = dist.mode
else: else:
act = dist.rsample() act_B = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1) log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian # apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21. # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation. # in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act) squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum( log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1, -1,
keepdim=True, keepdim=True,
) )
result = Batch( result = Batch(
logits=logits, logits=(loc_B, scale_B),
act=squashed_action, act=squashed_action,
state=hidden, state=hidden_BH,
dist=dist, dist=dist,
log_prob=log_prob, log_prob=log_prob,
) )

View File

@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t
"""Implementation of TD3, arXiv:1802.09477. """Implementation of TD3, arXiv:1802.09477.
:param actor: the actor network following the rules in :param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits) :class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network. :param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a)) :param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network. :param critic_optim: the optimizer for the first critic network.

View File

@ -11,7 +11,10 @@ from tianshou.data import Batch, SequenceSummaryStats
from tianshou.policy import NPGPolicy from tianshou.policy import NPGPolicy
from tianshou.policy.base import TLearningRateScheduler from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.npg import NPGTrainingStats from tianshou.policy.modelfree.npg import NPGTrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -25,7 +28,9 @@ TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats)
class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]): class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477. """Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
:param actor: the actor network following the rules in BasePolicy. (s -> logits) :param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s)) :param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network. :param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action. :param dist_fn: distribution class for computing the action.
@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
def __init__( def __init__(
self, self,
*, *,
actor: torch.nn.Module, actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module, critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer, optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction, dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space, action_space: gym.Space,
max_kl: float = 0.01, max_kl: float = 0.01,
backtrack_coeff: float = 0.8, backtrack_coeff: float = 0.8,

View File

@ -1,11 +1,11 @@
from typing import Any, Literal, Protocol, Self, cast, overload from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload
import numpy as np import numpy as np
from overrides import override from overrides import override
from tianshou.data import Batch, ReplayBuffer from tianshou.data import Batch, ReplayBuffer
from tianshou.data.batch import BatchProtocol, IndexType from tianshou.data.batch import BatchProtocol, IndexType
from tianshou.data.types import RolloutBatchProtocol from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats from tianshou.policy.base import TLearningRateScheduler, TrainingStats
@ -160,16 +160,18 @@ class MultiAgentPolicyManager(BasePolicy):
buffer._meta.rew = save_rew buffer._meta.rew = save_rew
return Batch(results) return Batch(results)
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise( def exploration_noise(
self, self,
act: np.ndarray | BatchProtocol, act: _TArrOrActBatch,
batch: RolloutBatchProtocol, batch: ObsBatchProtocol,
) -> np.ndarray | BatchProtocol: ) -> _TArrOrActBatch:
"""Add exploration noise from sub-policy onto act.""" """Add exploration noise from sub-policy onto act."""
assert isinstance( if not isinstance(batch.obs, Batch):
batch.obs, raise TypeError(
BatchProtocol, f"here only observations of type Batch are permitted, but got {type(batch.obs)}",
), f"here only observations of type Batch are permitted, but got {type(batch.obs)}" )
for agent_id, policy in self.policies.items(): for agent_id, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0] agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0: if len(agent_index) == 0:
@ -223,7 +225,7 @@ class MultiAgentPolicyManager(BasePolicy):
results.append((False, np.array([-1]), Batch(), Batch(), Batch())) results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
continue continue
tmp_batch = batch[agent_index] tmp_batch = batch[agent_index]
if isinstance(tmp_batch.rew, np.ndarray): if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray. # reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]] tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
if not hasattr(tmp_batch.obs, "mask"): if not hasattr(tmp_batch.obs, "mask"):

View File

@ -237,7 +237,13 @@ class BaseTrainer(ABC):
self.stop_fn_flag = False self.stop_fn_flag = False
self.iter_num = 0 self.iter_num = 0
def reset(self) -> None: def _reset_collectors(self, reset_buffer: bool = False) -> None:
if self.train_collector is not None:
self.train_collector.reset(reset_buffer=reset_buffer)
if self.test_collector is not None:
self.test_collector.reset(reset_buffer=reset_buffer)
def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None:
"""Initialize or reset the instance to yield a new iterator from zero.""" """Initialize or reset the instance to yield a new iterator from zero."""
self.is_run = False self.is_run = False
self.env_step = 0 self.env_step = 0
@ -250,16 +256,18 @@ class BaseTrainer(ABC):
self.last_rew, self.last_len = 0.0, 0.0 self.last_rew, self.last_len = 0.0, 0.0
self.start_time = time.time() self.start_time = time.time()
if self.train_collector is not None:
self.train_collector.reset_stat()
if self.train_collector.policy != self.policy or self.test_collector is None: if reset_collectors:
self._reset_collectors(reset_buffer=reset_buffer)
if self.train_collector is not None and (
self.train_collector.policy != self.policy or self.test_collector is None
):
self.test_in_train = False self.test_in_train = False
if self.test_collector is not None: if self.test_collector is not None:
assert self.episode_per_test is not None assert self.episode_per_test is not None
assert not isinstance(self.test_collector, AsyncCollector) # Issue 700 assert not isinstance(self.test_collector, AsyncCollector) # Issue 700
self.test_collector.reset_stat()
test_result = test_episode( test_result = test_episode(
self.policy, self.policy,
self.test_collector, self.test_collector,
@ -284,7 +292,7 @@ class BaseTrainer(ABC):
self.iter_num = 0 self.iter_num = 0
def __iter__(self): # type: ignore def __iter__(self): # type: ignore
self.reset() self.reset(reset_collectors=True, reset_buffer=False)
return self return self
def __next__(self) -> EpochStats: def __next__(self) -> EpochStats:
@ -308,8 +316,8 @@ class BaseTrainer(ABC):
# perform n step_per_epoch # perform n step_per_epoch
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t: with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
while t.n < t.total and not self.stop_fn_flag:
train_stat: CollectStatsBase train_stat: CollectStatsBase
while t.n < t.total and not self.stop_fn_flag:
if self.train_collector is not None: if self.train_collector is not None:
train_stat, self.stop_fn_flag = self.train_step() train_stat, self.stop_fn_flag = self.train_step()
pbar_data_dict = { pbar_data_dict = {
@ -515,12 +523,14 @@ class BaseTrainer(ABC):
stats of the whole dataset stats of the whole dataset
""" """
def run(self) -> InfoStats: def run(self, reset_prior_to_run: bool = True) -> InfoStats:
"""Consume iterator. """Consume iterator.
See itertools - recipes. Use functions that consume iterators at C speed See itertools - recipes. Use functions that consume iterators at C speed
(feed the entire iterator into a zero-length deque). (feed the entire iterator into a zero-length deque).
""" """
if reset_prior_to_run:
self.reset()
try: try:
self.is_run = True self.is_run = True
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque deque(self, maxlen=0) # feed the entire iterator into a zero-length deque

View File

@ -26,8 +26,7 @@ def test_episode(
reward_metric: Callable[[np.ndarray], np.ndarray] | None = None, reward_metric: Callable[[np.ndarray], np.ndarray] | None = None,
) -> CollectStats: ) -> CollectStats:
"""A simple wrapper of testing policy in collector.""" """A simple wrapper of testing policy in collector."""
collector.reset_env() collector.reset(reset_stats=False)
collector.reset_buffer()
policy.eval() policy.eval()
if test_fn: if test_fn:
test_fn(epoch, global_step) test_fn(epoch, global_step)

View File

@ -610,6 +610,17 @@ class BaseActor(nn.Module, ABC):
def get_output_dim(self) -> int: def get_output_dim(self) -> int:
pass pass
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[Any, Any]:
# TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction.
# Return type needs to be more specific
pass
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T: def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present. """Gets the given attribute from the given object or takes the alternative value if it is not present.

View File

@ -1,4 +1,5 @@
import warnings import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any from typing import Any
@ -9,6 +10,7 @@ from torch import nn
from tianshou.utils.net.common import ( from tianshou.utils.net.common import (
MLP, MLP,
BaseActor, BaseActor,
Net,
TActionShape, TActionShape,
TLinearLayer, TLinearLayer,
get_output_dim, get_output_dim,
@ -19,33 +21,27 @@ SIGMA_MAX = 2
class Actor(BaseActor): class Actor(BaseActor):
"""Simple actor network. """Simple actor network that directly outputs actions for continuous action space.
Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`.
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape. It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a :param preprocess_net: a self-defined preprocess_net, see usage.
flattened hidden state. Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action. :param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param preprocess_net_output_dim: the output dimension of
preprocess_net. preprocess_net.
:param max_action: the scale for the final action.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
action_shape: TActionShape, action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
max_action: float = 1.0, max_action: float = 1.0,
@ -77,42 +73,50 @@ class Actor(BaseActor):
state: Any = None, state: Any = None,
info: dict[str, Any] | None = None, info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]: ) -> tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action.""" """Mapping: s_B -> action_values_BA, hidden_state_BH | None.
if info is None:
info = {} Returns a tensor representing the actions directly, i.e, of shape
logits, hidden = self.preprocess(obs, state) `(n_actions, )`, and a hidden state (which may be None).
logits = self.max_action * torch.tanh(self.last(logits)) The hidden state is only not None if a recurrent net is used as part of the
return logits, hidden learning algorithm (support for RNNs is currently experimental).
"""
action_BA, hidden_BH = self.preprocess(obs, state)
action_BA = self.max_action * torch.tanh(self.last(action_BA))
return action_BA, hidden_BH
class Critic(nn.Module): class CriticBase(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None,
) -> torch.Tensor:
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
class Critic(CriticBase):
"""Simple critic network. """Simple critic network.
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value). It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a :param preprocess_net: a self-defined preprocess_net, see usage.
flattened hidden state. Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param preprocess_net_output_dim: the output dimension of
preprocess_net. preprocess_net.
:param linear_layer: use this module as linear layer. Default to nn.Linear. :param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
:param linear_layer: use this module as linear layer.
:param flatten_input: whether to flatten input data for the last layer. :param flatten_input: whether to flatten input data for the last layer.
Default to True.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
device: str | int | torch.device = "cpu", device: str | int | torch.device = "cpu",
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,
@ -139,9 +143,7 @@ class Critic(nn.Module):
act: np.ndarray | torch.Tensor | None = None, act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None, info: dict[str, Any] | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a).""" """Mapping: (s_B, a_B) -> Q(s, a)_B."""
if info is None:
info = {}
obs = torch.as_tensor( obs = torch.as_tensor(
obs, obs,
device=self.device, device=self.device,
@ -154,41 +156,35 @@ class Critic(nn.Module):
dtype=torch.float32, dtype=torch.float32,
).flatten(1) ).flatten(1)
obs = torch.cat([obs, act], dim=1) obs = torch.cat([obs, act], dim=1)
logits, hidden = self.preprocess(obs) values_B, hidden_BH = self.preprocess(obs)
return self.last(logits) return self.last(values_B)
class ActorProb(BaseActor): class ActorProb(BaseActor):
"""Simple actor network (output with a Gauss distribution). """Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian).
:param preprocess_net: a self-defined preprocess_net which output a Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`.
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action. :param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param unbounded: whether to apply tanh activation on final logits.
Default to False.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter. Default to False.
:param preprocess_net_output_dim: the output dimension of
preprocess_net. preprocess_net.
:param max_action: the scale for the final action logits.
:param unbounded: whether to apply tanh activation on final logits.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
# TODO: force kwargs, adjust downstream code # TODO: force kwargs, adjust downstream code
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
action_shape: TActionShape, action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
max_action: float = 1.0, max_action: float = 1.0,
@ -402,8 +398,7 @@ class Perturbation(nn.Module):
flattened hidden state. flattened hidden state.
:param max_action: the maximum value of each dimension of action. :param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on. :param device: which device to create this model on.
Default to cpu. :param phi: max perturbation parameter for BCQ.
:param phi: max perturbation parameter for BCQ. Default to 0.05.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
@ -449,7 +444,6 @@ class VAE(nn.Module):
:param latent_dim: the size of latent layer. :param latent_dim: the size of latent layer.
:param max_action: the maximum value of each dimension of action. :param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on. :param device: which device to create this model on.
Default to "cpu".
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.

View File

@ -7,17 +7,14 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from tianshou.data import Batch, to_torch from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim
class Actor(BaseActor): class Actor(BaseActor):
"""Simple actor network. """Simple actor network for discrete action spaces.
Will create an actor operated in discrete action space with structure of :param preprocess_net: a self-defined preprocess_net. Typically, an instance of
preprocess_net ---> action_shape. :class:`~tianshou.utils.net.common.Net`.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param action_shape: a sequence of int for the shape of action. :param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains preprocess_net. Default to empty sequence (where the MLP now contains
@ -25,20 +22,15 @@ class Actor(BaseActor):
:param softmax_output: whether to apply a softmax layer over the last :param softmax_output: whether to apply a softmax layer over the last
layer's output. layer's output.
:param preprocess_net_output_dim: the output dimension of :param preprocess_net_output_dim: the output dimension of
preprocess_net. `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
action_shape: TActionShape, action_shape: TActionShape,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
softmax_output: bool = True, softmax_output: bool = True,
@ -71,43 +63,44 @@ class Actor(BaseActor):
obs: np.ndarray | torch.Tensor, obs: np.ndarray | torch.Tensor,
state: Any = None, state: Any = None,
info: dict[str, Any] | None = None, info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
r"""Mapping: s -> Q(s, \*).""" r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
if info is None:
info = {} Returns a tensor representing the values of each action, i.e, of shape
logits, hidden = self.preprocess(obs, state) `(n_actions, )`, and
logits = self.last(logits) a hidden state (which may be None). If `self.softmax_output` is True, they are the
probabilities for taking each action. Otherwise, they will be action values.
The hidden state is only
not None if a recurrent net is used as part of the learning algorithm.
"""
x, hidden_BH = self.preprocess(obs, state)
x = self.last(x)
if self.softmax_output: if self.softmax_output:
logits = F.softmax(logits, dim=-1) x = F.softmax(x, dim=-1)
return logits, hidden # If we computed softmax, output is probabilities, otherwise it's the non-normalized action values
output_BA = x
return output_BA, hidden_BH
class Critic(nn.Module): class Critic(nn.Module):
"""Simple critic network. """Simple critic network for discrete action spaces.
It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value). :param preprocess_net: a self-defined preprocess_net. Typically, an instance of
:class:`~tianshou.utils.net.common.Net`.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param hidden_sizes: a sequence of int for constructing the MLP after :param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer). only a single linear layer).
:param last_size: the output dimension of Critic network. Default to 1. :param last_size: the output dimension of Critic network. Default to 1.
:param preprocess_net_output_dim: the output dimension of :param preprocess_net_output_dim: the output dimension of
preprocess_net. `preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`. :ref:`build_the_network`..
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
""" """
def __init__( def __init__(
self, self,
preprocess_net: nn.Module, preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (), hidden_sizes: Sequence[int] = (),
last_size: int = 1, last_size: int = 1,
preprocess_net_output_dim: int | None = None, preprocess_net_output_dim: int | None = None,
@ -120,8 +113,10 @@ class Critic(nn.Module):
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim) input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
# TODO: make a proper interface!
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor: def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Mapping: s -> V(s).""" """Mapping: s_B -> V(s)_B."""
# TODO: don't use this mechanism for passing state
logits, _ = self.preprocess(obs, state=kwargs.get("state", None)) logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
return self.last(logits) return self.last(logits)