diff --git a/.gitignore b/.gitignore index 98acab5..e63e24b 100644 --- a/.gitignore +++ b/.gitignore @@ -111,7 +111,7 @@ celerybeat.pid .env .venv venv/ -ENV/ +/ENV/ env.bak/ venv.bak/ diff --git a/pyproject.toml b/pyproject.toml index aa66eb5..c47b14b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -143,8 +143,8 @@ line-length = 100 target-version = ["py311"] [tool.nbqa.exclude] -ruff = ".jupyter_cache" -mypy = ".jupyter_cache" +ruff = "\\.jupyter_cache|jupyter_execute" +mypy = "\\.jupyter_cache|jupyter_execute" [tool.ruff] select = [ @@ -203,10 +203,10 @@ test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing -- test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes" _black_check = "black --check ." _ruff_check = "ruff check ." -_ruff_check_nb = "nbqa ruff ." +_ruff_check_nb = "nbqa ruff docs" _black_format = "black ." _ruff_format = "ruff --fix ." -_ruff_format_nb = "nbqa ruff --fix ." +_ruff_format_nb = "nbqa ruff --fix docs" lint = ["_black_check", "_ruff_check", "_ruff_check_nb"] _poetry_install_sort_plugin = "poetry self add poetry-plugin-sort" _poetry_sort = "poetry sort" @@ -221,5 +221,5 @@ doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"] doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build" doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"] _mypy = "mypy tianshou" -_mypy_nb = "nbqa mypy ." +_mypy_nb = "nbqa mypy docs" type-check = ["_mypy", "_mypy_nb"] diff --git a/tianshou/env/worker/ray.py b/tianshou/env/worker/ray.py index f465eae..76b8422 100644 --- a/tianshou/env/worker/ray.py +++ b/tianshou/env/worker/ray.py @@ -12,6 +12,9 @@ with contextlib.suppress(ImportError): import ray +# mypy: disable-error-code="unused-ignore" + + class _SetAttrWrapper(gym.Wrapper): def set_env_attr(self, key: str, value: Any) -> None: setattr(self.env.unwrapped, key, value) diff --git a/tianshou/env/worker/subproc.py b/tianshou/env/worker/subproc.py index 2ca60c2..af5ec4e 100644 --- a/tianshou/env/worker/subproc.py +++ b/tianshou/env/worker/subproc.py @@ -12,6 +12,9 @@ import numpy as np from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type from tianshou.env.worker import EnvWorker +# mypy: disable-error-code="unused-ignore" + + _NP_TO_CT = { np.bool_: ctypes.c_bool, np.uint8: ctypes.c_uint8, @@ -179,10 +182,10 @@ class SubprocEnvWorker(EnvWorker): if remain_time <= 0: break # connection.wait hangs if the list is empty - new_ready_conns = connection.wait(remain_conns, timeout=remain_time) + new_ready_conns = connection.wait(remain_conns, timeout=remain_time) # type: ignore ready_conns.extend(new_ready_conns) # type: ignore - remain_conns = [conn for conn in remain_conns if conn not in ready_conns] - return [workers[conns.index(con)] for con in ready_conns] + remain_conns = [conn for conn in remain_conns if conn not in ready_conns] # type: ignore + return [workers[conns.index(con)] for con in ready_conns] # type: ignore def send(self, action: np.ndarray | None, **kwargs: Any) -> None: if action is None: diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index b72ab5e..1a1a0bf 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC): max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, repeat_per_collect=sampling_config.repeat_per_collect, - episode_per_test=sampling_config.num_test_envs, + episode_per_test=sampling_config.num_test_episodes_per_test_env, batch_size=sampling_config.batch_size, step_per_collect=sampling_config.step_per_collect, save_best_fn=policy_persistence.get_save_best_fn(world), @@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC): max_epoch=sampling_config.num_epochs, step_per_epoch=sampling_config.step_per_epoch, step_per_collect=sampling_config.step_per_collect, - episode_per_test=sampling_config.num_test_envs, + episode_per_test=sampling_config.num_test_episodes_per_test_env, batch_size=sampling_config.batch_size, save_best_fn=policy_persistence.get_save_best_fn(world), logger=world.logger, diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index 80e0476..4982472 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,3 +1,4 @@ +import math import multiprocessing from dataclasses import dataclass @@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin): * collects environment steps/transitions (collection step), adding them to the (replay) buffer (see :attr:`step_per_collect`) - * performs one or more gradient updates (see :attr:`update_per_step`). + * performs one or more gradient updates (see :attr:`update_per_step`), + + and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate + agent performance. The number of training steps in each epoch is indirectly determined by :attr:`step_per_epoch`: As many training steps will be performed as are required in @@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin): num_test_envs: int = 1 """the number of test environments to use""" + num_test_episodes: int = 1 + """the total number of episodes to collect in each test step (across all test environments). + This should be a multiple of the number of test environments; if it is not, the effective + number of episodes collected will be the nearest multiple (rounded up). + """ + buffer_size: int = 4096 """the total size of the sample/replay buffer, in which environment steps (transitions) are stored""" @@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin): def __post_init__(self) -> None: if self.num_train_envs == -1: self.num_train_envs = multiprocessing.cpu_count() + + @property + def num_test_episodes_per_test_env(self) -> int: + """:return: the number of episodes to collect per test environment in every test step""" + return math.ceil(self.num_test_episodes / self.num_test_envs) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 0e754b1..db5fd90 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -1,11 +1,14 @@ from abc import ABC, abstractmethod -from typing import Any, Protocol +from collections.abc import Iterable +from typing import Any, Protocol, TypeAlias import torch from torch.optim import Adam, RMSprop from tianshou.utils.string import ToStringMixin +TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]] + class OptimizerWithLearningRateProtocol(Protocol): def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer: @@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol): class OptimizerFactory(ABC, ToStringMixin): + def create_optimizer( + self, + module: torch.nn.Module, + lr: float, + ) -> torch.optim.Optimizer: + return self.create_optimizer_for_params(module.parameters(), lr) + @abstractmethod - def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: pass @@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory): self.optim_class = optim_class self.kwargs = kwargs - def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: - return self.optim_class(module.parameters(), lr=lr, **self.kwargs) + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: + return self.optim_class(params, lr=lr, **self.kwargs) class OptimizerFactoryAdam(OptimizerFactory): @@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory): self.eps = eps self.betas = betas - def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: return Adam( - module.parameters(), + params, lr=lr, betas=self.betas, eps=self.eps, @@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory): self.weight_decay = weight_decay self.eps = eps - def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop: + def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer: return RMSprop( - module.parameters(), + params, lr=lr, alpha=self.alpha, eps=self.eps, diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 878ae4b..4e8490d 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC): pass -class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? +class AutoAlphaFactoryDefault(AutoAlphaFactory): def __init__(self, lr: float = 3e-4): self.lr = lr @@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: target_entropy = float(-np.prod(envs.get_action_shape())) log_alpha = torch.zeros(1, requires_grad=True, device=device) - alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) + alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) return target_entropy, log_alpha, alpha_optim