A2C benchmark for mujoco (#325)
@ -16,6 +16,7 @@ Supported algorithms are listed below:
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e605bdea942b408126ef4fbc740359773259c9ec)
|
||||
- [REINFORCE algorithm](https://papers.nips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/e27b5a26f330de446fe15388bf81c3777f024fb9)
|
||||
- A2C, commit id (TODO)
|
||||
|
||||
## Offpolicy algorithms
|
||||
|
||||
@ -149,6 +150,45 @@ By comparison to both classic literature and open source implementations (e.g.,
|
||||
5. We didn't tune `step-per-collect` option and `training-num` option. Default values are finetuned with PPO algorithm so we assume they are also good for REINFORCE. You can play with them if you want, but remember that `buffer-size` should always be larger than `step-per-collect`, and if `step-per-collect` is too small and `training-num` too large, episodes will be truncated and bootstrapped very often, which will harm performances. If `training-num` is too small (e.g., less than 8), speed will go down.
|
||||
6. Sigma of action is not fixed (normally seen in other implementation) or conditioned on observation, but is an independent parameter which can be updated by gradient descent. We choose this setting because it works well in PPO, and is recommended by [Andrychowicz, Marcin, et al](https://arxiv.org/abs/2006.05990). See Fig. 23.
|
||||
|
||||
### A2C
|
||||
|
||||
| Environment | Tianshou(3M steps) | [Spinning Up(Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench_vpg.html)|
|
||||
| :--------------------: | :----------------: | :--------------------: |
|
||||
| Ant | **5236.8+-236.7** | ~5 |
|
||||
| HalfCheetah | **2377.3+-1363.7** | ~600 |
|
||||
| Hopper | **1608.6+-529.5** | ~800 |
|
||||
| Walker2d | **1805.4+-1055.9** | ~460 |
|
||||
| Swimmer | 40.2+-1.8 | **~51** |
|
||||
| Humanoid | **5316.6+-554.8** | N |
|
||||
| Reacher | **-5.2+-0.5** | N |
|
||||
| InvertedPendulum | **1000.0+-0.0** | N |
|
||||
| InvertedDoublePendulum | **9351.3+-12.8** | N |
|
||||
|
||||
| Environment | Tianshou | [PPO paper](https://arxiv.org/abs/1707.06347) A2C | [PPO paper](https://arxiv.org/abs/1707.06347) A2C + Trust Region |
|
||||
| :--------------------: | :----------------: | :-------------: | :-------------: |
|
||||
| Ant | **3485.4+-433.1** | N | N |
|
||||
| HalfCheetah | **1829.9+-1068.3** | ~1000 | ~930 |
|
||||
| Hopper | **1253.2+-458.0** | ~900 | ~1220 |
|
||||
| Walker2d | **1091.6+-709.2** | ~850 | ~700 |
|
||||
| Swimmer | **36.6+-2.1** | ~31 | **~36** |
|
||||
| Humanoid | **1726.0+-1070.1** | N | N |
|
||||
| Reacher | **-6.7+-2.3** | ~-24 | ~-27 |
|
||||
| InvertedPendulum | **1000.0+-0.0** | **~1000** | **~1000** |
|
||||
| InvertedDoublePendulum | **9257.7+-277.4** | ~7100 | ~8100 |
|
||||
|
||||
\* details<sup>[[5]](#footnote5)</sup><sup>[[6]](#footnote6)</sup>
|
||||
|
||||
#### Hints for A2C
|
||||
|
||||
0. We choose `clip` action method in A2C instead `tanh` option as used in REINFORCE simply to be consistent with original implementation. `tanh` may be better or equally well but we didn't try.
|
||||
1. (Initial) learning rate, lr decay, and `step-per-collect`, `training-num` affect the performance of A2C to a great extend. These 4 hyperparameters also affect each other and should be tuned together. We have done full scale ablation studies on these 4 hyperparameters (more than 800 agents trained), below are our findings.
|
||||
2. `step-per-collect`/`training-num` = `bootstrap-lenghth`, which is max length of an "episode" used in GAE estimator, 80/16=5 in default settings. When `bootstrap-lenghth` is small, (maybe) because GAE can at most looks forward 5 steps, and use bootstrap strategy very often, the critic is less well-trained, so they actor cannot converge to very high scores. However, if we increase `step-per-collect` to increase `bootstrap-lenghth` (e.g. 256/16=16), actor/critic will be updated less often, so sample efficiency is low, which will make training process slow. To conclude, If you don't restrict env timesteps, you can try to use larger `bootstrap-lenghth`, and train for more steps, which perhaps will give you better converged scores. Train slower, achieve higher.
|
||||
3. 7e-4 learning rate with decay strategy if proper for `step-per-collect=80`, `training-num=16`, but if you use larger `step-per-collect`(e.g. 256 - 2048), 7e-4 `lr` is a little bit small, because now you have more data and less noise for each update, and will be more confidence if taking larger steps; so higher learning rate(e.g. 1e-3) is more appropriate and usually boost performance in this setting. If plotting results arises fast in early stages and become unstable later, consider lr decay before decreasing lr.
|
||||
4. `max-grad-norm` doesn't really help in our experiments, we simply keep it for consistency with other open-source implementations (e.g. SB3).
|
||||
5. We original paper of A3C use RMSprop optimizer, we find that Adam with the same learning rate works equally well. We use RMSprop anyway. Again, for consistency.
|
||||
6. We notice that in SB3's implementation of A2C that set `gae-lambda` to 1 by default, we don't know why and after doing some experiments, results show 0.95 is better overall.
|
||||
7. We find out that `step-per-collect=256`, `training-num=8` are also good hyperparameters. You can have a try.
|
||||
|
||||
## Note
|
||||
|
||||
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
|
||||
|
BIN
examples/mujoco/benchmark/Ant-v3/a2c/figure.png
Normal file
After Width: | Height: | Size: 267 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/a2c/figure.png
Normal file
After Width: | Height: | Size: 212 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/a2c/figure.png
Normal file
After Width: | Height: | Size: 325 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/a2c/figure.png
Normal file
After Width: | Height: | Size: 297 KiB |
After Width: | Height: | Size: 382 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/a2c/figure.png
Normal file
After Width: | Height: | Size: 226 KiB |
BIN
examples/mujoco/benchmark/Reacher-v2/a2c/figure.png
Normal file
After Width: | Height: | Size: 204 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/a2c/figure.png
Normal file
After Width: | Height: | Size: 236 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/a2c/figure.png
Normal file
After Width: | Height: | Size: 299 KiB |
157
examples/mujoco/mujoco_a2c.py
Executable file
@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.distributions import Independent, Normal
|
||||
|
||||
from tianshou.policy import A2CPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import onpolicy_trainer
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='HalfCheetah-v3')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=4096)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--lr', type=float, default=7e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=30000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=80)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=1)
|
||||
# batch-size >> step-per-collect means caculating all data in one singe forward.
|
||||
parser.add_argument('--batch-size', type=int, default=99999)
|
||||
parser.add_argument('--training-num', type=int, default=16)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
# a2c special
|
||||
parser.add_argument('--rew-norm', type=int, default=True)
|
||||
parser.add_argument('--vf-coef', type=float, default=0.5)
|
||||
parser.add_argument('--ent-coef', type=float, default=0.01)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--bound-action-method', type=str, default="clip")
|
||||
parser.add_argument('--lr-decay', type=int, default=True)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_a2c(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low),
|
||||
np.max(env.action_space.high))
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)],
|
||||
norm_obs=True)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)],
|
||||
norm_obs=True, obs_rms=train_envs.obs_rms, update_obs_rms=False)
|
||||
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh, device=args.device)
|
||||
actor = ActorProb(net_a, args.action_shape, max_action=args.max_action,
|
||||
unbounded=True, device=args.device).to(args.device)
|
||||
net_c = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
activation=nn.Tanh, device=args.device)
|
||||
critic = Critic(net_c, device=args.device).to(args.device)
|
||||
torch.nn.init.constant_(actor.sigma_param, -0.5)
|
||||
for m in list(actor.modules()) + list(critic.modules()):
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
# orthogonal initialization
|
||||
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
# do last policy layer scaling, this will make initial actions have (close to)
|
||||
# 0 mean and std, and will help boost performances,
|
||||
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
|
||||
for m in actor.mu.modules():
|
||||
if isinstance(m, torch.nn.Linear):
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
m.weight.data.copy_(0.01 * m.weight.data)
|
||||
|
||||
optim = torch.optim.RMSprop(set(actor.parameters()).union(critic.parameters()),
|
||||
lr=args.lr, eps=1e-5, alpha=0.99)
|
||||
|
||||
lr_scheduler = None
|
||||
if args.lr_decay:
|
||||
# decay learning rate to 0 linearly
|
||||
max_update_num = np.ceil(
|
||||
args.step_per_epoch / args.step_per_collect) * args.epoch
|
||||
|
||||
lr_scheduler = LambdaLR(
|
||||
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
|
||||
|
||||
def dist(*logits):
|
||||
return Independent(Normal(*logits), 1)
|
||||
|
||||
policy = A2CPolicy(actor, critic, optim, dist, discount_factor=args.gamma,
|
||||
gae_lambda=args.gae_lambda, max_grad_norm=args.max_grad_norm,
|
||||
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
|
||||
reward_normalization=args.rew_norm, action_scaling=True,
|
||||
action_bound_method=args.bound_action_method,
|
||||
lr_scheduler=lr_scheduler, action_space=env.action_space)
|
||||
|
||||
# collector
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
# log
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_a2c'
|
||||
log_path = os.path.join(args.logdir, args.task, 'a2c', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer, update_interval=100, train_interval=100)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
# trainer
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch, args.step_per_epoch,
|
||||
args.repeat_per_collect, args.test_num, args.batch_size,
|
||||
step_per_collect=args.step_per_collect, save_fn=save_fn, logger=logger,
|
||||
test_in_train=False)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_a2c()
|
@ -103,9 +103,9 @@ def test_ddpg(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str(args.seed) +
|
||||
'_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
|
||||
'-' + args.task.replace('-', '_') + '_ddpg')
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_ddpg'
|
||||
log_path = os.path.join(args.logdir, args.task, 'ddpg', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
@ -123,7 +123,7 @@ def test_reinforce(args=get_args()):
|
||||
log_path = os.path.join(args.logdir, args.task, 'reinforce', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer, update_interval=10)
|
||||
logger = BasicLogger(writer, update_interval=10, train_interval=100)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
@ -115,9 +115,9 @@ def test_sac(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(args.seed) +
|
||||
'_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
|
||||
'-' + args.task.replace('-', '_') + '_sac')
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_sac'
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
@ -117,9 +117,9 @@ def test_td3(args=get_args()):
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str(args.seed) +
|
||||
'_' + datetime.datetime.now().strftime('%m%d_%H%M%S') +
|
||||
'-' + args.task.replace('-', '_') + '_td3')
|
||||
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
|
||||
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_td3'
|
||||
log_path = os.path.join(args.logdir, args.task, 'td3', log_file)
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
@ -120,7 +120,7 @@ class A2CPolicy(PGPolicy):
|
||||
- self._weight_ent * ent_loss
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
if self._grad_norm is not None: # clip large gradient
|
||||
if self._grad_norm: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
max_norm=self._grad_norm)
|
||||
|
@ -96,7 +96,7 @@ class PPOPolicy(A2CPolicy):
|
||||
np.sqrt(self.ret_rms.var + self._eps)
|
||||
self.ret_rms.update(unnormalized_returns)
|
||||
mean, std = np.mean(advantages), np.std(advantages)
|
||||
advantages = (advantages - mean) / std # per-batch norm
|
||||
advantages = (advantages - mean) / std
|
||||
else:
|
||||
batch.returns = unnormalized_returns
|
||||
batch.act = to_torch_as(batch.act, batch.v_s)
|
||||
@ -139,7 +139,7 @@ class PPOPolicy(A2CPolicy):
|
||||
- self._weight_ent * ent_loss
|
||||
self.optim.zero_grad()
|
||||
loss.backward()
|
||||
if self._grad_norm is not None: # clip large gradient
|
||||
if self._grad_norm: # clip large gradient
|
||||
nn.utils.clip_grad_norm_(
|
||||
list(self.actor.parameters()) + list(self.critic.parameters()),
|
||||
max_norm=self._grad_norm)
|
||||
|