Improvements in High-Level API and Poe Tasks (#1055)

* Add an option to SamplingConfig which allows to configure number of
test episodes
* Make OptimizerFactory more flexible, adding method
`create_optimizer_for_params`
* Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer
* Fix mypy issues that were platform/installation-dependent
* Limit scope of nbqa, resolving issues with files generated by old
versions of the build

Fixes #1054
This commit is contained in:
Michael Panchenko 2024-02-15 12:02:16 +01:00 committed by GitHub
commit 9b6cb6903e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 53 additions and 22 deletions

2
.gitignore vendored
View File

@ -111,7 +111,7 @@ celerybeat.pid
.env
.venv
venv/
ENV/
/ENV/
env.bak/
venv.bak/

View File

@ -143,8 +143,8 @@ line-length = 100
target-version = ["py311"]
[tool.nbqa.exclude]
ruff = ".jupyter_cache"
mypy = ".jupyter_cache"
ruff = "\\.jupyter_cache|jupyter_execute"
mypy = "\\.jupyter_cache|jupyter_execute"
[tool.ruff]
select = [
@ -203,10 +203,10 @@ test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing --
test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes"
_black_check = "black --check ."
_ruff_check = "ruff check ."
_ruff_check_nb = "nbqa ruff ."
_ruff_check_nb = "nbqa ruff docs"
_black_format = "black ."
_ruff_format = "ruff --fix ."
_ruff_format_nb = "nbqa ruff --fix ."
_ruff_format_nb = "nbqa ruff --fix docs"
lint = ["_black_check", "_ruff_check", "_ruff_check_nb"]
_poetry_install_sort_plugin = "poetry self add poetry-plugin-sort"
_poetry_sort = "poetry sort"
@ -221,5 +221,5 @@ doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"]
doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build"
doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"]
_mypy = "mypy tianshou"
_mypy_nb = "nbqa mypy ."
_mypy_nb = "nbqa mypy docs"
type-check = ["_mypy", "_mypy_nb"]

View File

@ -12,6 +12,9 @@ with contextlib.suppress(ImportError):
import ray
# mypy: disable-error-code="unused-ignore"
class _SetAttrWrapper(gym.Wrapper):
def set_env_attr(self, key: str, value: Any) -> None:
setattr(self.env.unwrapped, key, value)

View File

@ -12,6 +12,9 @@ import numpy as np
from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type
from tianshou.env.worker import EnvWorker
# mypy: disable-error-code="unused-ignore"
_NP_TO_CT = {
np.bool_: ctypes.c_bool,
np.uint8: ctypes.c_uint8,
@ -179,10 +182,10 @@ class SubprocEnvWorker(EnvWorker):
if remain_time <= 0:
break
# connection.wait hangs if the list is empty
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
new_ready_conns = connection.wait(remain_conns, timeout=remain_time) # type: ignore
ready_conns.extend(new_ready_conns) # type: ignore
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
return [workers[conns.index(con)] for con in ready_conns]
remain_conns = [conn for conn in remain_conns if conn not in ready_conns] # type: ignore
return [workers[conns.index(con)] for con in ready_conns] # type: ignore
def send(self, action: np.ndarray | None, **kwargs: Any) -> None:
if action is None:

View File

@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
repeat_per_collect=sampling_config.repeat_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
step_per_collect=sampling_config.step_per_collect,
save_best_fn=policy_persistence.get_save_best_fn(world),
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
max_epoch=sampling_config.num_epochs,
step_per_epoch=sampling_config.step_per_epoch,
step_per_collect=sampling_config.step_per_collect,
episode_per_test=sampling_config.num_test_envs,
episode_per_test=sampling_config.num_test_episodes_per_test_env,
batch_size=sampling_config.batch_size,
save_best_fn=policy_persistence.get_save_best_fn(world),
logger=world.logger,

View File

@ -1,3 +1,4 @@
import math
import multiprocessing
from dataclasses import dataclass
@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
* collects environment steps/transitions (collection step), adding them to the (replay)
buffer (see :attr:`step_per_collect`)
* performs one or more gradient updates (see :attr:`update_per_step`).
* performs one or more gradient updates (see :attr:`update_per_step`),
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
agent performance.
The number of training steps in each epoch is indirectly determined by
:attr:`step_per_epoch`: As many training steps will be performed as are required in
@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
num_test_envs: int = 1
"""the number of test environments to use"""
num_test_episodes: int = 1
"""the total number of episodes to collect in each test step (across all test environments).
This should be a multiple of the number of test environments; if it is not, the effective
number of episodes collected will be the nearest multiple (rounded up).
"""
buffer_size: int = 4096
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
stored"""
@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
def __post_init__(self) -> None:
if self.num_train_envs == -1:
self.num_train_envs = multiprocessing.cpu_count()
@property
def num_test_episodes_per_test_env(self) -> int:
""":return: the number of episodes to collect per test environment in every test step"""
return math.ceil(self.num_test_episodes / self.num_test_envs)

View File

@ -1,11 +1,14 @@
from abc import ABC, abstractmethod
from typing import Any, Protocol
from collections.abc import Iterable
from typing import Any, Protocol, TypeAlias
import torch
from torch.optim import Adam, RMSprop
from tianshou.utils.string import ToStringMixin
TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
class OptimizerWithLearningRateProtocol(Protocol):
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol):
class OptimizerFactory(ABC, ToStringMixin):
def create_optimizer(
self,
module: torch.nn.Module,
lr: float,
) -> torch.optim.Optimizer:
return self.create_optimizer_for_params(module.parameters(), lr)
@abstractmethod
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
pass
@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory):
self.optim_class = optim_class
self.kwargs = kwargs
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return self.optim_class(params, lr=lr, **self.kwargs)
class OptimizerFactoryAdam(OptimizerFactory):
@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory):
self.eps = eps
self.betas = betas
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return Adam(
module.parameters(),
params,
lr=lr,
betas=self.betas,
eps=self.eps,
@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory):
self.weight_decay = weight_decay
self.eps = eps
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
return RMSprop(
module.parameters(),
params,
lr=lr,
alpha=self.alpha,
eps=self.eps,

View File

@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC):
pass
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
class AutoAlphaFactoryDefault(AutoAlphaFactory):
def __init__(self, lr: float = 3e-4):
self.lr = lr
@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = float(-np.prod(envs.get_action_shape()))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
return target_entropy, log_alpha, alpha_optim