Improvements in README and high-level API (#1022)
This makes several largely unrelated improvements in the high-level API and in the README. Main improvements in high-level API: * Improve naming in trainer-related abstractions, moved some classes from examples to the library * Improve environment factory abstraction * Some bug-fixes Main changes in README: * Add high-level example and update procedural/low-level example * Improve language/wording
This commit is contained in:
commit
6e1ffe58e5
219
README.md
219
README.md
@ -10,20 +10,17 @@
|
||||
> Tianshou no longer supports `gym`, and we recommend that you transition to
|
||||
> [Gymnasium](http://github.com/Farama-Foundation/Gymnasium).
|
||||
> If you absolutely have to use gym, you can try using [Shimmy](https://github.com/Farama-Foundation/Shimmy)
|
||||
> (the compatibility layer), but tianshou provides no guarantees that things will work then.
|
||||
> (the compatibility layer), but Tianshou provides no guarantees that things will work then.
|
||||
|
||||
> ⚠️️ **Current Status**: the tianshou master branch is currently under heavy development,
|
||||
> moving towards more features, improved interfaces, more documentation, and better compatibility with
|
||||
> other RL libraries. You can view the relevant issues in the corresponding
|
||||
> ⚠️️ **Current Status**: the Tianshou master branch is currently under heavy development,
|
||||
> moving towards more features, improved interfaces, more documentation.
|
||||
You can view the relevant issues in the corresponding
|
||||
> [milestone](https://github.com/thu-ml/tianshou/milestone/1)
|
||||
> Stay tuned! (and expect breaking changes until the release is done)
|
||||
|
||||
> ⚠️️ **Installing PyTorch**: Because of a problem with pytorch packaging and poetry in
|
||||
> current releases, the newest version of pytorch is not included in the tianshou dependencies.
|
||||
> You can still install the newest pytorch with `pip` after tianshou was installed with `poetry`.
|
||||
> [Here](https://github.com/python-poetry/poetry/issues/7902#issuecomment-1747400255) is a discussion between torch and poetry devs, who are trying to resolve it.
|
||||
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike other reinforcement learning libraries, which are partly based on TensorFlow, have unfriendly APIs ot are not optimized for speed, Tianshou provides a high-performance, modularized framework and user-friendly APIs for building deep reinforcement learning agents, enabling concise implementations without sacrificing flexibility.
|
||||
|
||||
**Tianshou** ([天授](https://baike.baidu.com/item/%E5%A4%A9%E6%8E%88)) is a reinforcement learning platform based on pure PyTorch. Unlike several existing reinforcement learning libraries, which are mainly based on TensorFlow, have many nested classes, unfriendly API, or slow-speed, Tianshou provides a fast-speed modularized framework and pythonic API for building the deep reinforcement learning agent with the least number of lines of code. The supported interface algorithms currently include:
|
||||
The set of supported algorithms includes the following:
|
||||
|
||||
- [Deep Q-Network (DQN)](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf)
|
||||
- [Double DQN](https://arxiv.org/pdf/1509.06461.pdf)
|
||||
@ -58,22 +55,28 @@
|
||||
- [Intrinsic Curiosity Module (ICM)](https://arxiv.org/pdf/1705.05363.pdf)
|
||||
- [Hindsight Experience Replay (HER)](https://arxiv.org/pdf/1707.01495.pdf)
|
||||
|
||||
Here are Tianshou's other features:
|
||||
Other noteworthy features:
|
||||
|
||||
- Elegant framework, using few lines of code in the core abstractions
|
||||
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
|
||||
- Support vectorized environment (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling)
|
||||
- Support super-fast vectorized environment [EnvPool](https://github.com/sail-sg/envpool/) for all algorithms [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#envpool-integration)
|
||||
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training)
|
||||
- Elegant framework with dual APIs:
|
||||
* Tianshou's high-level API maximizes ease of use for application development while still retaining a high degree
|
||||
of flexibility.
|
||||
* The fundamental procedural API provides a maximum of flexibility for algorithm development without being
|
||||
overly verbose.
|
||||
- State-of-the-art results in [MuJoCo benchmarks](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
|
||||
- Support for vectorized environments (synchronous or asynchronous) for all algorithms (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#parallel-sampling))
|
||||
- Support for super-fast vectorized environments based on [EnvPool](https://github.com/sail-sg/envpool/) for all algorithms (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#envpool-integration))
|
||||
- Support for recurrent state representations in actor networks and critic networks (RNN-style training for POMDPs) (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#rnn-style-training))
|
||||
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)
|
||||
- Support customized training process [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process)
|
||||
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are very fast thanks to numba jit function and vectorized numpy operation
|
||||
- Support multi-agent RL [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-agent-reinforcement-learning)
|
||||
- Support both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/) log tools
|
||||
- Support multi-GPU training [Usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-gpu)
|
||||
- Support for customized training processes (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#customize-training-process))
|
||||
- Support n-step returns estimation and prioritized experience replay for all Q-learning based algorithms; GAE, nstep and PER are highly optimized thanks to numba's just-in-time compilation and vectorized numpy operations
|
||||
- Support for multi-agent RL (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-agent-reinforcement-learning))
|
||||
- Support for logging based on both [TensorBoard](https://www.tensorflow.org/tensorboard) and [W&B](https://wandb.ai/)
|
||||
- Support for multi-GPU training (see [usage](https://tianshou.readthedocs.io/en/master/tutorials/cheatsheet.html#multi-gpu))
|
||||
- Comprehensive documentation, PEP8 code-style checking, type checking and thorough [tests](https://github.com/thu-ml/tianshou/actions)
|
||||
|
||||
In Chinese, Tianshou means divinely ordained and is derived to the gift of being born with. Tianshou is a reinforcement learning platform, and the RL algorithm does not learn from humans. So taking "Tianshou" means that there is no teacher to study with, but rather to learn by themselves through constant interaction with the environment.
|
||||
In Chinese, Tianshou means divinely ordained, being derived to the gift of being born.
|
||||
Tianshou is a reinforcement learning platform, and the nature of RL is not learn from humans.
|
||||
So taking "Tianshou" means that there is no teacher to learn from, but rather to learn by oneself through constant interaction with the environment.
|
||||
|
||||
“天授”意指上天所授,引申为与生具有的天赋。天授是强化学习平台,而强化学习算法并不是向人类学习的,所以取“天授”意思是没有老师来教,而是自己通过跟环境不断交互来进行学习。
|
||||
|
||||
@ -87,32 +90,32 @@ You can simply install Tianshou from PyPI with the following command:
|
||||
$ pip install tianshou
|
||||
```
|
||||
|
||||
If you use Anaconda or Miniconda, you can install Tianshou from conda-forge through the following command:
|
||||
If you are using Anaconda or Miniconda, you can install Tianshou from conda-forge:
|
||||
|
||||
```bash
|
||||
$ conda install tianshou -c conda-forge
|
||||
```
|
||||
|
||||
You can also install with the newest version through GitHub:
|
||||
Alternatively, you can also install the latest source version through GitHub:
|
||||
|
||||
```bash
|
||||
$ pip install git+https://github.com/thu-ml/tianshou.git@master --upgrade
|
||||
```
|
||||
|
||||
After installation, open your python console and type
|
||||
Finally, you may check the installation via your Python console as follows:
|
||||
|
||||
```python
|
||||
import tianshou
|
||||
print(tianshou.__version__)
|
||||
```
|
||||
|
||||
If no error occurs, you have successfully installed Tianshou.
|
||||
If no errors are reported, you have successfully installed Tianshou.
|
||||
|
||||
## Documentation
|
||||
|
||||
The tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
|
||||
Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
|
||||
|
||||
The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/master/test) folder and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folder.
|
||||
Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders.
|
||||
|
||||
中文文档位于 [https://tianshou.readthedocs.io/zh/master/](https://tianshou.readthedocs.io/zh/master/)。
|
||||
|
||||
@ -166,35 +169,149 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
|
||||
|
||||
<sup>(1): it has continuous integration but the coverage rate is not available</sup>
|
||||
|
||||
### Reproducible and High Quality Result
|
||||
### Reproducible, High-Quality Results
|
||||
|
||||
Tianshou has its tests. Different from other platforms, **the tests include the full agent training procedure for all of the implemented algorithms**. It would be failed once if it could not train an agent to perform well enough on limited epochs on toy scenarios. The tests secure the reproducibility of our platform. Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page for more detail.
|
||||
Tianshou is rigorously tested. In contrast to other RL platforms, **our tests include the full agent training procedure for all of the implemented algorithms**. Our tests would fail once if any of the agents failed to achieve a consistent level of performance on limited epochs.
|
||||
Our tests thus ensure reproducibility.
|
||||
Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page for more detail.
|
||||
|
||||
The Atari/Mujoco benchmark results are under [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders. **Our Mujoco result can beat most of existing benchmarks.**
|
||||
Atari and MuJoCo benchmark results can be found in the [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders respectively. **Our MuJoCo results reach or exceed the level of performance of most existing benchmarks.**
|
||||
|
||||
### Modularized Policy
|
||||
### Policy Interface
|
||||
|
||||
We decouple all algorithms roughly into the following parts:
|
||||
All algorithms implement the following, highly general API:
|
||||
|
||||
- `__init__`: initialize the policy;
|
||||
- `forward`: to compute actions over given observations;
|
||||
- `process_buffer`: process initial buffer, useful for some offline learning algorithms
|
||||
- `process_fn`: to preprocess data from replay buffer (since we have reformulated all algorithms to replay-buffer based algorithms);
|
||||
- `learn`: to learn from a given batch data;
|
||||
- `post_process_fn`: to update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
|
||||
- `forward`: compute actions based on given observations;
|
||||
- `process_buffer`: process initial buffer, which is useful for some offline learning algorithms
|
||||
- `process_fn`: preprocess data from the replay buffer (since we have reformulated *all* algorithms to replay buffer-based algorithms);
|
||||
- `learn`: learn from a given batch of data;
|
||||
- `post_process_fn`: update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
|
||||
- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`.
|
||||
|
||||
Within this API, we can interact with different policies conveniently.
|
||||
The implementation of this API suffices for a new algorithm to be applicable within Tianshou,
|
||||
making experimenation with new approaches particularly straightforward.
|
||||
|
||||
## Quick Start
|
||||
|
||||
This is an example of Deep Q Network. You can also run the full script at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
|
||||
Tianshou provides two API levels:
|
||||
* the high-level interface, which provides ease of use for end users seeking to run deep reinforcement learning applications
|
||||
* the procedural interface, which provides a maximum of control, especially for very advanced users and developers of reinforcement learning algorithms.
|
||||
|
||||
In the following, let us consider an example application using the *CartPole* gymnasium environment.
|
||||
We shall apply the deep Q network (DQN) learning algorithm using both APIs.
|
||||
|
||||
### High-Level API
|
||||
|
||||
To get started, we need some imports.
|
||||
|
||||
```python
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryRegistered,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
||||
from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTestDQNSetEps,
|
||||
TrainerEpochCallbackTrainDQNSetEps,
|
||||
)
|
||||
```
|
||||
|
||||
In the high-level API, the basis for an RL experiment is an `ExperimentBuilder`
|
||||
with which we can build the experiment we then seek to run.
|
||||
Since we want to use DQN, we use the specialization `DQNExperimentBuilder`.
|
||||
The other imports serve to provide configuration options for our experiment.
|
||||
|
||||
The high-level API provides largely declarative semantics, i.e. the code is
|
||||
almost exclusively concerned with configuration that controls what to do
|
||||
(rather than how to do it).
|
||||
|
||||
```python
|
||||
experiment = (
|
||||
DQNExperimentBuilder(
|
||||
EnvFactoryGymnasium(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
|
||||
ExperimentConfig(
|
||||
persistence_enabled=False,
|
||||
watch=True,
|
||||
watch_render=1 / 35,
|
||||
watch_num_episodes=100,
|
||||
),
|
||||
SamplingConfig(
|
||||
num_epochs=10,
|
||||
step_per_epoch=10000,
|
||||
batch_size=64,
|
||||
num_train_envs=10,
|
||||
num_test_envs=100,
|
||||
buffer_size=20000,
|
||||
step_per_collect=10,
|
||||
update_per_step=1 / 10,
|
||||
),
|
||||
)
|
||||
.with_dqn_params(
|
||||
DQNParams(
|
||||
lr=1e-3,
|
||||
discount_factor=0.9,
|
||||
estimation_step=3,
|
||||
target_update_freq=320,
|
||||
),
|
||||
)
|
||||
.with_model_factory_default(hidden_sizes=(64, 64))
|
||||
.with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3))
|
||||
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0))
|
||||
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
|
||||
.build()
|
||||
)
|
||||
experiment.run()
|
||||
```
|
||||
|
||||
The experiment builder takes three arguments:
|
||||
* the environment factory for the creation of environments. In this case,
|
||||
we use an existing factory implementation for gymnasium environments.
|
||||
* the experiment configuration, which controls persistence and the overall
|
||||
experiment flow. In this case, we have configured that we want to observe
|
||||
the agent's behavior after it is trained (`watch=True`) for a number of
|
||||
episodes (`watch_num_episodes=100`). We have disabled persistence, because
|
||||
we do not want to save training logs, the agent or its configuration for
|
||||
future use.
|
||||
* the sampling configuration, which controls fundamental training parameters,
|
||||
such as the total number of epochs we run the experiment for (`num_epochs=10`)
|
||||
and the number of environment steps each epoch shall consist of
|
||||
(`step_per_epoch=10000`).
|
||||
Every epoch consists of a series of data collection (rollout) steps and
|
||||
training steps.
|
||||
The parameter `step_per_collect` controls the amount of data that is
|
||||
collected in each collection step and after each collection step, we
|
||||
perform a training step, applying a gradient-based update based on a sample
|
||||
of data (`batch_size=64`) taken from the buffer of data that has been
|
||||
collected. For further details, see the documentation of `SamplingConfig`.
|
||||
|
||||
We then proceed to configure some of the parameters of the DQN algorithm itself
|
||||
and of the neural network model we want to use.
|
||||
A DQN-specific detail is the use of callbacks to configure the algorithm's
|
||||
epsilon parameter for exploration. We want to use random exploration during rollouts
|
||||
(train callback), but we don't when evaluating the agent's performance in the test
|
||||
environments (test callback).
|
||||
|
||||
Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py).
|
||||
Here's a run (with the training time cut short):
|
||||
|
||||
<p align="center" style="text-algin:center">
|
||||
<img src="docs/_static/images/discrete_dqn_hl.gif">
|
||||
</p>
|
||||
|
||||
|
||||
### Procedural API
|
||||
|
||||
Let us now consider an analogous example in the procedural API.
|
||||
Find the full script from which the snippets below were derived at [test/discrete/test_dqn.py](https://github.com/thu-ml/tianshou/blob/master/test/discrete/test_dqn.py).
|
||||
|
||||
First, import some relevant packages:
|
||||
|
||||
```python
|
||||
import gymnasium as gym
|
||||
import torch, numpy as np, torch.nn as nn
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import tianshou as ts
|
||||
```
|
||||
@ -202,7 +319,7 @@ import tianshou as ts
|
||||
Define some hyper-parameters:
|
||||
|
||||
```python
|
||||
task = 'CartPole-v0'
|
||||
task = 'CartPole-v1'
|
||||
lr, epoch, batch_size = 1e-3, 10, 64
|
||||
train_num, test_num = 10, 100
|
||||
gamma, n_step, target_freq = 0.9, 3, 320
|
||||
@ -227,7 +344,7 @@ Define the network:
|
||||
from tianshou.utils.net.common import Net
|
||||
# you can define other net by following the API:
|
||||
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
|
||||
env = gym.make(task)
|
||||
env = gym.make(task, render_mode="human")
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
|
||||
@ -267,7 +384,7 @@ result = ts.trainer.OffpolicyTrainer(
|
||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||
logger=logger,
|
||||
).run()
|
||||
print(f'Finished training! Use {result["duration"]}')
|
||||
print(f"Finished training in {result.timing.total_time} seconds")
|
||||
```
|
||||
|
||||
Save / load the trained policy (it's exactly the same as PyTorch `nn.module`):
|
||||
@ -294,19 +411,11 @@ $ tensorboard --logdir log/dqn
|
||||
|
||||
You can check out the [documentation](https://tianshou.readthedocs.io) for advanced usage.
|
||||
|
||||
It's worth a try: here is a test on a laptop (i7-8750H + GTX1060). It only uses **3** seconds for training an agent based on vanilla policy gradient on the CartPole-v0 task: (seed may be different across different platform and device)
|
||||
|
||||
```bash
|
||||
$ python3 test/discrete/test_pg.py --seed 0 --render 0.03
|
||||
```
|
||||
|
||||
<div align="center">
|
||||
<img src="https://github.com/thu-ml/tianshou/raw/master/docs/_static/images/testpg.gif"></a>
|
||||
</div>
|
||||
|
||||
## Contributing
|
||||
|
||||
Tianshou is still under development. More algorithms and features are going to be added and we always welcome contributions to help make Tianshou better. If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).
|
||||
Tianshou is still under development.
|
||||
Further algorithms and features are continuously being added, and we always welcome contributions to help make Tianshou better.
|
||||
If you would like to contribute, please check out [this link](https://tianshou.readthedocs.io/en/master/contributing.html).
|
||||
|
||||
## Citing Tianshou
|
||||
|
||||
@ -325,7 +434,7 @@ If you find Tianshou useful, please cite it in your publications.
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgment
|
||||
## Acknowledgments
|
||||
|
||||
Tianshou is supported by [appliedAI Institute for Europe](https://www.appliedai-institute.de/en/),
|
||||
who is committed to providing long-term support and development.
|
||||
|
||||
BIN
docs/_static/images/discrete_dqn_hl.gif
vendored
Normal file
BIN
docs/_static/images/discrete_dqn_hl.gif
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 600 KiB |
@ -1,33 +0,0 @@
|
||||
from tianshou.highlevel.trainer import (
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainingContext,
|
||||
)
|
||||
from tianshou.policy import DQNPolicy
|
||||
|
||||
|
||||
class TestEpochCallbackDQNSetEps(TrainerEpochCallbackTest):
|
||||
def __init__(self, eps_test: float):
|
||||
self.eps_test = eps_test
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class TrainEpochCallbackNatureDQNEpsLinearDecay(TrainerEpochCallbackTrain):
|
||||
def __init__(self, eps_train: float, eps_train_final: float):
|
||||
self.eps_train = eps_train
|
||||
self.eps_train_final = eps_train_final
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy: DQNPolicy = context.policy
|
||||
logger = context.logger
|
||||
# nature DQN setting, linear decay in the first 1M steps
|
||||
if env_step <= 1e6:
|
||||
eps = self.eps_train - env_step / 1e6 * (self.eps_train - self.eps_train_final)
|
||||
else:
|
||||
eps = self.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
if env_step % 1000 == 0:
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
@ -2,15 +2,11 @@
|
||||
|
||||
import os
|
||||
|
||||
from examples.atari.atari_callbacks import (
|
||||
TestEpochCallbackDQNSetEps,
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||
)
|
||||
from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DQNExperimentBuilder,
|
||||
@ -20,6 +16,10 @@ from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.highlevel.trainer import (
|
||||
EpochTestCallbackDQNSetEps,
|
||||
EpochTrainCallbackDQNEpsLinearDecay,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
@ -27,7 +27,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
scale_obs: bool = False,
|
||||
eps_test: float = 0.005,
|
||||
eps_train: float = 1.0,
|
||||
eps_train_final: float = 0.05,
|
||||
@ -79,11 +79,11 @@ def main(
|
||||
),
|
||||
)
|
||||
.with_model_factory(IntermediateModuleFactoryAtariDQN())
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
.with_epoch_train_callback(
|
||||
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
|
||||
@ -3,20 +3,20 @@
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
|
||||
from examples.atari.atari_callbacks import (
|
||||
TestEpochCallbackDQNSetEps,
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay,
|
||||
)
|
||||
from examples.atari.atari_network import (
|
||||
IntermediateModuleFactoryAtariDQN,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
IQNExperimentBuilder,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import IQNParams
|
||||
from tianshou.highlevel.trainer import (
|
||||
EpochTestCallbackDQNSetEps,
|
||||
EpochTrainCallbackDQNEpsLinearDecay,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
from tianshou.utils.logging import datetime_tag
|
||||
|
||||
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
scale_obs: bool = False,
|
||||
eps_test: float = 0.005,
|
||||
eps_train: float = 1.0,
|
||||
eps_train_final: float = 0.05,
|
||||
@ -83,11 +83,11 @@ def main(
|
||||
),
|
||||
)
|
||||
.with_preprocess_network_factory(IntermediateModuleFactoryAtariDQN(features_only=True))
|
||||
.with_trainer_epoch_callback_train(
|
||||
TrainEpochCallbackNatureDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
.with_epoch_train_callback(
|
||||
EpochTrainCallbackDQNEpsLinearDecay(eps_train, eps_train_final),
|
||||
)
|
||||
.with_trainer_epoch_callback_test(TestEpochCallbackDQNSetEps(eps_test))
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(eps_test))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
.build()
|
||||
)
|
||||
experiment.run(log_name)
|
||||
|
||||
@ -23,19 +23,27 @@ def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.
|
||||
return layer
|
||||
|
||||
|
||||
def scale_obs(module: type[nn.Module], denom: float = 255.0) -> type[nn.Module]:
|
||||
class scaled_module(module):
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
if info is None:
|
||||
info = {}
|
||||
return super().forward(obs / denom, state, info)
|
||||
class ScaledObsInputModule(torch.nn.Module):
|
||||
def __init__(self, module: torch.nn.Module, denom: float = 255.0):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.denom = denom
|
||||
# This is required such that the value can be retrieved by downstream modules (see usages of get_output_dim)
|
||||
self.output_dim = module.output_dim
|
||||
|
||||
return scaled_module
|
||||
def forward(
|
||||
self,
|
||||
obs: np.ndarray | torch.Tensor,
|
||||
state: Any | None = None,
|
||||
info: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, Any]:
|
||||
if info is None:
|
||||
info = {}
|
||||
return self.module.forward(obs / self.denom, state, info)
|
||||
|
||||
|
||||
def scale_obs(module: nn.Module, denom: float = 255.0) -> nn.Module:
|
||||
return ScaledObsInputModule(module, denom=denom)
|
||||
|
||||
|
||||
class DQN(nn.Module):
|
||||
@ -238,8 +246,7 @@ class ActorFactoryAtariDQN(ActorFactory):
|
||||
self.features_only = features_only
|
||||
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Actor:
|
||||
net_cls = scale_obs(DQN) if self.scale_obs else DQN
|
||||
net = net_cls(
|
||||
net = DQN(
|
||||
*envs.get_observation_shape(),
|
||||
envs.get_action_shape(),
|
||||
device=device,
|
||||
@ -247,6 +254,8 @@ class ActorFactoryAtariDQN(ActorFactory):
|
||||
output_dim=self.hidden_size,
|
||||
layer_init=layer_init,
|
||||
)
|
||||
if self.scale_obs:
|
||||
net = scale_obs(net)
|
||||
return Actor(net, envs.get_action_shape(), device=device, softmax_output=False).to(device)
|
||||
|
||||
|
||||
|
||||
@ -109,8 +109,7 @@ def test_ppo(args=get_args()):
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
# define model
|
||||
net_cls = scale_obs(DQN) if args.scale_obs else DQN
|
||||
net = net_cls(
|
||||
net = DQN(
|
||||
*args.state_shape,
|
||||
args.action_shape,
|
||||
device=args.device,
|
||||
@ -118,6 +117,8 @@ def test_ppo(args=get_args()):
|
||||
output_dim=args.hidden_size,
|
||||
layer_init=layer_init,
|
||||
)
|
||||
if args.scale_obs:
|
||||
net = scale_obs(net)
|
||||
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
|
||||
critic = Critic(net, device=args.device)
|
||||
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr, eps=1e-5)
|
||||
|
||||
@ -7,7 +7,7 @@ from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
ExperimentConfig,
|
||||
@ -95,7 +95,7 @@ def main(
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_sizes, scale_obs, features_only=True))
|
||||
.with_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
|
||||
@ -6,7 +6,7 @@ from examples.atari.atari_network import (
|
||||
ActorFactoryAtariDQN,
|
||||
IntermediateModuleFactoryAtariDQNFeatures,
|
||||
)
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariStopCallback
|
||||
from examples.atari.atari_wrapper import AtariEnvFactory, AtariEpochStopCallback
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.experiment import (
|
||||
DiscreteSACExperimentBuilder,
|
||||
@ -24,7 +24,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "PongNoFrameskip-v4",
|
||||
scale_obs: int = 0,
|
||||
scale_obs: bool = False,
|
||||
buffer_size: int = 100000,
|
||||
actor_lr: float = 1e-5,
|
||||
critic_lr: float = 1e-5,
|
||||
@ -82,7 +82,7 @@ def main(
|
||||
)
|
||||
.with_actor_factory(ActorFactoryAtariDQN(hidden_size, scale_obs=False, features_only=True))
|
||||
.with_common_critic_factory_use_actor()
|
||||
.with_trainer_stop_callback(AtariStopCallback(task))
|
||||
.with_epoch_stop_callback(AtariEpochStopCallback(task))
|
||||
)
|
||||
if icm_lr_scale > 0:
|
||||
builder.with_policy_wrapper_factory(
|
||||
|
||||
@ -1,21 +1,29 @@
|
||||
# Borrow a lot from openai baselines:
|
||||
# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from collections import deque
|
||||
|
||||
import cv2
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import Env
|
||||
|
||||
from tianshou.env import ShmemVectorEnv
|
||||
from tianshou.highlevel.env import DiscreteEnvironments, EnvFactory
|
||||
from tianshou.highlevel.trainer import TrainerStopCallback, TrainingContext
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryRegistered,
|
||||
EnvMode,
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.trainer import EpochStopCallback, TrainingContext
|
||||
|
||||
envpool_is_available = True
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool_is_available = False
|
||||
envpool = None
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_reset_result(reset_result):
|
||||
@ -282,7 +290,7 @@ class FrameStack(gym.Wrapper):
|
||||
|
||||
|
||||
def wrap_deepmind(
|
||||
env_id,
|
||||
env: Env,
|
||||
episode_life=True,
|
||||
clip_rewards=True,
|
||||
frame_stack=4,
|
||||
@ -293,7 +301,7 @@ def wrap_deepmind(
|
||||
|
||||
The observation is channel-first: (c, h, w) instead of (h, w, c).
|
||||
|
||||
:param str env_id: the atari environment id.
|
||||
:param env: the Atari environment to wrap.
|
||||
:param bool episode_life: wrap the episode life wrapper.
|
||||
:param bool clip_rewards: wrap the reward clipping wrapper.
|
||||
:param int frame_stack: wrap the frame stacking wrapper.
|
||||
@ -301,8 +309,6 @@ def wrap_deepmind(
|
||||
:param bool warp_frame: wrap the grayscale + resize observation wrapper.
|
||||
:return: the wrapped atari environment.
|
||||
"""
|
||||
assert "NoFrameskip" in env_id
|
||||
env = gym.make(env_id)
|
||||
env = NoopResetEnv(env, noop_max=30)
|
||||
env = MaxAndSkipEnv(env, skip=4)
|
||||
if episode_life:
|
||||
@ -320,79 +326,91 @@ def wrap_deepmind(
|
||||
return env
|
||||
|
||||
|
||||
def make_atari_env(task, seed, training_num, test_num, **kwargs):
|
||||
def make_atari_env(
|
||||
task,
|
||||
seed,
|
||||
training_num,
|
||||
test_num,
|
||||
scale: int | bool = False,
|
||||
frame_stack: int = 4,
|
||||
):
|
||||
"""Wrapper function for Atari env.
|
||||
|
||||
If EnvPool is installed, it will automatically switch to EnvPool's Atari env.
|
||||
|
||||
:return: a tuple of (single env, training envs, test envs).
|
||||
"""
|
||||
if envpool is not None:
|
||||
if kwargs.get("scale", 0):
|
||||
warnings.warn(
|
||||
"EnvPool does not include ScaledFloatFrame wrapper, "
|
||||
"please set `x = x / 255.0` inside CNN network's forward function.",
|
||||
)
|
||||
# parameters convertion
|
||||
train_envs = env = envpool.make_gymnasium(
|
||||
task.replace("NoFrameskip-v4", "-v5"),
|
||||
num_envs=training_num,
|
||||
seed=seed,
|
||||
episodic_life=True,
|
||||
reward_clip=True,
|
||||
stack_num=kwargs.get("frame_stack", 4),
|
||||
)
|
||||
test_envs = envpool.make_gymnasium(
|
||||
task.replace("NoFrameskip-v4", "-v5"),
|
||||
num_envs=test_num,
|
||||
seed=seed,
|
||||
episodic_life=False,
|
||||
reward_clip=False,
|
||||
stack_num=kwargs.get("frame_stack", 4),
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Recommend using envpool (pip install envpool) to run Atari games more efficiently.",
|
||||
)
|
||||
env = wrap_deepmind(task, **kwargs)
|
||||
train_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: wrap_deepmind(task, episode_life=True, clip_rewards=True, **kwargs)
|
||||
for _ in range(training_num)
|
||||
],
|
||||
)
|
||||
test_envs = ShmemVectorEnv(
|
||||
[
|
||||
lambda: wrap_deepmind(task, episode_life=False, clip_rewards=False, **kwargs)
|
||||
for _ in range(test_num)
|
||||
],
|
||||
)
|
||||
env.seed(seed)
|
||||
train_envs.seed(seed)
|
||||
test_envs.seed(seed)
|
||||
return env, train_envs, test_envs
|
||||
env_factory = AtariEnvFactory(task, seed, frame_stack, scale=bool(scale))
|
||||
envs = env_factory.create_envs(training_num, test_num)
|
||||
return envs.env, envs.train_envs, envs.test_envs
|
||||
|
||||
|
||||
class AtariEnvFactory(EnvFactory):
|
||||
def __init__(self, task: str, seed: int, frame_stack: int, scale: int = 0):
|
||||
self.task = task
|
||||
self.seed = seed
|
||||
class AtariEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
seed: int,
|
||||
frame_stack: int,
|
||||
scale: bool = False,
|
||||
use_envpool_if_available: bool = True,
|
||||
):
|
||||
assert "NoFrameskip" in task
|
||||
self.frame_stack = frame_stack
|
||||
self.scale = scale
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> DiscreteEnvironments:
|
||||
env, train_envs, test_envs = make_atari_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
training_num=num_training_envs,
|
||||
test_num=num_test_envs,
|
||||
scale=self.scale,
|
||||
frame_stack=self.frame_stack,
|
||||
envpool_factory = None
|
||||
if use_envpool_if_available:
|
||||
if envpool_is_available:
|
||||
envpool_factory = self.EnvPoolFactory(self)
|
||||
log.info("Using envpool, because it available")
|
||||
else:
|
||||
log.info("Not using envpool, because it is not available")
|
||||
super().__init__(
|
||||
task=task,
|
||||
seed=seed,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
envpool_factory=envpool_factory,
|
||||
)
|
||||
return DiscreteEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
|
||||
def create_env(self, mode: EnvMode) -> Env:
|
||||
env = super().create_env(mode)
|
||||
is_train = mode == EnvMode.TRAIN
|
||||
return wrap_deepmind(
|
||||
env,
|
||||
episode_life=is_train,
|
||||
clip_rewards=is_train,
|
||||
frame_stack=self.frame_stack,
|
||||
scale=self.scale,
|
||||
)
|
||||
|
||||
class EnvPoolFactory(EnvPoolFactory):
|
||||
"""Atari-specific envpool creation.
|
||||
Since envpool internally handles the functions that are implemented through the wrappers in `wrap_deepmind`,
|
||||
it sets the creation keyword arguments accordingly.
|
||||
"""
|
||||
|
||||
def __init__(self, parent: "AtariEnvFactory"):
|
||||
self.parent = parent
|
||||
if self.parent.scale:
|
||||
warnings.warn(
|
||||
"EnvPool does not include ScaledFloatFrame wrapper, "
|
||||
"please compensate by scaling inside your network's forward function (e.g. `x = x / 255.0` for Atari)",
|
||||
)
|
||||
|
||||
def _transform_task(self, task: str) -> str:
|
||||
task = super()._transform_task(task)
|
||||
# TODO: Maybe warn user, explain why this is needed
|
||||
return task.replace("NoFrameskip-v4", "-v5")
|
||||
|
||||
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
|
||||
kwargs = super()._transform_kwargs(kwargs, mode)
|
||||
is_train = mode == EnvMode.TRAIN
|
||||
kwargs["reward_clip"] = is_train
|
||||
kwargs["episodic_life"] = is_train
|
||||
kwargs["stack_num"] = self.parent.frame_stack
|
||||
return kwargs
|
||||
|
||||
|
||||
class AtariStopCallback(TrainerStopCallback):
|
||||
class AtariEpochStopCallback(EpochStopCallback):
|
||||
def __init__(self, task: str):
|
||||
self.task = task
|
||||
|
||||
|
||||
78
examples/discrete/discrete_dqn.py
Normal file
78
examples/discrete/discrete_dqn.py
Normal file
@ -0,0 +1,78 @@
|
||||
import gymnasium as gym
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
import tianshou as ts
|
||||
|
||||
|
||||
def main():
|
||||
task = "CartPole-v1"
|
||||
lr, epoch, batch_size = 1e-3, 10, 64
|
||||
train_num, test_num = 10, 100
|
||||
gamma, n_step, target_freq = 0.9, 3, 320
|
||||
buffer_size = 20000
|
||||
eps_train, eps_test = 0.1, 0.05
|
||||
step_per_epoch, step_per_collect = 10000, 10
|
||||
logger = ts.utils.TensorboardLogger(SummaryWriter("log/dqn")) # TensorBoard is supported!
|
||||
# For other loggers: https://tianshou.readthedocs.io/en/master/tutorials/logger.html
|
||||
|
||||
# you can also try with SubprocVectorEnv
|
||||
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
|
||||
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
|
||||
|
||||
from tianshou.utils.net.common import Net
|
||||
|
||||
# you can define other net by following the API:
|
||||
# https://tianshou.readthedocs.io/en/master/tutorials/dqn.html#build-the-network
|
||||
env = gym.make(task, render_mode="human")
|
||||
state_shape = env.observation_space.shape or env.observation_space.n
|
||||
action_shape = env.action_space.shape or env.action_space.n
|
||||
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
|
||||
optim = torch.optim.Adam(net.parameters(), lr=lr)
|
||||
|
||||
policy = ts.policy.DQNPolicy(
|
||||
model=net,
|
||||
optim=optim,
|
||||
discount_factor=gamma,
|
||||
action_space=env.action_space,
|
||||
estimation_step=n_step,
|
||||
target_update_freq=target_freq,
|
||||
)
|
||||
train_collector = ts.data.Collector(
|
||||
policy,
|
||||
train_envs,
|
||||
ts.data.VectorReplayBuffer(buffer_size, train_num),
|
||||
exploration_noise=True,
|
||||
)
|
||||
test_collector = ts.data.Collector(
|
||||
policy,
|
||||
test_envs,
|
||||
exploration_noise=True,
|
||||
) # because DQN uses epsilon-greedy method
|
||||
|
||||
result = ts.trainer.OffpolicyTrainer(
|
||||
policy=policy,
|
||||
train_collector=train_collector,
|
||||
test_collector=test_collector,
|
||||
max_epoch=epoch,
|
||||
step_per_epoch=step_per_epoch,
|
||||
step_per_collect=step_per_collect,
|
||||
episode_per_test=test_num,
|
||||
batch_size=batch_size,
|
||||
update_per_step=1 / step_per_collect,
|
||||
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
|
||||
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
|
||||
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
|
||||
logger=logger,
|
||||
).run()
|
||||
print(f"Finished training in {result.timing.total_time} seconds")
|
||||
|
||||
# watch performance
|
||||
policy.eval()
|
||||
policy.set_eps(eps_test)
|
||||
collector = ts.data.Collector(policy, env, exploration_noise=True)
|
||||
collector.collect(n_episode=100, render=1 / 35)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
55
examples/discrete/discrete_dqn_hl.py
Normal file
55
examples/discrete/discrete_dqn_hl.py
Normal file
@ -0,0 +1,55 @@
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import (
|
||||
EnvFactoryRegistered,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
|
||||
from tianshou.highlevel.params.policy_params import DQNParams
|
||||
from tianshou.highlevel.trainer import (
|
||||
EpochStopCallbackRewardThreshold,
|
||||
EpochTestCallbackDQNSetEps,
|
||||
EpochTrainCallbackDQNSetEps,
|
||||
)
|
||||
from tianshou.utils.logging import run_main
|
||||
|
||||
|
||||
def main():
|
||||
experiment = (
|
||||
DQNExperimentBuilder(
|
||||
EnvFactoryRegistered(task="CartPole-v1", seed=0, venv_type=VectorEnvType.DUMMY),
|
||||
ExperimentConfig(
|
||||
persistence_enabled=False,
|
||||
watch=True,
|
||||
watch_render=1 / 35,
|
||||
watch_num_episodes=100,
|
||||
),
|
||||
SamplingConfig(
|
||||
num_epochs=10,
|
||||
step_per_epoch=10000,
|
||||
batch_size=64,
|
||||
num_train_envs=10,
|
||||
num_test_envs=100,
|
||||
buffer_size=20000,
|
||||
step_per_collect=10,
|
||||
update_per_step=1 / 10,
|
||||
),
|
||||
)
|
||||
.with_dqn_params(
|
||||
DQNParams(
|
||||
lr=1e-3,
|
||||
discount_factor=0.9,
|
||||
estimation_step=3,
|
||||
target_update_freq=320,
|
||||
),
|
||||
)
|
||||
.with_model_factory_default(hidden_sizes=(64, 64))
|
||||
.with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3))
|
||||
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0))
|
||||
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
|
||||
.build()
|
||||
)
|
||||
experiment.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_main(main)
|
||||
@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=4096)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||
|
||||
@ -21,7 +21,7 @@ from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=1000000)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||
|
||||
@ -17,7 +17,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
task: str = "Ant-v4",
|
||||
buffer_size: int = 1000000,
|
||||
hidden_sizes: Sequence[int] = (256, 256),
|
||||
actor_lr: float = 1e-3,
|
||||
|
||||
@ -1,17 +1,21 @@
|
||||
import logging
|
||||
import pickle
|
||||
import warnings
|
||||
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import ShmemVectorEnv, VectorEnvNormObs
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory
|
||||
from tianshou.env import VectorEnvNormObs
|
||||
from tianshou.highlevel.env import (
|
||||
ContinuousEnvironments,
|
||||
EnvFactoryRegistered,
|
||||
EnvPoolFactory,
|
||||
VectorEnvType,
|
||||
)
|
||||
from tianshou.highlevel.persistence import Persistence, PersistEvent, RestoreEvent
|
||||
from tianshou.highlevel.world import World
|
||||
|
||||
envpool_is_available = True
|
||||
try:
|
||||
import envpool
|
||||
except ImportError:
|
||||
envpool_is_available = False
|
||||
envpool = None
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -24,25 +28,11 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in
|
||||
|
||||
:return: a tuple of (single env, training envs, test envs).
|
||||
"""
|
||||
if envpool is not None:
|
||||
train_envs = env = envpool.make_gymnasium(task, num_envs=num_train_envs, seed=seed)
|
||||
test_envs = envpool.make_gymnasium(task, num_envs=num_test_envs, seed=seed)
|
||||
else:
|
||||
warnings.warn(
|
||||
"Recommend using envpool (pip install envpool) "
|
||||
"to run Mujoco environments more efficiently.",
|
||||
)
|
||||
env = gym.make(task)
|
||||
train_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
|
||||
test_envs = ShmemVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
train_envs.seed(seed)
|
||||
test_envs.seed(seed)
|
||||
if obs_norm:
|
||||
# obs norm wrapper
|
||||
train_envs = VectorEnvNormObs(train_envs)
|
||||
test_envs = VectorEnvNormObs(test_envs, update_obs_rms=False)
|
||||
test_envs.set_obs_rms(train_envs.get_obs_rms())
|
||||
return env, train_envs, test_envs
|
||||
envs = MujocoEnvFactory(task, seed, obs_norm=obs_norm).create_envs(
|
||||
num_train_envs,
|
||||
num_test_envs,
|
||||
)
|
||||
return envs.env, envs.train_envs, envs.test_envs
|
||||
|
||||
|
||||
class MujocoEnvObsRmsPersistence(Persistence):
|
||||
@ -68,21 +58,25 @@ class MujocoEnvObsRmsPersistence(Persistence):
|
||||
world.envs.test_envs.set_obs_rms(obs_rms)
|
||||
|
||||
|
||||
class MujocoEnvFactory(EnvFactory):
|
||||
class MujocoEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self, task: str, seed: int, obs_norm=True):
|
||||
self.task = task
|
||||
self.seed = seed
|
||||
super().__init__(
|
||||
task=task,
|
||||
seed=seed,
|
||||
venv_type=VectorEnvType.SUBPROC_SHARED_MEM,
|
||||
envpool_factory=EnvPoolFactory() if envpool_is_available else None,
|
||||
)
|
||||
self.obs_norm = obs_norm
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> ContinuousEnvironments:
|
||||
env, train_envs, test_envs = make_mujoco_env(
|
||||
task=self.task,
|
||||
seed=self.seed,
|
||||
num_train_envs=num_training_envs,
|
||||
num_test_envs=num_test_envs,
|
||||
obs_norm=self.obs_norm,
|
||||
)
|
||||
envs = ContinuousEnvironments(env=env, train_envs=train_envs, test_envs=test_envs)
|
||||
envs = super().create_envs(num_training_envs, num_test_envs)
|
||||
assert isinstance(envs, ContinuousEnvironments)
|
||||
|
||||
# obs norm wrapper
|
||||
if self.obs_norm:
|
||||
envs.train_envs = VectorEnvNormObs(envs.train_envs)
|
||||
envs.test_envs = VectorEnvNormObs(envs.test_envs, update_obs_rms=False)
|
||||
envs.test_envs.set_obs_rms(envs.train_envs.get_obs_rms())
|
||||
envs.set_persistence(MujocoEnvObsRmsPersistence())
|
||||
|
||||
return envs
|
||||
|
||||
@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=4096)
|
||||
parser.add_argument(
|
||||
|
||||
@ -23,7 +23,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
task: str = "Ant-v4",
|
||||
buffer_size: int = 4096,
|
||||
hidden_sizes: Sequence[int] = (64, 64),
|
||||
lr: float = 1e-3,
|
||||
|
||||
@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=4096)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||
|
||||
@ -20,7 +20,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=1000000)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||
|
||||
@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=4096)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
|
||||
|
||||
@ -20,7 +20,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
task: str = "Ant-v4",
|
||||
buffer_size: int = 4096,
|
||||
hidden_sizes: Sequence[int] = (64, 64),
|
||||
lr: float = 1e-3,
|
||||
|
||||
@ -20,7 +20,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=1000000)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||
|
||||
@ -21,7 +21,7 @@ from tianshou.utils.net.continuous import Actor, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=1000000)
|
||||
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
|
||||
|
||||
@ -23,7 +23,7 @@ from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--task", type=str, default="Ant-v3")
|
||||
parser.add_argument("--task", type=str, default="Ant-v4")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--buffer-size", type=int, default=4096)
|
||||
parser.add_argument(
|
||||
|
||||
@ -23,7 +23,7 @@ from tianshou.utils.logging import datetime_tag
|
||||
|
||||
def main(
|
||||
experiment_config: ExperimentConfig,
|
||||
task: str = "Ant-v3",
|
||||
task: str = "Ant-v4",
|
||||
buffer_size: int = 4096,
|
||||
hidden_sizes: Sequence[int] = (64, 64),
|
||||
lr: float = 1e-3,
|
||||
|
||||
@ -140,7 +140,9 @@ ignore = [
|
||||
# Logging statement uses f-string warning
|
||||
"G004",
|
||||
# Unnecessary `elif` after `return` statement
|
||||
"RET505"
|
||||
"RET505",
|
||||
"D106", # undocumented public nested class
|
||||
"D205", # blank line after summary (prevents summary-only docstrings, which makes no sense)
|
||||
]
|
||||
unfixable = [
|
||||
"F841", # unused variable. ruff keeps the call, but mostly we want to get rid of it all
|
||||
|
||||
@ -1,27 +1,14 @@
|
||||
import gymnasium as gym
|
||||
|
||||
from tianshou.env import DummyVectorEnv
|
||||
from tianshou.highlevel.env import (
|
||||
ContinuousEnvironments,
|
||||
DiscreteEnvironments,
|
||||
EnvFactory,
|
||||
Environments,
|
||||
EnvFactoryRegistered,
|
||||
VectorEnvType,
|
||||
)
|
||||
|
||||
|
||||
class DiscreteTestEnvFactory(EnvFactory):
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
task = "CartPole-v0"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
class DiscreteTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self):
|
||||
super().__init__(task="CartPole-v0", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
|
||||
|
||||
class ContinuousTestEnvFactory(EnvFactory):
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
task = "Pendulum-v1"
|
||||
env = gym.make(task)
|
||||
train_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
|
||||
test_envs = DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
class ContinuousTestEnvFactory(EnvFactoryRegistered):
|
||||
def __init__(self):
|
||||
super().__init__(task="Pendulum-v1", seed=42, venv_type=VectorEnvType.DUMMY)
|
||||
|
||||
@ -97,7 +97,6 @@ class Collector:
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if isinstance(env, gym.Env) and not hasattr(env, "__len__"):
|
||||
warnings.warn("Single environment detected, wrap to DummyVectorEnv.")
|
||||
self.env = DummyVectorEnv([lambda: env])
|
||||
else:
|
||||
self.env = env # type: ignore
|
||||
|
||||
@ -163,17 +163,19 @@ class OnPolicyAgentFactory(AgentFactory, ABC):
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
callbacks.epoch_train_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_train_callback
|
||||
else None
|
||||
)
|
||||
test_fn = (
|
||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_test
|
||||
callbacks.epoch_test_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_test_callback
|
||||
else None
|
||||
)
|
||||
stop_fn = (
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
callbacks.epoch_stop_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_stop_callback
|
||||
else None
|
||||
)
|
||||
return OnpolicyTrainer(
|
||||
policy=world.policy,
|
||||
@ -205,17 +207,19 @@ class OffPolicyAgentFactory(AgentFactory, ABC):
|
||||
callbacks = self.trainer_callbacks
|
||||
context = TrainingContext(world.policy, world.envs, world.logger)
|
||||
train_fn = (
|
||||
callbacks.epoch_callback_train.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_train
|
||||
callbacks.epoch_train_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_train_callback
|
||||
else None
|
||||
)
|
||||
test_fn = (
|
||||
callbacks.epoch_callback_test.get_trainer_fn(context)
|
||||
if callbacks.epoch_callback_test
|
||||
callbacks.epoch_test_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_test_callback
|
||||
else None
|
||||
)
|
||||
stop_fn = (
|
||||
callbacks.stop_callback.get_trainer_fn(context) if callbacks.stop_callback else None
|
||||
callbacks.epoch_stop_callback.get_trainer_fn(context)
|
||||
if callbacks.epoch_stop_callback
|
||||
else None
|
||||
)
|
||||
return OffpolicyTrainer(
|
||||
policy=world.policy,
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from enum import Enum
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
import gymnasium as gym
|
||||
import gymnasium.spaces
|
||||
from gymnasium import Env
|
||||
|
||||
from tianshou.env import (
|
||||
BaseVectorEnv,
|
||||
@ -18,6 +21,8 @@ from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TObservationShape: TypeAlias = int | Sequence[int]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnvType(Enum):
|
||||
"""Enumeration of environment types."""
|
||||
@ -39,6 +44,23 @@ class EnvType(Enum):
|
||||
if not self.is_discrete():
|
||||
raise AssertionError(f"{requiring_entity} requires discrete environments")
|
||||
|
||||
@staticmethod
|
||||
def from_env(env: Env) -> "EnvType":
|
||||
if isinstance(env.action_space, gymnasium.spaces.Discrete):
|
||||
return EnvType.DISCRETE
|
||||
elif isinstance(env.action_space, gymnasium.spaces.Box):
|
||||
return EnvType.CONTINUOUS
|
||||
else:
|
||||
raise Exception(f"Unsupported environment type with action space {env.action_space}")
|
||||
|
||||
|
||||
class EnvMode(Enum):
|
||||
"""Indicates the purpose for which an environment is created."""
|
||||
|
||||
TRAIN = "train"
|
||||
TEST = "test"
|
||||
WATCH = "watch"
|
||||
|
||||
|
||||
class VectorEnvType(Enum):
|
||||
DUMMY = "dummy"
|
||||
@ -65,7 +87,7 @@ class VectorEnvType(Enum):
|
||||
|
||||
|
||||
class Environments(ToStringMixin, ABC):
|
||||
"""Represents (vectorized) environments."""
|
||||
"""Represents (vectorized) environments for a learning process."""
|
||||
|
||||
def __init__(self, env: gym.Env, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv):
|
||||
self.env = env
|
||||
@ -75,12 +97,11 @@ class Environments(ToStringMixin, ABC):
|
||||
|
||||
@staticmethod
|
||||
def from_factory_and_type(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
factory_fn: Callable[[EnvMode], gym.Env],
|
||||
env_type: EnvType,
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "Environments":
|
||||
"""Creates a suitable subtype instance from a factory function that creates a single instance and the type of environment (continuous/discrete).
|
||||
|
||||
@ -89,15 +110,11 @@ class Environments(ToStringMixin, ABC):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
if test_factory_fn is None:
|
||||
test_factory_fn = factory_fn
|
||||
train_envs = venv_type.create_venv([factory_fn] * num_training_envs)
|
||||
test_envs = venv_type.create_venv([test_factory_fn] * num_test_envs)
|
||||
env = factory_fn()
|
||||
train_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TRAIN)] * num_training_envs)
|
||||
test_envs = venv_type.create_venv([lambda: factory_fn(EnvMode.TEST)] * num_test_envs)
|
||||
env = factory_fn(EnvMode.TRAIN)
|
||||
match env_type:
|
||||
case EnvType.CONTINUOUS:
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
@ -153,11 +170,10 @@ class ContinuousEnvironments(Environments):
|
||||
|
||||
@staticmethod
|
||||
def from_factory(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
factory_fn: Callable[[EnvMode], gym.Env],
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "ContinuousEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -165,8 +181,6 @@ class ContinuousEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
@ -177,7 +191,6 @@ class ContinuousEnvironments(Environments):
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
test_factory_fn=test_factory_fn,
|
||||
),
|
||||
)
|
||||
|
||||
@ -222,11 +235,10 @@ class DiscreteEnvironments(Environments):
|
||||
|
||||
@staticmethod
|
||||
def from_factory(
|
||||
factory_fn: Callable[[], gym.Env],
|
||||
factory_fn: Callable[[EnvMode], gym.Env],
|
||||
venv_type: VectorEnvType,
|
||||
num_training_envs: int,
|
||||
num_test_envs: int,
|
||||
test_factory_fn: Callable[[], gym.Env] | None = None,
|
||||
) -> "DiscreteEnvironments":
|
||||
"""Creates an instance from a factory function that creates a single instance.
|
||||
|
||||
@ -234,19 +246,16 @@ class DiscreteEnvironments(Environments):
|
||||
:param venv_type: the vector environment type to use for parallelization
|
||||
:param num_training_envs: the number of training environments to create
|
||||
:param num_test_envs: the number of test environments to create
|
||||
:param test_factory_fn: the factory to use for the creation of test environment instances;
|
||||
if None, use `factory_fn` for all environments (train and test)
|
||||
:return: the instance
|
||||
"""
|
||||
return cast(
|
||||
DiscreteEnvironments,
|
||||
Environments.from_factory_and_type(
|
||||
factory_fn,
|
||||
EnvType.CONTINUOUS,
|
||||
EnvType.DISCRETE,
|
||||
venv_type,
|
||||
num_training_envs,
|
||||
num_test_envs,
|
||||
test_factory_fn=test_factory_fn,
|
||||
),
|
||||
)
|
||||
|
||||
@ -260,7 +269,153 @@ class DiscreteEnvironments(Environments):
|
||||
return EnvType.DISCRETE
|
||||
|
||||
|
||||
class EnvPoolFactory:
|
||||
"""A factory for the creation of envpool-based vectorized environments, which can be used in conjunction
|
||||
with :class:`EnvFactoryRegistered`.
|
||||
"""
|
||||
|
||||
def _transform_task(self, task: str) -> str:
|
||||
return task
|
||||
|
||||
def _transform_kwargs(self, kwargs: dict, mode: EnvMode) -> dict:
|
||||
"""Transforms gymnasium keyword arguments to be envpool-compatible.
|
||||
|
||||
:param kwargs: keyword arguments that would normally be passed to `gymnasium.make`.
|
||||
:param mode: the environment mode
|
||||
:return: the transformed keyword arguments
|
||||
"""
|
||||
kwargs = dict(kwargs)
|
||||
if "render_mode" in kwargs:
|
||||
del kwargs["render_mode"]
|
||||
return kwargs
|
||||
|
||||
def create_venv(
|
||||
self,
|
||||
task: str,
|
||||
num_envs: int,
|
||||
mode: EnvMode,
|
||||
seed: int,
|
||||
kwargs: dict,
|
||||
) -> BaseVectorEnv:
|
||||
import envpool
|
||||
|
||||
envpool_task = self._transform_task(task)
|
||||
envpool_kwargs = self._transform_kwargs(kwargs, mode)
|
||||
return envpool.make_gymnasium(
|
||||
envpool_task,
|
||||
num_envs=num_envs,
|
||||
seed=seed,
|
||||
**envpool_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class EnvFactory(ToStringMixin, ABC):
|
||||
"""Main interface for the creation of environments (in various forms)."""
|
||||
|
||||
def __init__(self, venv_type: VectorEnvType):
|
||||
""":param venv_type: the type of vectorized environment to use"""
|
||||
self.venv_type = venv_type
|
||||
|
||||
@abstractmethod
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
def create_env(self, mode: EnvMode) -> Env:
|
||||
pass
|
||||
|
||||
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||
"""Create vectorized environments.
|
||||
|
||||
:param num_envs: the number of environments
|
||||
:param mode: the mode for which to create
|
||||
:return: the vectorized environments
|
||||
"""
|
||||
return self.venv_type.create_venv([lambda: self.create_env(mode)] * num_envs)
|
||||
|
||||
def create_envs(self, num_training_envs: int, num_test_envs: int) -> Environments:
|
||||
"""Create environments for learning.
|
||||
|
||||
:param num_training_envs: the number of training environments
|
||||
:param num_test_envs: the number of test environments
|
||||
:return: the environments
|
||||
"""
|
||||
env = self.create_env(EnvMode.TRAIN)
|
||||
train_envs = self.create_venv(num_training_envs, EnvMode.TRAIN)
|
||||
test_envs = self.create_venv(num_test_envs, EnvMode.TEST)
|
||||
match EnvType.from_env(env):
|
||||
case EnvType.DISCRETE:
|
||||
return DiscreteEnvironments(env, train_envs, test_envs)
|
||||
case EnvType.CONTINUOUS:
|
||||
return ContinuousEnvironments(env, train_envs, test_envs)
|
||||
case _:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class EnvFactoryRegistered(EnvFactory):
|
||||
"""Factory for environments that are registered with gymnasium and thus can be created via `gymnasium.make`
|
||||
(or via `envpool.make_gymnasium`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
task: str,
|
||||
seed: int,
|
||||
venv_type: VectorEnvType,
|
||||
envpool_factory: EnvPoolFactory | None = None,
|
||||
render_mode_train: str | None = None,
|
||||
render_mode_test: str | None = None,
|
||||
render_mode_watch: str = "human",
|
||||
**make_kwargs: Any,
|
||||
):
|
||||
""":param task: the gymnasium task/environment identifier
|
||||
:param seed: the random seed
|
||||
:param venv_type: the type of vectorized environment to use (if `envpool_factory` is not specified)
|
||||
:param envpool_factory: the factory to use for vectorized environment creation based on envpool; envpool must be installed.
|
||||
:param render_mode_train: the render mode to use for training environments
|
||||
:param render_mode_test: the render mode to use for test environments
|
||||
:param render_mode_watch: the render mode to use for environments that are used to watch agent performance
|
||||
:param make_kwargs: additional keyword arguments to pass on to `gymnasium.make`.
|
||||
If envpool is used, the gymnasium parameters will be appropriately translated for use with
|
||||
`envpool.make_gymnasium`.
|
||||
"""
|
||||
super().__init__(venv_type)
|
||||
self.task = task
|
||||
self.envpool_factory = envpool_factory
|
||||
self.seed = seed
|
||||
self.render_modes = {
|
||||
EnvMode.TRAIN: render_mode_train,
|
||||
EnvMode.TEST: render_mode_test,
|
||||
EnvMode.WATCH: render_mode_watch,
|
||||
}
|
||||
self.make_kwargs = make_kwargs
|
||||
|
||||
def _create_kwargs(self, mode: EnvMode) -> dict:
|
||||
"""Adapts the keyword arguments for the given mode.
|
||||
|
||||
:param mode: the mode
|
||||
:return: adapted keyword arguments
|
||||
"""
|
||||
kwargs = dict(self.make_kwargs)
|
||||
kwargs["render_mode"] = self.render_modes.get(mode)
|
||||
return kwargs
|
||||
|
||||
def create_env(self, mode: EnvMode) -> Env:
|
||||
"""Creates a single environment for the given mode.
|
||||
|
||||
:param mode: the mode
|
||||
:return: an environment
|
||||
"""
|
||||
kwargs = self._create_kwargs(mode)
|
||||
return gymnasium.make(self.task, **kwargs)
|
||||
|
||||
def create_venv(self, num_envs: int, mode: EnvMode) -> BaseVectorEnv:
|
||||
if self.envpool_factory is not None:
|
||||
return self.envpool_factory.create_venv(
|
||||
self.task,
|
||||
num_envs,
|
||||
mode,
|
||||
self.seed,
|
||||
self._create_kwargs(mode),
|
||||
)
|
||||
else:
|
||||
venv = super().create_venv(num_envs, mode)
|
||||
venv.seed(self.seed)
|
||||
return venv
|
||||
|
||||
@ -26,7 +26,7 @@ from tianshou.highlevel.agent import (
|
||||
TRPOAgentFactory,
|
||||
)
|
||||
from tianshou.highlevel.config import SamplingConfig
|
||||
from tianshou.highlevel.env import EnvFactory
|
||||
from tianshou.highlevel.env import EnvFactory, EnvMode
|
||||
from tianshou.highlevel.logger import LoggerFactory, LoggerFactoryDefault, TLogger
|
||||
from tianshou.highlevel.module.actor import (
|
||||
ActorFactory,
|
||||
@ -70,10 +70,10 @@ from tianshou.highlevel.persistence import (
|
||||
PolicyPersistence,
|
||||
)
|
||||
from tianshou.highlevel.trainer import (
|
||||
EpochStopCallback,
|
||||
EpochTestCallback,
|
||||
EpochTrainCallback,
|
||||
TrainerCallbacks,
|
||||
TrainerEpochCallbackTest,
|
||||
TrainerEpochCallbackTrain,
|
||||
TrainerStopCallback,
|
||||
)
|
||||
from tianshou.highlevel.world import World
|
||||
from tianshou.policy import BasePolicy
|
||||
@ -99,7 +99,7 @@ class ExperimentConfig:
|
||||
"""Whether to perform training"""
|
||||
watch: bool = True
|
||||
"""Whether to watch agent performance (after training)"""
|
||||
watch_num_episodes = 10
|
||||
watch_num_episodes: int = 10
|
||||
"""Number of episodes for which to watch performance (if `watch` is enabled)"""
|
||||
watch_render: float = 0.0
|
||||
"""Milliseconds between rendered frames when watching agent performance (if `watch` is enabled)"""
|
||||
@ -293,7 +293,7 @@ class Experiment(ToStringMixin):
|
||||
self._watch_agent(
|
||||
self.config.watch_num_episodes,
|
||||
policy,
|
||||
test_collector,
|
||||
self.env_factory,
|
||||
self.config.watch_render,
|
||||
)
|
||||
|
||||
@ -303,15 +303,18 @@ class Experiment(ToStringMixin):
|
||||
def _watch_agent(
|
||||
num_episodes: int,
|
||||
policy: BasePolicy,
|
||||
test_collector: Collector,
|
||||
env_factory: EnvFactory,
|
||||
render: float,
|
||||
) -> None:
|
||||
policy.eval()
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=num_episodes, render=render)
|
||||
env = env_factory.create_env(EnvMode.WATCH)
|
||||
collector = Collector(policy, env)
|
||||
result = collector.collect(n_episode=num_episodes, render=render)
|
||||
assert result.returns_stat is not None # for mypy
|
||||
assert result.lens_stat is not None # for mypy
|
||||
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")
|
||||
log.info(
|
||||
f"Watched episodes: mean reward={result.returns_stat.mean}, mean episode length={result.lens_stat.mean}",
|
||||
)
|
||||
|
||||
|
||||
class ExperimentBuilder:
|
||||
@ -380,25 +383,25 @@ class ExperimentBuilder:
|
||||
self._optim_factory = OptimizerFactoryAdam(betas=betas, eps=eps, weight_decay=weight_decay)
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_train(self, callback: TrainerEpochCallbackTrain) -> Self:
|
||||
def with_epoch_train_callback(self, callback: EpochTrainCallback) -> Self:
|
||||
"""Allows to define a callback function which is called at the beginning of every epoch during training.
|
||||
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.epoch_callback_train = callback
|
||||
self._trainer_callbacks.epoch_train_callback = callback
|
||||
return self
|
||||
|
||||
def with_trainer_epoch_callback_test(self, callback: TrainerEpochCallbackTest) -> Self:
|
||||
def with_epoch_test_callback(self, callback: EpochTestCallback) -> Self:
|
||||
"""Allows to define a callback function which is called at the beginning of testing in each epoch.
|
||||
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.epoch_callback_test = callback
|
||||
self._trainer_callbacks.epoch_test_callback = callback
|
||||
return self
|
||||
|
||||
def with_trainer_stop_callback(self, callback: TrainerStopCallback) -> Self:
|
||||
def with_epoch_stop_callback(self, callback: EpochStopCallback) -> Self:
|
||||
"""Allows to define a callback that decides whether training shall stop early.
|
||||
|
||||
The callback receives the undiscounted returns of the testing result.
|
||||
@ -406,7 +409,7 @@ class ExperimentBuilder:
|
||||
:param callback: the callback
|
||||
:return: the builder
|
||||
"""
|
||||
self._trainer_callbacks.stop_callback = callback
|
||||
self._trainer_callbacks.epoch_stop_callback = callback
|
||||
return self
|
||||
|
||||
@abstractmethod
|
||||
@ -903,7 +906,7 @@ class DQNExperimentBuilder(
|
||||
super().__init__(env_factory, experiment_config, sampling_config)
|
||||
self._params: DQNParams = DQNParams()
|
||||
self._model_factory: IntermediateModuleFactory = IntermediateModuleFactoryFromActorFactory(
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False),
|
||||
)
|
||||
|
||||
def with_dqn_params(self, params: DQNParams) -> Self:
|
||||
@ -911,9 +914,34 @@ class DQNExperimentBuilder(
|
||||
return self
|
||||
|
||||
def with_model_factory(self, module_factory: IntermediateModuleFactory) -> Self:
|
||||
""":param module_factory: factory for a module which maps environment observations to a vector of Q-values (one for each action)
|
||||
:return: the builder
|
||||
"""
|
||||
self._model_factory = module_factory
|
||||
return self
|
||||
|
||||
def with_model_factory_default(
|
||||
self,
|
||||
hidden_sizes: Sequence[int],
|
||||
hidden_activation: ModuleType = torch.nn.ReLU,
|
||||
) -> Self:
|
||||
"""Allows to configure the default factory for the model of the Q function, which maps environment observations to a vector of
|
||||
Q-values (one for each action). The default model is a multi-layer perceptron.
|
||||
|
||||
:param hidden_sizes: the sequence of dimensions used for hidden layers
|
||||
:param hidden_activation: the activation function to use for hidden layers (not used for the output layer)
|
||||
:return: the builder
|
||||
"""
|
||||
self._model_factory = IntermediateModuleFactoryFromActorFactory(
|
||||
ActorFactoryDefault(
|
||||
ContinuousActorType.UNSUPPORTED,
|
||||
hidden_sizes=hidden_sizes,
|
||||
hidden_activation=hidden_activation,
|
||||
discrete_softmax=False,
|
||||
),
|
||||
)
|
||||
return self
|
||||
|
||||
def _create_agent_factory(self) -> AgentFactory:
|
||||
return DQNAgentFactory(
|
||||
self._params,
|
||||
@ -934,7 +962,7 @@ class IQNExperimentBuilder(ExperimentBuilder):
|
||||
self._params: IQNParams = IQNParams()
|
||||
self._preprocess_network_factory: IntermediateModuleFactory = (
|
||||
IntermediateModuleFactoryFromActorFactory(
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED),
|
||||
ActorFactoryDefault(ContinuousActorType.UNSUPPORTED, discrete_softmax=False),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar, cast
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.logger import TLogger
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.policy import BasePolicy, DQNPolicy
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingContext:
|
||||
@ -18,8 +20,10 @@ class TrainingContext:
|
||||
self.logger = logger
|
||||
|
||||
|
||||
class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch."""
|
||||
class EpochTrainCallback(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch, i.e. prior to the data collection phase
|
||||
of each epoch.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
@ -32,8 +36,8 @@ class TrainerEpochCallbackTrain(ToStringMixin, ABC):
|
||||
return fn
|
||||
|
||||
|
||||
class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of each epoch."""
|
||||
class EpochTestCallback(ToStringMixin, ABC):
|
||||
"""Callback which is called at the beginning of the test phase of each epoch."""
|
||||
|
||||
@abstractmethod
|
||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||
@ -46,8 +50,10 @@ class TrainerEpochCallbackTest(ToStringMixin, ABC):
|
||||
return fn
|
||||
|
||||
|
||||
class TrainerStopCallback(ToStringMixin, ABC):
|
||||
"""Callback indicating whether training should stop."""
|
||||
class EpochStopCallback(ToStringMixin, ABC):
|
||||
"""Callback which is called after the test phase of each epoch in order to determine
|
||||
whether training should stop early.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
@ -69,6 +75,77 @@ class TrainerStopCallback(ToStringMixin, ABC):
|
||||
class TrainerCallbacks:
|
||||
"""Container for callbacks used during training."""
|
||||
|
||||
epoch_callback_train: TrainerEpochCallbackTrain | None = None
|
||||
epoch_callback_test: TrainerEpochCallbackTest | None = None
|
||||
stop_callback: TrainerStopCallback | None = None
|
||||
epoch_train_callback: EpochTrainCallback | None = None
|
||||
epoch_test_callback: EpochTestCallback | None = None
|
||||
epoch_stop_callback: EpochStopCallback | None = None
|
||||
|
||||
|
||||
class EpochTrainCallbackDQNSetEps(EpochTrainCallback):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||
stage in each epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, eps_test: float):
|
||||
self.eps_test = eps_test
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy = cast(DQNPolicy, context.policy)
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class EpochTrainCallbackDQNEpsLinearDecay(EpochTrainCallback):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the training
|
||||
stage in each epoch, using a linear decay in the first `decay_steps` steps.
|
||||
"""
|
||||
|
||||
def __init__(self, eps_train: float, eps_train_final: float, decay_steps: int = 1000000):
|
||||
self.eps_train = eps_train
|
||||
self.eps_train_final = eps_train_final
|
||||
self.decay_steps = decay_steps
|
||||
|
||||
def callback(self, epoch: int, env_step: int, context: TrainingContext) -> None:
|
||||
policy = cast(DQNPolicy, context.policy)
|
||||
logger = context.logger
|
||||
if env_step <= self.decay_steps:
|
||||
eps = self.eps_train - env_step / self.decay_steps * (
|
||||
self.eps_train - self.eps_train_final
|
||||
)
|
||||
else:
|
||||
eps = self.eps_train_final
|
||||
policy.set_eps(eps)
|
||||
logger.write("train/env_step", env_step, {"train/eps": eps})
|
||||
|
||||
|
||||
class EpochTestCallbackDQNSetEps(EpochTestCallback):
|
||||
"""Sets the epsilon value for DQN-based policies at the beginning of the test
|
||||
stage in each epoch.
|
||||
"""
|
||||
|
||||
def __init__(self, eps_test: float):
|
||||
self.eps_test = eps_test
|
||||
|
||||
def callback(self, epoch: int, env_step: int | None, context: TrainingContext) -> None:
|
||||
policy = cast(DQNPolicy, context.policy)
|
||||
policy.set_eps(self.eps_test)
|
||||
|
||||
|
||||
class EpochStopCallbackRewardThreshold(EpochStopCallback):
|
||||
"""Stops training once the mean rewards exceed the given reward threshold or the threshold that
|
||||
is specified in the gymnasium environment (i.e. `env.spec.reward_threshold`).
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float | None = None):
|
||||
""":param threshold: the reward threshold beyond which to stop training.
|
||||
If it is None, use threshold given by the environment, i.e. `env.spec.reward_threshold`.
|
||||
"""
|
||||
self.threshold = threshold
|
||||
|
||||
def should_stop(self, mean_rewards: float, context: TrainingContext) -> bool:
|
||||
threshold = self.threshold
|
||||
if threshold is None:
|
||||
threshold = context.envs.env.spec.reward_threshold # type: ignore
|
||||
assert threshold is not None
|
||||
is_reached = mean_rewards >= threshold
|
||||
if is_reached:
|
||||
log.info(f"Reward threshold ({threshold}) exceeded")
|
||||
return is_reached
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray
|
||||
VALID_LOG_VALS_TYPE = int | Number | np.number | np.ndarray | float
|
||||
VALID_LOG_VALS = typing.get_args(
|
||||
VALID_LOG_VALS_TYPE,
|
||||
) # I know it's stupid, but we can't use Union type in isinstance
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any, TypeAlias, no_type_check
|
||||
from typing import Any, TypeAlias, TypeVar, no_type_check
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -13,6 +13,7 @@ ModuleType = type[nn.Module]
|
||||
ArgsType = tuple[Any, ...] | dict[Any, Any] | Sequence[tuple[Any, ...]] | Sequence[dict[Any, Any]]
|
||||
TActionShape: TypeAlias = Sequence[int] | int
|
||||
TLinearLayer: TypeAlias = Callable[[int, int], nn.Module]
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def miniblock(
|
||||
@ -608,3 +609,39 @@ class BaseActor(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def get_output_dim(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
def getattr_with_matching_alt_value(obj: Any, attr_name: str, alt_value: T | None) -> T:
|
||||
"""Gets the given attribute from the given object or takes the alternative value if it is not present.
|
||||
If both are present, they are required to match.
|
||||
|
||||
:param obj: the object from which to obtain the attribute value
|
||||
:param attr_name: the attribute name
|
||||
:param alt_value: the alternative value for the case where the attribute is not present, which cannot be None
|
||||
if the attribute is not present
|
||||
:return: the value
|
||||
"""
|
||||
v = getattr(obj, attr_name)
|
||||
if v is not None:
|
||||
if alt_value is not None and v != alt_value:
|
||||
raise ValueError(
|
||||
f"Attribute '{attr_name}' of {obj} is defined ({v}) but does not match alt. value ({alt_value})",
|
||||
)
|
||||
return v
|
||||
else:
|
||||
if alt_value is None:
|
||||
raise ValueError(
|
||||
f"Attribute '{attr_name}' of {obj} is not defined and no fallback given",
|
||||
)
|
||||
return alt_value
|
||||
|
||||
|
||||
def get_output_dim(module: nn.Module, alt_value: int | None) -> int:
|
||||
"""Retrieves value the `output_dim` attribute of the given module or uses the given alternative value if the attribute is not present.
|
||||
If both are present, they must match.
|
||||
|
||||
:param module: the module
|
||||
:param alt_value: the alternative value
|
||||
:return: the value
|
||||
"""
|
||||
return getattr_with_matching_alt_value(module, "output_dim", alt_value)
|
||||
|
||||
@ -1,12 +1,18 @@
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, TLinearLayer
|
||||
from tianshou.utils.net.common import (
|
||||
MLP,
|
||||
BaseActor,
|
||||
TActionShape,
|
||||
TLinearLayer,
|
||||
get_output_dim,
|
||||
)
|
||||
|
||||
SIGMA_MIN = -20
|
||||
SIGMA_MAX = 2
|
||||
@ -50,8 +56,7 @@ class Actor(BaseActor):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
input_dim = cast(int, input_dim)
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim,
|
||||
self.output_dim,
|
||||
@ -118,9 +123,9 @@ class Critic(nn.Module):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = 1
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim, # type: ignore
|
||||
input_dim,
|
||||
1,
|
||||
hidden_sizes,
|
||||
device=self.device,
|
||||
@ -199,12 +204,12 @@ class ActorProb(BaseActor):
|
||||
self.preprocess = preprocess_net
|
||||
self.device = device
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device) # type: ignore
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.mu = MLP(input_dim, self.output_dim, hidden_sizes, device=self.device)
|
||||
self._c_sigma = conditioned_sigma
|
||||
if conditioned_sigma:
|
||||
self.sigma = MLP(
|
||||
input_dim, # type: ignore
|
||||
input_dim,
|
||||
self.output_dim,
|
||||
hidden_sizes,
|
||||
device=self.device,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -7,7 +7,7 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from tianshou.data import Batch, to_torch
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape
|
||||
from tianshou.utils.net.common import MLP, BaseActor, TActionShape, get_output_dim
|
||||
|
||||
|
||||
class Actor(BaseActor):
|
||||
@ -51,8 +51,7 @@ class Actor(BaseActor):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = int(np.prod(action_shape))
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
input_dim = cast(int, input_dim)
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(
|
||||
input_dim,
|
||||
self.output_dim,
|
||||
@ -118,8 +117,8 @@ class Critic(nn.Module):
|
||||
self.device = device
|
||||
self.preprocess = preprocess_net
|
||||
self.output_dim = last_size
|
||||
input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device) # type: ignore
|
||||
input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.last = MLP(input_dim, last_size, hidden_sizes, device=self.device)
|
||||
|
||||
def forward(self, obs: np.ndarray | torch.Tensor, **kwargs: Any) -> torch.Tensor:
|
||||
"""Mapping: s -> V(s)."""
|
||||
@ -197,8 +196,8 @@ class ImplicitQuantileNetwork(Critic):
|
||||
) -> None:
|
||||
last_size = int(np.prod(action_shape))
|
||||
super().__init__(preprocess_net, hidden_sizes, last_size, preprocess_net_output_dim, device)
|
||||
self.input_dim = getattr(preprocess_net, "output_dim", preprocess_net_output_dim)
|
||||
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to( # type: ignore
|
||||
self.input_dim = get_output_dim(preprocess_net, preprocess_net_output_dim)
|
||||
self.embed_model = CosineEmbeddingNetwork(num_cosines, self.input_dim).to(
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user