This PR adds strict typing to the output of `update` and `learn` in all policies. This will likely be the last large refactoring PR before the next release (0.6.0, not 1.0.0), so it requires some attention. Several difficulties were encountered on the path to that goal: 1. The policy hierarchy is actually "broken" in the sense that the keys of dicts that were output by `learn` did not follow the same enhancement (inheritance) pattern as the policies. This is a real problem and should be addressed in the near future. Generally, several aspects of the policy design and hierarchy might deserve a dedicated discussion. 2. Each policy needs to be generic in the stats return type, because one might want to extend it at some point and then also extend the stats. Even within the source code base this pattern is necessary in many places. 3. The interaction between learn and update is a bit quirky, we currently handle it by having update modify special field inside TrainingStats, whereas all other fields are handled by learn. 4. The IQM module is a policy wrapper and required a TrainingStatsWrapper. The latter relies on a bunch of black magic. They were addressed by: 1. Live with the broken hierarchy, which is now made visible by bounds in generics. We use type: ignore where appropriate. 2. Make all policies generic with bounds following the policy inheritance hierarchy (which is incorrect, see above). We experimented a bit with nested TrainingStats classes, but that seemed to add more complexity and be harder to understand. Unfortunately, mypy thinks that the code below is wrong, wherefore we have to add `type: ignore` to the return of each `learn` ```python T = TypeVar("T", bound=int) def f() -> T: return 3 ``` 3. See above 4. Write representative tests for the `TrainingStatsWrapper`. Still, the black magic might cause nasty surprises down the line (I am not proud of it)... Closes #933 --------- Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de> Co-authored-by: Michael Panchenko <m.panchenko@appliedai.de>
892 lines
28 KiB
Plaintext
892 lines
28 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "PNM9wqstBSY_"
|
|
},
|
|
"source": [
|
|
"# Policy\n",
|
|
"In reinforcement learning, the agent interacts with environments to improve itself. In this tutorial we will concentrate on the agent part. In Tianshou, both the agent and the core DRL algorithm are implemented in the Policy module. Tianshou provides more than 20 Policy modules, each representing one DRL algorithm. See supported algorithms [here](https://github.com/thu-ml/tianshou).\n",
|
|
"\n",
|
|
"<div align=center>\n",
|
|
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/rl-loop.jpg\", title=\"The agents interacting with the environment\">\n",
|
|
"\n",
|
|
"<a> The agents interacting with the environment </a>\n",
|
|
"</div>\n",
|
|
"\n",
|
|
"All Policy modules inherit from a BasePolicy Class and share the same interface."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ZqdHYdoJJS51"
|
|
},
|
|
"source": [
|
|
"## Creating your own Policy\n",
|
|
"We will use the simple REINFORCE algorithm Policy to show the implementation of a Policy Module. The Policy we implement here will be a highly scaled-down version of [PGPolicy](https://github.com/thu-ml/tianshou/blob/master/tianshou/policy/modelfree/pg.py) in Tianshou."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "PWFBgZ4TJkfz"
|
|
},
|
|
"source": [
|
|
"### Initialization\n",
|
|
"Firstly we create the `REINFORCEPolicy` by inheriting from `BasePolicy` in Tianshou."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"editable": true,
|
|
"id": "cDlSjASbJmy-",
|
|
"slideshow": {
|
|
"slide_type": ""
|
|
},
|
|
"tags": [
|
|
"hide-cell",
|
|
"remove-output"
|
|
]
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import cast\n",
|
|
"\n",
|
|
"import numpy as np\n",
|
|
"import torch\n",
|
|
"import gymnasium as gym\n",
|
|
"\n",
|
|
"from dataclasses import dataclass\n",
|
|
"\n",
|
|
"from tianshou.data import Batch, ReplayBuffer, to_torch, to_torch_as, SequenceSummaryStats\n",
|
|
"from tianshou.policy import BasePolicy, TrainingStats\n",
|
|
"from tianshou.utils.net.common import Net\n",
|
|
"from tianshou.utils.net.discrete import Actor\n",
|
|
"from tianshou.data.types import BatchWithReturnsProtocol, RolloutBatchProtocol"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class REINFORCEPolicy(BasePolicy):\n",
|
|
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(self):\n",
|
|
" super().__init__(action_space=action_space)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "qc1RnIBbLCDN"
|
|
},
|
|
"source": [
|
|
"As we have mentioned, the Policy Module mainly does two things:\n",
|
|
"\n",
|
|
"\n",
|
|
"1. `policy.forward()` receives observation and other information (stored in a Batch) from the environment and returns a new Batch containing the action.\n",
|
|
"2. `policy.update()` receives training data sampled from the replay buffer and updates itself, and then returns logging details.\n",
|
|
"\n",
|
|
"\n",
|
|
"<div align=center>\n",
|
|
"<img src=\"https://tianshou.readthedocs.io/en/master/_images/pipeline.png\">\n",
|
|
"\n",
|
|
"<a> policy.forward() and policy.update() </a>\n",
|
|
"</div>\n",
|
|
"\n",
|
|
"We also need to take care of the following things:\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"1. Since Tianshou is a **Deep** RL libraries, there should be a policy network in our Policy Module, \n",
|
|
"also a Torch optimizer.\n",
|
|
"2. In Tianshou's BasePolicy, `Policy.update()` first calls `Policy.process_fn()` to \n",
|
|
"preprocess training data and computes quantities like episodic returns (gradient free), \n",
|
|
"then it will call `Policy.learn()` to perform the back-propagation.\n",
|
|
"3. Each Policy is accompanied by a dedicated implementation of `TrainingStats` to store details of training.\n",
|
|
"\n",
|
|
"Then we get the implementation below.\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "6j32PSKUQ23w"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"@dataclass(kw_only=True)\n",
|
|
"class REINFORCETrainingStats(TrainingStats):\n",
|
|
" \"\"\"A dedicated class for REINFORCE training statistics.\"\"\"\n",
|
|
"\n",
|
|
" loss: SequenceSummaryStats\n",
|
|
"\n",
|
|
"\n",
|
|
"class REINFORCEPolicy(BasePolicy[REINFORCETrainingStats]):\n",
|
|
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self, model: torch.nn.Module, optim: torch.optim.Optimizer, action_space: gym.Space\n",
|
|
" ):\n",
|
|
" super().__init__(action_space=action_space)\n",
|
|
" self.actor = model\n",
|
|
" self.optim = optim\n",
|
|
"\n",
|
|
" def forward(self, batch: Batch) -> Batch:\n",
|
|
" \"\"\"Compute action over the given batch data.\"\"\"\n",
|
|
" act = None\n",
|
|
" return Batch(act=act)\n",
|
|
"\n",
|
|
" def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n",
|
|
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
|
|
" pass\n",
|
|
"\n",
|
|
" def learn(self, batch: Batch, batch_size: int, repeat: int) -> REINFORCETrainingStats:\n",
|
|
" \"\"\"Perform the back-propagation.\"\"\"\n",
|
|
" return"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tjtqjt8WRY5e"
|
|
},
|
|
"source": [
|
|
"### Policy.forward()\n",
|
|
"According to the equation of REINFORCE algorithm in Spinning Up's [documentation](https://spinningup.openai.com/en/latest/algorithms/vpg.html), we need to map the observation to an action distribution in action space using neural network (`self.actor`).\n",
|
|
"\n",
|
|
"<div align=center>\n",
|
|
"<img src=\"https://spinningup.openai.com/en/latest/_images/math/3d29a18c0f98b1cdb656ecdf261ee37ffe8bb74b.svg\">\n",
|
|
"</div>\n",
|
|
"\n",
|
|
"Let us suppose the action space is discrete, and the distribution is a simple categorical distribution.\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "uE4YDE-_RwgN"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def forward(self, batch: Batch) -> Batch:\n",
|
|
" \"\"\"Compute action over the given batch data.\"\"\"\n",
|
|
" self.dist_fn = torch.distributions.Categorical\n",
|
|
" logits = self.actor(batch.obs)\n",
|
|
" dist = self.dist_fn(logits)\n",
|
|
" act = dist.sample()\n",
|
|
" return Batch(act=act, dist=dist)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "CultfOeuTx2V"
|
|
},
|
|
"source": [
|
|
"### Policy.process_fn()\n",
|
|
"Now that we have defined our actor, if given training data we can set up a loss function and optimize our neural network. However, before that, we must first calculate episodic returns for every step in our training data to construct the REINFORCE loss function.\n",
|
|
"\n",
|
|
"Calculating episodic return is not hard, given `ReplayBuffer.next()` allows us to access every reward to go in an episode. A more convenient way would be to simply use the built-in method `BasePolicy.compute_episodic_return()` inherited from BasePolicy.\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "wPAmOD7zV7n2"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def process_fn(self, batch: Batch, buffer: ReplayBuffer, indices: np.ndarray) -> Batch:\n",
|
|
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
|
|
" returns, _ = self.compute_episodic_return(batch, buffer, indices, gamma=0.99, gae_lambda=1.0)\n",
|
|
" batch.returns = returns\n",
|
|
" return batch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "XA8OF4GnWWr5"
|
|
},
|
|
"source": [
|
|
"`BasePolicy.compute_episodic_return()` could also be used to compute [GAE](https://arxiv.org/abs/1506.02438). Another similar method is `BasePolicy.compute_nstep_return()`. Check the [source code](https://github.com/thu-ml/tianshou/blob/6fc68578127387522424460790cbcb32a2bd43c4/tianshou/policy/base.py#L304) for more details."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "7UsdzNaOXPpC"
|
|
},
|
|
"source": [
|
|
"### Policy.learn()\n",
|
|
"Data batch returned by `Policy.process_fn()` will flow into `Policy.learn()`. Finally,\n",
|
|
"we can construct our loss function and perform the back-propagation. The method \n",
|
|
"should look something like this:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "aCO-dLXWXtz9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"from tianshou.utils.optim import optim_step\n",
|
|
"\n",
|
|
"\n",
|
|
"def learn(self, batch: Batch, batch_size: int, repeat: int):\n",
|
|
" \"\"\"Perform the back-propagation.\"\"\"\n",
|
|
" train_losses = []\n",
|
|
" for _ in range(repeat):\n",
|
|
" for minibatch in batch.split(batch_size, merge_last=True):\n",
|
|
" result = self(minibatch)\n",
|
|
" dist = result.dist\n",
|
|
" act = to_torch_as(minibatch.act, result.act)\n",
|
|
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
|
|
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
|
|
" loss = -(log_prob * ret).mean()\n",
|
|
" optim_step(loss, self.optim)\n",
|
|
" train_losses.append(loss.item())\n",
|
|
"\n",
|
|
" return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "1BtuV2W0YJTi"
|
|
},
|
|
"source": [
|
|
"## Implementation\n",
|
|
"Now we can assemble the methods and form a REINFORCE Policy. The outputs of\n",
|
|
"`learn` will be collected to a dedicated dataclass.\n",
|
|
"\n",
|
|
"We will also use protocols to specify what fields are expected and produced inside a `Batch` in\n",
|
|
"each processing step. By using protocols, we can get better type checking and IDE support \n",
|
|
"without having to implement a separate class for each combination of fields."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Ab0KNQHTOlGo"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class REINFORCEPolicy(BasePolicy):\n",
|
|
" \"\"\"Implementation of REINFORCE algorithm.\"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" *,\n",
|
|
" model: torch.nn.Module,\n",
|
|
" optim: torch.optim.Optimizer,\n",
|
|
" action_space: gym.Space,\n",
|
|
" ):\n",
|
|
" super().__init__(action_space=action_space)\n",
|
|
" self.actor = model\n",
|
|
" self.optim = optim\n",
|
|
" # action distribution\n",
|
|
" self.dist_fn = torch.distributions.Categorical\n",
|
|
"\n",
|
|
" def forward(self, batch: Batch) -> Batch:\n",
|
|
" \"\"\"Compute action over the given batch data.\"\"\"\n",
|
|
" logits, _ = self.actor(batch.obs)\n",
|
|
" dist = self.dist_fn(logits)\n",
|
|
" act = dist.sample()\n",
|
|
" return Batch(act=act, dist=dist)\n",
|
|
"\n",
|
|
" def process_fn(\n",
|
|
" self, batch: RolloutBatchProtocol, buffer: ReplayBuffer, indices: np.ndarray\n",
|
|
" ) -> BatchWithReturnsProtocol:\n",
|
|
" \"\"\"Compute the discounted returns for each transition.\"\"\"\n",
|
|
" returns, _ = self.compute_episodic_return(\n",
|
|
" batch, buffer, indices, gamma=0.99, gae_lambda=1.0\n",
|
|
" )\n",
|
|
" batch.returns = returns\n",
|
|
" return cast(BatchWithReturnsProtocol, batch)\n",
|
|
"\n",
|
|
" def learn(\n",
|
|
" self, batch: BatchWithReturnsProtocol, batch_size: int, repeat: int\n",
|
|
" ) -> REINFORCETrainingStats:\n",
|
|
" \"\"\"Perform the back-propagation.\"\"\"\n",
|
|
" train_losses = []\n",
|
|
" for _ in range(repeat):\n",
|
|
" for minibatch in batch.split(batch_size, merge_last=True):\n",
|
|
" result = self(minibatch)\n",
|
|
" dist = result.dist\n",
|
|
" act = to_torch_as(minibatch.act, result.act)\n",
|
|
" ret = to_torch(minibatch.returns, torch.float, result.act.device)\n",
|
|
" log_prob = dist.log_prob(act).reshape(len(ret), -1).transpose(0, 1)\n",
|
|
" loss = -(log_prob * ret).mean()\n",
|
|
" optim_step(loss, self.optim)\n",
|
|
" train_losses.append(loss.item())\n",
|
|
"\n",
|
|
" return REINFORCETrainingStats(loss=SequenceSummaryStats.from_sequence(train_losses))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "xlPAbh0lKti8"
|
|
},
|
|
"source": [
|
|
"## Use the policy\n",
|
|
"Note that `BasePolicy` itself inherits from `torch.nn.Module`. As a result, you can consider all Policy modules as a Torch Module. They share similar APIs.\n",
|
|
"\n",
|
|
"Firstly we will initialize a new REINFORCE policy."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "JkLFA9Z1KjuX"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"state_shape = 4\n",
|
|
"action_shape = 2\n",
|
|
"# Usually taken from an env by using env.action_space\n",
|
|
"action_space = gym.spaces.Box(low=-1, high=1, shape=(2,))\n",
|
|
"net = Net(state_shape, hidden_sizes=[16, 16], device=\"cpu\")\n",
|
|
"actor = Actor(net, action_shape, device=\"cpu\").to(\"cpu\")\n",
|
|
"optim = torch.optim.Adam(actor.parameters(), lr=0.0003)\n",
|
|
"\n",
|
|
"policy = REINFORCEPolicy(model=actor, optim=optim, action_space=action_space)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "LAo_0t2fekUD"
|
|
},
|
|
"source": [
|
|
"REINFORCE policy shares same APIs with the Torch Module."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "UiuTc8RhJiEi",
|
|
"outputId": "9b5bc54c-6303-45f3-ba81-2216a44931e8"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(policy)\n",
|
|
"print(\"========================================\")\n",
|
|
"for param in policy.parameters():\n",
|
|
" print(param.shape)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "-RCrsttYgAG-"
|
|
},
|
|
"source": [
|
|
"### Making decision\n",
|
|
"Given a batch of observations, the policy can return a batch of actions and other data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "0jkBb6AAgUla",
|
|
"outputId": "37948844-cdd8-4567-9481-89453c80a157"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"obs_batch = Batch(obs=np.ones(shape=(256, 4)))\n",
|
|
"action = policy(obs_batch) # forward() method is called\n",
|
|
"print(action)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "swikhnuDfKep"
|
|
},
|
|
"source": [
|
|
"### Save and Load models\n",
|
|
"Naturally, Tianshou Policy can be saved and loaded like a normal Torch Network."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "tYOoWM_OJRnA"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"torch.save(policy.state_dict(), \"policy.pth\")\n",
|
|
"assert policy.load_state_dict(torch.load(\"policy.pth\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "gp8PzOYsg5z-"
|
|
},
|
|
"source": [
|
|
"### Algorithm Updating\n",
|
|
"We have to collect some data and save them in the ReplayBuffer before updating our agent(policy). Typically we use collector to collect data, but we leave this part till later when we have learned the Collector in Tianshou. For now we generate some **fake** data."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "XrrPxOUAYShR"
|
|
},
|
|
"source": [
|
|
"#### Generating fake data\n",
|
|
"Firstly, we need to \"pretend\" that we are using the \"Policy\" to collect data. We plan to collect 10 data so that we can update our algorithm."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "a14CmzSfYh5C",
|
|
"outputId": "aaf45a1f-5e21-4bc8-cbe3-8ce798258af0"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# a buffer is initialised with its maxsize set to 20.\n",
|
|
"print(\"========================================\")\n",
|
|
"buf = ReplayBuffer(size=12)\n",
|
|
"print(buf)\n",
|
|
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))\n",
|
|
"env = gym.make(\"CartPole-v1\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "8S94cV7yZITR"
|
|
},
|
|
"source": [
|
|
"Now we are pretending to collect the first episode. The first episode ends at step 3 (perhaps because we are performing too badly)."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "a_mtvbmBZbfs"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"obs, info = env.reset()\n",
|
|
"for i in range(3):\n",
|
|
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
|
|
" obs_next, rew, _, truncated, info = env.step(act)\n",
|
|
" # pretend ending at step 3\n",
|
|
" terminated = True if i == 2 else False\n",
|
|
" info[\"id\"] = i\n",
|
|
" buf.add(\n",
|
|
" Batch(\n",
|
|
" obs=obs,\n",
|
|
" act=act,\n",
|
|
" rew=rew,\n",
|
|
" terminated=terminated,\n",
|
|
" truncated=truncated,\n",
|
|
" obs_next=obs_next,\n",
|
|
" info=info,\n",
|
|
" )\n",
|
|
" )\n",
|
|
" obs = obs_next"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(buf)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pkxq4gu9bGkt"
|
|
},
|
|
"source": [
|
|
"Now we are pretending to collect the second episode. At step 7 the second episode still doesn't end, but we are unwilling to wait, so we stop collecting to update the algorithm."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "pAoKe02ybG68"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"obs, info = env.reset()\n",
|
|
"for i in range(3, 10):\n",
|
|
" act = policy(Batch(obs=obs[np.newaxis, :])).act.item()\n",
|
|
" obs_next, rew, _, truncated, info = env.step(act)\n",
|
|
" # pretend this episode never end\n",
|
|
" terminated = False\n",
|
|
" info[\"id\"] = i\n",
|
|
" buf.add(\n",
|
|
" Batch(\n",
|
|
" obs=obs,\n",
|
|
" act=act,\n",
|
|
" rew=rew,\n",
|
|
" terminated=terminated,\n",
|
|
" truncated=truncated,\n",
|
|
" obs_next=obs_next,\n",
|
|
" info=info,\n",
|
|
" )\n",
|
|
" )\n",
|
|
" obs = obs_next"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "MKM6aWMucv-M"
|
|
},
|
|
"source": [
|
|
"Our replay buffer looks like this now."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "CSJEEWOqXdTU",
|
|
"outputId": "2b3bb75c-f219-4e82-ca78-0ea6173a91f9"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(buf)\n",
|
|
"print(\"maxsize: {}, data length: {}\".format(buf.maxsize, len(buf)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "55VWhWpkdfEb"
|
|
},
|
|
"source": [
|
|
"#### Updates\n",
|
|
"Now we have got a replay buffer with 10 data steps in it. We can call `Policy.update()` to train."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "i_O1lJDWdeoc",
|
|
"outputId": "b154741a-d6dc-46cb-898f-6e84fa14e5a7"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# 0 means sample all data from the buffer\n",
|
|
"# batch_size=10 defines the training batch size\n",
|
|
"# repeat=6 means repeat the training for 6 times\n",
|
|
"policy.update(0, buf, batch_size=10, repeat=6)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "enqlFQLSJrQl"
|
|
},
|
|
"source": [
|
|
"Not that difficult, right?"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "QJ5krjrcbuiA"
|
|
},
|
|
"source": [
|
|
"## Further Reading\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "pmWi3HuXWcV8"
|
|
},
|
|
"source": [
|
|
"### Pre-defined Networks\n",
|
|
"Tianshou provides numerous pre-defined networks usually used in DRL so that you don't have to bother yourself. Check this [documentation](https://tianshou.readthedocs.io/en/master/api/tianshou.utils.html#pre-defined-networks) for details."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "UPVl5LBEWJ0t"
|
|
},
|
|
"source": [
|
|
"### How to compute GAE on your own?\n",
|
|
"(Note that for this reading you need to understand the calculation of [GAE](https://arxiv.org/abs/1506.02438) advantage first)\n",
|
|
"\n",
|
|
"In terms of code implementation, perhaps the most difficult and annoying part is computing GAE advantage. Just now, we use the `self.compute_episodic_return()` method inherited from `BasePolicy` to save us from all those troubles. However, it is still important that we know the details behind this.\n",
|
|
"\n",
|
|
"To compute GAE advantage, the usage of `self.compute_episodic_return()` may go like:"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "D34GlVvPNz08",
|
|
"outputId": "43a4e5df-59b5-4e4a-c61c-e69090810215"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"batch, indices = buf.sample(0) # 0 means sampling all the data from the buffer\n",
|
|
"returns, advantage = BasePolicy.compute_episodic_return(\n",
|
|
" batch, buf, indices, v_s_=np.zeros(10), v_s=np.zeros(10), gamma=1.0, gae_lambda=1.0\n",
|
|
")\n",
|
|
"print(returns)\n",
|
|
"print(advantage)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "h_5Dt6XwQLXV"
|
|
},
|
|
"source": [
|
|
"In the code above, we sample all the 10 data in the buffer and try to compute the GAE advantage. As we know, we need to estimate the value function of every observation to compute GAE advantage. so the passed in `v_s` is the value of batch.obs, `v_s_` is the value of batch.obs_next this is usually computed by:\n",
|
|
"\n",
|
|
"`v_s = critic(batch.obs)`,\n",
|
|
"\n",
|
|
"`v_s_ = critic(batch.obs_next)`,\n",
|
|
"\n",
|
|
"where both `v_s` and `v_s_` are 10 dimensional arrays and `critic` is usually a neural network.\n",
|
|
"\n",
|
|
"After we've got all those values, GAE can be computed following the equation below."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "ooHNIICGUO19"
|
|
},
|
|
"source": [
|
|
"\\begin{aligned}\n",
|
|
"\\hat{A}_{t}^{\\mathrm{GAE}(\\gamma, \\lambda)}: =& \\sum_{l=0}^{\\infty}(\\gamma \\lambda)^{l} \\delta_{t+l}^{V}\n",
|
|
"\\end{aligned}\n",
|
|
"\n",
|
|
"while\n",
|
|
"\n",
|
|
"\\begin{equation}\n",
|
|
"\\delta_{t}^{V} \\quad=-V\\left(s_{t}\\right)+r_{t}+\\gamma V\\left(s_{t+1}\\right)\n",
|
|
"\\end{equation}\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "eV6XZaouU7EV"
|
|
},
|
|
"source": [
|
|
"But, if you do follow this equation I referred from the paper. You probably will get a slightly lower performance than you expected. There are at least 3 \"bugs\" in this equation."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "FCxD9gNNVYbd"
|
|
},
|
|
"source": [
|
|
"**First** is that Gym always returns you a `obs_next` even if this is already the last step. The value of this timestep is exactly 0 and you should not let the neural network estimate it."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "rNZNUNgQVvRJ",
|
|
"outputId": "44354595-c25a-4da8-b4d8-cffa31ac4b7d"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Assume v_s_ is got by calling critic(batch.obs_next)\n",
|
|
"v_s_ = np.ones(10)\n",
|
|
"v_s_ *= ~batch.done\n",
|
|
"print(v_s_)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "2EtMi18QWXTN"
|
|
},
|
|
"source": [
|
|
"After the fix above, we will perhaps get a more accurate estimate.\n",
|
|
"\n",
|
|
"**Secondly**, you must know when to stop bootstrapping. Usually we stop bootstrapping when we meet a `done` flag. However, in the buffer above, the last (10th) step is not marked by done=True, because the collecting has not finished. We must know all those unfinished steps so that we know when to stop bootstrapping.\n",
|
|
"\n",
|
|
"Luckily, this can be done under the assistance of buffer because buffers in Tianshou not only store data, but also help you manage data trajectories."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "saluvX4JU6bC",
|
|
"outputId": "2994d178-2f33-40a0-a6e4-067916b0b5c5"
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"unfinished_indexes = buf.unfinished_index()\n",
|
|
"print(unfinished_indexes)\n",
|
|
"done_indexes = np.where(batch.done)[0]\n",
|
|
"print(done_indexes)\n",
|
|
"stop_bootstrap_ids = np.concatenate([unfinished_indexes, done_indexes])\n",
|
|
"print(stop_bootstrap_ids)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "qp6vVE4dYWv1"
|
|
},
|
|
"source": [
|
|
"**Thirdly**, there are some special indexes which are marked by done flag. However, its value for obs_next should not be zero. This is because these steps are usually those at the last step of an episode, but this episode stops not because the agent can no longer get any rewards (value=0), but because the episode is too long so we have to truncate it. These kind of steps are always marked with `info['TimeLimit.truncated']=True` in Gym."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "tWkqXRJfZTvV"
|
|
},
|
|
"source": [
|
|
"As a result, we need to rewrite the equation above\n",
|
|
"\n",
|
|
"`v_s_ *= ~batch.done`"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "kms-QtxKZe-M"
|
|
},
|
|
"source": [
|
|
"to\n",
|
|
"\n",
|
|
"```\n",
|
|
"mask = batch.info['TimeLimit.truncated'] | (~batch.done)\n",
|
|
"v_s_ *= mask\n",
|
|
"\n",
|
|
"```\n",
|
|
"\n",
|
|
"\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "u_aPPoKraBu6"
|
|
},
|
|
"source": [
|
|
"### Summary\n",
|
|
"If you already felt bored by now, simply remember that Tianshou can help handle all these little details so that you can focus on the algorithm itself. Just call `BasePolicy.compute_episodic_return()`.\n",
|
|
"\n",
|
|
"If you still feel interested, we would recommend you check Appendix C in this [paper](https://arxiv.org/abs/2107.14171v2) and implementation of `BasePolicy.value_mask()` and `BasePolicy.compute_episodic_return()` for details."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {
|
|
"id": "2cPnUXRBWKD9"
|
|
},
|
|
"source": [
|
|
"<center>\n",
|
|
"<img src=../_static/images/timelimit.svg></img>\n",
|
|
"</center>\n",
|
|
"<center>\n",
|
|
"<img src=../_static/images/policy_table.svg></img>\n",
|
|
"</center>"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.5"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|