Merge branch 'thuml_master' into feature/algo-eval

This commit is contained in:
Maximilian Huettenrauch 2024-04-02 11:03:38 +02:00
commit f2e10b04bb
83 changed files with 1488 additions and 886 deletions

View File

@ -1,4 +1,38 @@
# Changelog
## Release 1.1.0
### Api Extensions
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
- `Collector`s can now be closed, and their reset is more granular. #1063
- Trainers can control whether collectors should be reset prior to training. #1063
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
### Internal Improvements
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
- Improved typing for `exploration_noise` and within Collector. #1063
- Better variable names related to model outputs (logits, dist input etc.). #1032
- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc.,
instead of just `nn.Module`. #1032
- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
### Breaking Changes
- Removed `.data` attribute from `Collector` and its child classes. #1063
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
expicitly or pass `reset_before_collect=True` . #1063
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
continuous and discrete cases. #1032
### Tests
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
Started after v1.0.0

View File

@ -147,9 +147,6 @@ Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://
Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders.
中文文档位于 [https://tianshou.readthedocs.io/zh/master/](https://tianshou.readthedocs.io/zh/master/)。
<!-- 这里有一份天授平台简短的中文简介https://www.zhihu.com/question/377263715 -->
## Why Tianshou?

View File

@ -164,7 +164,7 @@
"source": [
"# Let's watch its performance!\n",
"policy.eval()\n",
"eval_result = test_collector.collect(n_episode=1, render=False)\n",
"eval_result = test_collector.collect(n_episode=3, render=False)\n",
"print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")"
]
},

View File

@ -69,7 +69,7 @@
"from tianshou.policy import BasePolicy\n",
"from tianshou.policy.modelfree.pg import (\n",
" PGTrainingStats,\n",
" TDistributionFunction,\n",
" TDistFnDiscrOrCont,\n",
" TPGTrainingStats,\n",
")\n",
"from tianshou.utils import RunningMeanStd\n",
@ -339,7 +339,7 @@
" *,\n",
" actor: torch.nn.Module,\n",
" optim: torch.optim.Optimizer,\n",
" dist_fn: TDistributionFunction,\n",
" dist_fn: TDistFnDiscrOrCont,\n",
" action_space: gym.Space,\n",
" discount_factor: float = 0.99,\n",
" observation_space: gym.Space | None = None,\n",

View File

@ -119,7 +119,7 @@
},
"outputs": [],
"source": [
"collect_result = test_collector.collect(n_episode=9)\n",
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n",
"\n",
"collect_result.pprint_asdict()"
]
@ -146,8 +146,7 @@
"outputs": [],
"source": [
"# Reset the collector\n",
"test_collector.reset()\n",
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n",
"\n",
"collect_result.pprint_asdict()"
]

View File

@ -92,8 +92,6 @@ To compile documentation into webpage, run
The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/).
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
Documentation Generation Test
-----------------------------

View File

@ -57,9 +57,6 @@ Here is Tianshou's other features:
* Support multi-GPU training :ref:`multi_gpu`
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_
Installation
------------

View File

@ -257,3 +257,8 @@ macOS
joblib
master
Panchenko
BA
BH
BO
BD

View File

@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
# expert replay buffer
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))

View File

