237 lines
7.0 KiB
Plaintext
237 lines
7.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"editable": true,
|
|
"id": "r7aE6Rq3cAEE",
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"# Overview\n",
|
|
"In this tutorial, we use guide you step by step to show you how the most basic modules in Tianshou work and how they collaborate with each other to conduct a classic DRL experiment."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "1_mLTSEIcY2c"
|
|
},
|
|
"source": [
|
|
"## Run the code\n",
|
|
"Before we get started, we must first install Tianshou's library and Gym environment by running the commands below. Here I choose a specific version of Tianshou(0.4.8) which is the latest as of the time writing this tutorial. APIs in different versions may vary a little bit but most are the same. Feel free to use other versions in your own project."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "IcFNmCjYeIIU"
|
|
},
|
|
"source": [
|
|
"Below is a short script that use a certain DRL algorithm (PPO) to solve the classic CartPole-v1\n",
|
|
"problem in Gym. Simply run it and **don't worry** if you can't understand the code very well. That is\n",
|
|
"exactly what this tutorial is for.\n",
|
|
"\n",
|
|
"If the script ends normally, you will see the evaluation result printed out before the first\n",
|
|
"epoch is done."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"editable": true,
|
|
"is_executing": true,
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": [
|
|
"hide-cell",
|
|
"remove-output"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gymnasium as gym\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"from tianshou.data import Collector, VectorReplayBuffer\n",
|
|
"from tianshou.env import DummyVectorEnv\n",
|
|
"from tianshou.policy import PPOPolicy\n",
|
|
"from tianshou.trainer import OnpolicyTrainer\n",
|
|
"from tianshou.utils.net.common import ActorCritic, Net\n",
|
|
"from tianshou.utils.net.discrete import Actor, Critic\n",
|
|
"\n",
|
|
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"editable": true,
|
|
"is_executing": true,
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# environments\n",
|
|
"env = gym.make(\"CartPole-v1\")\n",
|
|
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(20)])\n",
|
|
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(10)])\n",
|
|
"\n",
|
|
"# model & optimizer\n",
|
|
"net = Net(env.observation_space.shape, hidden_sizes=[64, 64], device=device)\n",
|
|
"actor = Actor(net, env.action_space.n, device=device).to(device)\n",
|
|
"critic = Critic(net, device=device).to(device)\n",
|
|
"actor_critic = ActorCritic(actor, critic)\n",
|
|
"optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)\n",
|
|
"\n",
|
|
"# PPO policy\n",
|
|
"dist = torch.distributions.Categorical\n",
|
|
"policy = PPOPolicy(\n",
|
|
" actor=actor,\n",
|
|
" critic=critic,\n",
|
|
" optim=optim,\n",
|
|
" dist_fn=dist,\n",
|
|
" action_space=env.action_space,\n",
|
|
" action_scaling=False,\n",
|
|
")\n",
|
|
"\n",
|
|
"\n",
|
|
"# collector\n",
|
|
"train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))\n",
|
|
"test_collector = Collector(policy, test_envs)\n",
|
|
"\n",
|
|
"# trainer\n",
|
|
"result = OnpolicyTrainer(\n",
|
|
" policy=policy,\n",
|
|
" batch_size=256,\n",
|
|
" train_collector=train_collector,\n",
|
|
" test_collector=test_collector,\n",
|
|
" max_epoch=10,\n",
|
|
" step_per_epoch=50000,\n",
|
|
" repeat_per_collect=10,\n",
|
|
" episode_per_test=10,\n",
|
|
" step_per_collect=2000,\n",
|
|
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
|
|
")\n",
|
|
"print(result)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "G9YEQptYvCgx",
|
|
"is_executing": true,
|
|
"outputId": "2a9b5b22-be50-4bb7-ae93-af7e65e7442a"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Let's watch its performance!\n",
|
|
"policy.eval()\n",
|
|
"result = test_collector.collect(n_episode=1, render=False)\n",
|
|
"print(\"Final reward: {}, length: {}\".format(result[\"rews\"].mean(), result[\"lens\"].mean()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xFYlcPo8fpPU"
|
|
},
|
|
"source": [
|
|
"## Tutorial Introduction\n",
|
|
"\n",
|
|
"A common DRL experiment as is shown above may require many components to work together. The agent, the\n",
|
|
"environment (possibly parallelized ones), the replay buffer and the trainer all work together to complete a\n",
|
|
"training task.\n",
|
|
"\n",
|
|
"<div align=center>\n",
|
|
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\">\n",
|
|
"\n",
|
|
"</div>\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "kV_uOyimj-bk"
|
|
},
|
|
"source": [
|
|
"In Tianshou, all of these main components are factored out as different building blocks, which you\n",
|
|
"can use to create your own algorithm and finish your own experiment.\n",
|
|
"\n",
|
|
"Building blocks may include:\n",
|
|
"- Batch\n",
|
|
"- Replay Buffer\n",
|
|
"- Vectorized Environment Wrapper\n",
|
|
"- Policy (the agent and the training algorithm)\n",
|
|
"- Data Collector\n",
|
|
"- Trainer\n",
|
|
"- Logger\n",
|
|
"\n",
|
|
"\n",
|
|
"Check this [webpage](https://tianshou.readthedocs.io/en/master/tutorials/dqn.html) to find jupyter-notebook-style tutorials that will guide you through all these\n",
|
|
"modules one by one. You can also read the [documentation](https://tianshou.readthedocs.io/en/master/) of Tianshou for more detailed explanation and\n",
|
|
"advanced usages."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "S0mNKwH9i6Ek"
|
|
},
|
|
"source": [
|
|
"## Further reading"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "M3NPSUnAov4L"
|
|
},
|
|
"source": [
|
|
"### What if I am not familiar with the PPO algorithm itself?\n",
|
|
"As for the DRL algorithms themselves, we will refer you to the [Spinning up documentation](https://spinningup.openai.com/en/latest/algorithms/ppo.html), where they provide\n",
|
|
"plenty of resources and guides if you want to study the DRL algorithms. In Tianshou's tutorials, we will\n",
|
|
"focus on the usages of different modules, but not the algorithms themselves."
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"accelerator": "GPU",
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|