379 lines
215 KiB
Plaintext
379 lines
215 KiB
Plaintext
|
{
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 0,
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"provenance": [],
|
||
|
"collapsed_sections": [
|
||
|
"S3-tJZy35Ck_",
|
||
|
"XfsuU2AAE52C",
|
||
|
"p-7U_cwgF5Ej",
|
||
|
"_j3aUJZQ7nml"
|
||
|
]
|
||
|
},
|
||
|
"kernelspec": {
|
||
|
"name": "python3",
|
||
|
"display_name": "Python 3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"name": "python"
|
||
|
}
|
||
|
},
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {
|
||
|
"id": "wDZlC0v348Ym"
|
||
|
},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# Remember to install tianshou first\n",
|
||
|
"!pip install tianshou==0.4.8\n",
|
||
|
"!pip install gym"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"# Overview\n",
|
||
|
"Trainer is the highest-level encapsulation in Tianshou. It controls the training loop and the evaluation method. It also controls the interaction between the Collector and the Policy, with the ReplayBuffer serving as the media.\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"train_env_num = 4\n",
|
||
|
"buffer_size = 2000 # Since REINFORCE is an on-policy algorithm, we don't need a very large buffer size\n",
|
||
|
"\n",
|
||
|
"# Create the environments, used for training and evaluation\n",
|
||
|
"env = gym.make(\"CartPole-v0\")\n",
|
||
|
"test_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(2)])\n",
|
||
|
"train_envs = DummyVectorEnv([lambda: gym.make(\"CartPole-v0\") for _ in range(train_env_num)])\n",
|
||
|
"\n",
|
||
|
"# Create the Policy instance\n",
|
||
|
"net = Net(env.observation_space.shape, hidden_sizes=[16,])\n",
|
||
|
"actor = Actor(net, env.action_space.shape)\n",
|
||
|
"optim = torch.optim.Adam(actor.parameters(), lr=0.001)\n",
|
||
|
"policy = PGPolicy(actor, optim, dist_fn=torch.distributions.Categorical)\n",
|
||
|
"\n",
|
||
|
"# Create the replay buffer and the collector\n",
|
||
|
"replaybuffer = VectorReplayBuffer(buffer_size, train_env_num)\n",
|
||
|
"test_collector = Collector(policy, test_envs)\n",
|
||
|
"train_collector = Collector(policy, train_envs, replaybuffer)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "do-xZ-8B7nVH"
|
||
|
},
|
||
|
"execution_count": null,
|
||
|
"outputs": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"Now, we can try training our policy network. The logic is simple. We collect some data into the buffer and then we use the data to train our policy."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "wiEGiBgQIiFM"
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"train_collector.reset()\n",
|
||
|
"train_envs.reset()\n",
|
||
|
"test_collector.reset()\n",
|
||
|
"test_envs.reset()\n",
|
||
|
"replaybuffer.reset()\n",
|
||
|
"for i in range(10):\n",
|
||
|
" evaluation_result = test_collector.collect(n_episode=10)\n",
|
||
|
" print(\"Evaluation reward is {}\".format(evaluation_result[\"rew\"]))\n",
|
||
|
" train_collector.collect(n_step=2000)\n",
|
||
|
" # 0 means taking all data stored in train_collector.buffer\n",
|
||
|
" policy.update(0, train_collector.buffer, batch_size=512, repeat=1)\n",
|
||
|
" train_collector.reset_buffer(keep_statistics=True)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "JMUNPN5SI_kd",
|
||
|
"outputId": "7d68323c-0322-4b82-dafb-7c7f63e7a26d"
|
||
|
},
|
||
|
"execution_count": null,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Evaluation reward is 9.6\n",
|
||
|
"Evaluation reward is 9.6\n",
|
||
|
"Evaluation reward is 9.2\n",
|
||
|
"Evaluation reward is 9.1\n",
|
||
|
"Evaluation reward is 9.5\n",
|
||
|
"Evaluation reward is 9.7\n",
|
||
|
"Evaluation reward is 9.6\n",
|
||
|
"Evaluation reward is 9.4\n",
|
||
|
"Evaluation reward is 9.3\n",
|
||
|
"Evaluation reward is 9.1\n"
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"The evaluation reward doesn't seem to improve. That is simply because we haven't trained it for enough time. Plus, the network size is too small and REINFORCE algorithm is actually not very stable. Don't worry, we will solve this problem in the end. Still we get some idea on how to start a training loop."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "QXBHIBckMs_2"
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"## Training with trainer\n",
|
||
|
"The trainer does almost the same thing. The only difference is that it has considered many details and is more modular."
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "p-7U_cwgF5Ej"
|
||
|
}
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"source": [
|
||
|
"from tianshou.trainer import onpolicy_trainer\n",
|
||
|
"\n",
|
||
|
"train_collector.reset()\n",
|
||
|
"train_envs.reset()\n",
|
||
|
"test_collector.reset()\n",
|
||
|
"test_envs.reset()\n",
|
||
|
"replaybuffer.reset()\n",
|
||
|
"\n",
|
||
|
"result = onpolicy_trainer(\n",
|
||
|
" policy,\n",
|
||
|
" train_collector,\n",
|
||
|
" test_collector,\n",
|
||
|
" max_epoch=10,\n",
|
||
|
" step_per_epoch=1,\n",
|
||
|
" repeat_per_collect=1,\n",
|
||
|
" episode_per_test=10,\n",
|
||
|
" step_per_collect=2000,\n",
|
||
|
" batch_size=512,\n",
|
||
|
")\n",
|
||
|
"print(result)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"colab": {
|
||
|
"base_uri": "https://localhost:8080/"
|
||
|
},
|
||
|
"id": "vcvw9J8RNtFE",
|
||
|
"outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5"
|
||
|
},
|
||
|
"execution_count": null,
|
||
|
"outputs": [
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #1: 2000it [00:00, 4144.84it/s, env_step=2000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.34]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #1: test_reward: 9.500000 ± 0.500000, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #2: 2000it [00:00, 4208.58it/s, env_step=4000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.41]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #2: test_reward: 9.400000 ± 0.489898, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #3: 2000it [00:00, 4472.80it/s, env_step=6000, len=9, loss=0.000, n/ep=212, n/st=2000, rew=9.39]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #3: test_reward: 9.100000 ± 0.700000, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #4: 2000it [00:00, 4340.62it/s, env_step=8000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.38]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #4: test_reward: 9.400000 ± 0.800000, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #5: 2000it [00:00, 4483.35it/s, env_step=10000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.42]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #5: test_reward: 9.400000 ± 1.019804, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #6: 2000it [00:00, 4068.51it/s, env_step=12000, len=9, loss=0.000, n/ep=212, n/st=2000, rew=9.42]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #6: test_reward: 9.400000 ± 0.663325, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #7: 2000it [00:00, 4091.46it/s, env_step=14000, len=9, loss=0.000, n/ep=214, n/st=2000, rew=9.32]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #7: test_reward: 9.300000 ± 0.640312, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #8: 2000it [00:00, 4042.49it/s, env_step=16000, len=9, loss=0.000, n/ep=215, n/st=2000, rew=9.34]\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #8: test_reward: 9.600000 ± 0.800000, best_reward: 9.900000 ± 0.700000 in #0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"Epoch #9: 2000it [00:00, 4400.16it/s, env_step=18000, len=9, loss=0.000, n/ep=213, n/st=2000, rew=9.38]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stdout",
|
||
|
"text": [
|
||
|
"Epoch #9: test_reward: 9.000000 ± 0.632456, best_reward: 9.900000 ± 0.700000 in #0\n",
|
||
|
"{'duration': '4.79s', 'train_time/model': '0.22s', 'test_step': 940, 'test_episode': 100, 'test_time': '0.46s', 'test_speed': '2026.40 step/s', 'best_reward': 9.9, 'best_result': '9.90 ± 0.70', 'train_step': 18000, 'train_episode': 1918, 'train_time/collector': '4.11s', 'train_speed': '4156.80 step/s'}\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"output_type": "stream",
|
||
|
"name": "stderr",
|
||
|
"text": [
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"source": [
|
||
|
"# Further Reading\n",
|
||
|
"## Logger usages\n",
|
||
|
"Tianshou provides experiment loggers that are both tensorboard- and wandb-compatible. It also has a BaseLogger Class which allows you to self-define your own logger. Check the [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#tianshou.utils.BaseLogger) for details.\n",
|
||
|
"\n",
|
||
|
"## Learn more about the APIs of Trainers\n",
|
||
|
"[documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.trainer.html)"
|
||
|
],
|
||
|
"metadata": {
|
||
|
"id": "_j3aUJZQ7nml"
|
||
|
}
|
||
|
}
|
||
|
]
|
||
|
}
|