@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: A2CPolicy = A2CPolicy(
actor=actor,

View File

@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy = NPGPolicy(
actor=actor,

View File

@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy = PPOPolicy(
actor=actor,

View File

@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PGPolicy = PGPolicy(
actor=actor,

View File

@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: TRPOPolicy = TRPOPolicy(
actor=actor,

View File

@ -171,6 +171,7 @@ ignore = [
"RET505",
"D106", # undocumented public nested class
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
"PLW2901", # overwrite vars in loop
]
unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all

View File

@ -9,13 +9,24 @@ import numpy as np
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple
class MyTestEnv(gym.Env):
"""A task for "going right". The task is to go right ``size`` steps."""
class MoveToRightEnv(gym.Env):
"""A task for "going right". The task is to go right ``size`` steps.
The observation is the current index, and the action is to go left or right.
Action 0 is to go left, and action 1 is to go right.
Taking action 0 at index 0 will keep the index at 0.
Arriving at index ``size`` means the task is done.
In the current implementation, stepping after the task is done is possible, which will
lead the index to be larger than ``size``.
Index 0 is the starting point. If reset is called with default options, the index will
be reset to 0.
"""
def __init__(
self,
size: int,
sleep: int = 0,
sleep: float = 0.0,
dict_state: bool = False,
recurse_state: bool = False,
ma_rew: int = 0,
@ -74,8 +85,13 @@ class MyTestEnv(gym.Env):
def reset(
self,
seed: int | None = None,
# TODO: passing a dict here doesn't make any sense
options: dict[str, Any] | None = None,
) -> tuple[dict[str, Any] | np.ndarray, dict]:
""":param seed:
:param options: the start index is provided in options["state"]
:return:
"""
if options is None:
options = {"state": 0}
super().reset(seed=seed)
@ -188,7 +204,7 @@ class NXEnv(gym.Env):
return self._encode_obs(), 1.0, False, False, {}
class MyGoalEnv(MyTestEnv):
class MyGoalEnv(MoveToRightEnv):
def __init__(self, *args: Any, **kwargs: Any) -> None:
assert (
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0

View File

@ -22,13 +22,13 @@ from tianshou.data import (
from tianshou.data.utils.converter import to_hdf5
if __name__ == "__main__":
from env import MyGoalEnv, MyTestEnv
from env import MoveToRightEnv, MyGoalEnv
else: # pytest
from test.base.env import MyGoalEnv, MyTestEnv
from test.base.env import MoveToRightEnv, MyGoalEnv
def test_replaybuffer(size=10, bufsize=20) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize)
buf.update(buf)
assert str(buf) == buf.__class__.__name__ + "()"
@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None:
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = ReplayBuffer(bufsize, stack_num=stack_num)
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
def test_priortized_replaybuffer(size=32, bufsize=15) -> None:
env = MyTestEnv(size)
env = MoveToRightEnv(size)
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs, info = env.reset()
@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None:
bufsize = 9
stack_num = 4
cached_num = 3
env = MyTestEnv(size)
env = MoveToRightEnv(size)
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
buf4 = CachedReplayBuffer(
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),

View File

@ -2,7 +2,6 @@ import gymnasium as gym
import numpy as np
import pytest
import tqdm
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import (
AsyncCollector,
@ -22,12 +21,12 @@ except ImportError:
envpool = None
if __name__ == "__main__":
from env import MyTestEnv, NXEnv
from env import MoveToRightEnv, NXEnv
else: # pytest
from test.base.env import MyTestEnv, NXEnv
from test.base.env import MoveToRightEnv, NXEnv
class MyPolicy(BasePolicy):
class MaxActionPolicy(BasePolicy):
def __init__(
self,
action_space: gym.spaces.Space | None = None,
@ -35,7 +34,9 @@ class MyPolicy(BasePolicy):
need_state=True,
action_shape=None,
) -> None:
"""Mock policy for testing.
"""Mock policy for testing, will always return an array of ones of the shape of the action space.
Note that this doesn't make much sense for discrete action space (the output is then intepreted as
logits, meaning all actions would be equally likely).
:param action_space: the action space of the environment. If None, a dummy Box space will be used.
:param bool dict_state: if the observation of the environment is a dict
@ -63,215 +64,290 @@ class MyPolicy(BasePolicy):
pass
class Logger:
def __init__(self, writer) -> None:
self.cnt = 0
self.writer = writer
def test_collector() -> None:
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
def preprocess_fn(self, **kwargs):
# modify info before adding into the buffer, and recorded into tfb
# if obs && env_id exist -> reset
# if obs_next/rew/done/info/env_id exist -> normal step
if "rew" in kwargs:
info = kwargs["info"]
info.rew = kwargs["rew"]
if "key" in info:
self.writer.add_scalar("key", np.mean(info.key), global_step=self.cnt)
self.cnt += 1
return Batch(info=info)
return Batch()
@staticmethod
def single_preprocess_fn(**kwargs):
# same as above, without tfb
if "rew" in kwargs:
info = kwargs["info"]
info.rew = kwargs["rew"]
return Batch(info=info)
return Batch()
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
def test_collector(gym_reset_kwargs) -> None:
writer = SummaryWriter("log/collector")
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
venv = SubprocVectorEnv(env_fns)
dum = DummyVectorEnv(env_fns)
policy = MyPolicy()
env = env_fns[0]()
c0 = Collector(
subproc_venv_4_envs = SubprocVectorEnv(env_fns)
dummy_venv_4_envs = DummyVectorEnv(env_fns)
policy = MaxActionPolicy()
single_env = env_fns[0]()
c_single_env = Collector(
policy,
env,
single_env,
ReplayBuffer(size=100),
logger.preprocess_fn,
)
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 3
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
c_single_env.reset()
c_single_env.collect(n_step=3)
assert len(c_single_env.buffer) == 3
# TODO: direct attr access is an arcane way of using the buffer, it should be never done
# The placeholders for entries are all zeros, so buffer.obs is an array filled with 3
# observations, and 97 zeros.
# However, buffer[:] will have all attributes with length three... The non-filled entries are removed there
# See above. For the single env, we start with obs=0, obs_next=1.
# We move to obs=1, obs_next=2,
# then the env is reset and we move to obs=0
# Making one more step results in obs_next=1
# The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access
assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0])
assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1])
keys = np.zeros(100)
keys[:3] = 1
assert np.allclose(c0.buffer.info["key"], keys)
for e in c0.buffer.info["env"][:3]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"], 0)
assert np.allclose(c_single_env.buffer.info["key"], keys)
for e in c_single_env.buffer.info["env"][:3]:
assert isinstance(e, MoveToRightEnv)
assert np.allclose(c_single_env.buffer.info["env_id"], 0)
rews = np.zeros(100)
rews[:3] = [0, 1, 0]
assert np.allclose(c0.buffer.info["rew"], rews)
c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
assert len(c0.buffer) == 8
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
assert np.allclose(c0.buffer.info["key"][:8], 1)
for e in c0.buffer.info["env"][:8]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c0.buffer.info["env_id"][:8], 0)
assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)
assert np.allclose(c_single_env.buffer.rew, rews)
# At this point, the buffer contains obs 0 -> 1 -> 0
c1 = Collector(
# At start we have 3 entries in the buffer
# We collect 3 episodes, in addition to the transitions we have collected before
# 0 -> 1 -> 0 -> 0 (reset at collection start) -> 1 -> done (0) -> 1 -> done(0)
# obs_next: 1 -> 2 -> 1 -> 1 (reset at collection start) -> 2 -> 1 -> 2 -> 1 -> 2
# In total, we will have 3 + 6 = 9 entries in the buffer
c_single_env.collect(n_episode=3)
assert len(c_single_env.buffer) == 8
assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
assert np.allclose(c_single_env.buffer.info["key"][:8], 1)
for e in c_single_env.buffer.info["env"][:8]:
assert isinstance(e, MoveToRightEnv)
assert np.allclose(c_single_env.buffer.info["env_id"][:8], 0)
assert np.allclose(c_single_env.buffer.rew[:8], [0, 1, 0, 1, 0, 1, 0, 1])
c_single_env.collect(n_step=3, random=True)
c_subproc_venv_4_envs = Collector(
policy,
venv,
subproc_venv_4_envs,
VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn,
)
c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
c_subproc_venv_4_envs.reset()
# Collect some steps
c_subproc_venv_4_envs.collect(n_step=8)
obs = np.zeros(100)
valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
assert np.allclose(c1.buffer.obs[:, 0], obs)
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs)
assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
keys = np.zeros(100)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys)
for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]:
assert isinstance(e, MoveToRightEnv)
env_ids = np.zeros(100)
env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids)
rews = np.zeros(100)
rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
assert len(c1.buffer) == 16
assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews)
# we previously collected 8 steps, 2 from each env, now we collect 4 episodes
# each env will contribute an episode, which will be of lens 2 (first env was reset), 1, 2, 3
# So we get 8 + 2+1+2+3 = 16 steps
c_subproc_venv_4_envs.collect(n_episode=4)
assert len(c_subproc_venv_4_envs.buffer) == 16
valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c1.buffer.obs[:, 0], obs)
obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4]
assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs)
assert np.allclose(
c1.buffer[:].obs_next[..., 0],
c_subproc_venv_4_envs.buffer[:].obs_next[..., 0],
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5],
)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c1.buffer.info["key"], keys)
for e in c1.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys)
for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]:
assert isinstance(e, MoveToRightEnv)
env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
assert np.allclose(c1.buffer.info["env_id"], env_ids)
assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
assert np.allclose(c1.buffer.info["rew"], rews)
c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews)
c_subproc_venv_4_envs.collect(n_episode=4, random=True)
c2 = Collector(
c_dummy_venv_4_envs = Collector(
policy,
dum,
dummy_venv_4_envs,
VectorReplayBuffer(total_size=100, buffer_num=4),
logger.preprocess_fn,
)
c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
c_dummy_venv_4_envs.reset()
c_dummy_venv_4_envs.collect(n_episode=7)
obs1 = obs.copy()
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
obs2 = obs.copy()
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
c2obs = c2.buffer.obs[:, 0]
c2obs = c_dummy_venv_4_envs.buffer.obs[:, 0]
assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
c2.reset_buffer()
assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs).n_collected_episodes == 8
c_dummy_venv_4_envs.reset_env()
c_dummy_venv_4_envs.reset_buffer()
assert c_dummy_venv_4_envs.collect(n_episode=8).n_collected_episodes == 8
valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
assert np.all(c2.buffer.obs[:, 0] == obs)
assert np.all(c_dummy_venv_4_envs.buffer.obs[:, 0] == obs)
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
assert np.allclose(c2.buffer.info["key"], keys)
for e in c2.buffer.info["env"][valid_indices]:
assert isinstance(e, MyTestEnv)
assert np.allclose(c_dummy_venv_4_envs.buffer.info["key"], keys)
for e in c_dummy_venv_4_envs.buffer.info["env"][valid_indices]:
assert isinstance(e, MoveToRightEnv)
env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
assert np.allclose(c2.buffer.info["env_id"], env_ids)
assert np.allclose(c_dummy_venv_4_envs.buffer.info["env_id"], env_ids)
rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
assert np.allclose(c2.buffer.info["rew"], rews)
c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
assert np.allclose(c_dummy_venv_4_envs.buffer.rew, rews)
c_dummy_venv_4_envs.collect(n_episode=4, random=True)
# test corner case
with pytest.raises(TypeError):
Collector(policy, dum, ReplayBuffer(10))
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
with pytest.raises(TypeError):
Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
with pytest.raises(TypeError):
c2.collect()
c_dummy_venv_4_envs.collect()
# test NXEnv
for obs_type in ["array", "object"]:
envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]])
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
assert c3.buffer.obs.dtype == object
c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c_suproc_new.reset()
c_suproc_new.collect(n_step=6)
assert c_suproc_new.buffer.obs.dtype == object
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
def test_collector_with_async(gym_reset_kwargs) -> None:
@pytest.fixture()
def get_AsyncCollector():
env_lens = [2, 3, 4, 5]
writer = SummaryWriter("log/async_collector")
logger = Logger(writer)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens]
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens]
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
policy = MyPolicy()
policy = MaxActionPolicy()
bufsize = 60
c1 = AsyncCollector(
policy,
venv,
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
logger.preprocess_fn,
)
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
assert result.n_collected_episodes >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
assert result.n_collected_steps >= n_step
for i in range(4):
env_len = i + 2
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id == i)
assert np.all(buf.obs.reshape(-1, env_len) == seq)
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
with pytest.raises(TypeError):
c1.collect()
c1.reset()
return c1, env_lens
class TestAsyncCollector:
def test_collect_without_argument_gives_error(self, get_AsyncCollector):
c1, env_lens = get_AsyncCollector
with pytest.raises(TypeError):
c1.collect()
def test_collect_one_episode_async(self, get_AsyncCollector):
c1, env_lens = get_AsyncCollector
result = c1.collect(n_episode=1)
assert result.n_collected_episodes >= 1
def test_enough_episodes_two_collection_cycles_n_episode_without_reset(
self,
get_AsyncCollector,
):
c1, env_lens = get_AsyncCollector
n_episode = 2
result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False)
assert result_c1.n_collected_episodes >= n_episode
result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False)
assert result_c2.n_collected_episodes >= n_episode
def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector):
c1, env_lens = get_AsyncCollector
n_episode = 2
result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True)
assert result_c1.n_collected_episodes >= n_episode
result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=True)
assert result_c2.n_collected_episodes >= n_episode
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode(
self,
get_AsyncCollector,
):
c1, env_lens = get_AsyncCollector
ptr = [0, 0, 0, 0]
bufsize = 60
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode)
assert result.n_collected_episodes >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step(
self,
get_AsyncCollector,
):
c1, env_lens = get_AsyncCollector
bufsize = 60
ptr = [0, 0, 0, 0]
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step)
assert result.n_collected_steps >= n_step
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step(
self,
get_AsyncCollector,
gym_reset_kwargs,
):
c1, env_lens = get_AsyncCollector
bufsize = 60
ptr = [0, 0, 0, 0]
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
assert result.n_collected_episodes >= n_episode
# check buffer data, obs and obs_next, env_id
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
env_len = i + 2
total = env_len * count
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
ptr[i] = (ptr[i] + total) % bufsize
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id[indices] == i)
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
# test async n_step, for now the buffer should be full of data, thus no bincount stuff as above
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
assert result.n_collected_steps >= n_step
for i in range(4):
env_len = i + 2
seq = np.arange(env_len)
buf = c1.buffer.buffers[i]
assert np.all(buf.info.env_id == i)
assert np.all(buf.obs.reshape(-1, env_len) == seq)
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
def test_collector_with_dict_state() -> None:
env = MyTestEnv(size=5, sleep=0, dict_state=True)
policy = MyPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
env = MoveToRightEnv(size=5, sleep=0, dict_state=True)
policy = MaxActionPolicy(dict_state=True)
c0 = Collector(policy, env, ReplayBuffer(size=100))
c0.reset()
c0.collect(n_step=3)
c0.collect(n_episode=2)
assert len(c0.buffer) == 10
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]]
assert len(c0.buffer) == 10 # 3 + two episodes with 5 steps each
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
envs.seed(666)
obs, info = envs.reset()
@ -280,8 +356,8 @@ def test_collector_with_dict_state() -> None:
policy,
envs,
VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn,
)
c1.reset()
c1.collect(n_step=12)
result = c1.collect(n_episode=8)
assert result.n_collected_episodes == 8
@ -396,41 +472,47 @@ def test_collector_with_dict_state() -> None:
policy,
envs,
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn,
)
c2.reset()
c2.collect(n_episode=10)
batch, _ = c2.buffer.sample(10)
def test_collector_with_ma() -> None:
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
# n_step=3 will collect a full episode
rew = c0.collect(n_step=3).returns
assert len(rew) == 0
rew = c0.collect(n_episode=2).returns
assert rew.shape == (2, 4)
assert np.all(rew == 1)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
def test_collector_with_multi_agent() -> None:
multi_agent_env = MoveToRightEnv(size=5, sleep=0, ma_rew=4)
policy = MaxActionPolicy()
c_single_env = Collector(policy, multi_agent_env, ReplayBuffer(size=100))
c_single_env.reset()
multi_env_returns = c_single_env.collect(n_step=3).returns
# c_single_env has length 3
# We have no full episodes, so no returns yet
assert len(multi_env_returns) == 0
single_env_returns = c_single_env.collect(n_episode=2).returns
# now two episodes. Since we have 4 a agents, the returns have shape (2, 4)
assert single_env_returns.shape == (2, 4)
assert np.all(single_env_returns == 1)
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
c1 = Collector(
c_multi_env_ma = Collector(
policy,
envs,
VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn,
)
rew = c1.collect(n_step=12).returns
assert rew.shape == (2, 4) and np.all(rew == 1), rew
rew = c1.collect(n_episode=8).returns
assert rew.shape == (8, 4)
assert np.all(rew == 1)
batch, _ = c1.buffer.sample(10)
c_multi_env_ma.reset()
multi_env_returns = c_multi_env_ma.collect(n_step=12).returns
# each env makes 3 steps, the first two envs are done and result in two finished episodes
assert multi_env_returns.shape == (2, 4) and np.all(multi_env_returns == 1), multi_env_returns
multi_env_returns = c_multi_env_ma.collect(n_episode=8).returns
assert multi_env_returns.shape == (8, 4)
assert np.all(multi_env_returns == 1)
batch, _ = c_multi_env_ma.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
assert len(c0.buffer) in [42, 43]
if len(c0.buffer) == 42:
rew = [
c_single_env.buffer.update(c_multi_env_ma.buffer)
assert len(c_single_env.buffer) in [42, 43]
if len(c_single_env.buffer) == 42:
multi_env_returns = [
0,
0,
0,
@ -475,7 +557,7 @@ def test_collector_with_ma() -> None:
1,
]
else:
rew = [
multi_env_returns = [
0,
0,
0,
@ -520,17 +602,17 @@ def test_collector_with_ma() -> None:
0,
1,
]
assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew])
assert np.all(c0.buffer[:].done == rew)
assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns])
assert np.all(c_single_env.buffer[:].done == multi_env_returns)
c2 = Collector(
policy,
envs,
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn,
)
rew = c2.collect(n_episode=10).returns
assert rew.shape == (10, 4)
assert np.all(rew == 1)
c2.reset()
multi_env_returns = c2.collect(n_episode=10).returns
assert multi_env_returns.shape == (10, 4)
assert np.all(multi_env_returns == 1)
batch, _ = c2.buffer.sample(10)
@ -543,20 +625,21 @@ def test_collector_with_atari_setting() -> None:
reference_obs[i, 0] = i
# atari single buffer
env = MyTestEnv(size=5, sleep=0, array_state=True)
policy = MyPolicy()
env = MoveToRightEnv(size=5, sleep=0, array_state=True)
policy = MaxActionPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100))
c0.reset()
c0.collect(n_step=6)
c0.collect(n_episode=2)
assert c0.buffer.obs.shape == (100, 4, 84, 84)
assert c0.buffer.obs_next.shape == (100, 4, 84, 84)
assert len(c0.buffer) == 15
assert len(c0.buffer) == 15 # 6 + 2 episodes with 5 steps each
obs = np.zeros_like(c0.buffer.obs)
obs[np.arange(15)] = reference_obs[np.arange(15) % 5]
assert np.all(obs == c0.buffer.obs)
c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True))
c1.collect(n_episode=3)
c1.collect(n_episode=3, reset_before_collect=True)
assert np.allclose(c0.buffer.obs, c1.buffer.obs)
with pytest.raises(AttributeError):
c1.buffer.obs_next # noqa: B018
@ -567,6 +650,7 @@ def test_collector_with_atari_setting() -> None:
env,
ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True),
)
c2.reset()
c2.collect(n_step=8)
assert c2.buffer.obs.shape == (100, 84, 84)
obs = np.zeros_like(c2.buffer.obs)
@ -575,9 +659,10 @@ def test_collector_with_atari_setting() -> None:
assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1])
# atari multi buffer
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]]
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
c3.reset()
c3.collect(n_step=12)
result = c3.collect(n_episode=9)
assert result.n_collected_episodes == 9
@ -606,6 +691,7 @@ def test_collector_with_atari_setting() -> None:
save_only_last_obs=True,
),
)
c4.reset()
c4.collect(n_step=12)
result = c4.collect(n_episode=9)
assert result.n_collected_episodes == 9
@ -672,6 +758,7 @@ def test_collector_with_atari_setting() -> None:
buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True)
c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
c5.reset()
result_ = c5.collect(n_step=12)
assert len(buf) == 5
assert len(c5.buffer) == 12
@ -767,6 +854,7 @@ def test_collector_with_atari_setting() -> None:
# test buffer=None
c6 = Collector(policy, envs)
c6.reset()
result1 = c6.collect(n_step=12)
for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]:
assert np.allclose(getattr(result1, key), getattr(result_, key))
@ -778,7 +866,7 @@ def test_collector_with_atari_setting() -> None:
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_collector_envpool_gym_reset_return_info() -> None:
envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True)
policy = MyPolicy(action_shape=(len(envs), 1))
policy = MaxActionPolicy(action_shape=(len(envs), 1))
c0 = Collector(
policy,
@ -786,18 +874,59 @@ def test_collector_envpool_gym_reset_return_info() -> None:
VectorReplayBuffer(len(envs) * 10, len(envs)),
exploration_noise=True,
)
c0.reset()
c0.collect(n_step=8)
env_ids = np.zeros(len(envs) * 10)
env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3]
assert np.allclose(c0.buffer.info["env_id"], env_ids)
def test_collector_with_vector_env():
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]]
dum = DummyVectorEnv(env_fns)
policy = MaxActionPolicy()
c2 = Collector(
policy,
dum,
VectorReplayBuffer(total_size=100, buffer_num=4),
)
c2.reset()
c1r = c2.collect(n_episode=2)
assert np.array_equal(np.array([1, 8]), c1r.lens)
c2r = c2.collect(n_episode=10)
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 8, 9, 10]), c2r.lens)
c3r = c2.collect(n_step=20)
assert np.array_equal(np.array([1, 1, 1, 1, 1]), c3r.lens)
c4r = c2.collect(n_step=20)
assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens)
def test_async_collector_with_vector_env():
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]]
dum = DummyVectorEnv(env_fns)
policy = MaxActionPolicy()
c1 = AsyncCollector(
policy,
dum,
VectorReplayBuffer(total_size=100, buffer_num=4),
)
c1r = c1.collect(n_episode=10, reset_before_collect=True)
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens)
c2r = c1.collect(n_step=20)
assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens)
if __name__ == "__main__":
test_collector(gym_reset_kwargs=None)
test_collector(gym_reset_kwargs={})
test_collector()
test_collector_with_dict_state()
test_collector_with_ma()
test_collector_with_multi_agent()
test_collector_with_atari_setting()
test_collector_with_async(gym_reset_kwargs=None)
test_collector_with_async(gym_reset_kwargs={"return_info": True})
test_collector_envpool_gym_reset_return_info()
test_collector_with_vector_env()
test_async_collector_with_vector_env()

