TRPO benchmark release (#340)

This commit is contained in:
ChenDRAG 2021-04-19 17:05:06 +08:00 committed by GitHub
parent f68cb78ed7
commit a57503c0aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 213 additions and 16 deletions

View File

@ -26,7 +26,7 @@
- [Quantile Regression DQN (QRDQN)](https://arxiv.org/pdf/1710.10044.pdf)
- [Policy Gradient (PG)](https://papers.nips.cc/paper/1713-policy-gradient-methods-for-reinforcement-learning-with-function-approximation.pdf)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/)
- [Trust Region Policy Optimization](https://arxiv.org/pdf/1502.05477.pdf)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf)
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf)
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf)
@ -41,7 +41,7 @@
Here is Tianshou's other features:
- Elegant framework, using only ~3000 lines of code
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/PPO/DDPG/TD3/SAC algorithms
- State-of-the-art [MuJoCo benchmark](https://github.com/thu-ml/tianshou/tree/master/examples/mujoco) for REINFORCE/A2C/TRPO/PPO/DDPG/TD3/SAC algorithms
- Support parallel environment simulation (synchronous or asynchronous) for all algorithms [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#parallel-sampling)
- Support recurrent state representation in actor network and critic network (RNN-style training for POMDP) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#rnn-style-training)
- Support any type of environment state/action (e.g. a dict, a self-defined class, ...) [Usage](https://tianshou.readthedocs.io/en/latest/tutorials/cheatsheet.html#user-defined-environment-and-different-state-representation)

View File

@ -17,7 +17,8 @@ Supported algorithms are listed below:
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5d580c36624df0548818edf1f9b111b318dd7fd8)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
#### Usage
@ -65,7 +66,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
### DDPG
| Environment | Tianshou (1M) | [SpinningUp (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) |
| Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) |
| :--------------------: | :---------------: | :----------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: |
| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 |
| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 |
@ -81,7 +82,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
### TD3
| Environment | Tianshou (1M) | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) |
| Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) |
| :--------------------: | :---------------: | :----------------------------------------------------------: | :-------------------------------------------: |
| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 |
| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 |
@ -100,7 +101,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
### SAC
| Environment | Tianshou (1M) | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) |
| Environment | Tianshou (1M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) |
| :--------------------: | :----------------: | :----------------------------------------------------------: | :-------------------------------------------: |
| Ant | **5850.2±475.7** | ~3980 | ~3720 |
| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 |
@ -141,7 +142,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
| InvertedDoublePendulum | **7726.2±1287.3** |
| Environment | Tianshou (3M) | [SpinningUp (VPG Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)<sup>[[7]](#footnote7)</sup> |
| Environment | Tianshou (3M) | [Spinning Up (VPG PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)<sup>[[7]](#footnote7)</sup> |
| :--------------------: | :---------------: | :----------------------------------------------------------: |
| Ant | **474.9+-133.5** | ~5 |
| HalfCheetah | **884.0+-41.0** | ~600 |
@ -167,7 +168,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
### A2C
| Environment | Tianshou (3M) | [Spinning Up(Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html) |
| Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html) |
| :--------------------: | :----------------: | :----------------------------------------------------------: |
| Ant | **5236.8+-236.7** | ~5 |
| HalfCheetah | **2377.3+-1363.7** | ~600 |
@ -196,7 +197,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
#### Hints for A2C
1. We choose `clip` action method in A2C instead of `tanh` option as used in REINFORCE simply to be consistent with original implementation. `tanh` may be better or equally well but we didn't have a try.
2. (Initial) learning rate, lr decay, `step-per-collect` and `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents have been trained). Below are our findings.
2. (Initial) learning rate, lr\_decay, `step-per-collect` and `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents have been trained). Below are our findings.
3. `step-per-collect` / `training-num` are equal to `bootstrap-lenghth`, which is the max length of an "episode" used in GAE estimator and 80/16=5 in default settings. When `bootstrap-lenghth` is small, (maybe) because GAE can look forward at most 5 steps and use bootstrap strategy very often, the critic is less well-trained leading the actor to a not very high score. However, if we increase `step-per-collect` to increase `bootstrap-lenghth` (e.g. 256/16=16), actor/critic will be updated less often, resulting in low sample efficiency and slow training process. To conclude, If you don't restrict env timesteps, you can try using larger `bootstrap-lenghth` and train with more steps to get a better converged score. Train slower, achieve higher.
4. The learning rate 7e-4 with decay strategy is appropriate for `step-per-collect=80` and `training-num=16`. But if you use a larger `step-per-collect`(e.g. 256 - 2048), 7e-4 is a little bit small for `lr` because each update will have more data, less noise and thus smaller deviation in this case. So it is more appropriate to use a higher learning rate (e.g. 1e-3) to boost performance in this setting. If plotting results arise fast in early stages and become unstable later, consider lr decay first before decreasing lr.
5. `max-grad-norm` didn't really help in our experiments. We simply keep it for consistency with other open-source implementations (e.g. SB3).
@ -206,7 +207,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
### PPO
| Environment | Tianshou (1M) | [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [baselines](http://htmlpreview.github.io/?https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [spinningup(pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) |
| Environment | Tianshou (1M) | [ikostrikov/pytorch-a2c-ppo-acktr-gail](https://github.com/ikostrikov/pytorch-a2c-ppo-acktr-gail) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) |
| :--------------------: | :----------------: | :----------------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| Ant | **3258.4+-1079.3** | N | N | N | ~650 |
| HalfCheetah | **5783.9+-1244.0** | ~3120 | ~1800 | ~1700 | ~1670 |
@ -218,7 +219,7 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
| InvertedPendulum | **1000.0+-0.0** | N | **~1000** | ~940 | N |
| InvertedDoublePendulum | **9231.3+-270.4** | N | ~8000 | ~7350 | N |
| Environment | Tianshou (3M) | [Spinning Up(Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) |
| Environment | Tianshou (3M) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench_ppo.html) |
| :--------------------: | :----------------: | :----------------------------------------------------------: |
| Ant | **4079.3+-880.2** | ~3000 |
| HalfCheetah | **7337.4+-1508.2** | ~3130 |
@ -234,12 +235,36 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
#### Hints for PPO
1. Following [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990) Sec 3.5, we use "recompute advantage" strategy, which contributes a lot to our SOTA benchmark. However, I personally don't quite agree with their explanation about why "recompute advantage" helps. They stated that it's because old strategy "makes it impossible to compute advantages as the temporal structure is broken", but PPO's update equation is designed to learn from slightly-outdated advantages. I think the only reason "recompute advantage" works is that it update the critic several times rather than just one time per update, which leads to a better value function estimation.
2. We have done full scale ablation studies of PPO algorithm's hyperparameters. Here are our findings: In mujoco settings, `value-clip` and `norm-adv` may help a litte bit in some games (e.g. `norm-adv` helps stabilize training in InvertedPendulum-v2), but they make no difference to overall performance. So in our benchmark we do not use such tricks. We validate that setting `ent-coef` to 0.0 rather than 0.01 will increase overall performance in mujoco environments. `max-grad-norm` still offers no help for PPO algorithm, but we still keep it for consistency.
2. We have done full scale ablation studies of PPO algorithm's hyperparameters. Here are our findings: In Mujoco settings, `value-clip` and `norm-adv` may help a litte bit in some games (e.g. `norm-adv` helps stabilize training in InvertedPendulum-v2), but they make no difference to overall performance. So in our benchmark we do not use such tricks. We validate that setting `ent-coef` to 0.0 rather than 0.01 will increase overall performance in mujoco environments. `max-grad-norm` still offers no help for PPO algorithm, but we still keep it for consistency.
3. [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990)'s work indicates that using `gae-lambda` 0.9 and changing policy network's width based on which game you play (e.g. use [16, 16] `hidden-sizes` for `actor` network in HalfCheetah and [256, 256] for Ant) may help boost performance. Our ablation studies say otherwise: both options may lead to equal or lower performance overall in our experiments. We are not confident about this claim because we didn't change learning rate and other maybe-correlated factors in our experiments. So if you want, you can still have a try.
4. `batch-size` 128 and 64 (default) work equally well. Changing `training-num` alone slightly (maybe in range [8, 128]) won't affect performance. For bound action method, both `clip` and `tanh` work quite well.
5. In OPENAI implementations of PPO, they multiply value loss with a factor of 0.5 for no good reason (see this [issue](https://github.com/openai/baselines/issues/445#issuecomment-777988738)). We do not do so and therefore make our `vf-coef` 0.25 (half of standard 0.5). However, since value loss is only used to optimize `critic` network, setting different `vf-coef` should in theory make no difference if using Adam optimizer.
### TRPO
|Environment| Tianshou| [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf)| [PPO paper](https://arxiv.org/pdf/1707.06347.pdf)|[OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm)|[Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html)|
| :---------------: | :---------------: | :---------------: | :---------------: | :---------------: |:---------------: |
|Ant|**2866.7±707.9** | ~0 | N | N | ~150 |
|HalfCheetah|**4471.2±804.9** | ~400 | ~0| ~1350 | ~850 |
|Hopper| 2046.0±1037.9| ~1400 | ~2100 | **~2200** | ~1200 |
|Walker2d|**3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 |
|Swimmer|40.9±19.6 | ~40 | **~121** | ~95| ~85 |
|Humanoid|**810.1±126.1**| N | N | N | N |
|Reacher| **-5.1±0.8** | -8 | ~-115 | **~-5** | N |
|InvertedPendulum|**1000.0±0.0** | **~1000** | **~1000** | ~910 | N |
|InvertedDoublePendulum|**8435.2±1073.3**| ~800 | ~200 | ~7000 | N |
\* details<sup>[[4]](#footnote4)</sup><sup>[[5]](#footnote5)</sup>
#### Hints for TRPO
1. We have tried `step-per-collect` in (80, 1024, 2048, 4096), and `training-num` in (4, 16, 32, 64), and found out 1024 for `step-per-collect` (same as OpenAI Baselines) and smaller `training-num` (below 16) are good choices. Set `training-num` to 4 is actually better but we still use 16 considering the boost of training speed.
2. Advantage normalization is a standard trick in TRPO, but we found it of minor help, just like in PPO.
3. Larger `optim-critic-iters` (than 5, as used in OpenAI Baselines) helps in most environments. Smaller lr and lr\_decay strategy also help a tiny little bit for performance.
4. `gae-lambda` 0.98 and 0.95 work equally well.
5. We use GAE returns (GAE advantage + value) as the target of critic network when updating, while people usually tend to use reward to go (lambda = 0.) as target. We found that they work equally well although using GAE returns is a little bit inaccurate (biased) by math.
6. Empirically, Swimmer-v3 usually requires larger bootstrap lengths and learning rate. Humanoid-v3 and InvertedPendulum-v2, however, are on the opposite.
7. In contrast, with the statement made in TRPO paper, we found that backtracking in line search is rarely used at least in Mujoco settings, which is actually unimportant. This makes TRPO algorithm actually the same as TNPG algorithm (described in this [paper](http://proceedings.mlr.press/v48/duan16.html)). This also explains why TNPG and TRPO's plotting results look so similar in that paper.
8. "recompute advantage" is helpful in PPO but doesn't help in TRPO.
## Note

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 148 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 110 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 118 KiB

View File

@ -33,7 +33,7 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=80)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means caculating all data in one singe forward.
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)

View File

@ -33,7 +33,7 @@ def get_args():
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=2048)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means caculating all data in one singe forward.
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=64)
parser.add_argument('--test-num', type=int, default=10)

View File

@ -0,0 +1,173 @@
#!/usr/bin/env python3
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import TRPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*',
default=[64, 64]) # baselines [32, 32]
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=1024)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
# trpo special
parser.add_argument('--rew-norm', type=int, default=True)
parser.add_argument('--gae-lambda', type=float, default=0.95)
# TODO tanh support
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--optim-critic-iters', type=int, default=20)
parser.add_argument('--max-kl', type=float, default=0.01)
parser.add_argument('--backtrack-coeff', type=float, default=0.8)
parser.add_argument('--max-backtracks', type=int, default=10)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()
def test_trpo(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
norm_obs=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect) * args.epoch
lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = TRPOPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm, action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space,
advantage_normalization=args.norm_adv,
optim_critic_iters=args.optim_critic_iters,
max_kl=args.max_kl,
backtrack_coeff=args.backtrack_coeff,
max_backtracks=args.max_backtracks)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_trpo'
log_path = os.path.join(args.logdir, args.task, 'trpo', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
test_trpo()

View File

@ -178,8 +178,7 @@ class TRPOPolicy(A2CPolicy):
if kl < self._delta and new_actor_loss < actor_loss:
if i > 0:
warnings.warn(f"Backtracking to step {i}. "
"Hyperparamters aren't good enough.")
warnings.warn(f"Backtracking to step {i}.")
break
elif i < self._max_backtracks - 1:
step_size = step_size * self._backtrack_coeff