2023-10-17 10:28:24 +02:00
{
2023-10-17 13:59:37 +02:00
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "_UaXOSRjDUF9"
2023-10-26 16:27:59 +02:00
},
"source": [
"# Experiment\n",
"Finally, we can assemble building blocks that we have came across in previous tutorials to conduct our first DRL experiment. In this experiment, we will use [PPO](https://arxiv.org/abs/1707.06347) algorithm to solve the classic CartPole task in Gym."
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
2023-10-26 16:27:59 +02:00
"metadata": {
"id": "2QRbCJvDHNAd"
},
2023-10-17 13:59:37 +02:00
"source": [
2023-11-09 13:36:23 +01:00
"## Experiment\n",
2023-10-17 13:59:37 +02:00
"To conduct this experiment, we need the following building blocks.\n",
"\n",
"\n",
"* Two vectorized environments, one for training and one for evaluation\n",
"* A PPO agent\n",
"* A replay buffer to store transition data\n",
"* Two collectors to manage the data collecting process, one for training and one for evaluation\n",
"* A trainer to manage the training loop\n",
"\n",
"<div align=center>\n",
2023-11-15 15:50:06 +01:00
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\">\n",
2023-10-17 13:59:37 +02:00
"\n",
"</div>\n",
"\n",
"Let us do this step by step."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
2023-10-26 16:27:59 +02:00
"metadata": {
"id": "-Hh4E6i0Hj0I"
},
2023-10-17 13:59:37 +02:00
"source": [
"## Preparation\n",
"Firstly, install Tianshou if you haven't installed it before."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
2023-10-26 16:27:59 +02:00
"cell_type": "markdown",
2023-10-17 13:59:37 +02:00
"metadata": {
2023-10-26 16:27:59 +02:00
"id": "7E4EhiBeHxD5"
2023-10-17 13:59:37 +02:00
},
"source": [
"Import libraries we might need later."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
2023-10-26 16:27:59 +02:00
"execution_count": null,
"metadata": {
"editable": true,
"id": "ao9gWJDiHgG-",
"tags": [
"hide-cell",
"remove-output"
]
},
"outputs": [],
2023-10-17 13:59:37 +02:00
"source": [
2024-02-07 17:28:16 +01:00
"%%capture\n",
"\n",
2023-10-26 16:27:59 +02:00
"import gymnasium as gym\n",
2023-10-17 13:59:37 +02:00
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
2024-04-03 18:07:51 +02:00
"from tianshou.policy import PPOPolicy\n",
2023-10-26 16:27:59 +02:00
"from tianshou.trainer import OnpolicyTrainer\n",
2023-10-17 13:59:37 +02:00
"from tianshou.utils.net.common import ActorCritic, Net\n",
"from tianshou.utils.net.discrete import Actor, Critic\n",
"\n",
2023-10-26 16:27:59 +02:00
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
"metadata": {
"id": "QnRg5y7THRYw"
2023-10-26 16:27:59 +02:00
},
"source": [
"## Environment"
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
"metadata": {
"id": "YZERKCGtH8W1"
2023-10-26 16:27:59 +02:00
},
"source": [
"We create two vectorized environments both for training and testing. Since the execution time of CartPole is extremely short, there is no need to use multi-process wrappers and we simply use DummyVectorEnv."
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mpuj5PFnDKVS"
},
"outputs": [],
"source": [
2023-10-26 16:27:59 +02:00
"env = gym.make(\"CartPole-v1\")\n",
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n",
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])"
2023-10-17 13:59:37 +02:00
]
},
{
"cell_type": "markdown",
2023-10-26 16:27:59 +02:00
"metadata": {
"id": "BJtt_Ya8DTAh"
},
2023-10-17 13:59:37 +02:00
"source": [
"## Policy\n",
2023-11-09 13:36:23 +01:00
"Next we need to initialize our PPO policy. PPO is an actor-critic-style on-policy algorithm, so we have to define the actor and the critic in PPO first.\n",
2023-10-17 13:59:37 +02:00
"\n",
"The actor is a neural network that shares the same network head with the critic. Both networks' input is the environment observation. The output of the actor is the action and the output of the critic is a single value, representing the value of the current policy.\n",
"\n",
"Luckily, Tianshou already provides basic network modules that we can use in this experiment."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
2023-10-26 16:27:59 +02:00
"execution_count": null,
"metadata": {
"id": "_Vy8uPWXP4m_"
},
"outputs": [],
2023-10-17 13:59:37 +02:00
"source": [
"# net is the shared head of the actor and the critic\n",
2024-02-07 17:28:16 +01:00
"assert env.observation_space.shape is not None # for mypy\n",
"assert isinstance(env.action_space, gym.spaces.Discrete) # for mypy\n",
2023-10-26 16:27:59 +02:00
"net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
"actor = Actor(preprocess_net=net, action_shape=env.action_space.n, device=device).to(device)\n",
"critic = Critic(preprocess_net=net, device=device).to(device)\n",
"actor_critic = ActorCritic(actor=actor, critic=critic)\n",
2023-10-17 13:59:37 +02:00
"\n",
"# optimizer of the actor and the critic\n",
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)"
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
"metadata": {
"id": "Lh2-hwE5Dn9I"
2023-10-26 16:27:59 +02:00
},
"source": [
2024-04-29 14:10:47 +02:00
"Once we have defined the actor, the critic and the optimizer, we can use them to construct our PPO agent. CartPole is a discrete action space problem, so the distribution of our action space can be a categorical distribution."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
2023-10-26 16:27:59 +02:00
"execution_count": null,
2023-10-17 13:59:37 +02:00
"metadata": {
"id": "OiJ2GkT0Qnbr"
},
2023-10-26 16:27:59 +02:00
"outputs": [],
"source": [
"dist = torch.distributions.Categorical\n",
2024-04-03 18:07:51 +02:00
"policy: PPOPolicy = PPOPolicy(\n",
2023-10-26 16:27:59 +02:00
" actor=actor,\n",
" critic=critic,\n",
" optim=optim,\n",
" dist_fn=dist,\n",
" action_space=env.action_space,\n",
" deterministic_eval=True,\n",
" action_scaling=False,\n",
")"
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
"metadata": {
"id": "okxfj6IEQ-r8"
2023-10-26 16:27:59 +02:00
},
"source": [
"`deterministic_eval=True` means that we want to sample actions during training but we would like to always use the best action in evaluation. No randomness included."
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
2023-10-26 16:27:59 +02:00
"metadata": {
"id": "n5XAAbuBZarO"
},
2023-10-17 13:59:37 +02:00
"source": [
"## Collector\n",
"We can set up the collectors now. Train collector is used to collect and store training data, so an additional replay buffer has to be passed in."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
2023-10-26 16:27:59 +02:00
"execution_count": null,
2023-10-17 13:59:37 +02:00
"metadata": {
"id": "ezwz0qerZhQM"
},
2023-10-26 16:27:59 +02:00
"outputs": [],
"source": [
"train_collector = Collector(\n",
2024-02-07 17:28:16 +01:00
" policy=policy,\n",
" env=train_envs,\n",
" buffer=VectorReplayBuffer(20000, len(train_envs)),\n",
2023-10-26 16:27:59 +02:00
")\n",
"test_collector = Collector(policy=policy, env=test_envs)"
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZaoPxOd2hm0b"
2023-10-26 16:27:59 +02:00
},
"source": [
"We use `VectorReplayBuffer` here because it's more efficient to collaborate with vectorized environments, you can simply consider `VectorReplayBuffer` as a a list of ordinary replay buffers."
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
2023-10-26 16:27:59 +02:00
"metadata": {
"id": "qBoE9pLUiC-8"
},
2023-10-17 13:59:37 +02:00
"source": [
"## Trainer\n",
"Finally, we can use the trainer to help us set up the training loop."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
2023-10-26 16:27:59 +02:00
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
2023-12-14 19:31:53 +01:00
"editable": true,
2023-10-26 16:27:59 +02:00
"id": "i45EDnpxQ8gu",
2023-12-14 19:31:53 +01:00
"outputId": "b1666b88-0bfa-4340-868e-58611872d988",
"tags": [
"remove-output"
]
2023-10-26 16:27:59 +02:00
},
"outputs": [],
2023-10-17 13:59:37 +02:00
"source": [
2023-10-26 16:27:59 +02:00
"result = OnpolicyTrainer(\n",
" policy=policy,\n",
" train_collector=train_collector,\n",
" test_collector=test_collector,\n",
2023-10-17 13:59:37 +02:00
" max_epoch=10,\n",
" step_per_epoch=50000,\n",
" repeat_per_collect=10,\n",
" episode_per_test=10,\n",
" batch_size=256,\n",
" step_per_collect=2000,\n",
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
2023-12-14 19:31:53 +01:00
").run()"
2023-10-26 16:27:59 +02:00
]
2023-10-17 10:28:24 +02:00
},
2023-10-17 13:59:37 +02:00
{
"cell_type": "markdown",
2023-10-26 16:27:59 +02:00
"metadata": {
"id": "ckgINHE2iTFR"
},
2023-10-17 13:59:37 +02:00
"source": [
"## Results\n",
"Print the training result."
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
2023-10-26 16:27:59 +02:00
"execution_count": null,
2023-10-17 13:59:37 +02:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2023-10-17 10:28:24 +02:00
},
2023-10-17 13:59:37 +02:00
"id": "tJCPgmiyiaaX",
2023-12-14 19:31:53 +01:00
"outputId": "40123ae3-3365-4782-9563-46c43812f10f",
"tags": []
2023-10-17 13:59:37 +02:00
},
2023-10-26 16:27:59 +02:00
"outputs": [],
"source": [
2024-02-07 17:28:16 +01:00
"result.pprint_asdict()"
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "markdown",
"metadata": {
"id": "A-MJ9avMibxN"
2023-10-26 16:27:59 +02:00
},
"source": [
"We can also test our trained agent."
]
2023-10-17 13:59:37 +02:00
},
{
"cell_type": "code",
2023-10-26 16:27:59 +02:00
"execution_count": null,
2023-10-17 13:59:37 +02:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2023-10-17 10:28:24 +02:00
},
2023-10-17 13:59:37 +02:00
"id": "mnMANFcciiAQ",
"outputId": "6febcc1e-7265-4a75-c9dd-34e29a3e5d21"
},
2023-10-26 16:27:59 +02:00
"outputs": [],
"source": [
"# Let's watch its performance!\n",
"policy.eval()\n",
"result = test_collector.collect(n_episode=1, render=False)\n",
Feature/dataclasses (#996)
This PR adds strict typing to the output of `update` and `learn` in all
policies. This will likely be the last large refactoring PR before the
next release (0.6.0, not 1.0.0), so it requires some attention. Several
difficulties were encountered on the path to that goal:
1. The policy hierarchy is actually "broken" in the sense that the keys
of dicts that were output by `learn` did not follow the same enhancement
(inheritance) pattern as the policies. This is a real problem and should
be addressed in the near future. Generally, several aspects of the
policy design and hierarchy might deserve a dedicated discussion.
2. Each policy needs to be generic in the stats return type, because one
might want to extend it at some point and then also extend the stats.
Even within the source code base this pattern is necessary in many
places.
3. The interaction between learn and update is a bit quirky, we
currently handle it by having update modify special field inside
TrainingStats, whereas all other fields are handled by learn.
4. The IQM module is a policy wrapper and required a
TrainingStatsWrapper. The latter relies on a bunch of black magic.
They were addressed by:
1. Live with the broken hierarchy, which is now made visible by bounds
in generics. We use type: ignore where appropriate.
2. Make all policies generic with bounds following the policy
inheritance hierarchy (which is incorrect, see above). We experimented a
bit with nested TrainingStats classes, but that seemed to add more
complexity and be harder to understand. Unfortunately, mypy thinks that
the code below is wrong, wherefore we have to add `type: ignore` to the
return of each `learn`
```python
T = TypeVar("T", bound=int)
def f() -> T:
return 3
```
3. See above
4. Write representative tests for the `TrainingStatsWrapper`. Still, the
black magic might cause nasty surprises down the line (I am not proud of
it)...
Closes #933
---------
Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2023-12-30 11:09:03 +01:00
"print(f\"Final episode reward: {result.returns.mean()}, length: {result.lens.mean()}\")"
2023-10-26 16:27:59 +02:00
]
2023-10-17 13:59:37 +02:00
}
2023-10-26 16:27:59 +02:00
],
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
2023-10-17 13:59:37 +02:00
}