View File

@ -20,9 +20,9 @@ from tianshou.env.gym_wrappers import TruncatedAsTerminated
from tianshou.utils import RunningMeanStd
if __name__ == "__main__":
from env import MyTestEnv, NXEnv
from env import MoveToRightEnv, NXEnv
else: # pytest
from test.base.env import MyTestEnv, NXEnv
from test.base.env import MoveToRightEnv, NXEnv
try:
import envpool
@ -56,7 +56,7 @@ def recurse_comp(a, b):
def test_async_env(size=10000, num=8, sleep=0.1) -> None:
# simplify the test case, just keep stepping
env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
lambda i=i: MoveToRightEnv(size=i, sleep=sleep, random_sleep=True)
for i in range(size, size + num)
]
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
@ -108,10 +108,10 @@ def test_async_env(size=10000, num=8, sleep=0.1) -> None:
def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None:
env_fns = [
lambda: MyTestEnv(size=size, sleep=sleep * 2),
lambda: MyTestEnv(size=size, sleep=sleep * 3),
lambda: MyTestEnv(size=size, sleep=sleep * 5),
lambda: MyTestEnv(size=size, sleep=sleep * 7),
lambda: MoveToRightEnv(size=size, sleep=sleep * 2),
lambda: MoveToRightEnv(size=size, sleep=sleep * 3),
lambda: MoveToRightEnv(size=size, sleep=sleep * 5),
lambda: MoveToRightEnv(size=size, sleep=sleep * 7),
]
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
@ -156,7 +156,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None:
def test_vecenv(size=10, num=8, sleep=0.001) -> None:
env_fns = [
lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
lambda i=i: MoveToRightEnv(size=i, sleep=sleep, recurse_state=True)
for i in range(size, size + num)
]
venv = [
@ -237,7 +237,7 @@ def test_env_obs_dtype() -> None:
def test_env_reset_optional_kwargs(size=10000, num=8) -> None:
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)]
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
if has_ray():
test_cls += [RayVectorEnv]
@ -257,7 +257,7 @@ def test_venv_wrapper_gym(num_envs: int = 4) -> None:
except ValueError:
obs, info = envs.reset(return_info=True)
assert isinstance(obs, np.ndarray)
assert isinstance(info, list)
assert isinstance(info, np.ndarray)
assert isinstance(info[0], dict)
assert obs.shape[0] == len(info) == num_envs
@ -334,7 +334,7 @@ def test_venv_norm_obs() -> None:
action = np.array([1, 1, 1, 1])
total_step = 30
action_list = [action] * total_step
env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes]
env_fns = [lambda i=x: MoveToRightEnv(size=i, array_state=True) for x in sizes]
raw = DummyVectorEnv(env_fns)
train_env = VectorEnvNormObs(DummyVectorEnv(env_fns))
print(train_env.observation_space)

