Improvements in README and high-level API (#1022)

This makes several largely unrelated improvements in the high-level API
and in the README.

Main improvements in high-level API:
* Improve naming in trainer-related abstractions, moved some classes
from examples to the library
  * Improve environment factory abstraction
  * Some bug-fixes

Main changes in README:
  * Add high-level example and update procedural/low-level example
  * Improve language/wording
This commit is contained in:
Michael Panchenko 2024-01-16 15:24:41 +01:00 committed by GitHub
commit 6e1ffe58e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 872 additions and 348 deletions

219
README.md
View File

@ -10,20 +10,17 @@
> Tianshou no longer supports `gym`, and we recommend that you transition to
> [Gymnasium](http://github.com/Farama-Foundation/Gymnasium).
> If you absolutely have to use gym, you can try using [Shimmy](https://github.com/Farama-Foundation/Shimmy)
> (the compatibility layer), but tianshou provides no guarantees that things will work then.
> (the compatibility layer), but Tianshou provides no guarantees that things will work then.
> ⚠️️ **Current Status**: the tianshou master branch is currently under heavy development,
> moving towards more features, improved interfaces, more documentation, and better compatibility with
> other RL libraries. You can view the relevant issues in the corresponding
> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development,
> moving towards more features, improved interfaces, more documentation.
You can view the relevant issues in the corresponding
> [milestone](https://github.com/thu-ml/tianshou/milestone/1)
> Stay tuned! (and expect breaking changes until the release is done)
> ⚠️️ **Installing PyTorch**: Because of a problem with pytorch packaging and poetry in
> current releases, the newest version of pytorch is not included in the tianshou dependencies.
> You can still install the newest pytorch with `pip` after tianshou was installed with `poetry`.
> [Here](https://github.com/python-poetry/poetry/issues/7902#issuecomment-1747400255) is a discussion between torch and poetry devs, who are trying to resolve it.
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike other reinforcement learning libraries, which are partly based on TensorFlow, have unfriendly APIs ot are not optimized for speed, Tianshou provides a high-performance, modularized framework and user-friendly APIs for building deep reinforcement learning agents, enabling concise implementations without sacrificing flexibility.
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike several existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
The set of supported algorithms includes the following:
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
@ -58,22 +55,28 @@
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf)
Here are Tianshou's other features:
Other noteworthy features:
- Elegant framework, using few lines of code in the core abstractions
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support vectorized environment (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling)
- Support super-fast vectorized environment [EnvPool](https://github.com/sail-sg/envpool/) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#envpool-integration)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training)
- Elegant framework with dual APIs:
* Tianshou's high-level API maximizes ease of use for application development while still retaining a high degree
of flexibility.
* The fundamental procedural API provides a maximum of flexibility for algorithm development without being
overly verbose.
- State-of-the-art results in [MuJoCo benchmarks](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support for vectorized environments (synchronous or asynchronous) for all algorithms (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling))
- Support for super-fast vectorized environments based on [EnvPool](https://github.com/sail-sg/envpool/) for all algorithms (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#envpool-integration))
- Support for recurrent state representations in actor networks and critic networks (RNN-style training for POMDPs) (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training))
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-agent-reinforcement-learning)
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
- Support multi-GPU training [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-gpu)
- Support for customized training processes (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process))
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are highly optimized thanks to numba's just-in-time compilation and vectorized numpy operations
- Support for multi-agent RL (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-agent-reinforcement-learning))
- Support for logging based on both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/)
- Support for multi-GPU training (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-gpu))
- Comprehensive documentation, PEP8 code-style checking, type checking and thorough [tests](https://github.com/thu-ml/tianshou/actions)
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
In Chinese, Tianshou means divinely ordained, being derived to the gift of being born.
Tianshou is a reinforcement learning platform, and the nature of RL is not learn from humans.
So taking "Tianshou" means that there is no teacher to learn from, but rather to learn by oneself through constant interaction with the environment.
“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
@ -87,32 +90,32 @@ You can simply install Tianshou from PyPI with the following command:
$ pip install tianshou
```
If you use Anaconda or Miniconda, you can install Tianshou from conda-forge through the following command:
If you are using Anaconda or Miniconda, you can install Tianshou from conda-forge:
```bash
$ conda install tianshou -c conda-forge
```
You can also install with the newest version through GitHub:
Alternatively, you can also install the latest source version through GitHub:
```bash
$ pip install git+https://github.com/thu-ml/tianshou.git@master --upgrade
```
After installation, open your python console and type
Finally, you may check the installation via your Python console as follows:
```python
import tianshou
print(tianshou.__version__)
```
If no error occurs, you have successfully installed Tianshou.
If no errors are reported, you have successfully installed Tianshou.
## Documentation
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders.
中文文档位于 [https://tianshou.readthedocs.io/zh/master/](https://tianshou.readthedocs.io/zh/master/)。
@ -166,35 +169,149 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
<sup>(1): it has continuous integration but the coverage rate is not available</sup>
### Reproducible and High Quality Result
### Reproducible, High-Quality Results
Tianshou has its tests. Different from other platforms, **the tests include the full agent training procedure for all of the implemented algorithms**. It would be failed once if it could not train an agent to perform well enough on limited epochs on toy scenarios. The tests secure the reproducibility of our platform. Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page for more detail.
Tianshou is rigorously tested. In contrast to other RL platforms, **our tests include the full agent training procedure for all of the implemented algorithms**. Our tests would fail once if any of the agents failed to achieve a consistent level of performance on limited epochs.
Our tests thus ensure reproducibility.
Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page for more detail.
The Atari/Mujoco benchmark results are under [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders. **Our Mujoco result can beat most of existing benchmarks.**
Atari and MuJoCo benchmark results can be found in the [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders respectively. **Our MuJoCo results reach or exceed the level of performance of most existing benchmarks.**
### Modularized Policy
### Policy Interface
We decouple all algorithms roughly into the following parts:
All algorithms implement the following, highly general API:
- `__init__`: initialize the policy;
- `forward`: to compute actions over given observations;
- `process_buffer`: process initial buffer, useful for some offline learning algorithms
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
- `learn`: to learn from a given batch data;
- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
- `forward`: compute actions based on given observations;
- `process_buffer`: process initial buffer, which is useful for some offline learning algorithms
- `process_fn`: preprocess data from the replay buffer (since we have reformulated *all* algorithms to replay buffer-based algorithms);
- `learn`: learn from a given batch of data;
- `post_process_fn`: update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`.
Within this API, we can interact with different policies conveniently.
The implementation of this API suffices for a new algorithm to be applicable within Tianshou,
making experimenation with new approaches particularly straightforward.
## Quick Start
This is an example of Deep Q Network. You can also run the full script at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
Tianshou provides two API levels:
* the high-level interface, which provides ease of use for end users seeking to run deep reinforcement learning applications
* the procedural interface, which provides a maximum of control, especially for very advanced users and developers of reinforcement learning algorithms.
In the following, let us consider an example application using the *CartPole* gymnasium environment.
We shall apply the deep Q network (DQN) learning algorithm using both APIs.
### High-Level API
To get started, we need some imports.
```python
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import (
EnvFactoryRegistered,
VectorEnvType,
)
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTestDQNSetEps,
TrainerEpochCallbackTrainDQNSetEps,
)
```
In the high-level API, the basis for an RL experiment is an `ExperimentBuilder`
with which we can build the experiment we then seek to run.
Since we want to use DQN, we use the specialization `DQNExperimentBuilder`.
The other imports serve to provide configuration options for our experiment.
The high-level API provides largely declarative semantics, i.e. the code is
almost exclusively concerned with configuration that controls what to do
(rather than how to do it).
```python
experiment = (
DQNExperimentBuilder(
EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
ExperimentConfig(
persistence_enabled=False,
watch=True,
watch_render=1 / 35,
watch_num_episodes=100,
),
SamplingConfig(
num_epochs=10,
step_per_epoch=10000,
batch_size=64,
num_train_envs=10,
num_test_envs=100,
buffer_size=20000,
step_per_collect=10,
update_per_step=1 / 10,
),
)
.with_dqn_params(
DQNParams(
lr=1e-3,
discount_factor=0.9,
estimation_step=3,
target_update_freq=320,
),
)
.with_model_factory_default(hidden_sizes=(64, 64))
.with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0))
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
.build()
)
experiment.run()
```
The experiment builder takes three arguments:
* the environment factory for the creation of environments. In this case,
we use an existing factory implementation for gymnasium environments.
* the experiment configuration, which controls persistence and the overall
experiment flow. In this case, we have configured that we want to observe
the agent's behavior after it is trained (`watch=True`) for a number of
episodes (`watch_num_episodes=100`). We have disabled persistence, because
we do not want to save training logs, the agent or its configuration for
future use.
* the sampling configuration, which controls fundamental training parameters,
such as the total number of epochs we run the experiment for (`num_epochs=10`)
and the number of environment steps each epoch shall consist of
(`step_per_epoch=10000`).
Every epoch consists of a series of data collection (rollout) steps and
training steps.
The parameter `step_per_collect` controls the amount of data that is
collected in each collection step and after each collection step, we
perform a training step, applying a gradient-based update based on a sample
of data (`batch_size=64`) taken from the buffer of data that has been
collected. For further details, see the documentation of `SamplingConfig`.
We then proceed to configure some of the parameters of the DQN algorithm itself
and of the neural network model we want to use.
A DQN-specific detail is the use of callbacks to configure the algorithm's
epsilon parameter for exploration. We want to use random exploration during rollouts
(train callback), but we don't when evaluating the agent's performance in the test
environments (test callback).
Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py).
Here's a run (with the training time cut short):
<p align="center" style="text-algin:center">
<img src="docs/_static/images/discrete_dqn_hl.gif">
</p>
### Procedural API
Let us now consider an analogous example in the procedural API.
Find the full script from which the snippets below were derived at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
First, import some relevant packages:
```python
import gymnasium as gym
import torch, numpy as np, torch.nn as nn
import torch
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
```
@ -202,7 +319,7 @@ import tianshou as ts
Define some hyper-parameters:
```python
task = 'CartPole-v0'
task = 'CartPole-v1'
lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
@ -227,7 +344,7 @@ Define the network:
from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task)
env = gym.make(task, render_mode="human")
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
@ -267,7 +384,7 @@ result = ts.trainer.OffpolicyTrainer(
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
logger=logger,
).run()
print(f'Finished training! Use {result["duration"]}')
print(f"Finished training in {result.timing.total_time} seconds")
```
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
@ -294,19 +411,11 @@ $ tensorboard --logdir log/dqn
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses **3** seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
```bash
$ python3 test/discrete/test_pg.py --seed 0 --render 0.03
```
<div align="center">
<img src="https://github.com/thu-ml/tianshou/raw/master/docs/_static/images/testpg.gif"></a>
</div>
## Contributing
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).
Tianshou is still under development.
Further algorithms and features are continuously being added, and we always welcome contributions to help make Tianshou better.
If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).
## Citing Tianshou
@ -325,7 +434,7 @@ If you find Tianshou useful, please cite it in your publications.
}
```
## Acknowledgment
## Acknowledgments
Tianshou is supported by [appliedAI Institute for Europe](https://www.appliedai-institute.de/en/),
who is committed to providing long-term support and development.

BIN
docs/_static/images/discrete_dqn_hl.gif vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 600 KiB

View File

@ -1,33 +0,0 @@
from tianshou.highlevel.trainer import (
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainingContext,
)
from tianshou.policy import DQNPolicy
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
policy.set_eps(self.eps_test)
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
def __init__(self, eps_train: float, eps_train_final: float):
self.eps_train = eps_train
self.eps_train_final = eps_train_final
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy: DQNPolicy = context.policy
logger = context.logger
# nature DQN setting, linear decay in the first 1M steps
if env_step <= 1e6:
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
else:
eps = self.eps_train_final
policy.set_eps(eps)
if env_step % 1000 == 0:
logger.write("train/env_step", env_step, {"train/eps": eps})

View File

@ -2,15 +2,11 @@
import os
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DQNExperimentBuilder,
@ -20,6 +16,10 @@ from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.highlevel.trainer import (
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -27,7 +27,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
@ -79,11 +79,11 @@ def main(
),
)
.with_model_factory(IntermediateModuleFactoryAtariDQN())
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
.with_epoch_train_callback(
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(

View File

@ -3,20 +3,20 @@
import os
from collections.abc import Sequence
from examples.atari.atari_callbacks import (
TestEpochCallbackDQNSetEps,
TrainEpochCallbackNatureDQNEpsLinearDecay,
)
from examples.atari.atari_network import (
IntermediateModuleFactoryAtariDQN,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
IQNExperimentBuilder,
)
from tianshou.highlevel.params.policy_params import IQNParams
from tianshou.highlevel.trainer import (
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNEpsLinearDecay,
)
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
eps_test: float = 0.005,
eps_train: float = 1.0,
eps_train_final: float = 0.05,
@ -83,11 +83,11 @@ def main(
),
)
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
.with_trainer_epoch_callback_train(
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
.with_epoch_train_callback(
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
)
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
.build()
)
experiment.run(log_name)

View File

@ -23,19 +23,27 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.
return layer
def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]:
class scaled_module(module):
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return super().forward(obs / denom, state, info)
class ScaledObsInputModule(torch.nn.Module):
def __init__(self, module: torch.nn.Module, denom: float = 255.0):
super().__init__()
self.module = module
self.denom = denom
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
self.output_dim = module.output_dim
return scaled_module
def forward(
self,
obs: np.ndarray | torch.Tensor,
state: Any | None = None,
info: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, Any]:
if info is None:
info = {}
return self.module.forward(obs / self.denom, state, info)
def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module:
return ScaledObsInputModule(module, denom=denom)
class DQN(nn.Module):
@ -238,8 +246,7 @@ class ActorFactoryAtariDQN(ActorFactory):
self.features_only = features_only
def create_module(self, envs: Environments, device: TDevice) -> Actor:
net_cls = scale_obs(DQN) if self.scale_obs else DQN
net = net_cls(
net = DQN(
*envs.get_observation_shape(),
envs.get_action_shape(),
device=device,
@ -247,6 +254,8 @@ class ActorFactoryAtariDQN(ActorFactory):
output_dim=self.hidden_size,
layer_init=layer_init,
)
if self.scale_obs:
net = scale_obs(net)
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)

View File

@ -109,8 +109,7 @@ def test_ppo(args=get_args()):
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# define model
net_cls = scale_obs(DQN) if args.scale_obs else DQN
net = net_cls(
net = DQN(
*args.state_shape,
args.action_shape,
device=args.device,
@ -118,6 +117,8 @@ def test_ppo(args=get_args()):
output_dim=args.hidden_size,
layer_init=layer_init,
)
if args.scale_obs:
net = scale_obs(net)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)

View File

@ -7,7 +7,7 @@ from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
@ -95,7 +95,7 @@ def main(
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
.with_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(

View File

@ -6,7 +6,7 @@ from examples.atari.atari_network import (
ActorFactoryAtariDQN,
IntermediateModuleFactoryAtariDQNFeatures,
)
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
DiscreteSACExperimentBuilder,
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "PongNoFrameskip-v4",
scale_obs: int = 0,
scale_obs: bool = False,
buffer_size: int = 100000,
actor_lr: float = 1e-5,
critic_lr: float = 1e-5,
@ -82,7 +82,7 @@ def main(
)
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
.with_common_critic_factory_use_actor()
.with_trainer_stop_callback(AtariStopCallback(task))
.with_epoch_stop_callback(AtariEpochStopCallback(task))
)
if icm_lr_scale > 0:
builder.with_policy_wrapper_factory(

View File

@ -1,21 +1,29 @@
# Borrow a lot from openai baselines:
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
import logging
import warnings
from collections import deque
import cv2
import gymnasium as gym
import numpy as np
from gymnasium import Env
from tianshou.env import ShmemVectorEnv
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
from tianshou.highlevel.env import (
EnvFactoryRegistered,
EnvMode,
EnvPoolFactory,
VectorEnvType,
)
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
envpool_is_available = True
try:
import envpool
except ImportError:
envpool_is_available = False
envpool = None
log = logging.getLogger(__name__)
def _parse_reset_result(reset_result):
@ -282,7 +290,7 @@ class FrameStack(gym.Wrapper):
def wrap_deepmind(
env_id,
env: Env,
episode_life=True,
clip_rewards=True,
frame_stack=4,
@ -293,7 +301,7 @@ def wrap_deepmind(
The observation is channel-first: (c, h, w) instead of (h, w, c).
:param str env_id: the atari environment id.
:param env: the Atari environment to wrap.
:param bool episode_life: wrap the episode life wrapper.
:param bool clip_rewards: wrap the reward clipping wrapper.
:param int frame_stack: wrap the frame stacking wrapper.
@ -301,8 +309,6 @@ def wrap_deepmind(
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
:return: the wrapped atari environment.
"""
assert "NoFrameskip" in env_id
env = gym.make(env_id)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if episode_life:
@ -320,79 +326,91 @@ def wrap_deepmind(
return env
def make_atari_env(task, seed, training_num, test_num, **kwargs):
def make_atari_env(
task,
seed,
training_num,
test_num,
scale: int | bool = False,
frame_stack: int = 4,
):
"""Wrapper function for Atari env.
If EnvPool is installed, it will automatically switch to EnvPool's Atari env.
:return: a tuple of (single env, training envs, test envs).
"""
if envpool is not None:
if kwargs.get("scale", 0):
warnings.warn(
"EnvPool does not include ScaledFloatFrame wrapper, "
"please set `x = x / 255.0` inside CNN network's forward function.",
)
# parameters convertion
train_envs = env = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=training_num,
seed=seed,
episodic_life=True,
reward_clip=True,
stack_num=kwargs.get("frame_stack", 4),
)
test_envs = envpool.make_gymnasium(
task.replace("NoFrameskip-v4", "-v5"),
num_envs=test_num,
seed=seed,
episodic_life=False,
reward_clip=False,
stack_num=kwargs.get("frame_stack", 4),
)
else:
warnings.warn(
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
)
env = wrap_deepmind(task, **kwargs)
train_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs)
for _ in range(training_num)
],
)
test_envs = ShmemVectorEnv(
[
lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs)
for _ in range(test_num)
],
)
env.seed(seed)
train_envs.seed(seed)
test_envs.seed(seed)
return env, train_envs, test_envs
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
envs = env_factory.create_envs(training_num, test_num)
return envs.env, envs.train_envs, envs.test_envs
class AtariEnvFactory(EnvFactory):
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
self.task = task
self.seed = seed
class AtariEnvFactory(EnvFactoryRegistered):
def __init__(
self,
task: str,
seed: int,
frame_stack: int,
scale: bool = False,
use_envpool_if_available: bool = True,
):
assert "NoFrameskip" in task
self.frame_stack = frame_stack
self.scale = scale
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
env, train_envs, test_envs = make_atari_env(
task=self.task,
seed=self.seed,
training_num=num_training_envs,
test_num=num_test_envs,
scale=self.scale,
frame_stack=self.frame_stack,
envpool_factory = None
if use_envpool_if_available:
if envpool_is_available:
envpool_factory = self.EnvPoolFactory(self)
log.info("Using envpool, because it available")
else:
log.info("Not using envpool, because it is not available")
super().__init__(
task=task,
seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=envpool_factory,
)
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
def create_env(self, mode: EnvMode) -> Env:
env = super().create_env(mode)
is_train = mode == EnvMode.TRAIN
return wrap_deepmind(
env,
episode_life=is_train,
clip_rewards=is_train,
frame_stack=self.frame_stack,
scale=self.scale,
)
class EnvPoolFactory(EnvPoolFactory):
"""Atari-specific envpool creation.
Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`,
it sets the creation keyword arguments accordingly.
"""
def __init__(self, parent: "AtariEnvFactory"):
self.parent = parent
if self.parent.scale:
warnings.warn(
"EnvPool does not include ScaledFloatFrame wrapper, "
"please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)",
)
def _transform_task(self, task: str) -> str:
task = super()._transform_task(task)
# TODO: Maybe warn user, explain why this is needed
return task.replace("NoFrameskip-v4", "-v5")
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
kwargs = super()._transform_kwargs(kwargs, mode)
is_train = mode == EnvMode.TRAIN
kwargs["reward_clip"] = is_train
kwargs["episodic_life"] = is_train
kwargs["stack_num"] = self.parent.frame_stack
return kwargs
class AtariStopCallback(TrainerStopCallback):
class AtariEpochStopCallback(EpochStopCallback):
def __init__(self, task: str):
self.task = task

View File

@ -0,0 +1,78 @@
import gymnasium as gym
import torch
from torch.utils.tensorboard import SummaryWriter
import tianshou as ts
def main():
task = "CartPole-v1"
lr, epoch, batch_size = 1e-3, 10, 64
train_num, test_num = 10, 100
gamma, n_step, target_freq = 0.9, 3, 320
buffer_size = 20000
eps_train, eps_test = 0.1, 0.05
step_per_epoch, step_per_collect = 10000, 10
logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported!
# For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html
# you can also try with SubprocVectorEnv
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
from tianshou.utils.net.common import Net
# you can define other net by following the API:
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
env = gym.make(task, render_mode="human")
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
optim = torch.optim.Adam(net.parameters(), lr=lr)
policy = ts.policy.DQNPolicy(
model=net,
optim=optim,
discount_factor=gamma,
action_space=env.action_space,
estimation_step=n_step,
target_update_freq=target_freq,
)
train_collector = ts.data.Collector(
policy,
train_envs,
ts.data.VectorReplayBuffer(buffer_size, train_num),
exploration_noise=True,
)
test_collector = ts.data.Collector(
policy,
test_envs,
exploration_noise=True,
) # because DQN uses epsilon-greedy method
result = ts.trainer.OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=epoch,
step_per_epoch=step_per_epoch,
step_per_collect=step_per_collect,
episode_per_test=test_num,
batch_size=batch_size,
update_per_step=1 / step_per_collect,
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
logger=logger,
).run()
print(f"Finished training in {result.timing.total_time} seconds")
# watch performance
policy.eval()
policy.set_eps(eps_test)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=100, render=1 / 35)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,55 @@
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import (
EnvFactoryRegistered,
VectorEnvType,
)
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
from tianshou.highlevel.params.policy_params import DQNParams
from tianshou.highlevel.trainer import (
EpochStopCallbackRewardThreshold,
EpochTestCallbackDQNSetEps,
EpochTrainCallbackDQNSetEps,
)
from tianshou.utils.logging import run_main
def main():
experiment = (
DQNExperimentBuilder(
EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
ExperimentConfig(
persistence_enabled=False,
watch=True,
watch_render=1 / 35,
watch_num_episodes=100,
),
SamplingConfig(
num_epochs=10,
step_per_epoch=10000,
batch_size=64,
num_train_envs=10,
num_test_envs=100,
buffer_size=20000,
step_per_collect=10,
update_per_step=1 / 10,
),
)
.with_dqn_params(
DQNParams(
lr=1e-3,
discount_factor=0.9,
estimation_step=3,
target_update_freq=320,
),
)
.with_model_factory_default(hidden_sizes=(64, 64))
.with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3))
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0))
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
.build()
)
experiment.run()
if __name__ == "__main__":
run_main(main)

View File

@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])

View File

@ -21,7 +21,7 @@ from tianshou.utils.net.continuous import Actor, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])

View File

@ -17,7 +17,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
task: str = "Ant-v4",
buffer_size: int = 1000000,
hidden_sizes: Sequence[int] = (256, 256),
actor_lr: float = 1e-3,

View File

@ -1,17 +1,21 @@
import logging
import pickle
import warnings
import gymnasium as gym
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
from tianshou.env import VectorEnvNormObs
from tianshou.highlevel.env import (
ContinuousEnvironments,
EnvFactoryRegistered,
EnvPoolFactory,
VectorEnvType,
)
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
from tianshou.highlevel.world import World
envpool_is_available = True
try:
import envpool
except ImportError:
envpool_is_available = False
envpool = None
log = logging.getLogger(__name__)
@ -24,25 +28,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
:return: a tuple of (single env, training envs, test envs).
"""
if envpool is not None:
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
else:
warnings.warn(
"Recommend using envpool (pip install envpool) "
"to run Mujoco environments more efficiently.",
)
env = gym.make(task)
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
train_envs.seed(seed)
test_envs.seed(seed)
if obs_norm:
# obs norm wrapper
train_envs = VectorEnvNormObs(train_envs)
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
test_envs.set_obs_rms(train_envs.get_obs_rms())
return env, train_envs, test_envs
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
num_train_envs,
num_test_envs,
)
return envs.env, envs.train_envs, envs.test_envs
class MujocoEnvObsRmsPersistence(Persistence):
@ -68,21 +58,25 @@ class MujocoEnvObsRmsPersistence(Persistence):
world.envs.test_envs.set_obs_rms(obs_rms)
class MujocoEnvFactory(EnvFactory):
class MujocoEnvFactory(EnvFactoryRegistered):
def __init__(self, task: str, seed: int, obs_norm=True):
self.task = task
self.seed = seed
super().__init__(
task=task,
seed=seed,
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
)
self.obs_norm = obs_norm
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
env, train_envs, test_envs = make_mujoco_env(
task=self.task,
seed=self.seed,
num_train_envs=num_training_envs,
num_test_envs=num_test_envs,
obs_norm=self.obs_norm,
)
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
envs = super().create_envs(num_training_envs, num_test_envs)
assert isinstance(envs, ContinuousEnvironments)
# obs norm wrapper
if self.obs_norm:
envs.train_envs = VectorEnvNormObs(envs.train_envs)
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
envs.set_persistence(MujocoEnvObsRmsPersistence())
return envs

