Improvements in High-Level API and Poe Tasks (#1055)
* Add an option to SamplingConfig which allows to configure number of test episodes * Make OptimizerFactory more flexible, adding method `create_optimizer_for_params` * Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer * Fix mypy issues that were platform/installation-dependent * Limit scope of nbqa, resolving issues with files generated by old versions of the build Fixes #1054
This commit is contained in:
commit
9b6cb6903e
2
.gitignore
vendored
2
.gitignore
vendored
@ -111,7 +111,7 @@ celerybeat.pid
|
||||
.env
|
||||
.venv
|
||||
venv/
|
||||
ENV/
|
||||
/ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
|
@ -143,8 +143,8 @@ line-length = 100
|
||||
target-version = ["py311"]
|
||||
|
||||
[tool.nbqa.exclude]
|
||||
ruff = ".jupyter_cache"
|
||||
mypy = ".jupyter_cache"
|
||||
ruff = "\\.jupyter_cache|jupyter_execute"
|
||||
mypy = "\\.jupyter_cache|jupyter_execute"
|
||||
|
||||
[tool.ruff]
|
||||
select = [
|
||||
@ -203,10 +203,10 @@ test = "pytest test --cov=tianshou --cov-report=xml --cov-report=term-missing --
|
||||
test-reduced = "pytest test/base test/continuous --cov=tianshou --durations=0 -v --color=yes"
|
||||
_black_check = "black --check ."
|
||||
_ruff_check = "ruff check ."
|
||||
_ruff_check_nb = "nbqa ruff ."
|
||||
_ruff_check_nb = "nbqa ruff docs"
|
||||
_black_format = "black ."
|
||||
_ruff_format = "ruff --fix ."
|
||||
_ruff_format_nb = "nbqa ruff --fix ."
|
||||
_ruff_format_nb = "nbqa ruff --fix docs"
|
||||
lint = ["_black_check", "_ruff_check", "_ruff_check_nb"]
|
||||
_poetry_install_sort_plugin = "poetry self add poetry-plugin-sort"
|
||||
_poetry_sort = "poetry sort"
|
||||
@ -221,5 +221,5 @@ doc-generate-files = ["_autogen_rst", "_jb_generate_toc", "_jb_generate_config"]
|
||||
doc-spellcheck = "sphinx-build -W -b spelling docs docs/_build"
|
||||
doc-build = ["doc-generate-files", "doc-spellcheck", "_sphinx_build"]
|
||||
_mypy = "mypy tianshou"
|
||||
_mypy_nb = "nbqa mypy ."
|
||||
_mypy_nb = "nbqa mypy docs"
|
||||
type-check = ["_mypy", "_mypy_nb"]
|
||||
|
3
tianshou/env/worker/ray.py
vendored
3
tianshou/env/worker/ray.py
vendored
@ -12,6 +12,9 @@ with contextlib.suppress(ImportError):
|
||||
import ray
|
||||
|
||||
|
||||
# mypy: disable-error-code="unused-ignore"
|
||||
|
||||
|
||||
class _SetAttrWrapper(gym.Wrapper):
|
||||
def set_env_attr(self, key: str, value: Any) -> None:
|
||||
setattr(self.env.unwrapped, key, value)
|
||||
|
9
tianshou/env/worker/subproc.py
vendored
9
tianshou/env/worker/subproc.py
vendored
@ -12,6 +12,9 @@ import numpy as np
|
||||
from tianshou.env.utils import CloudpickleWrapper, gym_new_venv_step_type
|
||||
from tianshou.env.worker import EnvWorker
|
||||
|
||||
# mypy: disable-error-code="unused-ignore"
|
||||
|
||||
|
||||
_NP_TO_CT = {
|
||||
np.bool_: ctypes.c_bool,
|
||||
np.uint8: ctypes.c_uint8,
|
||||
@ -179,10 +182,10 @@ class SubprocEnvWorker(EnvWorker):
|
||||
if remain_time <= 0:
|
||||
break
|
||||
# connection.wait hangs if the list is empty
|
||||
new_ready_conns = connection.wait(remain_conns, timeout=remain_time)
|
||||
new_ready_conns = connection.wait(remain_conns, timeout=remain_time) # type: ignore
|
||||
ready_conns.extend(new_ready_conns) # type: ignore
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns]
|
||||
return [workers[conns.index(con)] for con in ready_conns]
|
||||
remain_conns = [conn for conn in remain_conns if conn not in ready_conns] # type: ignore
|
||||
return [workers[conns.index(con)] for con in ready_conns] # type: ignore
|
||||
|
||||
def send(self, action: np.ndarray | None, **kwargs: Any) -> None:
|
||||
if action is None:
|
||||
|
@ -184,7 +184,7 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
repeat_per_collect=sampling_config.repeat_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||
batch_size=sampling_config.batch_size,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
@ -228,7 +228,7 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
||||
max_epoch=sampling_config.num_epochs,
|
||||
step_per_epoch=sampling_config.step_per_epoch,
|
||||
step_per_collect=sampling_config.step_per_collect,
|
||||
episode_per_test=sampling_config.num_test_envs,
|
||||
episode_per_test=sampling_config.num_test_episodes_per_test_env,
|
||||
batch_size=sampling_config.batch_size,
|
||||
save_best_fn=policy_persistence.get_save_best_fn(world),
|
||||
logger=world.logger,
|
||||
|
@ -1,3 +1,4 @@
|
||||
import math
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -16,7 +17,10 @@ class SamplingConfig(ToStringMixin):
|
||||
|
||||
* collects environment steps/transitions (collection step), adding them to the (replay)
|
||||
buffer (see :attr:`step_per_collect`)
|
||||
* performs one or more gradient updates (see :attr:`update_per_step`).
|
||||
* performs one or more gradient updates (see :attr:`update_per_step`),
|
||||
|
||||
and the test step collects :attr:`num_episodes_per_test` test episodes in order to evaluate
|
||||
agent performance.
|
||||
|
||||
The number of training steps in each epoch is indirectly determined by
|
||||
:attr:`step_per_epoch`: As many training steps will be performed as are required in
|
||||
@ -49,6 +53,12 @@ class SamplingConfig(ToStringMixin):
|
||||
num_test_envs: int = 1
|
||||
"""the number of test environments to use"""
|
||||
|
||||
num_test_episodes: int = 1
|
||||
"""the total number of episodes to collect in each test step (across all test environments).
|
||||
This should be a multiple of the number of test environments; if it is not, the effective
|
||||
number of episodes collected will be the nearest multiple (rounded up).
|
||||
"""
|
||||
|
||||
buffer_size: int = 4096
|
||||
"""the total size of the sample/replay buffer, in which environment steps (transitions) are
|
||||
stored"""
|
||||
@ -119,3 +129,8 @@ class SamplingConfig(ToStringMixin):
|
||||
def __post_init__(self) -> None:
|
||||
if self.num_train_envs == -1:
|
||||
self.num_train_envs = multiprocessing.cpu_count()
|
||||
|
||||
@property
|
||||
def num_test_episodes_per_test_env(self) -> int:
|
||||
""":return: the number of episodes to collect per test environment in every test step"""
|
||||
return math.ceil(self.num_test_episodes / self.num_test_envs)
|
||||
|
@ -1,11 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Protocol
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Protocol, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TParams: TypeAlias = Iterable[torch.Tensor] | Iterable[dict[str, Any]]
|
||||
|
||||
|
||||
class OptimizerWithLearningRateProtocol(Protocol):
|
||||
def __call__(self, parameters: Any, lr: float, **kwargs: Any) -> torch.optim.Optimizer:
|
||||
@ -13,8 +16,15 @@ class OptimizerWithLearningRateProtocol(Protocol):
|
||||
|
||||
|
||||
class OptimizerFactory(ABC, ToStringMixin):
|
||||
def create_optimizer(
|
||||
self,
|
||||
module: torch.nn.Module,
|
||||
lr: float,
|
||||
) -> torch.optim.Optimizer:
|
||||
return self.create_optimizer_for_params(module.parameters(), lr)
|
||||
|
||||
@abstractmethod
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
pass
|
||||
|
||||
|
||||
@ -30,8 +40,8 @@ class OptimizerFactoryTorch(OptimizerFactory):
|
||||
self.optim_class = optim_class
|
||||
self.kwargs = kwargs
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer:
|
||||
return self.optim_class(module.parameters(), lr=lr, **self.kwargs)
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
return self.optim_class(params, lr=lr, **self.kwargs)
|
||||
|
||||
|
||||
class OptimizerFactoryAdam(OptimizerFactory):
|
||||
@ -45,9 +55,9 @@ class OptimizerFactoryAdam(OptimizerFactory):
|
||||
self.eps = eps
|
||||
self.betas = betas
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam:
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
return Adam(
|
||||
module.parameters(),
|
||||
params,
|
||||
lr=lr,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
@ -70,9 +80,9 @@ class OptimizerFactoryRMSprop(OptimizerFactory):
|
||||
self.weight_decay = weight_decay
|
||||
self.eps = eps
|
||||
|
||||
def create_optimizer(self, module: torch.nn.Module, lr: float) -> RMSprop:
|
||||
def create_optimizer_for_params(self, params: TParams, lr: float) -> torch.optim.Optimizer:
|
||||
return RMSprop(
|
||||
module.parameters(),
|
||||
params,
|
||||
lr=lr,
|
||||
alpha=self.alpha,
|
||||
eps=self.eps,
|
||||
|
@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
||||
class AutoAlphaFactoryDefault(AutoAlphaFactory):
|
||||
def __init__(self, lr: float = 3e-4):
|
||||
self.lr = lr
|
||||
|
||||
@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||
target_entropy = float(-np.prod(envs.get_action_shape()))
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||||
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
|
||||
return target_entropy, log_alpha, alpha_optim
|
||||
|
Loading…
x
Reference in New Issue
Block a user