View File

@ -90,20 +90,20 @@ class FiniteVectorEnv(BaseVectorEnv):
# END
def reset(self, id=None):
id = self._wrap_id(id)
def reset(self, env_id=None):
env_id = self._wrap_id(env_id)
self._reset_alive_envs()
# ask super to reset alive envs and remap to current index
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
obs = [None] * len(id)
infos = [None] * len(id)
id2idx = {i: k for k, i in enumerate(id)}
request_id = list(filter(lambda i: i in self._alive_env_ids, env_id))
obs = [None] * len(env_id)
infos = [None] * len(env_id)
id2idx = {i: k for k, i in enumerate(env_id)}
if request_id:
for k, o, info in zip(request_id, *super().reset(request_id), strict=True):
obs[id2idx[k]] = o
infos[id2idx[k]] = info
for i, o in zip(id, obs, strict=True):
for i, o in zip(env_id, obs, strict=True):
if o is None and i in self._alive_env_ids:
self._alive_env_ids.remove(i)
@ -121,7 +121,7 @@ class FiniteVectorEnv(BaseVectorEnv):
self.reset()
raise StopIteration
return np.stack(obs), infos
return np.stack(obs), np.array(infos)
def step(self, action, id=None):
id = self._wrap_id(id)
@ -204,10 +204,12 @@ def test_finite_dummy_vector_env() -> None:
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
test_collector.reset()
for _ in range(3):
envs.tracker = MetricTracker()
try:
# TODO: why on earth 10**18?
test_collector.collect(n_step=10**18)
except StopIteration:
envs.tracker.validate()
@ -218,6 +220,7 @@ def test_finite_subproc_vector_env() -> None:
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
policy = AnyPolicy()
test_collector = Collector(policy, envs, exploration_noise=True)
test_collector.reset()
for _ in range(3):
envs.tracker = MetricTracker()

View File

@ -2,7 +2,7 @@ import gymnasium as gym
import numpy as np
import pytest
import torch
from torch.distributions import Categorical, Independent, Normal
from torch.distributions import Categorical, Distribution, Independent, Normal
from tianshou.policy import PPOPolicy
from tianshou.utils.net.common import ActorCritic, Net
@ -25,7 +25,11 @@ def policy(request):
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
action_shape=action_space.shape,
)
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
elif action_type == "discrete":
action_space = gym.spaces.Discrete(3)
actor = Actor(

View File

@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: NPGPolicy[NPGTrainingStats] = NPGPolicy(
actor=actor,

View File

@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: PPOPolicy[PPOTrainingStats] = PPOPolicy(
actor=actor,

View File

@ -136,6 +136,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
exploration_noise=True,
)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
log_path = os.path.join(args.logdir, args.task, "redq")

View File

@ -162,6 +162,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
env = gym.make(args.task)
policy.eval()
collector = Collector(policy, env)
collector.reset()
collector_stats = collector.collect(n_episode=1, render=args.render)
print(collector_stats)

View File

@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = TRPOPolicy(
actor=actor,

View File

@ -109,7 +109,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
)
train_collector.reset()
test_collector = Collector(policy, test_envs)
test_collector.reset()
# log
log_path = os.path.join(args.logdir, args.task, "a2c")
writer = SummaryWriter(log_path)

View File

@ -108,7 +108,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
def train_fn(epoch: int, env_step: int) -> None: # exp decay
eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)

View File

@ -120,7 +120,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "c51")
writer = SummaryWriter(log_path)

View File

@ -111,7 +111,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "dqn")
writer = SummaryWriter(log_path)

View File

@ -95,7 +95,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
# the stack_num is for RNN training: sample framestack obs
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "drqn")
writer = SummaryWriter(log_path)

View File

@ -128,7 +128,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "fqf")
writer = SummaryWriter(log_path)

View File

@ -124,7 +124,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "iqn")
writer = SummaryWriter(log_path)

View File

@ -113,7 +113,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path)

View File

@ -128,7 +128,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "rainbow")
writer = SummaryWriter(log_path)

View File

@ -154,7 +154,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "dqn_icm")
writer = SummaryWriter(log_path)

View File

@ -81,7 +81,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True,
)
train_collector.reset()
test_collector = Collector(policy, test_envs)
test_collector.reset()
# Logger
log_path = os.path.join(args.logdir, args.task, "psrl")
writer = SummaryWriter(log_path)
@ -120,7 +122,6 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f"Final reward: {result.rew_mean}, length: {result.len_mean}")
elif env.spec.reward_threshold:

View File

@ -115,9 +115,11 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
# collector
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
train_collector.reset()
test_collector = Collector(policy, test_envs, exploration_noise=True)
test_collector.reset()
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, args.task, "qrdqn")
writer = SummaryWriter(log_path)
@ -165,6 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
policy.set_eps(0.2)
collector = Collector(policy, test_envs, buf, exploration_noise=True)
collector.reset()
collector_stats = collector.collect(n_step=args.buffer_size)
if args.save_buffer_name.endswith(".hdf5"):
buf.save_hdf5(args.save_buffer_name)

View File

@ -178,6 +178,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None:
test_discrete_bcq()
args.resume = True
test_discrete_bcq(args)

View File