View File

@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument(

View File

@ -23,7 +23,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 1e-3,

View File

@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])

View File

@ -20,7 +20,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])

View File

@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])

View File

@ -20,7 +20,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 1e-3,

View File

@ -20,7 +20,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])

View File

@ -21,7 +21,7 @@ from tianshou.utils.net.continuous import Actor, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])

View File

@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v3")
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument(

View File

@ -23,7 +23,7 @@ from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v3",
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 1e-3,

View File

@ -140,7 +140,9 @@ ignore = [
# Logging statement uses f-string warning
"G004",
# Unnecessary `elif` after `return` statement
"RET505"
"RET505",
"D106", # undocumented public nested class
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
]
unfixable = [
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all

View File

@ -1,27 +1,14 @@
import gymnasium as gym
from tianshou.env import DummyVectorEnv
from tianshou.highlevel.env import (
ContinuousEnvironments,
DiscreteEnvironments,
EnvFactory,
Environments,
EnvFactoryRegistered,
VectorEnvType,
)
class DiscreteTestEnvFactory(EnvFactory):
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
task = "CartPole-v0"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return DiscreteEnvironments(env, train_envs, test_envs)
class DiscreteTestEnvFactory(EnvFactoryRegistered):
def __init__(self):
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
class ContinuousTestEnvFactory(EnvFactory):
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
task = "Pendulum-v1"
env = gym.make(task)
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
return ContinuousEnvironments(env, train_envs, test_envs)
class ContinuousTestEnvFactory(EnvFactoryRegistered):
def __init__(self):
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)

View File

@ -97,7 +97,6 @@ class Collector:
) -> None:
super().__init__()
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
self.env = DummyVectorEnv([lambda: env])
else:
self.env = env # type: ignore

