Tianshou/docs/02_notebooks/L6_Trainer.ipynb
Daniel Plop 8a0629ded6
Fix mypy issues in tests and examples (#1077)
Closes #952 

- `SamplingConfig` supports `batch_size=None`. #1077
- tests and examples are covered by `mypy`. #1077
- `NetBase` is more used, stricter typing by making it generic. #1077
- `utils.net.common.Recurrent` now receives and returns a
`RecurrentStateBatch` instead of a dict. #1077

---------

Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
2024-04-03 18:07:51 +02:00

275 lines
7.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "S3-tJZy35Ck_"
},
"source": [
"# Trainer\n",
"Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n",
"\n",
"<center>\n",
"<img src=../_static/images/structure.svg></img>\n",
"</center>\n",
"\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ifsEQMzZ6mmz"
},
"source": [
"## Usages\n",
"In Tianshou v0.5.1, there are three types of Trainer. They are designed to be used in on-policy training, off-policy training and offline training respectively. We will use on-policy trainer as an example and leave the other two for further reading."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XfsuU2AAE52C"
},
"source": [
"### Pseudocode\n",
"<center>\n",
"<img src=../_static/images/pseudocode_off_policy.svg></img>\n",
"</center>\n",
"\n",
"For the on-policy trainer, the main difference is that we clear the buffer after Line 10."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hcp_o0CCFz12"
},
"source": [
"### Training without trainer\n",
"As we have learned the usages of the Collector and the Policy, it's possible that we write our own training logic.\n",
"\n",
"First, let us create the instances of Environment, ReplayBuffer, Policy and Collector."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"editable": true,
"id": "do-xZ-8B7nVH",
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-cell",
"remove-output"
]
},
"outputs": [],
"source": [
"%%capture\n",
"\n",
"import gymnasium as gym\n",
"import torch\n",
"\n",
"from tianshou.data import Collector, VectorReplayBuffer\n",
"from tianshou.env import DummyVectorEnv\n",
"from tianshou.policy import PGPolicy\n",
"from tianshou.trainer import OnpolicyTrainer\n",
"from tianshou.utils.net.common import Net\n",
"from tianshou.utils.net.discrete import Actor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_env_num = 4\n",
"buffer_size = (\n",
" 2000 # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
")\n",
"\n",
"# Create the environments, used for training and evaluation\n",
"env = gym.make(\"CartPole-v1\")\n",
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(2)])\n",
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v1\") for _ in range(train_env_num)])\n",
"\n",
"# Create the Policy instance\n",
"assert env.observation_space.shape is not None\n",
"net = Net(\n",
" env.observation_space.shape,\n",
" hidden_sizes=[\n",
" 16,\n",
" ],\n",
")\n",
"\n",
"assert isinstance(env.action_space, gym.spaces.Discrete)\n",
"actor = Actor(net, env.action_space.n)\n",
"optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
"\n",
"# We choose to use REINFORCE algorithm, also known as Policy Gradient\n",
"policy: PGPolicy = PGPolicy(\n",
" actor=actor,\n",
" optim=optim,\n",
" dist_fn=torch.distributions.Categorical,\n",
" action_space=env.action_space,\n",
" action_scaling=False,\n",
")\n",
"\n",
"# Create the replay buffer and the collector\n",
"replayBuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
"test_collector = Collector(policy, test_envs)\n",
"train_collector = Collector(policy, train_envs, replayBuffer)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wiEGiBgQIiFM"
},
"source": [
"Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "JMUNPN5SI_kd",
"outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d"
},
"outputs": [],
"source": [
"train_collector.reset()\n",
"train_envs.reset()\n",
"test_collector.reset()\n",
"test_envs.reset()\n",
"replayBuffer.reset()\n",
"\n",
"n_episode = 10\n",
"for _i in range(n_episode):\n",
" evaluation_result = test_collector.collect(n_episode=n_episode)\n",
" print(f\"Evaluation mean episodic reward is: {evaluation_result.returns.mean()}\")\n",
" train_collector.collect(n_step=2000)\n",
" # 0 means taking all data stored in train_collector.buffer\n",
" policy.update(sample_size=None, buffer=train_collector.buffer, batch_size=512, repeat=1)\n",
" train_collector.reset_buffer(keep_statistics=True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QXBHIBckMs_2"
},
"source": [
"The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "p-7U_cwgF5Ej"
},
"source": [
"### Training with trainer\n",
"The trainer does almost the same thing. The only difference is that it has considered many details and is more modular."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vcvw9J8RNtFE",
"outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5",
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
"train_collector.reset()\n",
"train_envs.reset()\n",
"test_collector.reset()\n",
"test_envs.reset()\n",
"replayBuffer.reset()\n",
"\n",
"result = OnpolicyTrainer(\n",
" policy=policy,\n",
" train_collector=train_collector,\n",
" test_collector=test_collector,\n",
" max_epoch=10,\n",
" step_per_epoch=1,\n",
" repeat_per_collect=1,\n",
" episode_per_test=10,\n",
" step_per_collect=2000,\n",
" batch_size=512,\n",
").run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"result.pprint_asdict()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_j3aUJZQ7nml"
},
"source": [
"## Further Reading\n",
"### Logger usages\n",
"Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n",
"\n",
"### Learn more about the APIs of Trainers\n",
"[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"S3-tJZy35Ck_",
"XfsuU2AAE52C",
"p-7U_cwgF5Ej",
"_j3aUJZQ7nml"
],
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}