@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
# replace DiagGuassian with Independent(Normal) which is equivalent
# pass *logits to be consistent with policy.forward
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
policy: BasePolicy = GAILPolicy(
actor=actor,

View File

@ -83,8 +83,8 @@ def get_agents(
if isinstance(env.observation_space, gym.spaces.Dict)
else env.observation_space
)
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.state_shape = observation_space.shape or int(observation_space.n)
args.action_shape = env.action_space.shape or int(env.action_space.n)
if agents is None:
agents = []
optims = []
@ -135,7 +135,7 @@ def train_agent(
exploration_noise=True,
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, "pistonball", "dqn")
writer = SummaryWriter(log_path)

View File

@ -181,8 +181,9 @@ def get_agents(
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
loc, scale = loc_scale
return Independent(Normal(loc, scale), 1)
agent: PPOPolicy = PPOPolicy(
actor,
@ -234,7 +235,7 @@ def train_agent(
exploration_noise=False, # True
)
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.batch_size * args.training_num)
# train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, "pistonball", "dqn")
writer = SummaryWriter(log_path)

View File

@ -102,8 +102,8 @@ def get_agents(
if isinstance(env.observation_space, gymnasium.spaces.Dict)
else env.observation_space
)
args.state_shape = observation_space.shape or observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.state_shape = observation_space.shape or int(observation_space.n)
args.action_shape = env.action_space.shape or int(env.action_space.n)
if agent_learn is None:
# model
net = Net(
@ -170,7 +170,7 @@ def train_agent(
)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
# log
log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn")
writer = SummaryWriter(log_path)

View File

@ -263,6 +263,9 @@ class BatchProtocol(Protocol):
def __repr__(self) -> str:
...
def __iter__(self) -> Iterator[Self]:
...
def to_numpy(self) -> None:
"""Change all torch.Tensor to numpy.ndarray in-place."""
...
@ -391,6 +394,12 @@ class BatchProtocol(Protocol):
"""
...
def to_dict(self) -> dict[str, Any]:
...
def to_list_of_dicts(self) -> list[dict[str, Any]]:
...
class Batch(BatchProtocol):
"""See :class:`~tianshou.data.batch.BatchProtocol`."""
@ -422,6 +431,17 @@ class Batch(BatchProtocol):
# Feels like kwargs could be just merged into batch_dict in the beginning
self.__init__(kwargs, copy=copy) # type: ignore
def to_dict(self) -> dict[str, Any]:
result = {}
for k, v in self.__dict__.items():
if isinstance(v, Batch):
v = v.to_dict()
result[k] = v
return result
def to_list_of_dicts(self) -> list[dict[str, Any]]:
return [entry.to_dict() for entry in self]
def __setattr__(self, key: str, value: Any) -> None:
"""Set self.key = value."""
self.__dict__[key] = _parse_value(value)
@ -478,6 +498,14 @@ class Batch(BatchProtocol):
return new_batch
raise IndexError("Cannot access item from empty Batch object.")
def __iter__(self) -> Iterator[Self]:
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
if len(self.__dict__) == 0:
yield from []
else:
for i in range(len(self)):
yield self[i]
def __setitem__(self, index: str | IndexType, value: Any) -> None:
"""Assign value to self[index]."""
value = _parse_value(value)
@ -601,10 +629,10 @@ class Batch(BatchProtocol):
else:
# ndarray or scalar
if not isinstance(obj, np.ndarray):
obj = np.asanyarray(obj) # noqa: PLW2901
obj = torch.from_numpy(obj).to(device) # noqa: PLW2901
obj = np.asanyarray(obj)
obj = torch.from_numpy(obj).to(device)
if dtype is not None:
obj = obj.type(dtype) # noqa: PLW2901
obj = obj.type(dtype)
self.__dict__[batch_key] = obj
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:

View File

@ -200,7 +200,7 @@ class ReplayBufferManager(ReplayBuffer):
return np.concatenate(
[
buf.sample_indices(bsz) + offset
buf.sample_indices(int(bsz)) + offset
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True)
],
)

File diff suppressed because it is too large Load Diff

View File

@ -12,6 +12,7 @@ from tianshou.data.batch import Batch, _parse_value
# TODO: confusing name, could actually return a batch...
# Overrides and generic types should be added
# todo check for ActBatchProtocol
@no_type_check
def to_numpy(x: Any) -> Batch | np.ndarray:
"""Return an object without torch.Tensor."""

View File

@ -44,14 +44,14 @@ class VectorEnvWrapper(BaseVectorEnv):
def reset(
self,
id: int | list[int] | np.ndarray | None = None,
env_id: int | list[int] | np.ndarray | None = None,
**kwargs: Any,
) -> tuple[np.ndarray, dict | list[dict]]:
return self.venv.reset(id, **kwargs)
) -> tuple[np.ndarray, np.ndarray]:
return self.venv.reset(env_id, **kwargs)
def step(
self,
action: np.ndarray | torch.Tensor,
action: np.ndarray | torch.Tensor | None,
id: int | list[int] | np.ndarray | None = None,
) -> gym_new_venv_step_type:
return self.venv.step(action, id)
@ -80,10 +80,10 @@ class VectorEnvNormObs(VectorEnvWrapper):
def reset(
self,
id: int | list[int] | np.ndarray | None = None,
env_id: int | list[int] | np.ndarray | None = None,
**kwargs: Any,
) -> tuple[np.ndarray, dict | list[dict]]:
obs, info = self.venv.reset(id, **kwargs)
) -> tuple[np.ndarray, np.ndarray]:
obs, info = self.venv.reset(env_id, **kwargs)
if isinstance(obs, tuple): # type: ignore
raise TypeError(
@ -98,7 +98,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
def step(
self,
action: np.ndarray | torch.Tensor,
action: np.ndarray | torch.Tensor | None,
id: int | list[int] | np.ndarray | None = None,
) -> gym_new_venv_step_type:
step_results = self.venv.step(action, id)

26
tianshou/env/venvs.py vendored
View File

@ -190,11 +190,13 @@ class BaseVectorEnv:
), f"Cannot interact with environment {i} which is stepping now."
assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}."
# TODO: for now, has to be kept in sync with reset in EnvPoolMixin
# In particular, can't rename env_id to env_ids
def reset(
self,
id: int | list[int] | np.ndarray | None = None,
env_id: int | list[int] | np.ndarray | None = None,
**kwargs: Any,
) -> tuple[np.ndarray, dict | list[dict]]:
) -> tuple[np.ndarray, np.ndarray]:
"""Reset the state of some envs and return initial observations.
If id is None, reset the state of all the environments and return
@ -202,14 +204,14 @@ class BaseVectorEnv:
the given id, either an int or a list.
"""
self._assert_is_not_closed()
id = self._wrap_id(id)
env_id = self._wrap_id(env_id)
if self.is_async:
self._assert_id(id)
self._assert_id(env_id)
# send(None) == reset() in worker
for i in id:
self.workers[i].send(None, **kwargs)
ret_list = [self.workers[i].recv() for i in id]
for id in env_id:
self.workers[id].send(None, **kwargs)
ret_list = [self.workers[id].recv() for id in env_id]
assert (
isinstance(ret_list[0], tuple | list)
@ -229,12 +231,12 @@ class BaseVectorEnv:
except ValueError: # different len(obs)
obs = np.array(obs_list, dtype=object)
infos = [r[1] for r in ret_list]
return obs, infos # type: ignore
infos = np.array([r[1] for r in ret_list])
return obs, infos
def step(
self,
action: np.ndarray | torch.Tensor,
action: np.ndarray | torch.Tensor | None,
id: int | list[int] | np.ndarray | None = None,
) -> gym_new_venv_step_type:
"""Run one timestep of some environments' dynamics.
@ -248,6 +250,8 @@ class BaseVectorEnv:
batch_done, batch_info) in numpy format.
:param numpy.ndarray action: a batch of action provided by the agent.
If the venv is async, the action can be None, which will result
in all arrays in the returned tuple being empty.
:return: A tuple consisting of either:
@ -271,6 +275,8 @@ class BaseVectorEnv:
self._assert_is_not_closed()
id = self._wrap_id(id)
if not self.is_async:
if action is None:
raise ValueError("action must be not-None for non-async")
assert len(action) == len(id)
for i, j in enumerate(id):
self.workers[j].send(action[i])

View File

@ -93,7 +93,14 @@ class AgentFactory(ABC, ToStringMixin):
self,
policy: BasePolicy,
envs: Environments,
reset_collectors: bool = True,
) -> tuple[Collector, Collector]:
""":param policy:
:param envs:
:param reset_collectors: Whether to reset the collectors before returning them.
Setting to True means that the envs will be reset as well.
:return:
"""
buffer_size = self.sampling_config.buffer_size
train_envs = envs.train_envs
buffer: ReplayBuffer
@ -114,6 +121,10 @@ class AgentFactory(ABC, ToStringMixin):
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, envs.test_envs)
if reset_collectors:
train_collector.reset()
test_collector.reset()
if self.sampling_config.start_timesteps > 0:
log.info(
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",

View File

@ -312,7 +312,7 @@ class Experiment(ToStringMixin):
) -> None:
policy.eval()
collector = Collector(policy, env)
result = collector.collect(n_episode=num_episodes, render=render)
result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True)
assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy
log.info(

View File

@ -1,40 +1,47 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any
import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
from tianshou.utils.string import ToStringMixin
class DistributionFunctionFactory(ToStringMixin, ABC):
# True return type defined in subclasses
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(
self,
envs: Environments,
) -> Callable[[Any], torch.distributions.Distribution]:
pass
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
envs.get_type().assert_discrete(self)
return self._dist_fn
@staticmethod
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
return torch.distributions.Categorical(logits=p)
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
envs.get_type().assert_continuous(self)
return self._dist_fn
@staticmethod
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
loc, scale = loc_scale
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
match envs.get_type():
case EnvType.DISCRETE:
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)