View File

@ -163,17 +163,19 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
callbacks = self.trainer_callbacks
context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
callbacks.epoch_train_callback.get_trainer_fn(context)
if callbacks.epoch_train_callback
else None
)
test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context)
if callbacks.epoch_callback_test
callbacks.epoch_test_callback.get_trainer_fn(context)
if callbacks.epoch_test_callback
else None
)
stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
callbacks.epoch_stop_callback.get_trainer_fn(context)
if callbacks.epoch_stop_callback
else None
)
return OnpolicyTrainer(
policy=world.policy,
@ -205,17 +207,19 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
callbacks = self.trainer_callbacks
context = TrainingContext(world.policy, world.envs, world.logger)
train_fn = (
callbacks.epoch_callback_train.get_trainer_fn(context)
if callbacks.epoch_callback_train
callbacks.epoch_train_callback.get_trainer_fn(context)
if callbacks.epoch_train_callback
else None
)
test_fn = (
callbacks.epoch_callback_test.get_trainer_fn(context)
if callbacks.epoch_callback_test
callbacks.epoch_test_callback.get_trainer_fn(context)
if callbacks.epoch_test_callback
else None
)
stop_fn = (
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
callbacks.epoch_stop_callback.get_trainer_fn(context)
if callbacks.epoch_stop_callback
else None
)
return OffpolicyTrainer(
policy=world.policy,

View File

@ -1,9 +1,12 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from enum import Enum
from typing import Any, TypeAlias, cast
import gymnasium as gym
import gymnasium.spaces
from gymnasium import Env
from tianshou.env import (
BaseVectorEnv,
@ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin
TObservationShape: TypeAlias = int | Sequence[int]
log = logging.getLogger(__name__)
class EnvType(Enum):
"""Enumeration of environment types."""
@ -39,6 +44,23 @@ class EnvType(Enum):
if not self.is_discrete():
raise AssertionError(f"{requiring_entity} requires discrete environments")
@staticmethod
def from_env(env: Env) -> "EnvType":
if isinstance(env.action_space, gymnasium.spaces.Discrete):
return EnvType.DISCRETE
elif isinstance(env.action_space, gymnasium.spaces.Box):
return EnvType.CONTINUOUS
else:
raise Exception(f"Unsupported environment type with action space {env.action_space}")
class EnvMode(Enum):
"""Indicates the purpose for which an environment is created."""
TRAIN = "train"
TEST = "test"
WATCH = "watch"
class VectorEnvType(Enum):
DUMMY = "dummy"
@ -65,7 +87,7 @@ class VectorEnvType(Enum):
class Environments(ToStringMixin, ABC):
"""Represents (vectorized) environments."""
"""Represents (vectorized) environments for a learning process."""
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
self.env = env
@ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC):
@staticmethod
def from_factory_and_type(
factory_fn: Callable[[], gym.Env],
factory_fn: Callable[[EnvMode], gym.Env],
env_type: EnvType,
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "Environments":
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
@ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
if test_factory_fn is None:
test_factory_fn = factory_fn
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
env = factory_fn()
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
env = factory_fn(EnvMode.TRAIN)
match env_type:
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs)
@ -153,11 +170,10 @@ class ContinuousEnvironments(Environments):
@staticmethod
def from_factory(
factory_fn: Callable[[], gym.Env],
factory_fn: Callable[[EnvMode], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "ContinuousEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -165,8 +181,6 @@ class ContinuousEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
@ -177,7 +191,6 @@ class ContinuousEnvironments(Environments):
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)
@ -222,11 +235,10 @@ class DiscreteEnvironments(Environments):
@staticmethod
def from_factory(
factory_fn: Callable[[], gym.Env],
factory_fn: Callable[[EnvMode], gym.Env],
venv_type: VectorEnvType,
num_training_envs: int,
num_test_envs: int,
test_factory_fn: Callable[[], gym.Env] | None = None,
) -> "DiscreteEnvironments":
"""Creates an instance from a factory function that creates a single instance.
@ -234,19 +246,16 @@ class DiscreteEnvironments(Environments):
:param venv_type: the vector environment type to use for parallelization
:param num_training_envs: the number of training environments to create
:param num_test_envs: the number of test environments to create
:param test_factory_fn: the factory to use for the creation of test environment instances;
if None, use `factory_fn` for all environments (train and test)
:return: the instance
"""
return cast(
DiscreteEnvironments,
Environments.from_factory_and_type(
factory_fn,
EnvType.CONTINUOUS,
EnvType.DISCRETE,
venv_type,
num_training_envs,
num_test_envs,
test_factory_fn=test_factory_fn,
),
)
@ -260,7 +269,153 @@ class DiscreteEnvironments(Environments):
return EnvType.DISCRETE
class EnvPoolFactory:
"""A factory for the creation of envpool-based vectorized environments, which can be used in conjunction
with :class:`EnvFactoryRegistered`.
"""
def _transform_task(self, task: str) -> str:
return task
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
"""Transforms gymnasium keyword arguments to be envpool-compatible.
:param kwargs: keyword arguments that would normally be passed to `gymnasium.make`.
:param mode: the environment mode
:return: the transformed keyword arguments
"""
kwargs = dict(kwargs)
if "render_mode" in kwargs:
del kwargs["render_mode"]
return kwargs
def create_venv(
self,
task: str,
num_envs: int,
mode: EnvMode,
seed: int,
kwargs: dict,
) -> BaseVectorEnv:
import envpool
envpool_task = self._transform_task(task)
envpool_kwargs = self._transform_kwargs(kwargs, mode)
return envpool.make_gymnasium(
envpool_task,
num_envs=num_envs,
seed=seed,
**envpool_kwargs,
)
class EnvFactory(ToStringMixin, ABC):
"""Main interface for the creation of environments (in various forms)."""
def __init__(self, venv_type: VectorEnvType):
""":param venv_type: the type of vectorized environment to use"""
self.venv_type = venv_type
@abstractmethod
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
def create_env(self, mode: EnvMode) -> Env:
pass
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
"""Create vectorized environments.
:param num_envs: the number of environments
:param mode: the mode for which to create
:return: the vectorized environments
"""
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
"""Create environments for learning.
:param num_training_envs: the number of training environments
:param num_test_envs: the number of test environments
:return: the environments
"""
env = self.create_env(EnvMode.TRAIN)
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
match EnvType.from_env(env):
case EnvType.DISCRETE:
return DiscreteEnvironments(env, train_envs, test_envs)
case EnvType.CONTINUOUS:
return ContinuousEnvironments(env, train_envs, test_envs)
case _:
raise ValueError
class EnvFactoryRegistered(EnvFactory):
"""Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make`
(or via `envpool.make_gymnasium`).
"""
def __init__(
self,
*,
task: str,
seed: int,
venv_type: VectorEnvType,
envpool_factory: EnvPoolFactory | None = None,
render_mode_train: str | None = None,
render_mode_test: str | None = None,
render_mode_watch: str = "human",
**make_kwargs: Any,
):
""":param task: the gymnasium task/environment identifier
:param seed: the random seed
:param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified)
:param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed.
:param render_mode_train: the render mode to use for training environments
:param render_mode_test: the render mode to use for test environments
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`.
If envpool is used, the gymnasium parameters will be appropriately translated for use with
`envpool.make_gymnasium`.
"""
super().__init__(venv_type)
self.task = task
self.envpool_factory = envpool_factory
self.seed = seed
self.render_modes = {
EnvMode.TRAIN: render_mode_train,
EnvMode.TEST: render_mode_test,
EnvMode.WATCH: render_mode_watch,
}
self.make_kwargs = make_kwargs
def _create_kwargs(self, mode: EnvMode) -> dict:
"""Adapts the keyword arguments for the given mode.
:param mode: the mode
:return: adapted keyword arguments
"""
kwargs = dict(self.make_kwargs)
kwargs["render_mode"] = self.render_modes.get(mode)
return kwargs
def create_env(self, mode: EnvMode) -> Env:
"""Creates a single environment for the given mode.
:param mode: the mode
:return: an environment
"""
kwargs = self._create_kwargs(mode)
return gymnasium.make(self.task, **kwargs)
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
if self.envpool_factory is not None:
return self.envpool_factory.create_venv(
self.task,
num_envs,
mode,
self.seed,
self._create_kwargs(mode),
)
else:
venv = super().create_venv(num_envs, mode)
venv.seed(self.seed)
return venv

View File

@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
TRPOAgentFactory,
)
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.env import EnvFactory
from tianshou.highlevel.env import EnvFactory, EnvMode
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
from tianshou.highlevel.module.actor import (
ActorFactory,
@ -70,10 +70,10 @@ from tianshou.highlevel.persistence import (
PolicyPersistence,
)
from tianshou.highlevel.trainer import (
EpochStopCallback,
EpochTestCallback,
EpochTrainCallback,
TrainerCallbacks,
TrainerEpochCallbackTest,
TrainerEpochCallbackTrain,
TrainerStopCallback,
)
from tianshou.highlevel.world import World
from tianshou.policy import BasePolicy
@ -99,7 +99,7 @@ class ExperimentConfig:
"""Whether to perform training"""
watch: bool = True
"""Whether to watch agent performance (after training)"""
watch_num_episodes = 10
watch_num_episodes: int = 10
"""Number of episodes for which to watch performance (if `watch` is enabled)"""
watch_render: float = 0.0
"""Milliseconds between rendered frames when watching agent performance (if `watch` is enabled)"""
@ -293,7 +293,7 @@ class Experiment(ToStringMixin):
self._watch_agent(
self.config.watch_num_episodes,
policy,
test_collector,
self.env_factory,
self.config.watch_render,
)
@ -303,15 +303,18 @@ class Experiment(ToStringMixin):
def _watch_agent(
num_episodes: int,
policy: BasePolicy,
test_collector: Collector,
env_factory: EnvFactory,
render: float,
) -> None:
policy.eval()
test_collector.reset()
result = test_collector.collect(n_episode=num_episodes, render=render)
env = env_factory.create_env(EnvMode.WATCH)
collector = Collector(policy, env)
result = collector.collect(n_episode=num_episodes, render=render)
assert result.returns_stat is not None # for mypy
assert result.lens_stat is not None # for mypy
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")
log.info(
f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}",
)
class ExperimentBuilder:
@ -380,25 +383,25 @@ class ExperimentBuilder:
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
return self
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self:
"""Allows to define a callback function which is called at the beginning of every epoch during training.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_train = callback
self._trainer_callbacks.epoch_train_callback = callback
return self
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self:
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.epoch_callback_test = callback
self._trainer_callbacks.epoch_test_callback = callback
return self
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self:
"""Allows to define a callback that decides whether training shall stop early.
The callback receives the undiscounted returns of the testing result.
@ -406,7 +409,7 @@ class ExperimentBuilder:
:param callback: the callback
:return: the builder
"""
self._trainer_callbacks.stop_callback = callback
self._trainer_callbacks.epoch_stop_callback = callback
return self
@abstractmethod
@ -903,7 +906,7 @@ class DQNExperimentBuilder(
super().__init__(env_factory, experiment_config, sampling_config)
self._params: DQNParams = DQNParams()
self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False),
)
def with_dqn_params(self, params: DQNParams) -> Self:
@ -911,9 +914,34 @@ class DQNExperimentBuilder(
return self
def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self:
""":param module_factory: factory for a module which maps environment observations to a vector of Q-values (one for each action)
:return: the builder
"""
self._model_factory = module_factory
return self
def with_model_factory_default(
self,
hidden_sizes: Sequence[int],
hidden_activation: ModuleType = torch.nn.ReLU,
) -> Self:
"""Allows to configure the default factory for the model of the Q function, which maps environment observations to a vector of
Q-values (one for each action). The default model is a multi-layer perceptron.
:param hidden_sizes: the sequence of dimensions used for hidden layers
:param hidden_activation: the activation function to use for hidden layers (not used for the output layer)
:return: the builder
"""
self._model_factory = IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(
ContinuousActorType.UNSUPPORTED,
hidden_sizes=hidden_sizes,
hidden_activation=hidden_activation,
discrete_softmax=False,
),
)
return self
def _create_agent_factory(self) -> AgentFactory:
return DQNAgentFactory(
self._params,
@ -934,7 +962,7 @@ class IQNExperimentBuilder(ExperimentBuilder):
self._params: IQNParams = IQNParams()
self._preprocess_network_factory: IntermediateModuleFactory = (
IntermediateModuleFactoryFromActorFactory(
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False),
)
)

View File

@ -1,14 +1,16 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from typing import TypeVar
from typing import TypeVar, cast
from tianshou.highlevel.env import Environments
from tianshou.highlevel.logger import TLogger
from tianshou.policy import BasePolicy
from tianshou.policy import BasePolicy, DQNPolicy
from tianshou.utils.string import ToStringMixin
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
log = logging.getLogger(__name__)
class TrainingContext:
@ -18,8 +20,10 @@ class TrainingContext:
self.logger = logger
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
class EpochTrainCallback(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
of each epoch.
"""
@abstractmethod
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
@ -32,8 +36,8 @@ class TrainerEpochCallbackTrain(ToStringMixin, ABC):
return fn
class TrainerEpochCallbackTest(ToStringMixin, ABC):
"""Callback which is called at the beginning of each epoch."""
class EpochTestCallback(ToStringMixin, ABC):
"""Callback which is called at the beginning of the test phase of each epoch."""
@abstractmethod
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
@ -46,8 +50,10 @@ class TrainerEpochCallbackTest(ToStringMixin, ABC):
return fn
class TrainerStopCallback(ToStringMixin, ABC):
"""Callback indicating whether training should stop."""
class EpochStopCallback(ToStringMixin, ABC):
"""Callback which is called after the test phase of each epoch in order to determine
whether training should stop early.
"""
@abstractmethod
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
@ -69,6 +75,77 @@ class TrainerStopCallback(ToStringMixin, ABC):
class TrainerCallbacks:
"""Container for callbacks used during training."""
epoch_callback_train: TrainerEpochCallbackTrain | None = None
epoch_callback_test: TrainerEpochCallbackTest | None = None
stop_callback: TrainerStopCallback | None = None
epoch_train_callback: EpochTrainCallback | None = None
epoch_test_callback: EpochTestCallback | None = None
epoch_stop_callback: EpochStopCallback | None = None
class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch.
"""
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy)
policy.set_eps(self.eps_test)
class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the training
stage in each epoch, using a linear decay in the first `decay_steps` steps.
"""
def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000):
self.eps_train = eps_train
self.eps_train_final = eps_train_final
self.decay_steps = decay_steps
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy)
logger = context.logger
if env_step <= self.decay_steps:
eps = self.eps_train - env_step / self.decay_steps * (
self.eps_train - self.eps_train_final
)
else:
eps = self.eps_train_final
policy.set_eps(eps)
logger.write("train/env_step", env_step, {"train/eps": eps})
class EpochTestCallbackDQNSetEps(EpochTestCallback):
"""Sets the epsilon value for DQN-based policies at the beginning of the test
stage in each epoch.
"""
def __init__(self, eps_test: float):
self.eps_test = eps_test
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
policy = cast(DQNPolicy, context.policy)
policy.set_eps(self.eps_test)
class EpochStopCallbackRewardThreshold(EpochStopCallback):
"""Stops training once the mean rewards exceed the given reward threshold or the threshold that
is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
"""
def __init__(self, threshold: float | None = None):
""":param threshold: the reward threshold beyond which to stop training.
If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
"""
self.threshold = threshold
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
threshold = self.threshold
if threshold is None:
threshold = context.envs.env.spec.reward_threshold # type: ignore
assert threshold is not None
is_reached = mean_rewards >= threshold
if is_reached:
log.info(f"Reward threshold ({threshold}) exceeded")
return is_reached

