Merge branch 'thuml_master' into feature/algo-eval
This commit is contained in:
commit
f2e10b04bb
34
CHANGELOG.md
34
CHANGELOG.md
@ -1,4 +1,38 @@
|
||||
# Changelog
|
||||
|
||||
## Release 1.1.0
|
||||
|
||||
### Api Extensions
|
||||
- Batch received two new methods: `to_dict` and `to_list_of_dicts`. #1063
|
||||
- `Collector`s can now be closed, and their reset is more granular. #1063
|
||||
- Trainers can control whether collectors should be reset prior to training. #1063
|
||||
- Convenience constructor for `CollectStats` called `with_autogenerated_stats`. #1063
|
||||
|
||||
### Internal Improvements
|
||||
- `Collector`s rely less on state, the few stateful things are stored explicitly instead of through a `.data` attribute. #1063
|
||||
- Introduced a first iteration of a naming convention for vars in `Collector`s. #1063
|
||||
- Generally improved readability of Collector code and associated tests (still quite some way to go). #1063
|
||||
- Improved typing for `exploration_noise` and within Collector. #1063
|
||||
- Better variable names related to model outputs (logits, dist input etc.). #1032
|
||||
- Improved typing for actors and critics, using Tianshou classes like `Actor`, `ActorProb`, etc.,
|
||||
instead of just `nn.Module`. #1032
|
||||
- Added interfaces for most `Actor` and `Critic` classes to enforce the presence of `forward` methods. #1032
|
||||
- Simplified `PGPolicy` forward by unifying the `dist_fn` interface (see associated breaking change). #1032
|
||||
- Use `.mode` of distribution instead of relying on knowledge of the distribution type. #1032
|
||||
|
||||
### Breaking Changes
|
||||
|
||||
- Removed `.data` attribute from `Collector` and its child classes. #1063
|
||||
- Collectors no longer reset the environment on initialization. Instead, the user might have to call `reset`
|
||||
expicitly or pass `reset_before_collect=True` . #1063
|
||||
- VectorEnvs now return an array of info-dicts on reset instead of a list. #1063
|
||||
- Fixed `iter(Batch(...)` which now behaves the same way as `Batch(...).__iter__()`. Can be considered a bugfix. #1063
|
||||
- Changed interface of `dist_fn` in `PGPolicy` and all subclasses to take a single argument in both
|
||||
continuous and discrete cases. #1032
|
||||
|
||||
### Tests
|
||||
- Fixed env seeding it test_sac_with_il.py so that the test doesn't fail randomly. #1081
|
||||
|
||||
|
||||
Started after v1.0.0
|
||||
|
||||
|
@ -147,9 +147,6 @@ Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://
|
||||
|
||||
Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders.
|
||||
|
||||
中文文档位于 [https://tianshou.readthedocs.io/zh/master/](https://tianshou.readthedocs.io/zh/master/)。
|
||||
|
||||
<!-- 这里有一份天授平台简短的中文简介:https://www.zhihu.com/question/377263715 -->
|
||||
|
||||
## Why Tianshou?
|
||||
|
||||
|
@ -164,7 +164,7 @@
|
||||
"source": [
|
||||
"# Let's watch its performance!\n",
|
||||
"policy.eval()\n",
|
||||
"eval_result = test_collector.collect(n_episode=1, render=False)\n",
|
||||
"eval_result = test_collector.collect(n_episode=3, render=False)\n",
|
||||
"print(f\"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}\")"
|
||||
]
|
||||
},
|
||||
|
@ -69,7 +69,7 @@
|
||||
"from tianshou.policy import BasePolicy\n",
|
||||
"from tianshou.policy.modelfree.pg import (\n",
|
||||
" PGTrainingStats,\n",
|
||||
" TDistributionFunction,\n",
|
||||
" TDistFnDiscrOrCont,\n",
|
||||
" TPGTrainingStats,\n",
|
||||
")\n",
|
||||
"from tianshou.utils import RunningMeanStd\n",
|
||||
@ -339,7 +339,7 @@
|
||||
" *,\n",
|
||||
" actor: torch.nn.Module,\n",
|
||||
" optim: torch.optim.Optimizer,\n",
|
||||
" dist_fn: TDistributionFunction,\n",
|
||||
" dist_fn: TDistFnDiscrOrCont,\n",
|
||||
" action_space: gym.Space,\n",
|
||||
" discount_factor: float = 0.99,\n",
|
||||
" observation_space: gym.Space | None = None,\n",
|
||||
|
@ -119,7 +119,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"collect_result = test_collector.collect(n_episode=9)\n",
|
||||
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9)\n",
|
||||
"\n",
|
||||
"collect_result.pprint_asdict()"
|
||||
]
|
||||
@ -146,8 +146,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Reset the collector\n",
|
||||
"test_collector.reset()\n",
|
||||
"collect_result = test_collector.collect(n_episode=9, random=True)\n",
|
||||
"collect_result = test_collector.collect(reset_before_collect=True, n_episode=9, random=True)\n",
|
||||
"\n",
|
||||
"collect_result.pprint_asdict()"
|
||||
]
|
||||
|
@ -92,8 +92,6 @@ To compile documentation into webpage, run
|
||||
|
||||
The generated webpage is in ``docs/_build`` and can be viewed with browser (http://0.0.0.0:8000/).
|
||||
|
||||
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.
|
||||
|
||||
|
||||
Documentation Generation Test
|
||||
-----------------------------
|
||||
|
@ -57,9 +57,6 @@ Here is Tianshou's other features:
|
||||
* Support multi-GPU training :ref:`multi_gpu`
|
||||
* Comprehensive `unit tests <https://github.com/thu-ml/tianshou/actions>`_, including functional checking, RL pipeline checking, documentation checking, PEP8 code-style checking, and type checking
|
||||
|
||||
中文文档位于 `https://tianshou.readthedocs.io/zh/master/ <https://tianshou.readthedocs.io/zh/master/>`_
|
||||
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
|
@ -257,3 +257,8 @@ macOS
|
||||
joblib
|
||||
master
|
||||
Panchenko
|
||||
BA
|
||||
BH
|
||||
BO
|
||||
BD
|
||||
|
||||
|
@ -167,8 +167,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
# expert replay buffer
|
||||
dataset = d4rl.qlearning_dataset(gym.make(args.expert_data_task))
|
||||
|
@ -137,8 +137,9 @@ def test_a2c(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: A2CPolicy = A2CPolicy(
|
||||
actor=actor,
|
||||
|
@ -134,8 +134,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: NPGPolicy = NPGPolicy(
|
||||
actor=actor,
|
||||
|
@ -137,8 +137,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: PPOPolicy = PPOPolicy(
|
||||
actor=actor,
|
||||
|
@ -119,8 +119,9 @@ def test_reinforce(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: PGPolicy = PGPolicy(
|
||||
actor=actor,
|
||||
|
@ -137,8 +137,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: TRPOPolicy = TRPOPolicy(
|
||||
actor=actor,
|
||||
|
@ -171,6 +171,7 @@ ignore = [
|
||||
"RET505",
|
||||
"D106", # undocumented public nested class
|
||||
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
|
||||
"PLW2901", # overwrite vars in loop
|
||||
]
|
||||
unfixable = [
|
||||
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
|
||||
|
@ -9,13 +9,24 @@ import numpy as np
|
||||
from gymnasium.spaces import Box, Dict, Discrete, MultiDiscrete, Space, Tuple
|
||||
|
||||
|
||||
class MyTestEnv(gym.Env):
|
||||
"""A task for "going right". The task is to go right ``size`` steps."""
|
||||
class MoveToRightEnv(gym.Env):
|
||||
"""A task for "going right". The task is to go right ``size`` steps.
|
||||
|
||||
The observation is the current index, and the action is to go left or right.
|
||||
Action 0 is to go left, and action 1 is to go right.
|
||||
Taking action 0 at index 0 will keep the index at 0.
|
||||
Arriving at index ``size`` means the task is done.
|
||||
In the current implementation, stepping after the task is done is possible, which will
|
||||
lead the index to be larger than ``size``.
|
||||
|
||||
Index 0 is the starting point. If reset is called with default options, the index will
|
||||
be reset to 0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
sleep: int = 0,
|
||||
sleep: float = 0.0,
|
||||
dict_state: bool = False,
|
||||
recurse_state: bool = False,
|
||||
ma_rew: int = 0,
|
||||
@ -74,8 +85,13 @@ class MyTestEnv(gym.Env):
|
||||
def reset(
|
||||
self,
|
||||
seed: int | None = None,
|
||||
# TODO: passing a dict here doesn't make any sense
|
||||
options: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, Any] | np.ndarray, dict]:
|
||||
""":param seed:
|
||||
:param options: the start index is provided in options["state"]
|
||||
:return:
|
||||
"""
|
||||
if options is None:
|
||||
options = {"state": 0}
|
||||
super().reset(seed=seed)
|
||||
@ -188,7 +204,7 @@ class NXEnv(gym.Env):
|
||||
return self._encode_obs(), 1.0, False, False, {}
|
||||
|
||||
|
||||
class MyGoalEnv(MyTestEnv):
|
||||
class MyGoalEnv(MoveToRightEnv):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
assert (
|
||||
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0
|
||||
|
@ -22,13 +22,13 @@ from tianshou.data import (
|
||||
from tianshou.data.utils.converter import to_hdf5
|
||||
|
||||
if __name__ == "__main__":
|
||||
from env import MyGoalEnv, MyTestEnv
|
||||
from env import MoveToRightEnv, MyGoalEnv
|
||||
else: # pytest
|
||||
from test.base.env import MyGoalEnv, MyTestEnv
|
||||
from test.base.env import MoveToRightEnv, MyGoalEnv
|
||||
|
||||
|
||||
def test_replaybuffer(size=10, bufsize=20) -> None:
|
||||
env = MyTestEnv(size)
|
||||
env = MoveToRightEnv(size)
|
||||
buf = ReplayBuffer(bufsize)
|
||||
buf.update(buf)
|
||||
assert str(buf) == buf.__class__.__name__ + "()"
|
||||
@ -209,7 +209,7 @@ def test_ignore_obs_next(size=10) -> None:
|
||||
|
||||
|
||||
def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
|
||||
env = MyTestEnv(size)
|
||||
env = MoveToRightEnv(size)
|
||||
buf = ReplayBuffer(bufsize, stack_num=stack_num)
|
||||
buf2 = ReplayBuffer(bufsize, stack_num=stack_num, sample_avail=True)
|
||||
buf3 = ReplayBuffer(bufsize, stack_num=stack_num, save_only_last_obs=True)
|
||||
@ -280,7 +280,7 @@ def test_stack(size=5, bufsize=9, stack_num=4, cached_num=3) -> None:
|
||||
|
||||
|
||||
def test_priortized_replaybuffer(size=32, bufsize=15) -> None:
|
||||
env = MyTestEnv(size)
|
||||
env = MoveToRightEnv(size)
|
||||
buf = PrioritizedReplayBuffer(bufsize, 0.5, 0.5)
|
||||
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
|
||||
obs, info = env.reset()
|
||||
@ -1028,7 +1028,7 @@ def test_multibuf_stack() -> None:
|
||||
bufsize = 9
|
||||
stack_num = 4
|
||||
cached_num = 3
|
||||
env = MyTestEnv(size)
|
||||
env = MoveToRightEnv(size)
|
||||
# test if CachedReplayBuffer can handle stack_num + ignore_obs_next
|
||||
buf4 = CachedReplayBuffer(
|
||||
ReplayBuffer(bufsize, stack_num=stack_num, ignore_obs_next=True),
|
||||
|
@ -2,7 +2,6 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import tqdm
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.data import (
|
||||
AsyncCollector,
|
||||
@ -22,12 +21,12 @@ except ImportError:
|
||||
envpool = None
|
||||
|
||||
if __name__ == "__main__":
|
||||
from env import MyTestEnv, NXEnv
|
||||
from env import MoveToRightEnv, NXEnv
|
||||
else: # pytest
|
||||
from test.base.env import MyTestEnv, NXEnv
|
||||
from test.base.env import MoveToRightEnv, NXEnv
|
||||
|
||||
|
||||
class MyPolicy(BasePolicy):
|
||||
class MaxActionPolicy(BasePolicy):
|
||||
def __init__(
|
||||
self,
|
||||
action_space: gym.spaces.Space | None = None,
|
||||
@ -35,7 +34,9 @@ class MyPolicy(BasePolicy):
|
||||
need_state=True,
|
||||
action_shape=None,
|
||||
) -> None:
|
||||
"""Mock policy for testing.
|
||||
"""Mock policy for testing, will always return an array of ones of the shape of the action space.
|
||||
Note that this doesn't make much sense for discrete action space (the output is then intepreted as
|
||||
logits, meaning all actions would be equally likely).
|
||||
|
||||
:param action_space: the action space of the environment. If None, a dummy Box space will be used.
|
||||
:param bool dict_state: if the observation of the environment is a dict
|
||||
@ -63,215 +64,290 @@ class MyPolicy(BasePolicy):
|
||||
pass
|
||||
|
||||
|
||||
class Logger:
|
||||
def __init__(self, writer) -> None:
|
||||
self.cnt = 0
|
||||
self.writer = writer
|
||||
def test_collector() -> None:
|
||||
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||
|
||||
def preprocess_fn(self, **kwargs):
|
||||
# modify info before adding into the buffer, and recorded into tfb
|
||||
# if obs && env_id exist -> reset
|
||||
# if obs_next/rew/done/info/env_id exist -> normal step
|
||||
if "rew" in kwargs:
|
||||
info = kwargs["info"]
|
||||
info.rew = kwargs["rew"]
|
||||
if "key" in info:
|
||||
self.writer.add_scalar("key", np.mean(info.key), global_step=self.cnt)
|
||||
self.cnt += 1
|
||||
return Batch(info=info)
|
||||
return Batch()
|
||||
|
||||
@staticmethod
|
||||
def single_preprocess_fn(**kwargs):
|
||||
# same as above, without tfb
|
||||
if "rew" in kwargs:
|
||||
info = kwargs["info"]
|
||||
info.rew = kwargs["rew"]
|
||||
return Batch(info=info)
|
||||
return Batch()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
|
||||
def test_collector(gym_reset_kwargs) -> None:
|
||||
writer = SummaryWriter("log/collector")
|
||||
logger = Logger(writer)
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0) for i in [2, 3, 4, 5]]
|
||||
|
||||
venv = SubprocVectorEnv(env_fns)
|
||||
dum = DummyVectorEnv(env_fns)
|
||||
policy = MyPolicy()
|
||||
env = env_fns[0]()
|
||||
c0 = Collector(
|
||||
subproc_venv_4_envs = SubprocVectorEnv(env_fns)
|
||||
dummy_venv_4_envs = DummyVectorEnv(env_fns)
|
||||
policy = MaxActionPolicy()
|
||||
single_env = env_fns[0]()
|
||||
c_single_env = Collector(
|
||||
policy,
|
||||
env,
|
||||
single_env,
|
||||
ReplayBuffer(size=100),
|
||||
logger.preprocess_fn,
|
||||
)
|
||||
c0.collect(n_step=3, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert len(c0.buffer) == 3
|
||||
assert np.allclose(c0.buffer.obs[:4, 0], [0, 1, 0, 0])
|
||||
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1])
|
||||
c_single_env.reset()
|
||||
c_single_env.collect(n_step=3)
|
||||
assert len(c_single_env.buffer) == 3
|
||||
# TODO: direct attr access is an arcane way of using the buffer, it should be never done
|
||||
# The placeholders for entries are all zeros, so buffer.obs is an array filled with 3
|
||||
# observations, and 97 zeros.
|
||||
# However, buffer[:] will have all attributes with length three... The non-filled entries are removed there
|
||||
|
||||
# See above. For the single env, we start with obs=0, obs_next=1.
|
||||
# We move to obs=1, obs_next=2,
|
||||
# then the env is reset and we move to obs=0
|
||||
# Making one more step results in obs_next=1
|
||||
# The final 0 in the buffer.obs is because the buffer is initialized with zeros and the direct attr access
|
||||
assert np.allclose(c_single_env.buffer.obs[:4, 0], [0, 1, 0, 0])
|
||||
assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1])
|
||||
keys = np.zeros(100)
|
||||
keys[:3] = 1
|
||||
assert np.allclose(c0.buffer.info["key"], keys)
|
||||
for e in c0.buffer.info["env"][:3]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
assert np.allclose(c0.buffer.info["env_id"], 0)
|
||||
assert np.allclose(c_single_env.buffer.info["key"], keys)
|
||||
for e in c_single_env.buffer.info["env"][:3]:
|
||||
assert isinstance(e, MoveToRightEnv)
|
||||
assert np.allclose(c_single_env.buffer.info["env_id"], 0)
|
||||
rews = np.zeros(100)
|
||||
rews[:3] = [0, 1, 0]
|
||||
assert np.allclose(c0.buffer.info["rew"], rews)
|
||||
c0.collect(n_episode=3, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert len(c0.buffer) == 8
|
||||
assert np.allclose(c0.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
|
||||
assert np.allclose(c0.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
assert np.allclose(c0.buffer.info["key"][:8], 1)
|
||||
for e in c0.buffer.info["env"][:8]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
assert np.allclose(c0.buffer.info["env_id"][:8], 0)
|
||||
assert np.allclose(c0.buffer.info["rew"][:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
c0.collect(n_step=3, random=True, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert np.allclose(c_single_env.buffer.rew, rews)
|
||||
# At this point, the buffer contains obs 0 -> 1 -> 0
|
||||
|
||||
c1 = Collector(
|
||||
# At start we have 3 entries in the buffer
|
||||
# We collect 3 episodes, in addition to the transitions we have collected before
|
||||
# 0 -> 1 -> 0 -> 0 (reset at collection start) -> 1 -> done (0) -> 1 -> done(0)
|
||||
# obs_next: 1 -> 2 -> 1 -> 1 (reset at collection start) -> 2 -> 1 -> 2 -> 1 -> 2
|
||||
# In total, we will have 3 + 6 = 9 entries in the buffer
|
||||
c_single_env.collect(n_episode=3)
|
||||
assert len(c_single_env.buffer) == 8
|
||||
assert np.allclose(c_single_env.buffer.obs[:10, 0], [0, 1, 0, 1, 0, 1, 0, 1, 0, 0])
|
||||
assert np.allclose(c_single_env.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
assert np.allclose(c_single_env.buffer.info["key"][:8], 1)
|
||||
for e in c_single_env.buffer.info["env"][:8]:
|
||||
assert isinstance(e, MoveToRightEnv)
|
||||
assert np.allclose(c_single_env.buffer.info["env_id"][:8], 0)
|
||||
assert np.allclose(c_single_env.buffer.rew[:8], [0, 1, 0, 1, 0, 1, 0, 1])
|
||||
c_single_env.collect(n_step=3, random=True)
|
||||
|
||||
c_subproc_venv_4_envs = Collector(
|
||||
policy,
|
||||
venv,
|
||||
subproc_venv_4_envs,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
logger.preprocess_fn,
|
||||
)
|
||||
c1.collect(n_step=8, gym_reset_kwargs=gym_reset_kwargs)
|
||||
c_subproc_venv_4_envs.reset()
|
||||
|
||||
# Collect some steps
|
||||
c_subproc_venv_4_envs.collect(n_step=8)
|
||||
obs = np.zeros(100)
|
||||
valid_indices = [0, 1, 25, 26, 50, 51, 75, 76]
|
||||
obs[valid_indices] = [0, 1, 0, 1, 0, 1, 0, 1]
|
||||
assert np.allclose(c1.buffer.obs[:, 0], obs)
|
||||
assert np.allclose(c1.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs)
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer[:].obs_next[..., 0], [1, 2, 1, 2, 1, 2, 1, 2])
|
||||
keys = np.zeros(100)
|
||||
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert np.allclose(c1.buffer.info["key"], keys)
|
||||
for e in c1.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys)
|
||||
for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MoveToRightEnv)
|
||||
env_ids = np.zeros(100)
|
||||
env_ids[valid_indices] = [0, 0, 1, 1, 2, 2, 3, 3]
|
||||
assert np.allclose(c1.buffer.info["env_id"], env_ids)
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids)
|
||||
rews = np.zeros(100)
|
||||
rews[valid_indices] = [0, 1, 0, 0, 0, 0, 0, 0]
|
||||
assert np.allclose(c1.buffer.info["rew"], rews)
|
||||
c1.collect(n_episode=4, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert len(c1.buffer) == 16
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews)
|
||||
|
||||
# we previously collected 8 steps, 2 from each env, now we collect 4 episodes
|
||||
# each env will contribute an episode, which will be of lens 2 (first env was reset), 1, 2, 3
|
||||
# So we get 8 + 2+1+2+3 = 16 steps
|
||||
c_subproc_venv_4_envs.collect(n_episode=4)
|
||||
assert len(c_subproc_venv_4_envs.buffer) == 16
|
||||
|
||||
valid_indices = [2, 3, 27, 52, 53, 77, 78, 79]
|
||||
obs[[2, 3, 27, 52, 53, 77, 78, 79]] = [0, 1, 2, 2, 3, 2, 3, 4]
|
||||
assert np.allclose(c1.buffer.obs[:, 0], obs)
|
||||
obs[valid_indices] = [0, 1, 2, 2, 3, 2, 3, 4]
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.obs[:, 0], obs)
|
||||
assert np.allclose(
|
||||
c1.buffer[:].obs_next[..., 0],
|
||||
c_subproc_venv_4_envs.buffer[:].obs_next[..., 0],
|
||||
[1, 2, 1, 2, 1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5],
|
||||
)
|
||||
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert np.allclose(c1.buffer.info["key"], keys)
|
||||
for e in c1.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.info["key"], keys)
|
||||
for e in c_subproc_venv_4_envs.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MoveToRightEnv)
|
||||
env_ids[valid_indices] = [0, 0, 1, 2, 2, 3, 3, 3]
|
||||
assert np.allclose(c1.buffer.info["env_id"], env_ids)
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.info["env_id"], env_ids)
|
||||
rews[valid_indices] = [0, 1, 1, 0, 1, 0, 0, 1]
|
||||
assert np.allclose(c1.buffer.info["rew"], rews)
|
||||
c1.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert np.allclose(c_subproc_venv_4_envs.buffer.rew, rews)
|
||||
c_subproc_venv_4_envs.collect(n_episode=4, random=True)
|
||||
|
||||
c2 = Collector(
|
||||
c_dummy_venv_4_envs = Collector(
|
||||
policy,
|
||||
dum,
|
||||
dummy_venv_4_envs,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
logger.preprocess_fn,
|
||||
)
|
||||
c2.collect(n_episode=7, gym_reset_kwargs=gym_reset_kwargs)
|
||||
c_dummy_venv_4_envs.reset()
|
||||
c_dummy_venv_4_envs.collect(n_episode=7)
|
||||
obs1 = obs.copy()
|
||||
obs1[[4, 5, 28, 29, 30]] = [0, 1, 0, 1, 2]
|
||||
obs2 = obs.copy()
|
||||
obs2[[28, 29, 30, 54, 55, 56, 57]] = [0, 1, 2, 0, 1, 2, 3]
|
||||
c2obs = c2.buffer.obs[:, 0]
|
||||
c2obs = c_dummy_venv_4_envs.buffer.obs[:, 0]
|
||||
assert np.all(c2obs == obs1) or np.all(c2obs == obs2)
|
||||
c2.reset_env(gym_reset_kwargs=gym_reset_kwargs)
|
||||
c2.reset_buffer()
|
||||
assert c2.collect(n_episode=8, gym_reset_kwargs=gym_reset_kwargs).n_collected_episodes == 8
|
||||
c_dummy_venv_4_envs.reset_env()
|
||||
c_dummy_venv_4_envs.reset_buffer()
|
||||
assert c_dummy_venv_4_envs.collect(n_episode=8).n_collected_episodes == 8
|
||||
valid_indices = [4, 5, 28, 29, 30, 54, 55, 56, 57]
|
||||
obs[valid_indices] = [0, 1, 0, 1, 2, 0, 1, 2, 3]
|
||||
assert np.all(c2.buffer.obs[:, 0] == obs)
|
||||
assert np.all(c_dummy_venv_4_envs.buffer.obs[:, 0] == obs)
|
||||
keys[valid_indices] = [1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
assert np.allclose(c2.buffer.info["key"], keys)
|
||||
for e in c2.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MyTestEnv)
|
||||
assert np.allclose(c_dummy_venv_4_envs.buffer.info["key"], keys)
|
||||
for e in c_dummy_venv_4_envs.buffer.info["env"][valid_indices]:
|
||||
assert isinstance(e, MoveToRightEnv)
|
||||
env_ids[valid_indices] = [0, 0, 1, 1, 1, 2, 2, 2, 2]
|
||||
assert np.allclose(c2.buffer.info["env_id"], env_ids)
|
||||
assert np.allclose(c_dummy_venv_4_envs.buffer.info["env_id"], env_ids)
|
||||
rews[valid_indices] = [0, 1, 0, 0, 1, 0, 0, 0, 1]
|
||||
assert np.allclose(c2.buffer.info["rew"], rews)
|
||||
c2.collect(n_episode=4, random=True, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert np.allclose(c_dummy_venv_4_envs.buffer.rew, rews)
|
||||
c_dummy_venv_4_envs.collect(n_episode=4, random=True)
|
||||
|
||||
# test corner case
|
||||
with pytest.raises(TypeError):
|
||||
Collector(policy, dum, ReplayBuffer(10))
|
||||
Collector(policy, dummy_venv_4_envs, ReplayBuffer(10))
|
||||
with pytest.raises(TypeError):
|
||||
Collector(policy, dum, PrioritizedReplayBuffer(10, 0.5, 0.5))
|
||||
Collector(policy, dummy_venv_4_envs, PrioritizedReplayBuffer(10, 0.5, 0.5))
|
||||
with pytest.raises(TypeError):
|
||||
c2.collect()
|
||||
c_dummy_venv_4_envs.collect()
|
||||
|
||||
# test NXEnv
|
||||
for obs_type in ["array", "object"]:
|
||||
envs = SubprocVectorEnv([lambda i=x, t=obs_type: NXEnv(i, t) for x in [5, 10, 15, 20]])
|
||||
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
|
||||
c3.collect(n_step=6, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert c3.buffer.obs.dtype == object
|
||||
c_suproc_new = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
|
||||
c_suproc_new.reset()
|
||||
c_suproc_new.collect(n_step=6)
|
||||
assert c_suproc_new.buffer.obs.dtype == object
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
|
||||
def test_collector_with_async(gym_reset_kwargs) -> None:
|
||||
@pytest.fixture()
|
||||
def get_AsyncCollector():
|
||||
env_lens = [2, 3, 4, 5]
|
||||
writer = SummaryWriter("log/async_collector")
|
||||
logger = Logger(writer)
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens]
|
||||
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0.001, random_sleep=True) for i in env_lens]
|
||||
|
||||
venv = SubprocVectorEnv(env_fns, wait_num=len(env_fns) - 1)
|
||||
policy = MyPolicy()
|
||||
policy = MaxActionPolicy()
|
||||
bufsize = 60
|
||||
c1 = AsyncCollector(
|
||||
policy,
|
||||
venv,
|
||||
VectorReplayBuffer(total_size=bufsize * 4, buffer_num=4),
|
||||
logger.preprocess_fn,
|
||||
)
|
||||
ptr = [0, 0, 0, 0]
|
||||
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
|
||||
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert result.n_collected_episodes >= n_episode
|
||||
# check buffer data, obs and obs_next, env_id
|
||||
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
|
||||
env_len = i + 2
|
||||
total = env_len * count
|
||||
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
|
||||
ptr[i] = (ptr[i] + total) % bufsize
|
||||
seq = np.arange(env_len)
|
||||
buf = c1.buffer.buffers[i]
|
||||
assert np.all(buf.info.env_id[indices] == i)
|
||||
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
|
||||
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
|
||||
# test async n_step, for now the buffer should be full of data
|
||||
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
|
||||
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert result.n_collected_steps >= n_step
|
||||
for i in range(4):
|
||||
env_len = i + 2
|
||||
seq = np.arange(env_len)
|
||||
buf = c1.buffer.buffers[i]
|
||||
assert np.all(buf.info.env_id == i)
|
||||
assert np.all(buf.obs.reshape(-1, env_len) == seq)
|
||||
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
|
||||
with pytest.raises(TypeError):
|
||||
c1.collect()
|
||||
c1.reset()
|
||||
return c1, env_lens
|
||||
|
||||
|
||||
class TestAsyncCollector:
|
||||
def test_collect_without_argument_gives_error(self, get_AsyncCollector):
|
||||
c1, env_lens = get_AsyncCollector
|
||||
with pytest.raises(TypeError):
|
||||
c1.collect()
|
||||
|
||||
def test_collect_one_episode_async(self, get_AsyncCollector):
|
||||
c1, env_lens = get_AsyncCollector
|
||||
result = c1.collect(n_episode=1)
|
||||
assert result.n_collected_episodes >= 1
|
||||
|
||||
def test_enough_episodes_two_collection_cycles_n_episode_without_reset(
|
||||
self,
|
||||
get_AsyncCollector,
|
||||
):
|
||||
c1, env_lens = get_AsyncCollector
|
||||
n_episode = 2
|
||||
result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=False)
|
||||
assert result_c1.n_collected_episodes >= n_episode
|
||||
result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=False)
|
||||
assert result_c2.n_collected_episodes >= n_episode
|
||||
|
||||
def test_enough_episodes_two_collection_cycles_n_episode_with_reset(self, get_AsyncCollector):
|
||||
c1, env_lens = get_AsyncCollector
|
||||
n_episode = 2
|
||||
result_c1 = c1.collect(n_episode=n_episode, reset_before_collect=True)
|
||||
assert result_c1.n_collected_episodes >= n_episode
|
||||
result_c2 = c1.collect(n_episode=n_episode, reset_before_collect=True)
|
||||
assert result_c2.n_collected_episodes >= n_episode
|
||||
|
||||
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_episode(
|
||||
self,
|
||||
get_AsyncCollector,
|
||||
):
|
||||
c1, env_lens = get_AsyncCollector
|
||||
ptr = [0, 0, 0, 0]
|
||||
bufsize = 60
|
||||
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
|
||||
result = c1.collect(n_episode=n_episode)
|
||||
assert result.n_collected_episodes >= n_episode
|
||||
# check buffer data, obs and obs_next, env_id
|
||||
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
|
||||
env_len = i + 2
|
||||
total = env_len * count
|
||||
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
|
||||
ptr[i] = (ptr[i] + total) % bufsize
|
||||
seq = np.arange(env_len)
|
||||
buf = c1.buffer.buffers[i]
|
||||
assert np.all(buf.info.env_id[indices] == i)
|
||||
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
|
||||
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
|
||||
|
||||
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_n_step(
|
||||
self,
|
||||
get_AsyncCollector,
|
||||
):
|
||||
c1, env_lens = get_AsyncCollector
|
||||
bufsize = 60
|
||||
ptr = [0, 0, 0, 0]
|
||||
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
|
||||
result = c1.collect(n_step=n_step)
|
||||
assert result.n_collected_steps >= n_step
|
||||
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
|
||||
env_len = i + 2
|
||||
total = env_len * count
|
||||
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
|
||||
ptr[i] = (ptr[i] + total) % bufsize
|
||||
seq = np.arange(env_len)
|
||||
buf = c1.buffer.buffers[i]
|
||||
assert np.all(buf.info.env_id[indices] == i)
|
||||
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
|
||||
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
|
||||
|
||||
@pytest.mark.parametrize("gym_reset_kwargs", [None, {}])
|
||||
def test_enough_episodes_and_correct_obs_indices_and_obs_next_iterative_collection_cycles_first_n_episode_then_n_step(
|
||||
self,
|
||||
get_AsyncCollector,
|
||||
gym_reset_kwargs,
|
||||
):
|
||||
c1, env_lens = get_AsyncCollector
|
||||
bufsize = 60
|
||||
ptr = [0, 0, 0, 0]
|
||||
for n_episode in tqdm.trange(1, 30, desc="test async n_episode"):
|
||||
result = c1.collect(n_episode=n_episode, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert result.n_collected_episodes >= n_episode
|
||||
# check buffer data, obs and obs_next, env_id
|
||||
for i, count in enumerate(np.bincount(result.lens, minlength=6)[2:]):
|
||||
env_len = i + 2
|
||||
total = env_len * count
|
||||
indices = np.arange(ptr[i], ptr[i] + total) % bufsize
|
||||
ptr[i] = (ptr[i] + total) % bufsize
|
||||
seq = np.arange(env_len)
|
||||
buf = c1.buffer.buffers[i]
|
||||
assert np.all(buf.info.env_id[indices] == i)
|
||||
assert np.all(buf.obs[indices].reshape(count, env_len) == seq)
|
||||
assert np.all(buf.obs_next[indices].reshape(count, env_len) == seq + 1)
|
||||
# test async n_step, for now the buffer should be full of data, thus no bincount stuff as above
|
||||
for n_step in tqdm.trange(1, 15, desc="test async n_step"):
|
||||
result = c1.collect(n_step=n_step, gym_reset_kwargs=gym_reset_kwargs)
|
||||
assert result.n_collected_steps >= n_step
|
||||
for i in range(4):
|
||||
env_len = i + 2
|
||||
seq = np.arange(env_len)
|
||||
buf = c1.buffer.buffers[i]
|
||||
assert np.all(buf.info.env_id == i)
|
||||
assert np.all(buf.obs.reshape(-1, env_len) == seq)
|
||||
assert np.all(buf.obs_next.reshape(-1, env_len) == seq + 1)
|
||||
|
||||
|
||||
def test_collector_with_dict_state() -> None:
|
||||
env = MyTestEnv(size=5, sleep=0, dict_state=True)
|
||||
policy = MyPolicy(dict_state=True)
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
|
||||
env = MoveToRightEnv(size=5, sleep=0, dict_state=True)
|
||||
policy = MaxActionPolicy(dict_state=True)
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100))
|
||||
c0.reset()
|
||||
c0.collect(n_step=3)
|
||||
c0.collect(n_episode=2)
|
||||
assert len(c0.buffer) == 10
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]]
|
||||
assert len(c0.buffer) == 10 # 3 + two episodes with 5 steps each
|
||||
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, dict_state=True) for i in [2, 3, 4, 5]]
|
||||
envs = DummyVectorEnv(env_fns)
|
||||
envs.seed(666)
|
||||
obs, info = envs.reset()
|
||||
@ -280,8 +356,8 @@ def test_collector_with_dict_state() -> None:
|
||||
policy,
|
||||
envs,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
Logger.single_preprocess_fn,
|
||||
)
|
||||
c1.reset()
|
||||
c1.collect(n_step=12)
|
||||
result = c1.collect(n_episode=8)
|
||||
assert result.n_collected_episodes == 8
|
||||
@ -396,41 +472,47 @@ def test_collector_with_dict_state() -> None:
|
||||
policy,
|
||||
envs,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
|
||||
Logger.single_preprocess_fn,
|
||||
)
|
||||
c2.reset()
|
||||
c2.collect(n_episode=10)
|
||||
batch, _ = c2.buffer.sample(10)
|
||||
|
||||
|
||||
def test_collector_with_ma() -> None:
|
||||
env = MyTestEnv(size=5, sleep=0, ma_rew=4)
|
||||
policy = MyPolicy()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
|
||||
# n_step=3 will collect a full episode
|
||||
rew = c0.collect(n_step=3).returns
|
||||
assert len(rew) == 0
|
||||
rew = c0.collect(n_episode=2).returns
|
||||
assert rew.shape == (2, 4)
|
||||
assert np.all(rew == 1)
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
|
||||
def test_collector_with_multi_agent() -> None:
|
||||
multi_agent_env = MoveToRightEnv(size=5, sleep=0, ma_rew=4)
|
||||
policy = MaxActionPolicy()
|
||||
c_single_env = Collector(policy, multi_agent_env, ReplayBuffer(size=100))
|
||||
c_single_env.reset()
|
||||
multi_env_returns = c_single_env.collect(n_step=3).returns
|
||||
# c_single_env has length 3
|
||||
# We have no full episodes, so no returns yet
|
||||
assert len(multi_env_returns) == 0
|
||||
|
||||
single_env_returns = c_single_env.collect(n_episode=2).returns
|
||||
# now two episodes. Since we have 4 a agents, the returns have shape (2, 4)
|
||||
assert single_env_returns.shape == (2, 4)
|
||||
assert np.all(single_env_returns == 1)
|
||||
|
||||
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
|
||||
envs = DummyVectorEnv(env_fns)
|
||||
c1 = Collector(
|
||||
c_multi_env_ma = Collector(
|
||||
policy,
|
||||
envs,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
Logger.single_preprocess_fn,
|
||||
)
|
||||
rew = c1.collect(n_step=12).returns
|
||||
assert rew.shape == (2, 4) and np.all(rew == 1), rew
|
||||
rew = c1.collect(n_episode=8).returns
|
||||
assert rew.shape == (8, 4)
|
||||
assert np.all(rew == 1)
|
||||
batch, _ = c1.buffer.sample(10)
|
||||
c_multi_env_ma.reset()
|
||||
multi_env_returns = c_multi_env_ma.collect(n_step=12).returns
|
||||
# each env makes 3 steps, the first two envs are done and result in two finished episodes
|
||||
assert multi_env_returns.shape == (2, 4) and np.all(multi_env_returns == 1), multi_env_returns
|
||||
multi_env_returns = c_multi_env_ma.collect(n_episode=8).returns
|
||||
assert multi_env_returns.shape == (8, 4)
|
||||
assert np.all(multi_env_returns == 1)
|
||||
batch, _ = c_multi_env_ma.buffer.sample(10)
|
||||
print(batch)
|
||||
c0.buffer.update(c1.buffer)
|
||||
assert len(c0.buffer) in [42, 43]
|
||||
if len(c0.buffer) == 42:
|
||||
rew = [
|
||||
c_single_env.buffer.update(c_multi_env_ma.buffer)
|
||||
assert len(c_single_env.buffer) in [42, 43]
|
||||
if len(c_single_env.buffer) == 42:
|
||||
multi_env_returns = [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
@ -475,7 +557,7 @@ def test_collector_with_ma() -> None:
|
||||
1,
|
||||
]
|
||||
else:
|
||||
rew = [
|
||||
multi_env_returns = [
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
@ -520,17 +602,17 @@ def test_collector_with_ma() -> None:
|
||||
0,
|
||||
1,
|
||||
]
|
||||
assert np.all(c0.buffer[:].rew == [[x] * 4 for x in rew])
|
||||
assert np.all(c0.buffer[:].done == rew)
|
||||
assert np.all(c_single_env.buffer[:].rew == [[x] * 4 for x in multi_env_returns])
|
||||
assert np.all(c_single_env.buffer[:].done == multi_env_returns)
|
||||
c2 = Collector(
|
||||
policy,
|
||||
envs,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
|
||||
Logger.single_preprocess_fn,
|
||||
)
|
||||
rew = c2.collect(n_episode=10).returns
|
||||
assert rew.shape == (10, 4)
|
||||
assert np.all(rew == 1)
|
||||
c2.reset()
|
||||
multi_env_returns = c2.collect(n_episode=10).returns
|
||||
assert multi_env_returns.shape == (10, 4)
|
||||
assert np.all(multi_env_returns == 1)
|
||||
batch, _ = c2.buffer.sample(10)
|
||||
|
||||
|
||||
@ -543,20 +625,21 @@ def test_collector_with_atari_setting() -> None:
|
||||
reference_obs[i, 0] = i
|
||||
|
||||
# atari single buffer
|
||||
env = MyTestEnv(size=5, sleep=0, array_state=True)
|
||||
policy = MyPolicy()
|
||||
env = MoveToRightEnv(size=5, sleep=0, array_state=True)
|
||||
policy = MaxActionPolicy()
|
||||
c0 = Collector(policy, env, ReplayBuffer(size=100))
|
||||
c0.reset()
|
||||
c0.collect(n_step=6)
|
||||
c0.collect(n_episode=2)
|
||||
assert c0.buffer.obs.shape == (100, 4, 84, 84)
|
||||
assert c0.buffer.obs_next.shape == (100, 4, 84, 84)
|
||||
assert len(c0.buffer) == 15
|
||||
assert len(c0.buffer) == 15 # 6 + 2 episodes with 5 steps each
|
||||
obs = np.zeros_like(c0.buffer.obs)
|
||||
obs[np.arange(15)] = reference_obs[np.arange(15) % 5]
|
||||
assert np.all(obs == c0.buffer.obs)
|
||||
|
||||
c1 = Collector(policy, env, ReplayBuffer(size=100, ignore_obs_next=True))
|
||||
c1.collect(n_episode=3)
|
||||
c1.collect(n_episode=3, reset_before_collect=True)
|
||||
assert np.allclose(c0.buffer.obs, c1.buffer.obs)
|
||||
with pytest.raises(AttributeError):
|
||||
c1.buffer.obs_next # noqa: B018
|
||||
@ -567,6 +650,7 @@ def test_collector_with_atari_setting() -> None:
|
||||
env,
|
||||
ReplayBuffer(size=100, ignore_obs_next=True, save_only_last_obs=True),
|
||||
)
|
||||
c2.reset()
|
||||
c2.collect(n_step=8)
|
||||
assert c2.buffer.obs.shape == (100, 84, 84)
|
||||
obs = np.zeros_like(c2.buffer.obs)
|
||||
@ -575,9 +659,10 @@ def test_collector_with_atari_setting() -> None:
|
||||
assert np.allclose(c2.buffer[:].obs_next, reference_obs[[1, 2, 3, 4, 4, 1, 2, 2], -1])
|
||||
|
||||
# atari multi buffer
|
||||
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]]
|
||||
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0, array_state=True) for i in [2, 3, 4, 5]]
|
||||
envs = DummyVectorEnv(env_fns)
|
||||
c3 = Collector(policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4))
|
||||
c3.reset()
|
||||
c3.collect(n_step=12)
|
||||
result = c3.collect(n_episode=9)
|
||||
assert result.n_collected_episodes == 9
|
||||
@ -606,6 +691,7 @@ def test_collector_with_atari_setting() -> None:
|
||||
save_only_last_obs=True,
|
||||
),
|
||||
)
|
||||
c4.reset()
|
||||
c4.collect(n_step=12)
|
||||
result = c4.collect(n_episode=9)
|
||||
assert result.n_collected_episodes == 9
|
||||
@ -672,6 +758,7 @@ def test_collector_with_atari_setting() -> None:
|
||||
|
||||
buf = ReplayBuffer(100, stack_num=4, ignore_obs_next=True, save_only_last_obs=True)
|
||||
c5 = Collector(policy, envs, CachedReplayBuffer(buf, 4, 10))
|
||||
c5.reset()
|
||||
result_ = c5.collect(n_step=12)
|
||||
assert len(buf) == 5
|
||||
assert len(c5.buffer) == 12
|
||||
@ -767,6 +854,7 @@ def test_collector_with_atari_setting() -> None:
|
||||
|
||||
# test buffer=None
|
||||
c6 = Collector(policy, envs)
|
||||
c6.reset()
|
||||
result1 = c6.collect(n_step=12)
|
||||
for key in ["n_collected_episodes", "n_collected_steps", "returns", "lens"]:
|
||||
assert np.allclose(getattr(result1, key), getattr(result_, key))
|
||||
@ -778,7 +866,7 @@ def test_collector_with_atari_setting() -> None:
|
||||
@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
|
||||
def test_collector_envpool_gym_reset_return_info() -> None:
|
||||
envs = envpool.make_gymnasium("Pendulum-v1", num_envs=4, gym_reset_return_info=True)
|
||||
policy = MyPolicy(action_shape=(len(envs), 1))
|
||||
policy = MaxActionPolicy(action_shape=(len(envs), 1))
|
||||
|
||||
c0 = Collector(
|
||||
policy,
|
||||
@ -786,18 +874,59 @@ def test_collector_envpool_gym_reset_return_info() -> None:
|
||||
VectorReplayBuffer(len(envs) * 10, len(envs)),
|
||||
exploration_noise=True,
|
||||
)
|
||||
c0.reset()
|
||||
c0.collect(n_step=8)
|
||||
env_ids = np.zeros(len(envs) * 10)
|
||||
env_ids[[0, 1, 10, 11, 20, 21, 30, 31]] = [0, 0, 1, 1, 2, 2, 3, 3]
|
||||
assert np.allclose(c0.buffer.info["env_id"], env_ids)
|
||||
|
||||
|
||||
def test_collector_with_vector_env():
|
||||
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]]
|
||||
|
||||
dum = DummyVectorEnv(env_fns)
|
||||
policy = MaxActionPolicy()
|
||||
|
||||
c2 = Collector(
|
||||
policy,
|
||||
dum,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
)
|
||||
|
||||
c2.reset()
|
||||
|
||||
c1r = c2.collect(n_episode=2)
|
||||
assert np.array_equal(np.array([1, 8]), c1r.lens)
|
||||
c2r = c2.collect(n_episode=10)
|
||||
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 8, 9, 10]), c2r.lens)
|
||||
c3r = c2.collect(n_step=20)
|
||||
assert np.array_equal(np.array([1, 1, 1, 1, 1]), c3r.lens)
|
||||
c4r = c2.collect(n_step=20)
|
||||
assert np.array_equal(np.array([1, 1, 1, 8, 1, 9, 1, 10]), c4r.lens)
|
||||
|
||||
|
||||
def test_async_collector_with_vector_env():
|
||||
env_fns = [lambda x=i: MoveToRightEnv(size=x, sleep=0) for i in [1, 8, 9, 10]]
|
||||
|
||||
dum = DummyVectorEnv(env_fns)
|
||||
policy = MaxActionPolicy()
|
||||
c1 = AsyncCollector(
|
||||
policy,
|
||||
dum,
|
||||
VectorReplayBuffer(total_size=100, buffer_num=4),
|
||||
)
|
||||
|
||||
c1r = c1.collect(n_episode=10, reset_before_collect=True)
|
||||
assert np.array_equal(np.array([1, 1, 1, 1, 1, 1, 1, 1, 8, 1, 9]), c1r.lens)
|
||||
c2r = c1.collect(n_step=20)
|
||||
assert np.array_equal(np.array([1, 10, 1, 1, 1, 1]), c2r.lens)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_collector(gym_reset_kwargs=None)
|
||||
test_collector(gym_reset_kwargs={})
|
||||
test_collector()
|
||||
test_collector_with_dict_state()
|
||||
test_collector_with_ma()
|
||||
test_collector_with_multi_agent()
|
||||
test_collector_with_atari_setting()
|
||||
test_collector_with_async(gym_reset_kwargs=None)
|
||||
test_collector_with_async(gym_reset_kwargs={"return_info": True})
|
||||
test_collector_envpool_gym_reset_return_info()
|
||||
test_collector_with_vector_env()
|
||||
test_async_collector_with_vector_env()
|
||||
|
@ -20,9 +20,9 @@ from tianshou.env.gym_wrappers import TruncatedAsTerminated
|
||||
from tianshou.utils import RunningMeanStd
|
||||
|
||||
if __name__ == "__main__":
|
||||
from env import MyTestEnv, NXEnv
|
||||
from env import MoveToRightEnv, NXEnv
|
||||
else: # pytest
|
||||
from test.base.env import MyTestEnv, NXEnv
|
||||
from test.base.env import MoveToRightEnv, NXEnv
|
||||
|
||||
try:
|
||||
import envpool
|
||||
@ -56,7 +56,7 @@ def recurse_comp(a, b):
|
||||
def test_async_env(size=10000, num=8, sleep=0.1) -> None:
|
||||
# simplify the test case, just keep stepping
|
||||
env_fns = [
|
||||
lambda i=i: MyTestEnv(size=i, sleep=sleep, random_sleep=True)
|
||||
lambda i=i: MoveToRightEnv(size=i, sleep=sleep, random_sleep=True)
|
||||
for i in range(size, size + num)
|
||||
]
|
||||
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
|
||||
@ -108,10 +108,10 @@ def test_async_env(size=10000, num=8, sleep=0.1) -> None:
|
||||
|
||||
def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None:
|
||||
env_fns = [
|
||||
lambda: MyTestEnv(size=size, sleep=sleep * 2),
|
||||
lambda: MyTestEnv(size=size, sleep=sleep * 3),
|
||||
lambda: MyTestEnv(size=size, sleep=sleep * 5),
|
||||
lambda: MyTestEnv(size=size, sleep=sleep * 7),
|
||||
lambda: MoveToRightEnv(size=size, sleep=sleep * 2),
|
||||
lambda: MoveToRightEnv(size=size, sleep=sleep * 3),
|
||||
lambda: MoveToRightEnv(size=size, sleep=sleep * 5),
|
||||
lambda: MoveToRightEnv(size=size, sleep=sleep * 7),
|
||||
]
|
||||
test_cls = [SubprocVectorEnv, ShmemVectorEnv]
|
||||
if has_ray():
|
||||
@ -156,7 +156,7 @@ def test_async_check_id(size=100, num=4, sleep=0.2, timeout=0.7) -> None:
|
||||
|
||||
def test_vecenv(size=10, num=8, sleep=0.001) -> None:
|
||||
env_fns = [
|
||||
lambda i=i: MyTestEnv(size=i, sleep=sleep, recurse_state=True)
|
||||
lambda i=i: MoveToRightEnv(size=i, sleep=sleep, recurse_state=True)
|
||||
for i in range(size, size + num)
|
||||
]
|
||||
venv = [
|
||||
@ -237,7 +237,7 @@ def test_env_obs_dtype() -> None:
|
||||
|
||||
|
||||
def test_env_reset_optional_kwargs(size=10000, num=8) -> None:
|
||||
env_fns = [lambda i=i: MyTestEnv(size=i) for i in range(size, size + num)]
|
||||
env_fns = [lambda i=i: MoveToRightEnv(size=i) for i in range(size, size + num)]
|
||||
test_cls = [DummyVectorEnv, SubprocVectorEnv, ShmemVectorEnv]
|
||||
if has_ray():
|
||||
test_cls += [RayVectorEnv]
|
||||
@ -257,7 +257,7 @@ def test_venv_wrapper_gym(num_envs: int = 4) -> None:
|
||||
except ValueError:
|
||||
obs, info = envs.reset(return_info=True)
|
||||
assert isinstance(obs, np.ndarray)
|
||||
assert isinstance(info, list)
|
||||
assert isinstance(info, np.ndarray)
|
||||
assert isinstance(info[0], dict)
|
||||
assert obs.shape[0] == len(info) == num_envs
|
||||
|
||||
@ -334,7 +334,7 @@ def test_venv_norm_obs() -> None:
|
||||
action = np.array([1, 1, 1, 1])
|
||||
total_step = 30
|
||||
action_list = [action] * total_step
|
||||
env_fns = [lambda i=x: MyTestEnv(size=i, array_state=True) for x in sizes]
|
||||
env_fns = [lambda i=x: MoveToRightEnv(size=i, array_state=True) for x in sizes]
|
||||
raw = DummyVectorEnv(env_fns)
|
||||
train_env = VectorEnvNormObs(DummyVectorEnv(env_fns))
|
||||
print(train_env.observation_space)
|
||||
|
@ -90,20 +90,20 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
|
||||
# END
|
||||
|
||||
def reset(self, id=None):
|
||||
id = self._wrap_id(id)
|
||||
def reset(self, env_id=None):
|
||||
env_id = self._wrap_id(env_id)
|
||||
self._reset_alive_envs()
|
||||
|
||||
# ask super to reset alive envs and remap to current index
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, id))
|
||||
obs = [None] * len(id)
|
||||
infos = [None] * len(id)
|
||||
id2idx = {i: k for k, i in enumerate(id)}
|
||||
request_id = list(filter(lambda i: i in self._alive_env_ids, env_id))
|
||||
obs = [None] * len(env_id)
|
||||
infos = [None] * len(env_id)
|
||||
id2idx = {i: k for k, i in enumerate(env_id)}
|
||||
if request_id:
|
||||
for k, o, info in zip(request_id, *super().reset(request_id), strict=True):
|
||||
obs[id2idx[k]] = o
|
||||
infos[id2idx[k]] = info
|
||||
for i, o in zip(id, obs, strict=True):
|
||||
for i, o in zip(env_id, obs, strict=True):
|
||||
if o is None and i in self._alive_env_ids:
|
||||
self._alive_env_ids.remove(i)
|
||||
|
||||
@ -121,7 +121,7 @@ class FiniteVectorEnv(BaseVectorEnv):
|
||||
self.reset()
|
||||
raise StopIteration
|
||||
|
||||
return np.stack(obs), infos
|
||||
return np.stack(obs), np.array(infos)
|
||||
|
||||
def step(self, action, id=None):
|
||||
id = self._wrap_id(id)
|
||||
@ -204,10 +204,12 @@ def test_finite_dummy_vector_env() -> None:
|
||||
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
test_collector.reset()
|
||||
|
||||
for _ in range(3):
|
||||
envs.tracker = MetricTracker()
|
||||
try:
|
||||
# TODO: why on earth 10**18?
|
||||
test_collector.collect(n_step=10**18)
|
||||
except StopIteration:
|
||||
envs.tracker.validate()
|
||||
@ -218,6 +220,7 @@ def test_finite_subproc_vector_env() -> None:
|
||||
envs = FiniteSubprocVectorEnv([_finite_env_factory(dataset, 5, i) for i in range(5)])
|
||||
policy = AnyPolicy()
|
||||
test_collector = Collector(policy, envs, exploration_noise=True)
|
||||
test_collector.reset()
|
||||
|
||||
for _ in range(3):
|
||||
envs.tracker = MetricTracker()
|
||||
|
@ -2,7 +2,7 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.distributions import Categorical, Independent, Normal
|
||||
from torch.distributions import Categorical, Distribution, Independent, Normal
|
||||
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.utils.net.common import ActorCritic, Net
|
||||
@ -25,7 +25,11 @@ def policy(request):
|
||||
Net(state_shape=obs_shape, hidden_sizes=[64, 64], action_shape=action_space.shape),
|
||||
action_shape=action_space.shape,
|
||||
)
|
||||
dist_fn = lambda *logits: Independent(Normal(*logits), 1)
|
||||
|
||||
def dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
elif action_type == "discrete":
|
||||
action_space = gym.spaces.Discrete(3)
|
||||
actor = Actor(
|
||||
|
@ -103,8 +103,9 @@ def test_npg(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: NPGPolicy[NPGTrainingStats] = NPGPolicy(
|
||||
actor=actor,
|
||||
|
@ -100,8 +100,9 @@ def test_ppo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: PPOPolicy[PPOTrainingStats] = PPOPolicy(
|
||||
actor=actor,
|
||||
|
@ -136,6 +136,7 @@ def test_redq(args: argparse.Namespace = get_args()) -> None:
|
||||
exploration_noise=True,
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.reset()
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "redq")
|
||||
|
@ -162,6 +162,7 @@ def test_td3(args: argparse.Namespace = get_args()) -> None:
|
||||
env = gym.make(args.task)
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_episode=1, render=args.render)
|
||||
print(collector_stats)
|
||||
|
||||
|
@ -102,8 +102,9 @@ def test_trpo(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: BasePolicy = TRPOPolicy(
|
||||
actor=actor,
|
||||
|
@ -109,7 +109,9 @@ def test_a2c_with_il(args: argparse.Namespace = get_args()) -> None:
|
||||
train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
)
|
||||
train_collector.reset()
|
||||
test_collector = Collector(policy, test_envs)
|
||||
test_collector.reset()
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "a2c")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -108,7 +108,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=False)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
|
||||
def train_fn(epoch: int, env_step: int) -> None: # exp decay
|
||||
eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)
|
||||
|
@ -120,7 +120,7 @@ def test_c51(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "c51")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -111,7 +111,7 @@ def test_dqn(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "dqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -95,7 +95,7 @@ def test_drqn(args: argparse.Namespace = get_args()) -> None:
|
||||
# the stack_num is for RNN training: sample framestack obs
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "drqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -128,7 +128,7 @@ def test_fqf(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "fqf")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -124,7 +124,7 @@ def test_iqn(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "iqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -113,7 +113,7 @@ def test_qrdqn(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "qrdqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -128,7 +128,7 @@ def test_rainbow(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "rainbow")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -154,7 +154,7 @@ def test_dqn_icm(args: argparse.Namespace = get_args()) -> None:
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "dqn_icm")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -81,7 +81,9 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True,
|
||||
)
|
||||
train_collector.reset()
|
||||
test_collector = Collector(policy, test_envs)
|
||||
test_collector.reset()
|
||||
# Logger
|
||||
log_path = os.path.join(args.logdir, args.task, "psrl")
|
||||
writer = SummaryWriter(log_path)
|
||||
@ -120,7 +122,6 @@ def test_psrl(args: argparse.Namespace = get_args()) -> None:
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f"Final reward: {result.rew_mean}, length: {result.len_mean}")
|
||||
elif env.spec.reward_threshold:
|
||||
|
@ -115,9 +115,11 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(train_envs))
|
||||
# collector
|
||||
train_collector = Collector(policy, train_envs, buf, exploration_noise=True)
|
||||
train_collector.reset()
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
test_collector.reset()
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, "qrdqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
@ -165,6 +167,7 @@ def gather_data() -> VectorReplayBuffer | PrioritizedVectorReplayBuffer:
|
||||
buf = VectorReplayBuffer(args.buffer_size, buffer_num=len(test_envs))
|
||||
policy.set_eps(0.2)
|
||||
collector = Collector(policy, test_envs, buf, exploration_noise=True)
|
||||
collector.reset()
|
||||
collector_stats = collector.collect(n_step=args.buffer_size)
|
||||
if args.save_buffer_name.endswith(".hdf5"):
|
||||
buf.save_hdf5(args.save_buffer_name)
|
||||
|
@ -178,6 +178,7 @@ def test_discrete_bcq(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
|
||||
def test_discrete_bcq_resume(args: argparse.Namespace = get_args()) -> None:
|
||||
test_discrete_bcq()
|
||||
args.resume = True
|
||||
test_discrete_bcq(args)
|
||||
|
||||
|
@ -133,8 +133,9 @@ def test_gail(args: argparse.Namespace = get_args()) -> None:
|
||||
|
||||
# replace DiagGuassian with Independent(Normal) which is equivalent
|
||||
# pass *logits to be consistent with policy.forward
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
policy: BasePolicy = GAILPolicy(
|
||||
actor=actor,
|
||||
|
@ -83,8 +83,8 @@ def get_agents(
|
||||
if isinstance(env.observation_space, gym.spaces.Dict)
|
||||
else env.observation_space
|
||||
)
|
||||
args.state_shape = observation_space.shape or observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.state_shape = observation_space.shape or int(observation_space.n)
|
||||
args.action_shape = env.action_space.shape or int(env.action_space.n)
|
||||
if agents is None:
|
||||
agents = []
|
||||
optims = []
|
||||
@ -135,7 +135,7 @@ def train_agent(
|
||||
exploration_noise=True,
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, "pistonball", "dqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -181,8 +181,9 @@ def get_agents(
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
optim = torch.optim.Adam(set(actor.parameters()).union(critic.parameters()), lr=args.lr)
|
||||
|
||||
def dist(*logits: torch.Tensor) -> Distribution:
|
||||
return Independent(Normal(*logits), 1)
|
||||
def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
|
||||
loc, scale = loc_scale
|
||||
return Independent(Normal(loc, scale), 1)
|
||||
|
||||
agent: PPOPolicy = PPOPolicy(
|
||||
actor,
|
||||
@ -234,7 +235,7 @@ def train_agent(
|
||||
exploration_noise=False, # True
|
||||
)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
# train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, "pistonball", "dqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -102,8 +102,8 @@ def get_agents(
|
||||
if isinstance(env.observation_space, gymnasium.spaces.Dict)
|
||||
else env.observation_space
|
||||
)
|
||||
args.state_shape = observation_space.shape or observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.state_shape = observation_space.shape or int(observation_space.n)
|
||||
args.action_shape = env.action_space.shape or int(env.action_space.n)
|
||||
if agent_learn is None:
|
||||
# model
|
||||
net = Net(
|
||||
@ -170,7 +170,7 @@ def train_agent(
|
||||
)
|
||||
test_collector = Collector(policy, test_envs, exploration_noise=True)
|
||||
# policy.set_eps(1)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num)
|
||||
train_collector.collect(n_step=args.batch_size * args.training_num, reset_before_collect=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, "tic_tac_toe", "dqn")
|
||||
writer = SummaryWriter(log_path)
|
||||
|
@ -263,6 +263,9 @@ class BatchProtocol(Protocol):
|
||||
def __repr__(self) -> str:
|
||||
...
|
||||
|
||||
def __iter__(self) -> Iterator[Self]:
|
||||
...
|
||||
|
||||
def to_numpy(self) -> None:
|
||||
"""Change all torch.Tensor to numpy.ndarray in-place."""
|
||||
...
|
||||
@ -391,6 +394,12 @@ class BatchProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
...
|
||||
|
||||
def to_list_of_dicts(self) -> list[dict[str, Any]]:
|
||||
...
|
||||
|
||||
|
||||
class Batch(BatchProtocol):
|
||||
"""See :class:`~tianshou.data.batch.BatchProtocol`."""
|
||||
@ -422,6 +431,17 @@ class Batch(BatchProtocol):
|
||||
# Feels like kwargs could be just merged into batch_dict in the beginning
|
||||
self.__init__(kwargs, copy=copy) # type: ignore
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if isinstance(v, Batch):
|
||||
v = v.to_dict()
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
def to_list_of_dicts(self) -> list[dict[str, Any]]:
|
||||
return [entry.to_dict() for entry in self]
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
"""Set self.key = value."""
|
||||
self.__dict__[key] = _parse_value(value)
|
||||
@ -478,6 +498,14 @@ class Batch(BatchProtocol):
|
||||
return new_batch
|
||||
raise IndexError("Cannot access item from empty Batch object.")
|
||||
|
||||
def __iter__(self) -> Iterator[Self]:
|
||||
# TODO: empty batch raises an error on len and needs separate treatment, that's probably not a good idea
|
||||
if len(self.__dict__) == 0:
|
||||
yield from []
|
||||
else:
|
||||
for i in range(len(self)):
|
||||
yield self[i]
|
||||
|
||||
def __setitem__(self, index: str | IndexType, value: Any) -> None:
|
||||
"""Assign value to self[index]."""
|
||||
value = _parse_value(value)
|
||||
@ -601,10 +629,10 @@ class Batch(BatchProtocol):
|
||||
else:
|
||||
# ndarray or scalar
|
||||
if not isinstance(obj, np.ndarray):
|
||||
obj = np.asanyarray(obj) # noqa: PLW2901
|
||||
obj = torch.from_numpy(obj).to(device) # noqa: PLW2901
|
||||
obj = np.asanyarray(obj)
|
||||
obj = torch.from_numpy(obj).to(device)
|
||||
if dtype is not None:
|
||||
obj = obj.type(dtype) # noqa: PLW2901
|
||||
obj = obj.type(dtype)
|
||||
self.__dict__[batch_key] = obj
|
||||
|
||||
def __cat(self, batches: Sequence[dict | Self], lens: list[int]) -> None:
|
||||
|
@ -200,7 +200,7 @@ class ReplayBufferManager(ReplayBuffer):
|
||||
|
||||
return np.concatenate(
|
||||
[
|
||||
buf.sample_indices(bsz) + offset
|
||||
buf.sample_indices(int(bsz)) + offset
|
||||
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True)
|
||||
],
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -12,6 +12,7 @@ from tianshou.data.batch import Batch, _parse_value
|
||||
|
||||
# TODO: confusing name, could actually return a batch...
|
||||
# Overrides and generic types should be added
|
||||
# todo check for ActBatchProtocol
|
||||
@no_type_check
|
||||
def to_numpy(x: Any) -> Batch | np.ndarray:
|
||||
"""Return an object without torch.Tensor."""
|
||||
|
16
tianshou/env/venv_wrappers.py
vendored
16
tianshou/env/venv_wrappers.py
vendored
@ -44,14 +44,14 @@ class VectorEnvWrapper(BaseVectorEnv):
|
||||
|
||||
def reset(
|
||||
self,
|
||||
id: int | list[int] | np.ndarray | None = None,
|
||||
env_id: int | list[int] | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[np.ndarray, dict | list[dict]]:
|
||||
return self.venv.reset(id, **kwargs)
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
return self.venv.reset(env_id, **kwargs)
|
||||
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray | torch.Tensor,
|
||||
action: np.ndarray | torch.Tensor | None,
|
||||
id: int | list[int] | np.ndarray | None = None,
|
||||
) -> gym_new_venv_step_type:
|
||||
return self.venv.step(action, id)
|
||||
@ -80,10 +80,10 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
||||
|
||||
def reset(
|
||||
self,
|
||||
id: int | list[int] | np.ndarray | None = None,
|
||||
env_id: int | list[int] | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[np.ndarray, dict | list[dict]]:
|
||||
obs, info = self.venv.reset(id, **kwargs)
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
obs, info = self.venv.reset(env_id, **kwargs)
|
||||
|
||||
if isinstance(obs, tuple): # type: ignore
|
||||
raise TypeError(
|
||||
@ -98,7 +98,7 @@ class VectorEnvNormObs(VectorEnvWrapper):
|
||||
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray | torch.Tensor,
|
||||
action: np.ndarray | torch.Tensor | None,
|
||||
id: int | list[int] | np.ndarray | None = None,
|
||||
) -> gym_new_venv_step_type:
|
||||
step_results = self.venv.step(action, id)
|
||||
|
26
tianshou/env/venvs.py
vendored
26
tianshou/env/venvs.py
vendored
@ -190,11 +190,13 @@ class BaseVectorEnv:
|
||||
), f"Cannot interact with environment {i} which is stepping now."
|
||||
assert i in self.ready_id, f"Can only interact with ready environments {self.ready_id}."
|
||||
|
||||
# TODO: for now, has to be kept in sync with reset in EnvPoolMixin
|
||||
# In particular, can't rename env_id to env_ids
|
||||
def reset(
|
||||
self,
|
||||
id: int | list[int] | np.ndarray | None = None,
|
||||
env_id: int | list[int] | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> tuple[np.ndarray, dict | list[dict]]:
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Reset the state of some envs and return initial observations.
|
||||
|
||||
If id is None, reset the state of all the environments and return
|
||||
@ -202,14 +204,14 @@ class BaseVectorEnv:
|
||||
the given id, either an int or a list.
|
||||
"""
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
env_id = self._wrap_id(env_id)
|
||||
if self.is_async:
|
||||
self._assert_id(id)
|
||||
self._assert_id(env_id)
|
||||
|
||||
# send(None) == reset() in worker
|
||||
for i in id:
|
||||
self.workers[i].send(None, **kwargs)
|
||||
ret_list = [self.workers[i].recv() for i in id]
|
||||
for id in env_id:
|
||||
self.workers[id].send(None, **kwargs)
|
||||
ret_list = [self.workers[id].recv() for id in env_id]
|
||||
|
||||
assert (
|
||||
isinstance(ret_list[0], tuple | list)
|
||||
@ -229,12 +231,12 @@ class BaseVectorEnv:
|
||||
except ValueError: # different len(obs)
|
||||
obs = np.array(obs_list, dtype=object)
|
||||
|
||||
infos = [r[1] for r in ret_list]
|
||||
return obs, infos # type: ignore
|
||||
infos = np.array([r[1] for r in ret_list])
|
||||
return obs, infos
|
||||
|
||||
def step(
|
||||
self,
|
||||
action: np.ndarray | torch.Tensor,
|
||||
action: np.ndarray | torch.Tensor | None,
|
||||
id: int | list[int] | np.ndarray | None = None,
|
||||
) -> gym_new_venv_step_type:
|
||||
"""Run one timestep of some environments' dynamics.
|
||||
@ -248,6 +250,8 @@ class BaseVectorEnv:
|
||||
batch_done, batch_info) in numpy format.
|
||||
|
||||
:param numpy.ndarray action: a batch of action provided by the agent.
|
||||
If the venv is async, the action can be None, which will result
|
||||
in all arrays in the returned tuple being empty.
|
||||
|
||||
:return: A tuple consisting of either:
|
||||
|
||||
@ -271,6 +275,8 @@ class BaseVectorEnv:
|
||||
self._assert_is_not_closed()
|
||||
id = self._wrap_id(id)
|
||||
if not self.is_async:
|
||||
if action is None:
|
||||
raise ValueError("action must be not-None for non-async")
|
||||
assert len(action) == len(id)
|
||||
for i, j in enumerate(id):
|
||||
self.workers[j].send(action[i])
|
||||
|
@ -93,7 +93,14 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
self,
|
||||
policy: BasePolicy,
|
||||
envs: Environments,
|
||||
reset_collectors: bool = True,
|
||||
) -> tuple[Collector, Collector]:
|
||||
""":param policy:
|
||||
:param envs:
|
||||
:param reset_collectors: Whether to reset the collectors before returning them.
|
||||
Setting to True means that the envs will be reset as well.
|
||||
:return:
|
||||
"""
|
||||
buffer_size = self.sampling_config.buffer_size
|
||||
train_envs = envs.train_envs
|
||||
buffer: ReplayBuffer
|
||||
@ -114,6 +121,10 @@ class AgentFactory(ABC, ToStringMixin):
|
||||
)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, envs.test_envs)
|
||||
if reset_collectors:
|
||||
train_collector.reset()
|
||||
test_collector.reset()
|
||||
|
||||
if self.sampling_config.start_timesteps > 0:
|
||||
log.info(
|
||||
f"Collecting {self.sampling_config.start_timesteps} initial environment steps before training (random={self.sampling_config.start_timesteps_random})",
|
||||
|
@ -312,7 +312,7 @@ class Experiment(ToStringMixin):
|
||||
) -> None:
|
||||
policy.eval()
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=num_episodes, render=render)
|
||||
result = collector.collect(n_episode=num_episodes, render=render, reset_before_collect=True)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
assert result.lens_stat is not None # for mypy
|
||||
log.info(
|
||||
|
@ -1,40 +1,47 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrete, TDistFnDiscrOrCont
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class DistributionFunctionFactory(ToStringMixin, ABC):
|
||||
# True return type defined in subclasses
|
||||
@abstractmethod
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(
|
||||
self,
|
||||
envs: Environments,
|
||||
) -> Callable[[Any], torch.distributions.Distribution]:
|
||||
pass
|
||||
|
||||
|
||||
class DistributionFunctionFactoryCategorical(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrete:
|
||||
envs.get_type().assert_discrete(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(p: torch.Tensor) -> torch.distributions.Distribution:
|
||||
def _dist_fn(p: torch.Tensor) -> torch.distributions.Categorical:
|
||||
return torch.distributions.Categorical(logits=p)
|
||||
|
||||
|
||||
class DistributionFunctionFactoryIndependentGaussians(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
|
||||
envs.get_type().assert_continuous(self)
|
||||
return self._dist_fn
|
||||
|
||||
@staticmethod
|
||||
def _dist_fn(*p: torch.Tensor) -> torch.distributions.Distribution:
|
||||
return torch.distributions.Independent(torch.distributions.Normal(*p), 1)
|
||||
def _dist_fn(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> torch.distributions.Distribution:
|
||||
loc, scale = loc_scale
|
||||
return torch.distributions.Independent(torch.distributions.Normal(loc, scale), 1)
|
||||
|
||||
|
||||
class DistributionFunctionFactoryDefault(DistributionFunctionFactory):
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
def create_dist_fn(self, envs: Environments) -> TDistFnDiscrOrCont:
|
||||
match envs.get_type():
|
||||
case EnvType.DISCRETE:
|
||||
return DistributionFunctionFactoryCategorical().create_dist_fn(envs)
|
||||
|
@ -19,7 +19,7 @@ from tianshou.highlevel.params.dist_fn import (
|
||||
from tianshou.highlevel.params.env_param import EnvValueFactory, FloatEnvValueFactory
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
||||
from tianshou.highlevel.params.noise import NoiseFactory
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils import MultipleLRSchedulers
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
@ -322,7 +322,7 @@ class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithSche
|
||||
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
|
||||
Does not affect training.
|
||||
"""
|
||||
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
||||
dist_fn: TDistFnDiscrOrCont | DistributionFunctionFactory | Literal["default"] = "default"
|
||||
"""
|
||||
This can either be a function which maps the model output to a torch distribution or a
|
||||
factory for the creation of such a function.
|
||||
|
@ -18,6 +18,7 @@ from tianshou.data.batch import Batch, BatchProtocol, arr_type
|
||||
from tianshou.data.buffer.base import TBuffer
|
||||
from tianshou.data.types import (
|
||||
ActBatchProtocol,
|
||||
ActStateBatchProtocol,
|
||||
BatchWithReturnsProtocol,
|
||||
ObsBatchProtocol,
|
||||
RolloutBatchProtocol,
|
||||
@ -212,10 +213,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
super().__init__()
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
self._action_type: Literal["discrete", "continuous"]
|
||||
if isinstance(action_space, Discrete | MultiDiscrete | MultiBinary):
|
||||
self.action_type = "discrete"
|
||||
self._action_type = "discrete"
|
||||
elif isinstance(action_space, Box):
|
||||
self.action_type = "continuous"
|
||||
self._action_type = "continuous"
|
||||
else:
|
||||
raise ValueError(f"Unsupported action space: {action_space}.")
|
||||
self.agent_id = 0
|
||||
@ -225,6 +227,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self._compile()
|
||||
|
||||
@property
|
||||
def action_type(self) -> Literal["discrete", "continuous"]:
|
||||
return self._action_type
|
||||
|
||||
def set_agent_id(self, agent_id: int) -> None:
|
||||
"""Set self.agent_id = agent_id, for MARL."""
|
||||
self.agent_id = agent_id
|
||||
@ -233,11 +239,14 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
# have a method to add noise to action.
|
||||
# So we add the default behavior here. It's a little messy, maybe one can
|
||||
# find a better way to do this.
|
||||
|
||||
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: np.ndarray | BatchProtocol,
|
||||
batch: RolloutBatchProtocol,
|
||||
) -> np.ndarray | BatchProtocol:
|
||||
act: _TArrOrActBatch,
|
||||
batch: ObsBatchProtocol,
|
||||
) -> _TArrOrActBatch:
|
||||
"""Modify the action from policy.forward with exploration noise.
|
||||
|
||||
NOTE: currently does not add any noise! Needs to be overridden by subclasses
|
||||
@ -287,7 +296,7 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
||||
batch: ObsBatchProtocol,
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ActBatchProtocol:
|
||||
) -> ActBatchProtocol | ActStateBatchProtocol: # TODO: make consistent typing
|
||||
"""Compute action over the given batch data.
|
||||
|
||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
|
||||
|
@ -16,6 +16,12 @@ from tianshou.data.types import (
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
|
||||
# Dimension Naming Convention
|
||||
# B - Batch Size
|
||||
# A - Action
|
||||
# D - Dist input (usually 2, loc and scale)
|
||||
# H - Dimension of hidden, can be None
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ImitationTrainingStats(TrainingStats):
|
||||
@ -72,9 +78,20 @@ class ImitationPolicy(BasePolicy[TImitationTrainingStats], Generic[TImitationTra
|
||||
state: dict | BatchProtocol | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ModelOutputBatchProtocol:
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
act = logits.max(dim=1)[1] if self.action_type == "discrete" else logits
|
||||
result = Batch(logits=logits, act=act, state=hidden)
|
||||
# TODO - ALGO-REFACTORING: marked for refactoring when Algorithm abstraction is introduced
|
||||
if self.action_type == "discrete":
|
||||
# If it's discrete, the "actor" is usually a critic that maps obs to action_values
|
||||
# which then could be turned into logits or a Categorigal
|
||||
action_values_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
act_B = action_values_BA.argmax(dim=1)
|
||||
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
|
||||
elif self.action_type == "continuous":
|
||||
# If it's continuous, the actor would usually deliver something like loc, scale determining a
|
||||
# Gaussian dist
|
||||
dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
result = Batch(logits=dist_input_BD, act=dist_input_BD, state=hidden_BH)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown {self.action_type=}, this shouldn't have happened!")
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(
|
||||
|
@ -34,8 +34,7 @@ TDiscreteBCQTrainingStats = TypeVar("TDiscreteBCQTrainingStats", bound=DiscreteB
|
||||
class DiscreteBCQPolicy(DQNPolicy[TDiscreteBCQTrainingStats]):
|
||||
"""Implementation of discrete BCQ algorithm. arXiv:1910.01708.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> q_value)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param imitator: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> imitation_logits)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
|
@ -25,8 +25,7 @@ TDiscreteCQLTrainingStats = TypeVar("TDiscreteCQLTrainingStats", bound=DiscreteC
|
||||
class DiscreteCQLPolicy(QRDQNPolicy[TDiscreteCQLTrainingStats]):
|
||||
"""Implementation of discrete Conservative Q-Learning algorithm. arXiv:2006.04779.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param action_space: Env's action space.
|
||||
:param min_q_weight: the weight for the cql loss.
|
||||
|
@ -11,6 +11,7 @@ from tianshou.data import to_torch, to_torch_as
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import PGPolicy, PGTrainingStats
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -26,8 +27,9 @@ TDiscreteCRRTrainingStats = TypeVar("TDiscreteCRRTrainingStats", bound=DiscreteC
|
||||
class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
|
||||
r"""Implementation of discrete Critic Regularized Regression. arXiv:2006.15134.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the action-value critic (i.e., Q function)
|
||||
network. (s -> Q(s, \*))
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
@ -55,8 +57,8 @@ class DiscreteCRRPolicy(PGPolicy[TDiscreteCRRTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | Actor,
|
||||
critic: torch.nn.Module | Critic,
|
||||
optim: torch.optim.Optimizer,
|
||||
action_space: gym.spaces.Discrete,
|
||||
discount_factor: float = 0.99,
|
||||
|
@ -15,8 +15,11 @@ from tianshou.data import (
|
||||
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import PPOPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.policy.modelfree.ppo import PPOTrainingStats
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -32,7 +35,9 @@ TGailTrainingStats = TypeVar("TGailTrainingStats", bound=GailTrainingStats)
|
||||
class GAILPolicy(PPOPolicy[TGailTrainingStats]):
|
||||
r"""Implementation of Generative Adversarial Imitation Learning. arXiv:1606.03476.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -75,10 +80,10 @@ class GAILPolicy(PPOPolicy[TGailTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
expert_buffer: ReplayBuffer,
|
||||
disc_net: torch.nn.Module,
|
||||
|
@ -25,7 +25,7 @@ class TD3BCPolicy(TD3Policy[TTD3BCTrainingStats]):
|
||||
"""Implementation of TD3+BC. arXiv:2106.06860.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, Literal, Self
|
||||
from typing import Any, Literal, Self, TypeVar
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -105,11 +105,13 @@ class ICMPolicy(BasePolicy[ICMTrainingStats]):
|
||||
"""
|
||||
return self.policy.forward(batch, state, **kwargs)
|
||||
|
||||
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: np.ndarray | BatchProtocol,
|
||||
batch: RolloutBatchProtocol,
|
||||
) -> np.ndarray | BatchProtocol:
|
||||
act: _TArrOrActBatch,
|
||||
batch: ObsBatchProtocol,
|
||||
) -> _TArrOrActBatch:
|
||||
return self.policy.exploration_noise(act, batch)
|
||||
|
||||
def set_eps(self, eps: float) -> None:
|
||||
|
@ -11,8 +11,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
|
||||
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import PGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -30,7 +33,9 @@ TA2CTrainingStats = TypeVar("TA2CTrainingStats", bound=A2CTrainingStats)
|
||||
class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # type: ignore[type-var]
|
||||
"""Implementation of Synchronous Advantage Actor-Critic. arXiv:1602.01783.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -59,10 +64,10 @@ class A2CPolicy(PGPolicy[TA2CTrainingStats], Generic[TA2CTrainingStats]): # typ
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
vf_coef: float = 0.5,
|
||||
ent_coef: float = 0.01,
|
||||
|
@ -8,6 +8,7 @@ import torch
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch, to_torch_as
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ActBatchProtocol,
|
||||
BatchWithReturnsProtocol,
|
||||
ModelOutputBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
@ -30,7 +31,7 @@ TBDQNTrainingStats = TypeVar("TBDQNTrainingStats", bound=BDQNTrainingStats)
|
||||
class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
||||
"""Implementation of the Branching dual Q network arXiv:1711.08946.
|
||||
|
||||
:param model: BranchingNet mapping (obs, state, info) -> logits.
|
||||
:param model: BranchingNet mapping (obs, state, info) -> action_values_BA.
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param estimation_step: the number of steps to look ahead.
|
||||
@ -155,10 +156,10 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
||||
model = getattr(self, model)
|
||||
obs = batch.obs
|
||||
# TODO: this is very contrived, see also iqn.py
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
act = to_numpy(logits.max(dim=-1)[1])
|
||||
result = Batch(logits=logits, act=act, state=hidden)
|
||||
obs_next_BO = obs.obs if hasattr(obs, "obs") else obs
|
||||
action_values_BA, hidden_BH = model(obs_next_BO, state=state, info=batch.info)
|
||||
act_B = to_numpy(action_values_BA.argmax(dim=-1))
|
||||
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TBDQNTrainingStats:
|
||||
@ -182,11 +183,13 @@ class BranchingDQNPolicy(DQNPolicy[TBDQNTrainingStats]):
|
||||
|
||||
return BDQNTrainingStats(loss=loss.item()) # type: ignore[return-value]
|
||||
|
||||
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: np.ndarray | BatchProtocol,
|
||||
batch: RolloutBatchProtocol,
|
||||
) -> np.ndarray | BatchProtocol:
|
||||
act: _TArrOrActBatch,
|
||||
batch: ObsBatchProtocol,
|
||||
) -> _TArrOrActBatch:
|
||||
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
|
@ -23,8 +23,7 @@ TC51TrainingStats = TypeVar("TC51TrainingStats", bound=C51TrainingStats)
|
||||
class C51Policy(DQNPolicy[TC51TrainingStats], Generic[TC51TrainingStats]):
|
||||
"""Implementation of Categorical Deep Q-Network. arXiv:1707.06887.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param num_atoms: the number of atoms in the support set of the
|
||||
|
@ -10,6 +10,7 @@ import torch
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ActBatchProtocol,
|
||||
ActStateBatchProtocol,
|
||||
BatchWithReturnsProtocol,
|
||||
ObsBatchProtocol,
|
||||
@ -18,6 +19,7 @@ from tianshou.data.types import (
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -32,8 +34,7 @@ TDDPGTrainingStats = TypeVar("TDDPGTrainingStats", bound=DDPGTrainingStats)
|
||||
class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
|
||||
"""Implementation of Deep Deterministic Policy Gradient. arXiv:1509.02971.
|
||||
|
||||
:param actor: The actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> model_output)
|
||||
:param actor: The actor network following the rules (s -> actions)
|
||||
:param actor_optim: The optimizer for actor network.
|
||||
:param critic: The critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: The optimizer for critic network.
|
||||
@ -59,9 +60,9 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | Actor,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic: torch.nn.Module | Critic,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
action_space: gym.Space,
|
||||
tau: float = 0.005,
|
||||
@ -208,11 +209,13 @@ class DDPGPolicy(BasePolicy[TDDPGTrainingStats], Generic[TDDPGTrainingStats]):
|
||||
|
||||
return DDPGTrainingStats(actor_loss=actor_loss.item(), critic_loss=critic_loss.item()) # type: ignore[return-value]
|
||||
|
||||
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: np.ndarray | BatchProtocol,
|
||||
batch: RolloutBatchProtocol,
|
||||
) -> np.ndarray | BatchProtocol:
|
||||
act: _TArrOrActBatch,
|
||||
batch: ObsBatchProtocol,
|
||||
) -> _TArrOrActBatch:
|
||||
if self._exploration_noise is None:
|
||||
return act
|
||||
if isinstance(act, np.ndarray):
|
||||
|
@ -8,11 +8,11 @@ from overrides import override
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import SACPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.sac import SACTrainingStats
|
||||
from tianshou.utils.net.discrete import Actor, Critic
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -26,8 +26,7 @@ TDiscreteSACTrainingStats = TypeVar("TDiscreteSACTrainingStats", bound=DiscreteS
|
||||
class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
"""Implementation of SAC for Discrete Action Settings. arXiv:1910.07207.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param actor: the actor network following the rules (s_B -> dist_input_BD)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
@ -55,12 +54,12 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | Actor,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic: torch.nn.Module | Critic,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
action_space: gym.spaces.Discrete,
|
||||
critic2: torch.nn.Module | None = None,
|
||||
critic2: torch.nn.Module | Critic | None = None,
|
||||
critic2_optim: torch.optim.Optimizer | None = None,
|
||||
tau: float = 0.005,
|
||||
gamma: float = 0.99,
|
||||
@ -106,13 +105,13 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits)
|
||||
logits_BA, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Categorical(logits=logits_BA)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.sample()
|
||||
return Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
act_B = dist.sample()
|
||||
return Batch(logits=logits_BA, act=act_B, state=hidden_BH, dist=dist)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
obs_next_batch = Batch(
|
||||
@ -184,9 +183,11 @@ class DiscreteSACPolicy(SACPolicy[TDiscreteSACTrainingStats]):
|
||||
alpha_loss=None if not self.is_auto_alpha else alpha_loss.item(),
|
||||
)
|
||||
|
||||
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: np.ndarray | BatchProtocol,
|
||||
batch: RolloutBatchProtocol,
|
||||
) -> np.ndarray | BatchProtocol:
|
||||
act: _TArrOrActBatch,
|
||||
batch: ObsBatchProtocol,
|
||||
) -> _TArrOrActBatch:
|
||||
return act
|
||||
|
@ -9,6 +9,7 @@ import torch
|
||||
from tianshou.data import Batch, ReplayBuffer, to_numpy, to_torch_as
|
||||
from tianshou.data.batch import BatchProtocol
|
||||
from tianshou.data.types import (
|
||||
ActBatchProtocol,
|
||||
BatchWithReturnsProtocol,
|
||||
ModelOutputBatchProtocol,
|
||||
ObsBatchProtocol,
|
||||
@ -16,6 +17,7 @@ from tianshou.data.types import (
|
||||
)
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -34,8 +36,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
Implementation of Dueling DQN. arXiv:1511.06581 (the dueling DQN is
|
||||
implemented in the network side, not here).
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param estimation_step: the number of steps to look ahead.
|
||||
@ -59,7 +60,7 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: torch.nn.Module,
|
||||
model: torch.nn.Module | Net,
|
||||
optim: torch.optim.Optimizer,
|
||||
# TODO: type violates Liskov substitution principle
|
||||
action_space: gym.spaces.Discrete,
|
||||
@ -200,12 +201,12 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
obs = batch.obs
|
||||
# TODO: this is convoluted! See also other places where this is done.
|
||||
obs_next = obs.obs if hasattr(obs, "obs") else obs
|
||||
logits, hidden = model(obs_next, state=state, info=batch.info)
|
||||
q = self.compute_q_value(logits, getattr(obs, "mask", None))
|
||||
action_values_BA, hidden_BH = model(obs_next, state=state, info=batch.info)
|
||||
q = self.compute_q_value(action_values_BA, getattr(obs, "mask", None))
|
||||
if self.max_action_num is None:
|
||||
self.max_action_num = q.shape[1]
|
||||
act = to_numpy(q.max(dim=1)[1])
|
||||
result = Batch(logits=logits, act=act, state=hidden)
|
||||
act_B = to_numpy(q.argmax(dim=1))
|
||||
result = Batch(logits=action_values_BA, act=act_B, state=hidden_BH)
|
||||
return cast(ModelOutputBatchProtocol, result)
|
||||
|
||||
def learn(self, batch: RolloutBatchProtocol, *args: Any, **kwargs: Any) -> TDQNTrainingStats:
|
||||
@ -232,11 +233,13 @@ class DQNPolicy(BasePolicy[TDQNTrainingStats], Generic[TDQNTrainingStats]):
|
||||
|
||||
return DQNTrainingStats(loss=loss.item()) # type: ignore[return-value]
|
||||
|
||||
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: np.ndarray | BatchProtocol,
|
||||
batch: RolloutBatchProtocol,
|
||||
) -> np.ndarray | BatchProtocol:
|
||||
act: _TArrOrActBatch,
|
||||
batch: ObsBatchProtocol,
|
||||
) -> _TArrOrActBatch:
|
||||
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0):
|
||||
bsz = len(act)
|
||||
rand_mask = np.random.rand(bsz) < self.eps
|
||||
|
@ -27,8 +27,7 @@ TFQFTrainingStats = TypeVar("TFQFTrainingStats", bound=FQFTrainingStats)
|
||||
class FQFPolicy(QRDQNPolicy[TFQFTrainingStats]):
|
||||
"""Implementation of Fully-parameterized Quantile Function. arXiv:1911.02140.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param fraction_model: a FractionProposalNetwork for
|
||||
proposing fractions/quantiles given state.
|
||||
|
@ -29,8 +29,7 @@ TIQNTrainingStats = TypeVar("TIQNTrainingStats", bound=IQNTrainingStats)
|
||||
class IQNPolicy(QRDQNPolicy[TIQNTrainingStats]):
|
||||
"""Implementation of Implicit Quantile Network. arXiv:1806.06923.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s_B -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param discount_factor: in [0, 1].
|
||||
:param sample_size: the number of samples for policy evaluation.
|
||||
|
@ -12,7 +12,10 @@ from tianshou.data import Batch, ReplayBuffer, SequenceSummaryStats
|
||||
from tianshou.data.types import BatchWithAdvantagesProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -31,7 +34,9 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
|
||||
|
||||
https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -55,10 +60,10 @@ class NPGPolicy(A2CPolicy[TNPGTrainingStats], Generic[TNPGTrainingStats]): # ty
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
optim_critic_iters: int = 5,
|
||||
actor_step_size: float = 0.5,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Generic, Literal, TypeAlias, TypeVar, cast
|
||||
from typing import Any, Generic, Literal, TypeVar, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
@ -24,9 +24,22 @@ from tianshou.data.types import (
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils import RunningMeanStd
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
from tianshou.utils.net.discrete import Actor
|
||||
|
||||
# TODO: Is there a better way to define this type? mypy doesn't like Callable[[torch.Tensor, ...], torch.distributions.Distribution]
|
||||
TDistributionFunction: TypeAlias = Callable[..., torch.distributions.Distribution]
|
||||
# Dimension Naming Convention
|
||||
# B - Batch Size
|
||||
# A - Action
|
||||
# D - Dist input (usually 2, loc and scale)
|
||||
# H - Dimension of hidden, can be None
|
||||
|
||||
TDistFnContinuous = Callable[
|
||||
[tuple[torch.Tensor, torch.Tensor]],
|
||||
torch.distributions.Distribution,
|
||||
]
|
||||
TDistFnDiscrete = Callable[[torch.Tensor], torch.distributions.Categorical]
|
||||
|
||||
TDistFnDiscrOrCont = TDistFnContinuous | TDistFnDiscrete
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -40,8 +53,9 @@ TPGTrainingStats = TypeVar("TPGTrainingStats", bound=PGTrainingStats)
|
||||
class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
"""Implementation of REINFORCE algorithm.
|
||||
|
||||
:param actor: mapping (s->model_output), should follow the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`.
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param optim: optimizer for actor network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
Maps model_output -> distribution. Typically a Gaussian distribution
|
||||
@ -71,9 +85,9 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | Actor,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
discount_factor: float = 0.99,
|
||||
# TODO: rename to return_normalization?
|
||||
@ -175,20 +189,20 @@ class PGPolicy(BasePolicy[TPGTrainingStats], Generic[TPGTrainingStats]):
|
||||
Please refer to :meth:`~tianshou.policy.BasePolicy.forward` for
|
||||
more detailed explanation.
|
||||
"""
|
||||
# TODO: rename? It's not really logits and there are particular
|
||||
# assumptions about the order of the output and on distribution type
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
if isinstance(logits, tuple):
|
||||
dist = self.dist_fn(*logits)
|
||||
else:
|
||||
dist = self.dist_fn(logits)
|
||||
# TODO - ALGO: marked for algorithm refactoring
|
||||
action_dist_input_BD, hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
# in the case that self.action_type == "discrete", the dist should always be Categorical, and D=A
|
||||
# therefore action_dist_input_BD is equivalent to logits_BA
|
||||
# If discrete, dist_fn will typically map loc, scale to a distribution (usually a Gaussian)
|
||||
# the action_dist_input_BD in that case is a tuple of loc_B, scale_B and needs to be unpacked
|
||||
dist = self.dist_fn(action_dist_input_BD)
|
||||
|
||||
# in this case, the dist is unused!
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.sample()
|
||||
result = Batch(logits=logits, act=act, state=hidden, dist=dist)
|
||||
act_B = dist.sample()
|
||||
# act is of dimension BA in continuous case and of dimension B in discrete
|
||||
result = Batch(logits=action_dist_input_BD, act=act_B, state=hidden_BH, dist=dist)
|
||||
return cast(DistBatchProtocol, result)
|
||||
|
||||
# TODO: why does mypy complain?
|
||||
|
@ -10,8 +10,11 @@ from tianshou.data import ReplayBuffer, SequenceSummaryStats, to_torch_as
|
||||
from tianshou.data.types import LogpOldProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -29,7 +32,9 @@ TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats)
|
||||
class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var]
|
||||
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -67,10 +72,10 @@ class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # ty
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
eps_clip: float = 0.2,
|
||||
dual_clip: float | None = None,
|
||||
|
@ -25,8 +25,7 @@ TQRDQNTrainingStats = TypeVar("TQRDQNTrainingStats", bound=QRDQNTrainingStats)
|
||||
class QRDQNPolicy(DQNPolicy[TQRDQNTrainingStats], Generic[TQRDQNTrainingStats]):
|
||||
"""Implementation of Quantile Regression Deep Q-Network. arXiv:1710.10044.
|
||||
|
||||
:param model: a model following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param model: a model following the rules (s -> action_values_BA)
|
||||
:param optim: a torch.optim for optimizing the model.
|
||||
:param action_space: Env's action space.
|
||||
:param discount_factor: in [0, 1].
|
||||
|
@ -12,6 +12,7 @@ from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.ddpg import DDPGTrainingStats
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -61,7 +62,7 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
@ -150,23 +151,28 @@ class REDQPolicy(DDPGPolicy[TREDQTrainingStats]):
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Batch:
|
||||
loc_scale, h = self.actor(batch.obs, state=state, info=batch.info)
|
||||
loc, scale = loc_scale
|
||||
dist = Independent(Normal(loc, scale), 1)
|
||||
(loc_B, scale_B), h_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc_B, scale_B), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
act_B = dist.rsample()
|
||||
log_prob = dist.log_prob(act_B).unsqueeze(-1)
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
squashed_action = torch.tanh(act)
|
||||
squashed_action = torch.tanh(act_B)
|
||||
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
|
||||
-1,
|
||||
keepdim=True,
|
||||
)
|
||||
return Batch(logits=loc_scale, act=squashed_action, state=h, dist=dist, log_prob=log_prob)
|
||||
return Batch(
|
||||
logits=(loc_B, scale_B),
|
||||
act=squashed_action,
|
||||
state=h_BH,
|
||||
dist=dist,
|
||||
log_prob=log_prob,
|
||||
)
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indices: np.ndarray) -> torch.Tensor:
|
||||
obs_next_batch = Batch(
|
||||
|
@ -17,6 +17,7 @@ from tianshou.exploration import BaseNoise
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
from tianshou.utils.conversion import to_optional_float
|
||||
from tianshou.utils.net.continuous import ActorProb
|
||||
from tianshou.utils.optim import clone_optimizer
|
||||
|
||||
|
||||
@ -36,8 +37,7 @@ TSACTrainingStats = TypeVar("TSACTrainingStats", bound=SACTrainingStats)
|
||||
class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # type: ignore[type-var]
|
||||
"""Implementation of Soft Actor-Critic. arXiv:1812.05905.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:param actor: the actor network following the rules (s -> dist_input_BD)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
@ -76,7 +76,7 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb,
|
||||
actor_optim: torch.optim.Optimizer,
|
||||
critic: torch.nn.Module,
|
||||
critic_optim: torch.optim.Optimizer,
|
||||
@ -173,26 +173,25 @@ class SACPolicy(DDPGPolicy[TSACTrainingStats], Generic[TSACTrainingStats]): # t
|
||||
state: dict | Batch | np.ndarray | None = None,
|
||||
**kwargs: Any,
|
||||
) -> DistLogProbBatchProtocol:
|
||||
logits, hidden = self.actor(batch.obs, state=state, info=batch.info)
|
||||
assert isinstance(logits, tuple)
|
||||
dist = Independent(Normal(*logits), 1)
|
||||
(loc_B, scale_B), hidden_BH = self.actor(batch.obs, state=state, info=batch.info)
|
||||
dist = Independent(Normal(loc=loc_B, scale=scale_B), 1)
|
||||
if self.deterministic_eval and not self.training:
|
||||
act = dist.mode
|
||||
act_B = dist.mode
|
||||
else:
|
||||
act = dist.rsample()
|
||||
log_prob = dist.log_prob(act).unsqueeze(-1)
|
||||
act_B = dist.rsample()
|
||||
log_prob = dist.log_prob(act_B).unsqueeze(-1)
|
||||
# apply correction for Tanh squashing when computing logprob from Gaussian
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
squashed_action = torch.tanh(act)
|
||||
squashed_action = torch.tanh(act_B)
|
||||
log_prob = log_prob - torch.log((1 - squashed_action.pow(2)) + self.__eps).sum(
|
||||
-1,
|
||||
keepdim=True,
|
||||
)
|
||||
result = Batch(
|
||||
logits=logits,
|
||||
logits=(loc_B, scale_B),
|
||||
act=squashed_action,
|
||||
state=hidden,
|
||||
state=hidden_BH,
|
||||
dist=dist,
|
||||
log_prob=log_prob,
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ class TD3Policy(DDPGPolicy[TTD3TrainingStats], Generic[TTD3TrainingStats]): # t
|
||||
"""Implementation of TD3, arXiv:1802.09477.
|
||||
|
||||
:param actor: the actor network following the rules in
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> logits)
|
||||
:class:`~tianshou.policy.BasePolicy`. (s -> actions)
|
||||
:param actor_optim: the optimizer for actor network.
|
||||
:param critic: the first critic network. (s, a -> Q(s, a))
|
||||
:param critic_optim: the optimizer for the first critic network.
|
||||
|
@ -11,7 +11,10 @@ from tianshou.data import Batch, SequenceSummaryStats
|
||||
from tianshou.policy import NPGPolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler
|
||||
from tianshou.policy.modelfree.npg import NPGTrainingStats
|
||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||
from tianshou.policy.modelfree.pg import TDistFnDiscrOrCont
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.utils.net.discrete import Actor as DiscreteActor
|
||||
from tianshou.utils.net.discrete import Critic as DiscreteCritic
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
@ -25,7 +28,9 @@ TTRPOTrainingStats = TypeVar("TTRPOTrainingStats", bound=TRPOTrainingStats)
|
||||
class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
|
||||
"""Implementation of Trust Region Policy Optimization. arXiv:1502.05477.
|
||||
|
||||
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
|
||||
:param actor: the actor network following the rules:
|
||||
If `self.action_type == "discrete"`: (`s_B` ->`action_values_BA`).
|
||||
If `self.action_type == "continuous"`: (`s_B` -> `dist_input_BD`).
|
||||
:param critic: the critic network. (s -> V(s))
|
||||
:param optim: the optimizer for actor and critic network.
|
||||
:param dist_fn: distribution class for computing the action.
|
||||
@ -53,10 +58,10 @@ class TRPOPolicy(NPGPolicy[TTRPOTrainingStats]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
actor: torch.nn.Module,
|
||||
critic: torch.nn.Module,
|
||||
actor: torch.nn.Module | ActorProb | DiscreteActor,
|
||||
critic: torch.nn.Module | Critic | DiscreteCritic,
|
||||
optim: torch.optim.Optimizer,
|
||||
dist_fn: TDistributionFunction,
|
||||
dist_fn: TDistFnDiscrOrCont,
|
||||
action_space: gym.Space,
|
||||
max_kl: float = 0.01,
|
||||
backtrack_coeff: float = 0.8,
|
||||
|
@ -1,11 +1,11 @@
|
||||
from typing import Any, Literal, Protocol, Self, cast, overload
|
||||
from typing import Any, Literal, Protocol, Self, TypeVar, cast, overload
|
||||
|
||||
import numpy as np
|
||||
from overrides import override
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer
|
||||
from tianshou.data.batch import BatchProtocol, IndexType
|
||||
from tianshou.data.types import RolloutBatchProtocol
|
||||
from tianshou.data.types import ActBatchProtocol, ObsBatchProtocol, RolloutBatchProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy.base import TLearningRateScheduler, TrainingStats
|
||||
|
||||
@ -160,16 +160,18 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
buffer._meta.rew = save_rew
|
||||
return Batch(results)
|
||||
|
||||
_TArrOrActBatch = TypeVar("_TArrOrActBatch", bound="np.ndarray | ActBatchProtocol")
|
||||
|
||||
def exploration_noise(
|
||||
self,
|
||||
act: np.ndarray | BatchProtocol,
|
||||
batch: RolloutBatchProtocol,
|
||||
) -> np.ndarray | BatchProtocol:
|
||||
act: _TArrOrActBatch,
|
||||
batch: ObsBatchProtocol,
|
||||
) -> _TArrOrActBatch:
|
||||
"""Add exploration noise from sub-policy onto act."""
|
||||
assert isinstance(
|
||||
batch.obs,
|
||||
BatchProtocol,
|
||||
), f"here only observations of type Batch are permitted, but got {type(batch.obs)}"
|
||||
if not isinstance(batch.obs, Batch):
|
||||
raise TypeError(
|
||||
f"here only observations of type Batch are permitted, but got {type(batch.obs)}",
|
||||
)
|
||||
for agent_id, policy in self.policies.items():
|
||||
agent_index = np.nonzero(batch.obs.agent_id == agent_id)[0]
|
||||
if len(agent_index) == 0:
|
||||
@ -223,7 +225,7 @@ class MultiAgentPolicyManager(BasePolicy):
|
||||
results.append((False, np.array([-1]), Batch(), Batch(), Batch()))
|
||||
continue
|
||||
tmp_batch = batch[agent_index]
|
||||
if isinstance(tmp_batch.rew, np.ndarray):
|
||||
if "rew" in tmp_batch.keys() and isinstance(tmp_batch.rew, np.ndarray):
|
||||
# reward can be empty Batch (after initial reset) or nparray.
|
||||
tmp_batch.rew = tmp_batch.rew[:, self.agent_idx[agent_id]]
|
||||
if not hasattr(tmp_batch.obs, "mask"):
|
||||
|
@ -237,7 +237,13 @@ class BaseTrainer(ABC):
|
||||
self.stop_fn_flag = False
|
||||
self.iter_num = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
def _reset_collectors(self, reset_buffer: bool = False) -> None:
|
||||
if self.train_collector is not None:
|
||||
self.train_collector.reset(reset_buffer=reset_buffer)
|
||||
if self.test_collector is not None:
|
||||
self.test_collector.reset(reset_buffer=reset_buffer)
|
||||
|
||||
def reset(self, reset_collectors: bool = True, reset_buffer: bool = False) -> None:
|
||||
"""Initialize or reset the instance to yield a new iterator from zero."""
|
||||
self.is_run = False
|
||||
self.env_step = 0
|
||||
@ -250,16 +256,18 @@ class BaseTrainer(ABC):
|
||||
|
||||
self.last_rew, self.last_len = 0.0, 0.0
|
||||
self.start_time = time.time()
|
||||
if self.train_collector is not None:
|
||||
self.train_collector.reset_stat()
|
||||
|
||||
if self.train_collector.policy != self.policy or self.test_collector is None:
|
||||
self.test_in_train = False
|
||||
if reset_collectors:
|
||||
self._reset_collectors(reset_buffer=reset_buffer)
|
||||
|
||||
if self.train_collector is not None and (
|
||||
self.train_collector.policy != self.policy or self.test_collector is None
|
||||
):
|
||||
self.test_in_train = False
|
||||
|
||||
if self.test_collector is not None:
|
||||
assert self.episode_per_test is not None
|
||||
assert not isinstance(self.test_collector, AsyncCollector) # Issue 700
|
||||
self.test_collector.reset_stat()
|
||||
test_result = test_episode(
|
||||
self.policy,
|
||||
self.test_collector,
|
||||
@ -284,7 +292,7 @@ class BaseTrainer(ABC):
|
||||
self.iter_num = 0
|
||||
|
||||
def __iter__(self): # type: ignore
|
||||
self.reset()
|
||||
self.reset(reset_collectors=True, reset_buffer=False)
|
||||
return self
|
||||
|
||||
def __next__(self) -> EpochStats:
|
||||
@ -308,8 +316,8 @@ class BaseTrainer(ABC):
|
||||
|
||||
# perform n step_per_epoch
|
||||
with progress(total=self.step_per_epoch, desc=f"Epoch #{self.epoch}", **tqdm_config) as t:
|
||||
train_stat: CollectStatsBase
|
||||
while t.n < t.total and not self.stop_fn_flag:
|
||||
train_stat: CollectStatsBase
|
||||
if self.train_collector is not None:
|
||||
train_stat, self.stop_fn_flag = self.train_step()
|
||||
pbar_data_dict = {
|
||||
@ -515,12 +523,14 @@ class BaseTrainer(ABC):
|
||||
stats of the whole dataset
|
||||
"""
|
||||
|
||||
def run(self) -> InfoStats:
|
||||
def run(self, reset_prior_to_run: bool = True) -> InfoStats:
|
||||
"""Consume iterator.
|
||||
|
||||
See itertools - recipes. Use functions that consume iterators at C speed
|
||||
(feed the entire iterator into a zero-length deque).
|
||||
"""
|
||||
if reset_prior_to_run:
|
||||
self.reset()
|
||||
try:
|
||||
self.is_run = True
|
||||
deque(self, maxlen=0) # feed the entire iterator into a zero-length deque
|
||||
|
@ -26,8 +26,7 @@ def test_episode(
|
||||
reward_metric: Callable[[np.ndarray], np.ndarray] | None = None,
|
||||
) -> CollectStats:
|
||||
"""A simple wrapper of testing policy in collector."""
|
||||
collector.reset_env()
|
||||
collector.reset_buffer()
|
||||
collector.reset(reset_stats=False)
|
||||
policy.eval()
|
||||
if test_fn:
|
||||
test_fn(epoch, global_step)
|
||||
|
@ -610,6 +610,17 @@ class BaseActor(nn.Module, ABC):
|
||||
def get_output_dim(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
# TODO: ALGO-REFACTORING. Marked to be addressed as part of Algorithm abstraction.
|
||||
# Return type needs to be more specific
|
||||
pass
|
||||
|
||||
|
||||
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
|
||||
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
|
||||
|
@ -1,4 +1,5 @@
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
@ -9,6 +10,7 @@ from torch import nn
|
||||
from tianshou.utils.net.common import (
|
||||
MLP,
|
||||
BaseActor,
|
||||
Net,
|
||||
TActionShape,
|
||||
TLinearLayer,
|
||||
get_output_dim,
|
||||
@ -19,33 +21,27 @@ SIGMA_MAX = 2
|
||||
|
||||
|
||||
class Actor(BaseActor):
|
||||
"""Simple actor network.
|
||||
"""Simple actor network that directly outputs actions for continuous action space.
|
||||
Used primarily in DDPG and its variants. For probabilistic policies, see :class:`~ActorProb`.
|
||||
|
||||
It will create an actor operated in continuous action space with structure of preprocess_net ---> action_shape.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net, see usage.
|
||||
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param max_action: the scale for the final action logits. Default to
|
||||
1.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
:param max_action: the scale for the final action.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
@ -77,42 +73,50 @@ class Actor(BaseActor):
|
||||
state: Any = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
"""Mapping: obs -> logits -> action."""
|
||||
if info is None:
|
||||
info = {}
|
||||
logits, hidden = self.preprocess(obs, state)
|
||||
logits = self.max_action * torch.tanh(self.last(logits))
|
||||
return logits, hidden
|
||||
"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
|
||||
|
||||
Returns a tensor representing the actions directly, i.e, of shape
|
||||
`(n_actions, )`, and a hidden state (which may be None).
|
||||
The hidden state is only not None if a recurrent net is used as part of the
|
||||
learning algorithm (support for RNNs is currently experimental).
|
||||
"""
|
||||
action_BA, hidden_BH = self.preprocess(obs, state)
|
||||
action_BA = self.max_action * torch.tanh(self.last(action_BA))
|
||||
return action_BA, hidden_BH
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
class CriticBase(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
act: np.ndarray | torch.Tensor | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
|
||||
|
||||
|
||||
class Critic(CriticBase):
|
||||
"""Simple critic network.
|
||||
|
||||
It will create an actor operated in continuous action space with structure of preprocess_net ---> 1(q value).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net, see usage.
|
||||
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
:param linear_layer: use this module as linear layer. Default to nn.Linear.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
:param linear_layer: use this module as linear layer.
|
||||
:param flatten_input: whether to flatten input data for the last layer.
|
||||
Default to True.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
device: str | int | torch.device = "cpu",
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
@ -139,9 +143,7 @@ class Critic(nn.Module):
|
||||
act: np.ndarray | torch.Tensor | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Mapping: (s, a) -> logits -> Q(s, a)."""
|
||||
if info is None:
|
||||
info = {}
|
||||
"""Mapping: (s_B, a_B) -> Q(s, a)_B."""
|
||||
obs = torch.as_tensor(
|
||||
obs,
|
||||
device=self.device,
|
||||
@ -154,41 +156,35 @@ class Critic(nn.Module):
|
||||
dtype=torch.float32,
|
||||
).flatten(1)
|
||||
obs = torch.cat([obs, act], dim=1)
|
||||
logits, hidden = self.preprocess(obs)
|
||||
return self.last(logits)
|
||||
values_B, hidden_BH = self.preprocess(obs)
|
||||
return self.last(values_B)
|
||||
|
||||
|
||||
class ActorProb(BaseActor):
|
||||
"""Simple actor network (output with a Gauss distribution).
|
||||
"""Simple actor network that outputs `mu` and `sigma` to be used as input for a `dist_fn` (typically, a Gaussian).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
Used primarily in SAC, PPO and variants thereof. For deterministic policies, see :class:`~Actor`.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net, see usage.
|
||||
Typically, an instance of :class:`~tianshou.utils.net.common.Net`.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param max_action: the scale for the final action logits. Default to
|
||||
1.
|
||||
:param unbounded: whether to apply tanh activation on final logits.
|
||||
Default to False.
|
||||
:param conditioned_sigma: True when sigma is calculated from the
|
||||
input, False when sigma is an independent parameter. Default to False.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
:param max_action: the scale for the final action logits.
|
||||
:param unbounded: whether to apply tanh activation on final logits.
|
||||
:param conditioned_sigma: True when sigma is calculated from the
|
||||
input, False when sigma is an independent parameter.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
# TODO: force kwargs, adjust downstream code
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
max_action: float = 1.0,
|
||||
@ -402,8 +398,7 @@ class Perturbation(nn.Module):
|
||||
flattened hidden state.
|
||||
:param max_action: the maximum value of each dimension of action.
|
||||
:param device: which device to create this model on.
|
||||
Default to cpu.
|
||||
:param phi: max perturbation parameter for BCQ. Default to 0.05.
|
||||
:param phi: max perturbation parameter for BCQ.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
@ -449,7 +444,6 @@ class VAE(nn.Module):
|
||||
:param latent_dim: the size of latent layer.
|
||||
:param max_action: the maximum value of each dimension of action.
|
||||
:param device: which device to create this model on.
|
||||
Default to "cpu".
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
@ -7,17 +7,14 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
|
||||
from tianshou.utils.net.common import MLP, BaseActor, Net, TActionShape, get_output_dim
|
||||
|
||||
|
||||
class Actor(BaseActor):
|
||||
"""Simple actor network.
|
||||
"""Simple actor network for discrete action spaces.
|
||||
|
||||
Will create an actor operated in discrete action space with structure of
|
||||
preprocess_net ---> action_shape.
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
|
||||
:class:`~tianshou.utils.net.common.Net`.
|
||||
:param action_shape: a sequence of int for the shape of action.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
@ -25,20 +22,15 @@ class Actor(BaseActor):
|
||||
:param softmax_output: whether to apply a softmax layer over the last
|
||||
layer's output.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
action_shape: TActionShape,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
softmax_output: bool = True,
|
||||
@ -71,43 +63,44 @@ class Actor(BaseActor):
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
r"""Mapping: s -> Q(s, \*)."""
|
||||
if info is None:
|
||||
info = {}
|
||||
logits, hidden = self.preprocess(obs, state)
|
||||
logits = self.last(logits)
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
r"""Mapping: s_B -> action_values_BA, hidden_state_BH | None.
|
||||
|
||||
Returns a tensor representing the values of each action, i.e, of shape
|
||||
`(n_actions, )`, and
|
||||
a hidden state (which may be None). If `self.softmax_output` is True, they are the
|
||||
probabilities for taking each action. Otherwise, they will be action values.
|
||||
The hidden state is only
|
||||
not None if a recurrent net is used as part of the learning algorithm.
|
||||
"""
|
||||
x, hidden_BH = self.preprocess(obs, state)
|
||||
x = self.last(x)
|
||||
if self.softmax_output:
|
||||
logits = F.softmax(logits, dim=-1)
|
||||
return logits, hidden
|
||||
x = F.softmax(x, dim=-1)
|
||||
# If we computed softmax, output is probabilities, otherwise it's the non-normalized action values
|
||||
output_BA = x
|
||||
return output_BA, hidden_BH
|
||||
|
||||
|
||||
class Critic(nn.Module):
|
||||
"""Simple critic network.
|
||||
"""Simple critic network for discrete action spaces.
|
||||
|
||||
It will create an actor operated in discrete action space with structure of preprocess_net ---> 1(q value).
|
||||
|
||||
:param preprocess_net: a self-defined preprocess_net which output a
|
||||
flattened hidden state.
|
||||
:param preprocess_net: a self-defined preprocess_net. Typically, an instance of
|
||||
:class:`~tianshou.utils.net.common.Net`.
|
||||
:param hidden_sizes: a sequence of int for constructing the MLP after
|
||||
preprocess_net. Default to empty sequence (where the MLP now contains
|
||||
only a single linear layer).
|
||||
:param last_size: the output dimension of Critic network. Default to 1.
|
||||
:param preprocess_net_output_dim: the output dimension of
|
||||
preprocess_net.
|
||||
`preprocess_net`. Only used when `preprocess_net` does not have the attribute `output_dim`.
|
||||
|
||||
For advanced usage (how to customize the network), please refer to
|
||||
:ref:`build_the_network`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
Please refer to :class:`~tianshou.utils.net.common.Net` as an instance
|
||||
of how preprocess_net is suggested to be defined.
|
||||
:ref:`build_the_network`..
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocess_net: nn.Module,
|
||||
preprocess_net: nn.Module | Net,
|
||||
hidden_sizes: Sequence[int] = (),
|
||||
last_size: int = 1,
|
||||
preprocess_net_output_dim: int | None = None,
|
||||
@ -120,8 +113,10 @@ class Critic(nn.Module):
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
|
||||
|
||||
# TODO: make a proper interface!
|
||||
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
"""Mapping: s -> V(s)."""
|
||||
"""Mapping: s_B -> V(s)_B."""
|
||||
# TODO: don't use this mechanism for passing state
|
||||
logits, _ = self.preprocess(obs, state=kwargs.get("state", None))
|
||||
return self.last(logits)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user