NPG Mujoco benchmark release (#347)
@ -16,6 +16,7 @@ Supported algorithms are listed below:
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
|
||||
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
|
||||
- [Natural Policy Gradient](https://proceedings.neurips.cc/paper/2001/file/4b86abe48d358ecf194c56c69108433e-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/1dcf65fe21dc7636966796b6099ede1f4bd775e1)
|
||||
- [Advantage Actor-Critic (A2C)](https://openai.com/blog/baselines-acktr-a2c/), [commit id](https://github.com/thu-ml/tianshou/tree/1730a9008ad6bb67cac3b21347bed33b532b17bc)
|
||||
- [Proximal Policy Optimization (PPO)](https://arxiv.org/pdf/1707.06347.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/6426a39796db052bafb7cabe85c764db20a722b0)
|
||||
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/pdf/1502.05477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/5057b5c89e6168220272c9c28a15b758a72efc32)
|
||||
@ -242,17 +243,17 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
|
||||
|
||||
### TRPO
|
||||
|
||||
|Environment| Tianshou| [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf)| [PPO paper](https://arxiv.org/pdf/1707.06347.pdf)|[OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm)|[Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html)|
|
||||
| :---------------: | :---------------: | :---------------: | :---------------: | :---------------: |:---------------: |
|
||||
|Ant|**2866.7±707.9** | ~0 | N | N | ~150 |
|
||||
|HalfCheetah|**4471.2±804.9** | ~400 | ~0| ~1350 | ~850 |
|
||||
|Hopper| 2046.0±1037.9| ~1400 | ~2100 | **~2200** | ~1200 |
|
||||
|Walker2d|**3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 |
|
||||
|Swimmer|40.9±19.6 | ~40 | **~121** | ~95| ~85 |
|
||||
|Humanoid|**810.1±126.1**| N | N | N | N |
|
||||
|Reacher| **-5.1±0.8** | -8 | ~-115 | **~-5** | N |
|
||||
|InvertedPendulum|**1000.0±0.0** | **~1000** | **~1000** | ~910 | N |
|
||||
|InvertedDoublePendulum|**8435.2±1073.3**| ~800 | ~200 | ~7000 | N |
|
||||
| Environment | Tianshou (1M) | [ACKTR paper](https://arxiv.org/pdf/1708.05144.pdf) | [PPO paper](https://arxiv.org/pdf/1707.06347.pdf) | [OpenAI Baselines](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm) | [Spinning Up (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) |
|
||||
| :--------------------: | :---------------: | :-------------------------------------------------: | :-----------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
|
||||
| Ant | **2866.7±707.9** | ~0 | N | N | ~150 |
|
||||
| HalfCheetah | **4471.2±804.9** | ~400 | ~0 | ~1350 | ~850 |
|
||||
| Hopper | 2046.0±1037.9 | ~1400 | ~2100 | **~2200** | ~1200 |
|
||||
| Walker2d | **3826.7±782.7** | ~550 | ~1100 | ~2350 | ~600 |
|
||||
| Swimmer | 40.9±19.6 | ~40 | **~121** | ~95 | ~85 |
|
||||
| Humanoid | **810.1±126.1** | N | N | N | N |
|
||||
| Reacher | **-5.1±0.8** | -8 | ~-115 | **~-5** | N |
|
||||
| InvertedPendulum | **1000.0±0.0** | **~1000** | **~1000** | ~910 | N |
|
||||
| InvertedDoublePendulum | **8435.2±1073.3** | ~800 | ~200 | ~7000 | N |
|
||||
|
||||
\* details<sup>[[4]](#footnote4)</sup><sup>[[5]](#footnote5)</sup>
|
||||
|
||||
@ -266,6 +267,26 @@ For pretrained agents, detailed graphs (single agent, single game) and log detai
|
||||
7. In contrast, with the statement made in TRPO paper, we found that backtracking in line search is rarely used at least in Mujoco settings, which is actually unimportant. This makes TRPO algorithm actually the same as TNPG algorithm (described in this [paper](http://proceedings.mlr.press/v48/duan16.html)). This also explains why TNPG and TRPO's plotting results look so similar in that paper.
|
||||
8. "recompute advantage" is helpful in PPO but doesn't help in TRPO.
|
||||
|
||||
### NPG
|
||||
|
||||
| Environment | Tianshou (1M) |
|
||||
| :--------------------: | :--------------: |
|
||||
| Ant | **2358.0±517.5** |
|
||||
| HalfCheetah | **3485.2±716.6** |
|
||||
| Hopper | **1915.2±550.5** |
|
||||
| Walker2d | **2503.2±963.3** |
|
||||
| Swimmer | **31.5±8.0** |
|
||||
| Humanoid | **765.1±91.3** |
|
||||
| Reacher | **-4.5±0.5** |
|
||||
| InvertedPendulum | **1000.0±0.0** |
|
||||
| InvertedDoublePendulum | **9243.2±276.0** |
|
||||
|
||||
\* details<sup>[[4]](#footnote4)</sup><sup>[[5]](#footnote5)</sup>
|
||||
|
||||
#### Hints for NPG
|
||||
1. All shared hyperparameters are exactly the same as TRPO, regarding how similar these two algorithms are.
|
||||
2. We found different games in Mujoco may require quite different `actor-step-size`: Reacher/Swimmer are insensitive to step-size in range (0.1~1.0), while InvertedDoublePendulum / InvertedPendulum / Humanoid are quite sensitive to step size, and even 0.1 is too large. Other games may require `actor-step-size` in range (0.1~0.4), but aren't that sensitive in general.
|
||||
|
||||
## Note
|
||||
|
||||
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
|
||||
|
BIN
examples/mujoco/benchmark/Ant-v3/npg/figure.png
Normal file
After Width: | Height: | Size: 103 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/npg/figure.png
Normal file
After Width: | Height: | Size: 94 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/npg/figure.png
Normal file
After Width: | Height: | Size: 109 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/npg/figure.png
Normal file
After Width: | Height: | Size: 92 KiB |
After Width: | Height: | Size: 158 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/npg/figure.png
Normal file
After Width: | Height: | Size: 81 KiB |
BIN
examples/mujoco/benchmark/Reacher-v2/npg/figure.png
Normal file
After Width: | Height: | Size: 86 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/npg/figure.png
Normal file
After Width: | Height: | Size: 94 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/npg/figure.png
Normal file
After Width: | Height: | Size: 105 KiB |
168
examples/mujoco/mujoco_npg.py
Normal file
@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.policy import NPGPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=4096)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*',
|
||||
default=[64, 64]) # baselines [32, 32]
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=30000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=1024)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
# batch-size >> step-per-collect means calculating all data in one singe forward.
|
||||
parser.add_argument('--batch-size', type=int, default=99999)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
# npg special
|
||||
parser.add_argument('--rew-norm', type=int, default=True)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--bound-action-method', type=str, default="clip")
|
||||
parser.add_argument('--lr-decay', type=int, default=True)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--norm-adv', type=int, default=1)
|
||||
parser.add_argument('--optim-critic-iters', type=int, default=20)
|
||||
parser.add_argument('--actor-step-size', type=float, default=0.1)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_npg(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low),
|
||||
np.max(env.action_space.high))
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
norm_obs=True)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
|
||||
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh, device=args.device)
|
||||
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
|
||||
unbounded=True, device=args.device).to(args.device)
|
||||
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh, device=args.device)
|
||||
critic = Critic(net_c, device=args.device).to(args.device)
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
optim = torch.optim.Adam(critic.parameters(), lr=args.lr)
|
||||
lr_scheduler = None
|
||||
if args.lr_decay:
|
||||
# decay learning rate to 0 linearly
|
||||
max_update_num = np.ceil(
|
||||
args.step_per_epoch / args.step_per_collect) * args.epoch
|
||||
|
||||
lr_scheduler = LambdaLR(
|
||||
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
policy = NPGPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
|
||||
gae_lambda=args.gae_lambda,
|
||||
reward_normalization=args.rew_norm, action_scaling=True,
|
||||
action_bound_method=args.bound_action_method,
|
||||
lr_scheduler=lr_scheduler, action_space=env.action_space,
|
||||
advantage_normalization=args.norm_adv,
|
||||
optim_critic_iters=args.optim_critic_iters,
|
||||
actor_step_size=args.actor_step_size)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_npg'
|
||||
log_path = os.path.join(args.logdir, args.task, 'npg', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer, update_interval=100, train_interval=100)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
if not args.watch:
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
|
||||
args.repeat_per_collect, args.test_num, args.batch_size,
|
||||
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
|
||||
test_in_train=False)
|
||||
pprint.pprint(result)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_npg()
|
@ -84,10 +84,10 @@ class NPGPolicy(A2CPolicy):
|
||||
for b in batch.split(batch_size, merge_last=True):
|
||||
# optimize actor
|
||||
# direction: calculate villia gradient
|
||||
dist = self(b).dist # TODO could come from batch
|
||||
ratio = (dist.log_prob(b.act) - b.logp_old).exp().float()
|
||||
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
|
||||
actor_loss = -(ratio * b.adv).mean()
|
||||
dist = self(b).dist
|
||||
log_prob = dist.log_prob(b.act)
|
||||
log_prob = log_prob.reshape(log_prob.size(0), -1).transpose(0, 1)
|
||||
actor_loss = -(log_prob * b.adv).mean()
|
||||
flat_grads = self._get_flat_grad(
|
||||
actor_loss, self.actor, retain_graph=True).detach()
|
||||
|
||||
|