Docs/fix trainer fct notebooks (#1009)

This PR resolves #1008
This commit is contained in:
Carlo Cagnetta 2023-12-14 19:31:53 +01:00 committed by GitHub
parent ea48cc2989
commit b7df31f2a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 35 additions and 20 deletions

View File

@ -41,11 +41,11 @@ First of all, you have to make an environment for your agent to interact with. Y
import gymnasium as gym
import tianshou as ts
env = gym.make('CartPole-v0')
env = gym.make('CartPole-v1')
CartPole-v0 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both.
CartPole-v1 includes a cart carrying a pole moving on a track. This is a simple environment with a discrete action space, for which DQN applies. You have to identify whether the action space is continuous or discrete and apply eligible algorithms. DDPG :cite:`DDPG`, for example, could only be applied to continuous action spaces, while almost all other policy gradient methods could be applied to both.
Here is the detail of useful fields of CartPole-v0:
Here is the detail of useful fields of CartPole-v1:
- ``state``: the position of the cart, the velocity of the cart, the angle of the pole and the velocity of the tip of the pole;
- ``action``: can only be one of ``[0, 1, 2]``, for moving the cart left, no move, and right;
@ -62,8 +62,8 @@ Setup Vectorized Environment
If you want to use the original ``gym.Env``:
::
train_envs = gym.make('CartPole-v0')
test_envs = gym.make('CartPole-v0')
train_envs = gym.make('CartPole-v1')
test_envs = gym.make('CartPole-v1')
Tianshou supports vectorized environment for all algorithms. It provides four types of vectorized environment wrapper:
@ -74,8 +74,8 @@ Tianshou supports vectorized environment for all algorithms. It provides four ty
::
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v1') for _ in range(100)])
Here, we set up 10 environments in ``train_envs`` and 100 environments in ``test_envs``.
@ -84,8 +84,8 @@ You can also try the super-fast vectorized environment `EnvPool <https://github.
::
import envpool
train_envs = envpool.make_gymnasium("CartPole-v0", num_envs=10)
test_envs = envpool.make_gymnasium("CartPole-v0", num_envs=100)
train_envs = envpool.make_gymnasium("CartPole-v1", num_envs=10)
test_envs = envpool.make_gymnasium("CartPole-v1", num_envs=100)
For the demonstration, here we use the second code-block.

View File

@ -353,7 +353,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto
::
policy = PGPolicy(...) # or other policies if you wish
env = gym.make("CartPole-v0")
env = gym.make("CartPole-v1")
replay_buffer = ReplayBuffer(size=10000)
@ -363,7 +363,7 @@ The general explanation is listed in :ref:`pseudocode`. Other usages of collecto
# the collector supports vectorized environments as well
vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3)
# buffer_num should be equal to (suggested) or larger than #envs
envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)])
envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(3)])
collector = Collector(policy, envs, buffer=vec_buffer)
# collect 3 episodes

View File

@ -159,7 +159,7 @@ toy_text and classic_control environments. For more information, please refer to
# install envpool: pip3 install envpool
import envpool
envs = envpool.make_gymnasium("CartPole-v0", num_envs=10)
envs = envpool.make_gymnasium("CartPole-v1", num_envs=10)
collector = Collector(policy, envs, buffer)
Here are some other `examples <https://github.com/sail-sg/envpool/tree/master/examples/tianshou_examples>`_.

View File

@ -180,7 +180,10 @@
"base_uri": "https://localhost:8080/"
},
"id": "vcvw9J8RNtFE",
"outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5"
"outputId": "b483fa8b-2a57-4051-a3d0-6d8162d948c5",
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
@ -200,7 +203,17 @@
" episode_per_test=10,\n",
" step_per_collect=2000,\n",
" batch_size=512,\n",
")\n",
").run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"print(result)"
]
},

View File

@ -59,9 +59,6 @@
"metadata": {
"editable": true,
"id": "ao9gWJDiHgG-",
"slideshow": {
"slide_type": ""
},
"tags": [
"hide-cell",
"remove-output"
@ -233,8 +230,12 @@
"colab": {
"base_uri": "https://localhost:8080/"
},
"editable": true,
"id": "i45EDnpxQ8gu",
"outputId": "b1666b88-0bfa-4340-868e-58611872d988"
"outputId": "b1666b88-0bfa-4340-868e-58611872d988",
"tags": [
"remove-output"
]
},
"outputs": [],
"source": [
@ -249,7 +250,7 @@
" batch_size=256,\n",
" step_per_collect=2000,\n",
" stop_fn=lambda mean_reward: mean_reward >= 195,\n",
")"
").run()"
]
},
{
@ -270,7 +271,8 @@
"base_uri": "https://localhost:8080/"
},
"id": "tJCPgmiyiaaX",
"outputId": "40123ae3-3365-4782-9563-46c43812f10f"
"outputId": "40123ae3-3365-4782-9563-46c43812f10f",
"tags": []
},
"outputs": [],
"source": [