View File

@ -7,7 +7,7 @@ from typing import Any
import numpy as np
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float
VALID_LOG_VALS = typing.get_args(
VALID_LOG_VALS_TYPE,
) # I know it's stupid, but we can't use Union type in isinstance

View File

@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import Any, TypeAlias, no_type_check
from typing import Any, TypeAlias, TypeVar, no_type_check
import numpy as np
import torch
@ -13,6 +13,7 @@ ModuleType = type[nn.Module]
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
TActionShape: TypeAlias = Sequence[int] | int
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
T = TypeVar("T")
def miniblock(
@ -608,3 +609,39 @@ class BaseActor(nn.Module, ABC):
@abstractmethod
def get_output_dim(self) -> int:
pass
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
If both are present, they are required to match.
:param obj: the object from which to obtain the attribute value
:param attr_name: the attribute name
:param alt_value: the alternative value for the case where the attribute is not present, which cannot be None
if the attribute is not present
:return: the value
"""
v = getattr(obj, attr_name)
if v is not None:
if alt_value is not None and v != alt_value:
raise ValueError(
f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})",
)
return v
else:
if alt_value is None:
raise ValueError(
f"Attribute '{attr_name}' of {obj} is not defined and no fallback given",
)
return alt_value
def get_output_dim(module: nn.Module, alt_value: int | None) -> int:
"""Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present.
If both are present, they must match.
:param module: the module
:param alt_value: the alternative value
:return: the value
"""
return getattr_with_matching_alt_value(module, "output_dim", alt_value)

View File

@ -1,12 +1,18 @@
import warnings
from collections.abc import Sequence
from typing import Any, cast
from typing import Any
import numpy as np
import torch
from torch import nn
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
from tianshou.utils.net.common import (
MLP,
BaseActor,
TActionShape,
TLinearLayer,
get_output_dim,
)
SIGMA_MIN = -20
SIGMA_MAX = 2
@ -50,8 +56,7 @@ class Actor(BaseActor):
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = cast(int, input_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim,
self.output_dim,
@ -118,9 +123,9 @@ class Critic(nn.Module):
self.device = device
self.preprocess = preprocess_net
self.output_dim = 1
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim, # type: ignore
input_dim,
1,
hidden_sizes,
device=self.device,
@ -199,12 +204,12 @@ class ActorProb(BaseActor):
self.preprocess = preprocess_net
self.device = device
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
self._c_sigma = conditioned_sigma
if conditioned_sigma:
self.sigma = MLP(
input_dim, # type: ignore
input_dim,
self.output_dim,
hidden_sizes,
device=self.device,

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Any, cast
from typing import Any
import numpy as np
import torch
@ -7,7 +7,7 @@ import torch.nn.functional as F
from torch import nn
from tianshou.data import Batch, to_torch
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
class Actor(BaseActor):
@ -51,8 +51,7 @@ class Actor(BaseActor):
self.device = device
self.preprocess = preprocess_net
self.output_dim = int(np.prod(action_shape))
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
input_dim = cast(int, input_dim)
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(
input_dim,
self.output_dim,
@ -118,8 +117,8 @@ class Critic(nn.Module):
self.device = device
self.preprocess = preprocess_net
self.output_dim = last_size
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""Mapping: s -> V(s)."""
@ -197,8 +196,8 @@ class ImplicitQuantileNetwork(Critic):
) -> None:
last_size = int(np.prod(action_shape))
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore
self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(
device,
)