MuJoCo Benchmark - DDPG, TD3, SAC (#305)
Releasing Tianshou's SOTA benchmark of 9 out of 13 environments from the MuJoCo Gym task suite.
@ -1,27 +1,135 @@
|
||||
# Mujoco Result
|
||||
# Tianshou's Mujoco Benchmark
|
||||
|
||||
We benchmarked Tianshou algorithm implementations in 9 out of 13 environments from the MuJoCo Gym task suite<sup>[[1]](#footnote1)</sup>.
|
||||
|
||||
For each supported algorithm and supported mujoco environments, we provide:
|
||||
- Default hyperparameters used for benchmark and scripts to reproduce the benchmark;
|
||||
- A comparison of performance (or code level details) with other open source implementations or classic papers;
|
||||
- Graphs and raw data that can be used for research purposes<sup>[[2]](#footnote2)</sup>;
|
||||
- Log details obtained during training<sup>[[2]](#footnote2)</sup>;
|
||||
- Pretrained agents<sup>[[2]](#footnote2)</sup>;
|
||||
- Some hints on how to tune the algorithm.
|
||||
|
||||
|
||||
## SAC (single run)
|
||||
Supported algorithms are listed below:
|
||||
- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
|
||||
- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
|
||||
- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
|
||||
|
||||
The best reward computes from 100 episodes returns in the test phase.
|
||||
## Offpolicy algorithms
|
||||
|
||||
SAC on Swimmer-v3 always stops at 47\~48.
|
||||
#### Usage
|
||||
|
||||
| task | 3M best reward | parameters | time cost (3M) |
|
||||
| -------------- | ----------------- | ------------------------------------------------------- | -------------- |
|
||||
| HalfCheetah-v3 | 10157.70 ± 171.70 | `python3 mujoco_sac.py --task HalfCheetah-v3` | 2~3h |
|
||||
| Walker2d-v3 | 5143.04 ± 15.57 | `python3 mujoco_sac.py --task Walker2d-v3` | 2~3h |
|
||||
| Hopper-v3 | 3604.19 ± 169.55 | `python3 mujoco_sac.py --task Hopper-v3` | 2~3h |
|
||||
| Humanoid-v3 | 6579.20 ± 1470.57 | `python3 mujoco_sac.py --task Humanoid-v3 --alpha 0.05` | 2~3h |
|
||||
| Ant-v3 | 6281.65 ± 686.28 | `python3 mujoco_sac.py --task Ant-v3` | 2~3h |
|
||||
Run
|
||||
|
||||

|
||||
```bash
|
||||
$ python mujoco_sac.py --task Ant-v3
|
||||
```
|
||||
|
||||
### Which parts are important?
|
||||
Logs is saved in `./log/` and can be monitored with tensorboard.
|
||||
|
||||
```bash
|
||||
$ tensorboard --logdir log
|
||||
```
|
||||
|
||||
You can also reproduce the benchmark (e.g. SAC in Ant-v3) with the example script we provide under `examples/mujoco/`:
|
||||
|
||||
```bash
|
||||
$ ./run_experiments.sh Ant-v3
|
||||
```
|
||||
|
||||
This will start 10 experiments with different seeds.
|
||||
|
||||
#### Example benchmark
|
||||
|
||||
<img src="./benchmark/Ant-v3/figure.png" width="500" height="450">
|
||||
|
||||
Other graphs can be found under `/examples/mujuco/benchmark/`
|
||||
|
||||
#### Hints
|
||||
|
||||
In offpolicy algorithms(DDPG, TD3, SAC), the shared hyperparameters are almost the same<sup>[[8]](#footnote8)</sup>, and most hyperparameters are consistent with those used for benchmark in SpinningUp's implementations<sup>[[9]](#footnote9)</sup>.
|
||||
|
||||
By comparison to both classic literature and open source implementations (e.g., SpinningUp)<sup>[[1]](#footnote1)</sup><sup>[[2]](#footnote2)</sup>, Tianshou's implementations of DDPG, TD3, and SAC are roughly at-parity with or better than the best reported results for these algorithms.
|
||||
|
||||
### DDPG
|
||||
|
||||
| Environment | Tianshou | [SpinningUp (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) |
|
||||
| :--------------------: | :---------------: | :----------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: |
|
||||
| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 |
|
||||
| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 |
|
||||
| Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 |
|
||||
| Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** |
|
||||
| Swimmer | **144.1±6.5** | ~137 | N | N |
|
||||
| Humanoid | **177.3±77.6** | N | N | N |
|
||||
| Reacher | **-3.3±0.3** | N | -6.51 | -4.01 |
|
||||
| InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** |
|
||||
| InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 |
|
||||
|
||||
\* details<sup>[[5]](#footnote5)</sup><sup>[[6]](#footnote6)</sup><sup>[[7]](#footnote7)</sup>
|
||||
|
||||
### TD3
|
||||
|
||||
| Environment | Tianshou | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) |
|
||||
| :--------------------: | :---------------: | :-------------------: | :--------------: |
|
||||
| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 |
|
||||
| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 |
|
||||
| Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** |
|
||||
| Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** |
|
||||
| Swimmer | **104.2±34.2** | ~78 | N |
|
||||
| Humanoid | **5189.5±178.5** | N | N |
|
||||
| Reacher | **-2.7±0.2** | N | -3.6±0.6 |
|
||||
| InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** |
|
||||
| InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** |
|
||||
|
||||
\* details<sup>[[5]](#footnote5)</sup><sup>[[6]](#footnote6)</sup><sup>[[7]](#footnote7)</sup>
|
||||
|
||||
### SAC
|
||||
|
||||
| Environment | Tianshou | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) |
|
||||
| :--------------------: | :----------------: | :-------------------: | :---------: |
|
||||
| Ant | **5850.2±475.7** | ~3980 | ~3720 |
|
||||
| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 |
|
||||
| Hopper | **3542.2±51.5** | ~3150 | ~3370 |
|
||||
| Walker2d | **5007.0±251.5** | ~4250 | ~3740 |
|
||||
| Swimmer | **44.4±0.5** | ~41.7 | N |
|
||||
| Humanoid | **5488.5±81.2** | N | ~5200 |
|
||||
| Reacher | **-2.6±0.2** | N | N |
|
||||
| InvertedPendulum | **1000.0±0.0** | N | N |
|
||||
| InvertedDoublePendulum | **9359.5±0.4** | N | N |
|
||||
|
||||
\* details<sup>[[5]](#footnote5)</sup><sup>[[6]](#footnote6)</sup>
|
||||
|
||||
#### Hints for SAC
|
||||
|
||||
0. DO NOT share the same network with two critic networks.
|
||||
1. The sigma (of the Gaussian policy) MUST be conditioned on input.
|
||||
1. The sigma (of the Gaussian policy) should be conditioned on input.
|
||||
2. The network size should not be less than 256.
|
||||
3. The deterministic evaluation helps a lot :)
|
||||
|
||||
## Onpolicy Algorithms
|
||||
|
||||
TBD
|
||||
|
||||
|
||||
|
||||
|
||||
## Note
|
||||
|
||||
<a name="footnote1">[1]</a> Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
|
||||
|
||||
<a name="footnote2">[2]</a> Pretrained agents, detailed graphs (single agent, single game) and log details can all be found [here](https://cloud.tsinghua.edu.cn/d/356e0f5d1e66426b9828/).
|
||||
|
||||
<a name="footnote3">[3]</a> We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though)
|
||||
|
||||
<a name="footnote4">[4]</a> We didn't compare offpolicy algorithms to OpenAI baselines [benchmark](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm), because for now it seems that they haven't provided benchmark for offpolicy algorithms, but in [SpinningUp docs](https://spinningup.openai.com/en/latest/spinningup/bench.html) they stated that "SpinningUp implementations of DDPG, TD3, and SAC are roughly at-parity with the best-reported results for these algorithms", so we think lack of comparisons with OpenAI baselines is okay.
|
||||
|
||||
<a name="footnote5">[5]</a> ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided.
|
||||
|
||||
<a name="footnote6">[6]</a> Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34).
|
||||
|
||||
<a name="footnote7">[7]</a> In TD3 paper, shaded region represents only half of standard deviation.
|
||||
|
||||
<a name="footnote8">[8]</a> SAC's start-timesteps is set to 10000 by default while it is 25000 is DDPG/TD3. TD3's learning rate is set to 3e-4 while it is 1e-3 for DDPG/SAC. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because of SpinningUp) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can!
|
||||
|
||||
<a name="footnote9">[9]</a> We use batchsize of 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `step_per_collect`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.
|
||||
|
BIN
examples/mujoco/benchmark/Ant-v3/ddpg/figure.png
Normal file
After Width: | Height: | Size: 218 KiB |
BIN
examples/mujoco/benchmark/Ant-v3/figure.png
Normal file
After Width: | Height: | Size: 344 KiB |
BIN
examples/mujoco/benchmark/Ant-v3/sac/figure.png
Normal file
After Width: | Height: | Size: 213 KiB |
BIN
examples/mujoco/benchmark/Ant-v3/td3/figure.png
Normal file
After Width: | Height: | Size: 191 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/ddpg/figure.png
Normal file
After Width: | Height: | Size: 191 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/figure.png
Normal file
After Width: | Height: | Size: 342 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/sac/figure.png
Normal file
After Width: | Height: | Size: 188 KiB |
BIN
examples/mujoco/benchmark/HalfCheetah-v3/td3/figure.png
Normal file
After Width: | Height: | Size: 184 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/ddpg/figure.png
Normal file
After Width: | Height: | Size: 259 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/figure.png
Normal file
After Width: | Height: | Size: 423 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/sac/figure.png
Normal file
After Width: | Height: | Size: 202 KiB |
BIN
examples/mujoco/benchmark/Hopper-v3/td3/figure.png
Normal file
After Width: | Height: | Size: 213 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/ddpg/figure.png
Normal file
After Width: | Height: | Size: 124 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/figure.png
Normal file
After Width: | Height: | Size: 304 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/sac/figure.png
Normal file
After Width: | Height: | Size: 201 KiB |
BIN
examples/mujoco/benchmark/Humanoid-v3/td3/figure.png
Normal file
After Width: | Height: | Size: 190 KiB |
After Width: | Height: | Size: 144 KiB |
BIN
examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png
Normal file
After Width: | Height: | Size: 328 KiB |
After Width: | Height: | Size: 215 KiB |
After Width: | Height: | Size: 230 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/ddpg/figure.png
Normal file
After Width: | Height: | Size: 282 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/figure.png
Normal file
After Width: | Height: | Size: 351 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/sac/figure.png
Normal file
After Width: | Height: | Size: 160 KiB |
BIN
examples/mujoco/benchmark/InvertedPendulum-v2/td3/figure.png
Normal file
After Width: | Height: | Size: 163 KiB |
BIN
examples/mujoco/benchmark/Reacher-v2/ddpg/figure.png
Normal file
After Width: | Height: | Size: 172 KiB |
BIN
examples/mujoco/benchmark/Reacher-v2/figure.png
Normal file
After Width: | Height: | Size: 232 KiB |
BIN
examples/mujoco/benchmark/Reacher-v2/sac/figure.png
Normal file
After Width: | Height: | Size: 174 KiB |
BIN
examples/mujoco/benchmark/Reacher-v2/td3/figure.png
Normal file
After Width: | Height: | Size: 176 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/ddpg/figure.png
Normal file
After Width: | Height: | Size: 182 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/figure.png
Normal file
After Width: | Height: | Size: 302 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/sac/figure.png
Normal file
After Width: | Height: | Size: 183 KiB |
BIN
examples/mujoco/benchmark/Swimmer-v3/td3/figure.png
Normal file
After Width: | Height: | Size: 176 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/ddpg/figure.png
Normal file
After Width: | Height: | Size: 209 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/figure.png
Normal file
After Width: | Height: | Size: 356 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/sac/figure.png
Normal file
After Width: | Height: | Size: 194 KiB |
BIN
examples/mujoco/benchmark/Walker2d-v3/td3/figure.png
Normal file
After Width: | Height: | Size: 188 KiB |
131
examples/mujoco/mujoco_ddpg.py
Executable file
@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import DDPGPolicy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Ant-v3')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=1000000)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
||||
parser.add_argument("--start-timesteps", type=int, default=25000)
|
||||
parser.add_argument('--epoch', type=int, default=200)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=1)
|
||||
parser.add_argument('--update-per-step', type=int, default=1)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument('--training-num', type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_ddpg(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
args.exploration_noise = args.exploration_noise * args.max_action
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low),
|
||||
np.max(env.action_space.high))
|
||||
# train_envs = gym.make(args.task)
|
||||
if args.training_num > 1:
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
else:
|
||||
train_envs = gym.make(args.task)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(
|
||||
net_a, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic = Critic(net_c, device=args.device).to(args.device)
|
||||
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
|
||||
policy = DDPGPolicy(
|
||||
actor, actor_optim, critic, critic_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
estimation_step=args.n_step)
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
args.resume_path, map_location=args.device
|
||||
))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str(
|
||||
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step, test_in_train=False)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_ddpg()
|
92
examples/mujoco/mujoco_sac.py
Normal file → Executable file
@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import pprint
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
@ -12,42 +13,38 @@ from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.data import Collector, VectorReplayBuffer
|
||||
from tianshou.utils.net.continuous import ActorProb, Critic
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Ant-v3')
|
||||
parser.add_argument('--seed', type=int, default=1626)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=1000000)
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--critic-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
|
||||
parser.add_argument('--actor-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--critic-lr', type=float, default=1e-3)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--alpha', type=float, default=0.2)
|
||||
parser.add_argument('--auto-alpha', default=False, action='store_true')
|
||||
parser.add_argument('--alpha-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--n-step', type=int, default=2)
|
||||
parser.add_argument('--epoch', type=int, default=100)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=40000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=4)
|
||||
parser.add_argument('--update-per-step', type=float, default=0.25)
|
||||
parser.add_argument('--pre-collect-step', type=int, default=10000)
|
||||
parser.add_argument("--start-timesteps", type=int, default=10000)
|
||||
parser.add_argument('--epoch', type=int, default=200)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=1)
|
||||
parser.add_argument('--update-per-step', type=int, default=1)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument('--hidden-sizes', type=int,
|
||||
nargs='*', default=[128, 128])
|
||||
parser.add_argument('--training-num', type=int, default=4)
|
||||
parser.add_argument('--test-num', type=int, default=100)
|
||||
parser.add_argument('--training-num', type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument('--log-interval', type=int, default=1000)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
parser.add_argument('--watch', default=False, action='store_true',
|
||||
help='watch the play of pre-trained policy only')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -61,8 +58,11 @@ def test_sac(args=get_args()):
|
||||
print("Action range:", np.min(env.action_space.low),
|
||||
np.max(env.action_space.high))
|
||||
# train_envs = gym.make(args.task)
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
if args.training_num > 1:
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
else:
|
||||
train_envs = gym.make(args.task)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
@ -72,21 +72,20 @@ def test_sac(args=get_args()):
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
|
||||
device=args.device)
|
||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = ActorProb(
|
||||
net, args.action_shape, max_action=args.max_action,
|
||||
net_a, args.action_shape, max_action=args.max_action,
|
||||
device=args.device, unbounded=True, conditioned_sigma=True
|
||||
).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
net_c2 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
@ -109,46 +108,35 @@ def test_sac(args=get_args()):
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
VectorReplayBuffer(args.buffer_size, len(train_envs)),
|
||||
exploration_noise=True)
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
|
||||
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer, train_interval=args.log_interval)
|
||||
|
||||
def watch():
|
||||
# watch agent's performance
|
||||
print("Testing agent ...")
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
pprint.pprint(result)
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
|
||||
def stop_fn(mean_rewards):
|
||||
return False
|
||||
|
||||
if args.watch:
|
||||
watch()
|
||||
exit(0)
|
||||
|
||||
# trainer
|
||||
train_collector.collect(n_step=args.pre_collect_step, random=True)
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step)
|
||||
pprint.pprint(result)
|
||||
watch()
|
||||
args.batch_size, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step, test_in_train=False)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
144
examples/mujoco/mujoco_td3.py
Executable file
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import gym
|
||||
import torch
|
||||
import datetime
|
||||
import argparse
|
||||
import numpy as np
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.policy import TD3Policy
|
||||
from tianshou.utils import BasicLogger
|
||||
from tianshou.env import SubprocVectorEnv
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.exploration import GaussianNoise
|
||||
from tianshou.trainer import offpolicy_trainer
|
||||
from tianshou.utils.net.continuous import Actor, Critic
|
||||
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--task', type=str, default='Ant-v3')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=1000000)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
|
||||
parser.add_argument('--actor-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--critic-lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--tau', type=float, default=0.005)
|
||||
parser.add_argument('--exploration-noise', type=float, default=0.1)
|
||||
parser.add_argument('--policy-noise', type=float, default=0.2)
|
||||
parser.add_argument('--noise-clip', type=float, default=0.5)
|
||||
parser.add_argument('--update-actor-freq', type=int, default=2)
|
||||
parser.add_argument("--start-timesteps", type=int, default=25000)
|
||||
parser.add_argument('--epoch', type=int, default=200)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=5000)
|
||||
parser.add_argument('--step-per-collect', type=int, default=1)
|
||||
parser.add_argument('--update-per-step', type=int, default=1)
|
||||
parser.add_argument('--n-step', type=int, default=1)
|
||||
parser.add_argument('--batch-size', type=int, default=256)
|
||||
parser.add_argument('--training-num', type=int, default=1)
|
||||
parser.add_argument('--test-num', type=int, default=10)
|
||||
parser.add_argument('--logdir', type=str, default='log')
|
||||
parser.add_argument('--render', type=float, default=0.)
|
||||
parser.add_argument(
|
||||
'--device', type=str,
|
||||
default='cuda' if torch.cuda.is_available() else 'cpu')
|
||||
parser.add_argument('--resume-path', type=str, default=None)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_td3(args=get_args()):
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
args.max_action = env.action_space.high[0]
|
||||
args.exploration_noise = args.exploration_noise * args.max_action
|
||||
args.policy_noise = args.policy_noise * args.max_action
|
||||
args.noise_clip = args.noise_clip * args.max_action
|
||||
print("Observations shape:", args.state_shape)
|
||||
print("Actions shape:", args.action_shape)
|
||||
print("Action range:", np.min(env.action_space.low),
|
||||
np.max(env.action_space.high))
|
||||
# train_envs = gym.make(args.task)
|
||||
if args.training_num > 1:
|
||||
train_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.training_num)])
|
||||
else:
|
||||
train_envs = gym.make(args.task)
|
||||
# test_envs = gym.make(args.task)
|
||||
test_envs = SubprocVectorEnv(
|
||||
[lambda: gym.make(args.task) for _ in range(args.test_num)])
|
||||
# seed
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
train_envs.seed(args.seed)
|
||||
test_envs.seed(args.seed)
|
||||
# model
|
||||
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
|
||||
actor = Actor(
|
||||
net_a, args.action_shape, max_action=args.max_action,
|
||||
device=args.device).to(args.device)
|
||||
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
|
||||
net_c1 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
net_c2 = Net(args.state_shape, args.action_shape,
|
||||
hidden_sizes=args.hidden_sizes,
|
||||
concat=True, device=args.device)
|
||||
critic1 = Critic(net_c1, device=args.device).to(args.device)
|
||||
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
|
||||
critic2 = Critic(net_c2, device=args.device).to(args.device)
|
||||
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
|
||||
|
||||
policy = TD3Policy(
|
||||
actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
|
||||
action_range=[env.action_space.low[0], env.action_space.high[0]],
|
||||
tau=args.tau, gamma=args.gamma,
|
||||
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
|
||||
policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
|
||||
noise_clip=args.noise_clip, estimation_step=args.n_step)
|
||||
|
||||
# load a previous policy
|
||||
if args.resume_path:
|
||||
policy.load_state_dict(torch.load(
|
||||
args.resume_path, map_location=args.device
|
||||
))
|
||||
print("Loaded agent from: ", args.resume_path)
|
||||
|
||||
# collector
|
||||
if args.training_num > 1:
|
||||
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
|
||||
else:
|
||||
buffer = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
|
||||
test_collector = Collector(policy, test_envs)
|
||||
train_collector.collect(n_step=args.start_timesteps, random=True)
|
||||
# log
|
||||
log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str(
|
||||
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
|
||||
writer = SummaryWriter(log_path)
|
||||
writer.add_text("args", str(args))
|
||||
logger = BasicLogger(writer)
|
||||
|
||||
def save_fn(policy):
|
||||
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
||||
# trainer
|
||||
result = offpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.step_per_collect, args.test_num,
|
||||
args.batch_size, save_fn=save_fn, logger=logger,
|
||||
update_per_step=args.update_per_step, test_in_train=False)
|
||||
|
||||
# Let's watch its performance!
|
||||
policy.eval()
|
||||
test_envs.seed(args.seed)
|
||||
test_collector.reset()
|
||||
result = test_collector.collect(n_episode=args.test_num, render=args.render)
|
||||
print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_td3()
|
Before Width: | Height: | Size: 126 KiB |
10
examples/mujoco/run_experiments.sh
Executable file
@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
LOGDIR="results"
|
||||
TASK=$1
|
||||
|
||||
echo "Experiments started."
|
||||
for seed in $(seq 1 10)
|
||||
do
|
||||
python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1
|
||||
done
|
@ -129,18 +129,31 @@ class DDPGPolicy(BasePolicy):
|
||||
obs = batch[input]
|
||||
actions, h = model(obs, state=state, info=batch.info)
|
||||
actions += self._action_bias
|
||||
actions = actions.clamp(self._range[0], self._range[1])
|
||||
return Batch(act=actions, state=h)
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
current_q = self.critic(batch.obs, batch.act).flatten()
|
||||
@staticmethod
|
||||
def _mse_optimizer(
|
||||
batch: Batch, critic: torch.nn.Module, optimizer: torch.optim.Optimizer
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A simple wrapper script for updating critic network."""
|
||||
weight = getattr(batch, "weight", 1.0)
|
||||
current_q = critic(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
td = current_q - target_q
|
||||
# critic_loss = F.mse_loss(current_q1, target_q)
|
||||
critic_loss = (td.pow(2) * weight).mean()
|
||||
batch.weight = td # prio-buffer
|
||||
self.critic_optim.zero_grad()
|
||||
optimizer.zero_grad()
|
||||
critic_loss.backward()
|
||||
self.critic_optim.step()
|
||||
optimizer.step()
|
||||
return td, critic_loss
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
# critic
|
||||
td, critic_loss = self._mse_optimizer(
|
||||
batch, self.critic, self.critic_optim)
|
||||
batch.weight = td # prio-buffer
|
||||
# actor
|
||||
action = self(batch).act
|
||||
actor_loss = -self.critic(batch.obs, action).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
|
@ -69,7 +69,7 @@ class DQNPolicy(BasePolicy):
|
||||
|
||||
def sync_weight(self) -> None:
|
||||
"""Synchronize the weight for the target network."""
|
||||
self.model_old.load_state_dict(self.model.state_dict())
|
||||
self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
|
||||
|
||||
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
|
||||
batch = buffer[indice] # batch.obs_next: s_{t+n}
|
||||
|
@ -115,7 +115,11 @@ class SACPolicy(DDPGPolicy):
|
||||
x = dist.rsample()
|
||||
y = torch.tanh(x)
|
||||
act = y * self._action_scale + self._action_bias
|
||||
# __eps is used to avoid log of zero/negative number.
|
||||
y = self._action_scale * (1 - y.pow(2)) + self.__eps
|
||||
# Compute logprob from Gaussian, and then apply correction for Tanh squashing.
|
||||
# You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
|
||||
# in appendix C to get some understanding of this equation.
|
||||
log_prob = dist.log_prob(x).unsqueeze(-1)
|
||||
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
|
||||
|
||||
@ -134,26 +138,11 @@ class SACPolicy(DDPGPolicy):
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
td1 = current_q1 - target_q
|
||||
critic1_loss = (td1.pow(2) * weight).mean()
|
||||
# critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act).flatten()
|
||||
td2 = current_q2 - target_q
|
||||
critic2_loss = (td2.pow(2) * weight).mean()
|
||||
# critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
# critic 1&2
|
||||
td1, critic1_loss = self._mse_optimizer(
|
||||
batch, self.critic1, self.critic1_optim)
|
||||
td2, critic2_loss = self._mse_optimizer(
|
||||
batch, self.critic2, self.critic2_optim)
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
|
||||
# actor
|
||||
|
@ -105,25 +105,14 @@ class TD3Policy(DDPGPolicy):
|
||||
return target_q
|
||||
|
||||
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
|
||||
weight = batch.pop("weight", 1.0)
|
||||
# critic 1
|
||||
current_q1 = self.critic1(batch.obs, batch.act).flatten()
|
||||
target_q = batch.returns.flatten()
|
||||
td1 = current_q1 - target_q
|
||||
critic1_loss = (td1.pow(2) * weight).mean()
|
||||
# critic1_loss = F.mse_loss(current_q1, target_q)
|
||||
self.critic1_optim.zero_grad()
|
||||
critic1_loss.backward()
|
||||
self.critic1_optim.step()
|
||||
# critic 2
|
||||
current_q2 = self.critic2(batch.obs, batch.act).flatten()
|
||||
td2 = current_q2 - target_q
|
||||
critic2_loss = (td2.pow(2) * weight).mean()
|
||||
# critic2_loss = F.mse_loss(current_q2, target_q)
|
||||
self.critic2_optim.zero_grad()
|
||||
critic2_loss.backward()
|
||||
self.critic2_optim.step()
|
||||
# critic 1&2
|
||||
td1, critic1_loss = self._mse_optimizer(
|
||||
batch, self.critic1, self.critic1_optim)
|
||||
td2, critic2_loss = self._mse_optimizer(
|
||||
batch, self.critic2, self.critic2_optim)
|
||||
batch.weight = (td1 + td2) / 2.0 # prio-buffer
|
||||
|
||||
# actor
|
||||
if self._cnt % self._freq == 0:
|
||||
actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean()
|
||||
self.actor_optim.zero_grad()
|
||||
|
@ -93,8 +93,9 @@ def offline_trainer(
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
|
||||
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
|
||||
|
@ -145,8 +145,9 @@ def offpolicy_trainer(
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
|
||||
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, train_collector, test_collector,
|
||||
|
@ -155,8 +155,9 @@ def onpolicy_trainer(
|
||||
if save_fn:
|
||||
save_fn(policy)
|
||||
if verbose:
|
||||
print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
|
||||
f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
print(
|
||||
f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
|
||||
f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
|
||||
if stop_fn and stop_fn(best_reward):
|
||||
break
|
||||
return gather_info(start_time, train_collector, test_collector,
|
||||
|