View File

@ -19,7 +19,7 @@ from tianshou.highlevel.params.dist_fn import (
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin
@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
Does not affect training.
"""
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
"""
This can either be a function which maps the model output to a torch distribution or a
factory for the creation of such a function.

View File

@ -18,6 +18,7 @@ from tianshou.data.batch import Batch, BatchProtocol, arr_type
from tianshou.data.buffer.base import TBuffer
from tianshou.data.types import (
ActBatchProtocol,
ActStateBatchProtocol,
BatchWithReturnsProtocol,
ObsBatchProtocol,
RolloutBatchProtocol,
@ -212,10 +213,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
super().__init__()
self.observation_space = observation_space
self.action_space = action_space
self._action_type: Literal["discrete", "continuous"]
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
self.action_type = "discrete"
self._action_type = "discrete"
elif isinstance(action_space, Box):
self.action_type = "continuous"
self._action_type = "continuous"
else:
raise ValueError(f"Unsupported action space: {action_space}.")
self.agent_id = 0
@ -225,6 +227,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
self.lr_scheduler = lr_scheduler
self._compile()
@property
def action_type(self) -> Literal["discrete", "continuous"]:
return self._action_type
def set_agent_id(self, agent_id: int) -> None:
"""Set self.agent_id = agent_id, for MARL."""
self.agent_id = agent_id
@ -233,11 +239,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
# have a method to add noise to action.
# So we add the default behavior here. It's a little messy, maybe one can
# find a better way to do this.
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
"""Modify the action from policy.forward with exploration noise.
NOTE: currently does not add any noise! Needs to be overridden by subclasses
@ -287,7 +296,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
batch: ObsBatchProtocol,
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ActBatchProtocol:
) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing
"""Compute action over the given batch data.
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:

View File

@ -16,6 +16,12 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
@dataclass(kw_only=True)
class ImitationTrainingStats(TrainingStats):
@ -72,9 +78,20 @@ class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTra
state: dict | BatchProtocol | np.ndarray | None = None,
**kwargs: Any,
) -> ModelOutputBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits
result = Batch(logits=logits, act=act, state=hidden)
# TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced
if self.action_type == "discrete":
# If it's discrete, the "actor" is usually a critic that maps obs to action_values
# which then could be turned into logits or a Categorigal
action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
act_B = action_values_BA.argmax(dim=1)
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
elif self.action_type == "continuous":
# If it's continuous, the actor would usually deliver something like loc, scale determining a
# Gaussian dist
dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH)
else:
raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!")
return cast(ModelOutputBatchProtocol, result)
def learn(

View File

@ -34,8 +34,7 @@ TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteB
class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]):
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
:param model: a model following the rules (s_B -> action_values_BA)
:param imitator: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
:param optim: a torch.optim for optimizing the model.

View File

@ -25,8 +25,7 @@ TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteC
class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]):
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space.
:param min_q_weight: the weight for the cql loss.

View File

@ -11,6 +11,7 @@ from tianshou.data import to_torch, to_torch_as
from tianshou.data.types import RolloutBatchProtocol
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass
@ -26,8 +27,9 @@ TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteC
class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the action-value critic (i.e., Q function)
network. (s -> Q(s, \*))
:param optim: a torch.optim for optimizing the model.
@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | Actor,
critic: torch.nn.Module | Critic,
optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete,
discount_factor: float = 0.99,

View File

@ -15,8 +15,11 @@ from tianshou.data import (
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import PPOPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.policy.modelfree.ppo import PPOTrainingStats
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -32,7 +35,9 @@ TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats)
class GAILPolicy(PPOPolicy[TGailTrainingStats]):
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
expert_buffer: ReplayBuffer,
disc_net: torch.nn.Module,

View File

@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]):
"""Implementation of TD3+BC. arXiv:2106.06860.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Self
from typing import Any, Literal, Self, TypeVar
import gymnasium as gym
import numpy as np
@ -105,11 +105,13 @@ class ICMPolicy(BasePolicy[ICMTrainingStats]):
"""
return self.policy.forward(batch, state, **kwargs)
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
return self.policy.exploration_noise(act, batch)
def set_eps(self, eps: float) -> None:

View File

@ -11,8 +11,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import PGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -30,7 +33,9 @@ TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats)
class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var]
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
vf_coef: float = 0.5,
ent_coef: float = 0.01,

View File

@ -8,6 +8,7 @@ import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
ActBatchProtocol,
BatchWithReturnsProtocol,
ModelOutputBatchProtocol,
ObsBatchProtocol,
@ -30,7 +31,7 @@ TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats)
class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
"""Implementation of the Branching dual Q network arXiv:1711.08946.
:param model: BranchingNet mapping (obs, state, info) -> logits.
:param model: BranchingNet mapping (obs, state, info) -> action_values_BA.
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead.
@ -155,10 +156,10 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
model = getattr(self, model)
obs = batch.obs
# TODO: this is very contrived, see also iqn.py
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
act = to_numpy(logits.max(dim=-1)[1])
result = Batch(logits=logits, act=act, state=hidden)
obs_next_BO = obs.obs if hasattr(obs, "obs") else obs
action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info)
act_B = to_numpy(action_values_BA.argmax(dim=-1))
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
@ -182,11 +183,13 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value]
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps

View File

@ -23,8 +23,7 @@ TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats)
class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param num_atoms: the number of atoms in the support set of the

View File

@ -10,6 +10,7 @@ import torch
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
ActBatchProtocol,
ActStateBatchProtocol,
BatchWithReturnsProtocol,
ObsBatchProtocol,
@ -18,6 +19,7 @@ from tianshou.data.types import (
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.continuous import Actor, Critic
@dataclass(kw_only=True)
@ -32,8 +34,7 @@ TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats)
class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
:param actor: The actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> model_output)
:param actor: The actor network following the rules (s -> actions)
:param actor_optim: The optimizer for actor network.
:param critic: The critic network. (s, a -> Q(s, a))
:param critic_optim: The optimizer for critic network.
@ -59,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer,
action_space: gym.Space,
tau: float = 0.005,
@ -208,11 +209,13 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value]
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
if self._exploration_noise is None:
return act
if isinstance(act, np.ndarray):

View File

@ -8,11 +8,11 @@ from overrides import override
from torch.distributions import Categorical
from tianshou.data import Batch, ReplayBuffer, to_torch
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import SACPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.sac import SACTrainingStats
from tianshou.utils.net.discrete import Actor, Critic
@dataclass
@ -26,8 +26,7 @@ TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteS
class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor: the actor network following the rules (s_B -> dist_input_BD)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.
@ -55,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | Actor,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic: torch.nn.Module | Critic,
critic_optim: torch.optim.Optimizer,
action_space: gym.spaces.Discrete,
critic2: torch.nn.Module | None = None,
critic2: torch.nn.Module | Critic | None = None,
critic2_optim: torch.optim.Optimizer | None = None,
tau: float = 0.005,
gamma: float = 0.99,
@ -106,13 +105,13 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits)
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Categorical(logits=logits_BA)
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.sample()
return Batch(logits=logits, act=act, state=hidden, dist=dist)
act_B = dist.sample()
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch(
@ -184,9 +183,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(),
)
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
return act

View File

@ -9,6 +9,7 @@ import torch
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
from tianshou.data.batch import BatchProtocol
from tianshou.data.types import (
ActBatchProtocol,
BatchWithReturnsProtocol,
ModelOutputBatchProtocol,
ObsBatchProtocol,
@ -16,6 +17,7 @@ from tianshou.data.types import (
)
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.net.common import Net
@dataclass(kw_only=True)
@ -34,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
implemented in the network side, not here).
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param estimation_step: the number of steps to look ahead.
@ -59,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
def __init__(
self,
*,
model: torch.nn.Module,
model: torch.nn.Module | Net,
optim: torch.optim.Optimizer,
# TODO: type violates Liskov substitution principle
action_space: gym.spaces.Discrete,
@ -200,12 +201,12 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
obs = batch.obs
# TODO: this is convoluted! See also other places where this is done.
obs_next = obs.obs if hasattr(obs, "obs") else obs
logits, hidden = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(logits, getattr(obs, "mask", None))
action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info)
q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None))
if self.max_action_num is None:
self.max_action_num = q.shape[1]
act = to_numpy(q.max(dim=1)[1])
result = Batch(logits=logits, act=act, state=hidden)
act_B = to_numpy(q.argmax(dim=1))
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
return cast(ModelOutputBatchProtocol, result)
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
@ -232,11 +233,13 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value]
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
bsz = len(act)
rand_mask = np.random.rand(bsz) < self.eps

View File

@ -27,8 +27,7 @@ TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats)
class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param fraction_model: a FractionProposalNetwork for
proposing fractions/quantiles given state.

View File

@ -29,8 +29,7 @@ TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats)
class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
"""Implementation of Implicit Quantile Network. arXiv:1806.06923.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s_B -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param discount_factor: in [0, 1].
:param sample_size: the number of samples for policy evaluation.

