NPG Mujoco benchmark release (#347)

This commit is contained in:
ChenDRAG 2021-04-21 16:31:20 +08:00 committed by GitHub
parent 1dcf65fe21
commit 844d7703c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 204 additions and 15 deletions

View File

@ -16,6 +16,7 @@ Supported algorithms are listed below:
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
- [Natural Policy Gradient](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/1dcf65fe21dc7636966796b6099ede1f4bd775e1)
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
@ -242,17 +243,17 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
### TRPO
|Environment| Tianshou| [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf)| [PPO paper](https://arxiv.org/pdf/1707.06347.pdf)|[OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm)|[Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html)|
| :---------------: | :---------------: | :---------------: | :---------------: | :---------------: |:---------------: |
|Ant|**2866.7±707.9** | ~0 | N | N | ~150 |
|HalfCheetah|**4471.2±804.9** | ~400 | ~0| ~1350 | ~850 |
|Hopper| 2046.0±1037.9| ~1400 | ~2100 | **~2200** | ~1200 |
|Walker2d|**3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 |
|Swimmer|40.9±19.6 | ~40 | **~121** | ~95| ~85 |
|Humanoid|**810.1±126.1**| N | N | N | N |
|Reacher| **-5.1±0.8** | -8 | ~-115 | **~-5** | N |
|InvertedPendulum|**1000.0±0.0** | **~1000** | **~1000** | ~910 | N |
|InvertedDoublePendulum|**8435.2±1073.3**| ~800 | ~200 | ~7000 | N |
| Environment | Tianshou (1M) | [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) |
| :--------------------: | :---------------: | :-------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
| Ant | **2866.7±707.9** | ~0 | N | N | ~150 |
| HalfCheetah | **4471.2±804.9** | ~400 | ~0 | ~1350 | ~850 |
| Hopper | 2046.0±1037.9 | ~1400 | ~2100 | **~2200** | ~1200 |
| Walker2d | **3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 |
| Swimmer | 40.9±19.6 | ~40 | **~121** | ~95 | ~85 |
| Humanoid | **810.1±126.1** | N | N | N | N |
| Reacher | **-5.1±0.8** | -8 | ~-115 | **~-5** | N |
| InvertedPendulum | **1000.0±0.0** | **~1000** | **~1000** | ~910 | N |
| InvertedDoublePendulum | **8435.2±1073.3** | ~800 | ~200 | ~7000 | N |
\* details<sup>[[4]](#footnote4)</sup><sup>[[5]](#footnote5)</sup>
@ -266,6 +267,26 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
7. In contrast, with the statement made in TRPO paper, we found that backtracking in line search is rarely used at least in Mujoco settings, which is actually unimportant. This makes TRPO algorithm actually the same as TNPG algorithm (described in this [paper](http://proceedings.mlr.press/v48/duan16.html)). This also explains why TNPG and TRPO's plotting results look so similar in that paper.
8. "recompute advantage" is helpful in PPO but doesn't help in TRPO.
### NPG
| Environment | Tianshou (1M) |
| :--------------------: | :--------------: |
| Ant | **2358.0±517.5** |
| HalfCheetah | **3485.2±716.6** |
| Hopper | **1915.2±550.5** |
| Walker2d | **2503.2±963.3** |
| Swimmer | **31.5±8.0** |
| Humanoid | **765.1±91.3** |
| Reacher | **-4.5±0.5** |
| InvertedPendulum | **1000.0±0.0** |
| InvertedDoublePendulum | **9243.2±276.0** |
\* details<sup>[[4]](#footnote4)</sup><sup>[[5]](#footnote5)</sup>
#### Hints for NPG
1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are.
2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general.
## Note
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.

Binary file not shown.

After

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 109 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 92 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 158 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 81 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 105 KiB

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
import os
import gym
import torch
import pprint
import datetime
import argparse
import numpy as np
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import Independent, Normal
from tianshou.policy import NPGPolicy
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=4096)
parser.add_argument('--hidden-sizes', type=int, nargs='*',
default=[64, 64]) # baselines [32, 32]
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=30000)
parser.add_argument('--step-per-collect', type=int, default=1024)
parser.add_argument('--repeat-per-collect', type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument('--batch-size', type=int, default=99999)
parser.add_argument('--training-num', type=int, default=16)
parser.add_argument('--test-num', type=int, default=10)
# npg special
parser.add_argument('--rew-norm', type=int, default=True)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--bound-action-method', type=str, default="clip")
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--optim-critic-iters', type=int, default=20)
parser.add_argument('--actor-step-size', type=float, default=0.1)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
parser.add_argument('--watch', default=False, action='store_true',
help='watch the play of pre-trained policy only')
return parser.parse_args()
def test_npg(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
norm_obs=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
unbounded=True, device=args.device).to(args.device)
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
activation=nn.Tanh, device=args.device)
critic = Critic(net_c, device=args.device).to(args.device)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect) * args.epoch
lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = NPGPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm, action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler, action_space=env.action_space,
advantage_normalization=args.norm_adv,
optim_critic_iters=args.optim_critic_iters,
actor_step_size=args.actor_step_size)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_npg'
log_path = os.path.join(args.logdir, args.task, 'npg', log_file)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=100, train_interval=100)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
if not args.watch:
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
args.repeat_per_collect, args.test_num, args.batch_size,
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
test_in_train=False)
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
test_npg()

View File

@ -84,10 +84,10 @@ class NPGPolicy(A2CPolicy):
for b in batch.split(batch_size, merge_last=True):
# optimize actor
# direction: calculate villia gradient
dist = self(b).dist # TODO could come from batch
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
actor_loss = -(ratio * b.adv).mean()
dist = self(b).dist
log_prob = dist.log_prob(b.act)
log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1)
actor_loss = -(log_prob * b.adv).mean()
flat_grads = self._get_flat_grad(
actor_loss, self.actor, retain_graph=True).detach()