View File

@ -12,7 +12,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
optim_critic_iters: int = 5,
actor_step_size: float = 0.5,

View File

@ -1,7 +1,7 @@
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
from typing import Any, Generic, Literal, TypeVar, cast
import gymnasium as gym
import numpy as np
@ -24,9 +24,22 @@ from tianshou.data.types import (
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils import RunningMeanStd
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.net.discrete import Actor
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
# Dimension Naming Convention
# B - Batch Size
# A - Action
# D - Dist input (usually 2, loc and scale)
# H - Dimension of hidden, can be None
TDistFnContinuous = Callable[
[tuple[torch.Tensor, torch.Tensor]],
torch.distributions.Distribution,
]
TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical]
TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete
@dataclass(kw_only=True)
@ -40,8 +53,9 @@ TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats)
class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
"""Implementation of REINFORCE algorithm.
:param actor: mapping (s->model_output), should follow the rules in
:class:`~tianshou.policy.BasePolicy`.
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param optim: optimizer for actor network.
:param dist_fn: distribution class for computing the action.
Maps model_output -> distribution. Typically a Gaussian distribution
@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | ActorProb | Actor,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
discount_factor: float = 0.99,
# TODO: rename to return_normalization?
@ -175,20 +189,20 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
more detailed explanation.
"""
# TODO: rename? It's not really logits and there are particular
# assumptions about the order of the output and on distribution type
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
if isinstance(logits, tuple):
dist = self.dist_fn(*logits)
else:
dist = self.dist_fn(logits)
# TODO - ALGO: marked for algorithm refactoring
action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
# in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
# therefore action_dist_input_BD is equivalent to logits_BA
# If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian)
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
dist = self.dist_fn(action_dist_input_BD)
# in this case, the dist is unused!
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.sample()
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
act_B = dist.sample()
# act is of dimension BA in continuous case and of dimension B in discrete
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
return cast(DistBatchProtocol, result)
# TODO: why does mypy complain?

View File

@ -10,8 +10,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
from tianshou.policy import A2CPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -29,7 +32,9 @@ TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats)
class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var]
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
eps_clip: float = 0.2,
dual_clip: float | None = None,

View File

@ -25,8 +25,7 @@ TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats)
class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
:param model: a model following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param model: a model following the rules (s -> action_values_BA)
:param optim: a torch.optim for optimizing the model.
:param action_space: Env's action space.
:param discount_factor: in [0, 1].

View File

@ -12,6 +12,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.ddpg import DDPGTrainingStats
from tianshou.utils.net.continuous import ActorProb
@dataclass
@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
@ -150,23 +151,28 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> Batch:
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
loc, scale = loc_scale
dist = Independent(Normal(loc, scale), 1)
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc_B, scale_B), 1)
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
act_B = dist.rsample()
log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
return Batch(
logits=(loc_B, scale_B),
act=squashed_action,
state=h_BH,
dist=dist,
log_prob=log_prob,
)
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
obs_next_batch = Batch(

View File

@ -17,6 +17,7 @@ from tianshou.exploration import BaseNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
from tianshou.utils.conversion import to_optional_float
from tianshou.utils.net.continuous import ActorProb
from tianshou.utils.optim import clone_optimizer
@ -36,8 +37,7 @@ TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats)
class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var]
"""Implementation of Soft Actor-Critic. arXiv:1812.05905.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:param actor: the actor network following the rules (s -> dist_input_BD)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.
@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
def __init__(
self,
*,
actor: torch.nn.Module,
actor: torch.nn.Module | ActorProb,
actor_optim: torch.optim.Optimizer,
critic: torch.nn.Module,
critic_optim: torch.optim.Optimizer,
@ -173,26 +173,25 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
state: dict | Batch | np.ndarray | None = None,
**kwargs: Any,
) -> DistLogProbBatchProtocol:
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
assert isinstance(logits, tuple)
dist = Independent(Normal(*logits), 1)
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
if self.deterministic_eval and not self.training:
act = dist.mode
act_B = dist.mode
else:
act = dist.rsample()
log_prob = dist.log_prob(act).unsqueeze(-1)
act_B = dist.rsample()
log_prob = dist.log_prob(act_B).unsqueeze(-1)
# apply correction for Tanh squashing when computing logprob from Gaussian
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
# in appendix C to get some understanding of this equation.
squashed_action = torch.tanh(act)
squashed_action = torch.tanh(act_B)
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
-1,
keepdim=True,
)
result = Batch(
logits=logits,
logits=(loc_B, scale_B),
act=squashed_action,
state=hidden,
state=hidden_BH,
dist=dist,
log_prob=log_prob,
)

View File

@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t
"""Implementation of TD3, arXiv:1802.09477.
:param actor: the actor network following the rules in
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
:param actor_optim: the optimizer for actor network.
:param critic: the first critic network. (s, a -> Q(s, a))
:param critic_optim: the optimizer for the first critic network.

View File

@ -11,7 +11,10 @@ from tianshou.data import Batch, SequenceSummaryStats
from tianshou.policy import NPGPolicy
from tianshou.policy.base import TLearningRateScheduler
from tianshou.policy.modelfree.npg import NPGTrainingStats
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.net.discrete import Actor as DiscreteActor
from tianshou.utils.net.discrete import Critic as DiscreteCritic
@dataclass(kw_only=True)
@ -25,7 +28,9 @@ TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats)
class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param actor: the actor network following the rules:
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
actor: torch.nn.Module | ActorProb | DiscreteActor,
critic: torch.nn.Module | Critic | DiscreteCritic,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
dist_fn: TDistFnDiscrOrCont,
action_space: gym.Space,
max_kl: float = 0.01,
backtrack_coeff: float = 0.8,

View File

@ -1,11 +1,11 @@
from typing import Any, Literal, Protocol, Self, cast, overload
from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload
import numpy as np
from overrides import override
from tianshou.data import Batch, ReplayBuffer
from tianshou.data.batch import BatchProtocol, IndexType
from tianshou.data.types import RolloutBatchProtocol
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
from tianshou.policy import BasePolicy
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
@ -160,16 +160,18 @@ class MultiAgentPolicyManager(BasePolicy):
buffer._meta.rew = save_rew
return Batch(results)
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
def exploration_noise(
self,
act: np.ndarray | BatchProtocol,
batch: RolloutBatchProtocol,
) -> np.ndarray | BatchProtocol:
act: _TArrOrActBatch,
batch: ObsBatchProtocol,
) -> _TArrOrActBatch:
"""Add exploration noise from sub-policy onto act."""
assert isinstance(
batch.obs,
BatchProtocol,
), f"here only observations of type Batch are permitted, but got {type(batch.obs)}"
if not isinstance(batch.obs, Batch):
raise TypeError(
f"here only observations of type Batch are permitted, but got {type(batch.obs)}",
)
for agent_id, policy in self.policies.items():
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
if len(agent_index) == 0:
@ -223,7 +225,7 @@ class MultiAgentPolicyManager(BasePolicy):
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
continue
tmp_batch = batch[agent_index]
if isinstance(tmp_batch.rew, np.ndarray):
if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray):
# reward can be empty Batch (after initial reset) or nparray.
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
if not hasattr(tmp_batch.obs, "mask"):

View File

@ -237,7 +237,13 @@ class BaseTrainer(ABC):
self.stop_fn_flag = False
self.iter_num = 0
def reset(self) -> None:
def _reset_collectors(self, reset_buffer: bool = False) -> None:
if self.train_collector is not None:
self.train_collector.reset(reset_buffer=reset_buffer)
if self.test_collector is not None:
self.test_collector.reset(reset_buffer=reset_buffer)
def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None:
"""Initialize or reset the instance to yield a new iterator from zero."""
self.is_run = False
self.env_step = 0
@ -250,16 +256,18 @@ class BaseTrainer(ABC):
self.last_rew, self.last_len = 0.0, 0.0
self.start_time = time.time()
if self.train_collector is not None:
self.train_collector.reset_stat()
if self.train_collector.policy != self.policy or self.test_collector is None:
self.test_in_train = False
if reset_collectors:
self._reset_collectors(reset_buffer=reset_buffer)
if self.train_collector is not None and (
self.train_collector.policy != self.policy or self.test_collector is None
):
self.test_in_train = False
if self.test_collector is not None:
assert self.episode_per_test is not None
assert not isinstance(self.test_collector, AsyncCollector) # Issue 700
self.test_collector.reset_stat()
test_result = test_episode(
self.policy,
self.test_collector,
@ -284,7 +292,7 @@ class BaseTrainer(ABC):
self.iter_num = 0
def __iter__(self): # type: ignore
self.reset()
self.reset(reset_collectors=True, reset_buffer=False)
return self
def __next__(self) -> EpochStats:
@ -308,8 +316,8 @@ class BaseTrainer(ABC):
# perform n step_per_epoch
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
train_stat: CollectStatsBase
while t.n < t.total and not self.stop_fn_flag:
train_stat: CollectStatsBase
if self.train_collector is not None:
train_stat, self.stop_fn_flag = self.train_step()
pbar_data_dict = {
@ -515,12 +523,14 @@ class BaseTrainer(ABC):
stats of the whole dataset
"""
def run(self) -> InfoStats:
def run(self, reset_prior_to_run: bool = True) -> InfoStats:
"""Consume iterator.
See itertools - recipes. Use functions that consume iterators at C speed
(feed the entire iterator into a zero-length deque).
"""
if reset_prior_to_run:
self.reset()
try:
self.is_run = True
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque

View File

@ -26,8 +26,7 @@ def test_episode(
reward_metric: Callable[[np.ndarray], np.ndarray] | None = None,
) -> CollectStats:
"""A simple wrapper of testing policy in collector."""
collector.reset_env()
collector.reset_buffer()
collector.reset(reset_stats=False)
policy.eval()
if test_fn:
test_fn(epoch, global_step)

View File

@ -610,6 +610,17 @@ class BaseActor(nn.Module, ABC):
def get_output_dim(self) -> int:
pass
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[Any, Any]:
# TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction.
# Return type needs to be more specific
pass
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present.

View File

@ -1,4 +1,5 @@
import warnings
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any
@ -9,6 +10,7 @@ from torch import nn
from tianshou.utils.net.common import (
MLP,
BaseActor,
Net,
TActionShape,
TLinearLayer,
get_output_dim,
@ -19,33 +21,27 @@ SIGMA_MAX = 2
class Actor(BaseActor):
"""Simple actor network.
"""Simple actor network that directly outputs actions for continuous action space.
Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`.
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
:param max_action: the scale for the final action.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
@ -77,42 +73,50 @@ class Actor(BaseActor):
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
"""Mapping: obs -> logits -> action."""
if info is None:
info = {}
logits, hidden = self.preprocess(obs, state)
logits = self.max_action * torch.tanh(self.last(logits))
return logits, hidden
"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
Returns a tensor representing the actions directly, i.e, of shape
`(n_actions, )`, and a hidden state (which may be None).
The hidden state is only not None if a recurrent net is used as part of the
learning algorithm (support for RNNs is currently experimental).
"""
action_BA, hidden_BH = self.preprocess(obs, state)
action_BA = self.max_action * torch.tanh(self.last(action_BA))
return action_BA, hidden_BH
class Critic(nn.Module):
class CriticBase(nn.Module, ABC):
@abstractmethod
def forward(
self,
obs: np.ndarray | torch.Tensor,
act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None,
) -> torch.Tensor:
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
class Critic(CriticBase):
"""Simple critic network.
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
:param linear_layer: use this module as linear layer. Default to nn.Linear.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
:param linear_layer: use this module as linear layer.
:param flatten_input: whether to flatten input data for the last layer.
Default to True.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (),
device: str | int | torch.device = "cpu",
preprocess_net_output_dim: int | None = None,
@ -139,9 +143,7 @@ class Critic(nn.Module):
act: np.ndarray | torch.Tensor | None = None,
info: dict[str, Any] | None = None,
) -> torch.Tensor:
"""Mapping: (s, a) -> logits -> Q(s, a)."""
if info is None:
info = {}
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
obs = torch.as_tensor(
obs,
device=self.device,
@ -154,41 +156,35 @@ class Critic(nn.Module):
dtype=torch.float32,
).flatten(1)
obs = torch.cat([obs, act], dim=1)
logits, hidden = self.preprocess(obs)
return self.last(logits)
values_B, hidden_BH = self.preprocess(obs)
return self.last(values_B)
class ActorProb(BaseActor):
"""Simple actor network (output with a Gauss distribution).
"""Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`.
:param preprocess_net: a self-defined preprocess_net, see usage.
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param max_action: the scale for the final action logits. Default to
1.
:param unbounded: whether to apply tanh activation on final logits.
Default to False.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter. Default to False.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
:param max_action: the scale for the final action logits.
:param unbounded: whether to apply tanh activation on final logits.
:param conditioned_sigma: True when sigma is calculated from the
input, False when sigma is an independent parameter.
:param preprocess_net_output_dim: the output dimension of
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
# TODO: force kwargs, adjust downstream code
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
max_action: float = 1.0,
@ -402,8 +398,7 @@ class Perturbation(nn.Module):
flattened hidden state.
:param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on.
Default to cpu.
:param phi: max perturbation parameter for BCQ. Default to 0.05.
:param phi: max perturbation parameter for BCQ.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
@ -449,7 +444,6 @@ class VAE(nn.Module):
:param latent_dim: the size of latent layer.
:param max_action: the maximum value of each dimension of action.
:param device: which device to create this model on.
Default to "cpu".
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.

View File

@ -7,17 +7,14 @@ import torch.nn.functional as F
from torch import nn
from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim
class Actor(BaseActor):
"""Simple actor network.
"""Simple actor network for discrete action spaces.
Will create an actor operated in discrete action space with structure of
preprocess_net ---> action_shape.
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
:class:`~tianshou.utils.net.common.Net`.
:param action_shape: a sequence of int for the shape of action.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
@ -25,20 +22,15 @@ class Actor(BaseActor):
:param softmax_output: whether to apply a softmax layer over the last
layer's output.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
action_shape: TActionShape,
hidden_sizes: Sequence[int] = (),
softmax_output: bool = True,
@ -71,43 +63,44 @@ class Actor(BaseActor):
obs: np.ndarray | torch.Tensor,
state: Any = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
r"""Mapping: s -> Q(s, \*)."""
if info is None:
info = {}
logits, hidden = self.preprocess(obs, state)
logits = self.last(logits)
) -> tuple[torch.Tensor, torch.Tensor | None]:
r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
Returns a tensor representing the values of each action, i.e, of shape
`(n_actions, )`, and
a hidden state (which may be None). If `self.softmax_output` is True, they are the
probabilities for taking each action. Otherwise, they will be action values.
The hidden state is only
not None if a recurrent net is used as part of the learning algorithm.
"""
x, hidden_BH = self.preprocess(obs, state)
x = self.last(x)
if self.softmax_output:
logits = F.softmax(logits, dim=-1)
return logits, hidden
x = F.softmax(x, dim=-1)
# If we computed softmax, output is probabilities, otherwise it's the non-normalized action values
output_BA = x
return output_BA, hidden_BH
class Critic(nn.Module):
"""Simple critic network.
"""Simple critic network for discrete action spaces.
It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value).
:param preprocess_net: a self-defined preprocess_net which output a
flattened hidden state.
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
:class:`~tianshou.utils.net.common.Net`.
:param hidden_sizes: a sequence of int for constructing the MLP after
preprocess_net. Default to empty sequence (where the MLP now contains
only a single linear layer).
:param last_size: the output dimension of Critic network. Default to 1.
:param preprocess_net_output_dim: the output dimension of
preprocess_net.
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
For advanced usage (how to customize the network), please refer to
:ref:`build_the_network`.
.. seealso::
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
of how preprocess_net is suggested to be defined.
:ref:`build_the_network`..
"""
def __init__(
self,
preprocess_net: nn.Module,
preprocess_net: nn.Module | Net,
hidden_sizes: Sequence[int] = (),
last_size: int = 1,
preprocess_net_output_dim: int | None = None,
@ -120,8 +113,10 @@ class Critic(nn.Module):
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
# TODO: make a proper interface!
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Mapping: s -> V(s)."""
"""Mapping: s_B -> V(s)_B."""
# TODO: don't use this mechanism for passing state
logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
return self.last(logits)