diff --git a/examples/mujoco/README.md b/examples/mujoco/README.md
index 3f116fd..0f1e7f9 100644
--- a/examples/mujoco/README.md
+++ b/examples/mujoco/README.md
@@ -1,27 +1,135 @@
-# Mujoco Result
+# Tianshou's Mujoco Benchmark
+We benchmarked Tianshou algorithm implementations in 9 out of 13 environments from the MuJoCo Gym task suite[[1]](#footnote1).
+For each supported algorithm and supported mujoco environments, we provide:
+- Default hyperparameters used for benchmark and scripts to reproduce the benchmark;
+- A comparison of performance (or code level details) with other open source implementations or classic papers;
+- Graphs and raw data that can be used for research purposes[[2]](#footnote2);
+- Log details obtained during training[[2]](#footnote2);
+- Pretrained agents[[2]](#footnote2);
+- Some hints on how to tune the algorithm.
+
-## SAC (single run)
+Supported algorithms are listed below:
+- [Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
+- [Twin Delayed DDPG (TD3)](https://arxiv.org/pdf/1802.09477.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
+- [Soft Actor-Critic (SAC)](https://arxiv.org/pdf/1812.05905.pdf), [commit id](https://github.com/thu-ml/tianshou/tree/v0.4.0)
-The best reward computes from 100 episodes returns in the test phase.
+## Offpolicy algorithms
-SAC on Swimmer-v3 always stops at 47\~48.
+#### Usage
-| task | 3M best reward | parameters | time cost (3M) |
-| -------------- | ----------------- | ------------------------------------------------------- | -------------- |
-| HalfCheetah-v3 | 10157.70 ± 171.70 | `python3 mujoco_sac.py --task HalfCheetah-v3` | 2~3h |
-| Walker2d-v3 | 5143.04 ± 15.57 | `python3 mujoco_sac.py --task Walker2d-v3` | 2~3h |
-| Hopper-v3 | 3604.19 ± 169.55 | `python3 mujoco_sac.py --task Hopper-v3` | 2~3h |
-| Humanoid-v3 | 6579.20 ± 1470.57 | `python3 mujoco_sac.py --task Humanoid-v3 --alpha 0.05` | 2~3h |
-| Ant-v3 | 6281.65 ± 686.28 | `python3 mujoco_sac.py --task Ant-v3` | 2~3h |
+Run
-
+```bash
+$ python mujoco_sac.py --task Ant-v3
+```
-### Which parts are important?
+Logs is saved in `./log/` and can be monitored with tensorboard.
+
+```bash
+$ tensorboard --logdir log
+```
+
+You can also reproduce the benchmark (e.g. SAC in Ant-v3) with the example script we provide under `examples/mujoco/`:
+
+```bash
+$ ./run_experiments.sh Ant-v3
+```
+
+This will start 10 experiments with different seeds.
+
+#### Example benchmark
+
+
+
+Other graphs can be found under `/examples/mujuco/benchmark/`
+
+#### Hints
+
+In offpolicy algorithms(DDPG, TD3, SAC), the shared hyperparameters are almost the same[[8]](#footnote8), and most hyperparameters are consistent with those used for benchmark in SpinningUp's implementations[[9]](#footnote9).
+
+By comparison to both classic literature and open source implementations (e.g., SpinningUp)[[1]](#footnote1)[[2]](#footnote2), Tianshou's implementations of DDPG, TD3, and SAC are roughly at-parity with or better than the best reported results for these algorithms.
+
+### DDPG
+
+| Environment | Tianshou | [SpinningUp (PyTorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper (DDPG)](https://arxiv.org/abs/1802.09477) | [TD3 paper (OurDDPG)](https://arxiv.org/abs/1802.09477) |
+| :--------------------: | :---------------: | :----------------------------------------------------------: | :--------------------------------------------------: | :-----------------------------------------------------: |
+| Ant | 990.4±4.3 | ~840 | **1005.3** | 888.8 |
+| HalfCheetah | **11718.7±465.6** | ~11000 | 3305.6 | 8577.3 |
+| Hopper | **2197.0±971.6** | ~1800 | **2020.5** | 1860.0 |
+| Walker2d | 1400.6±905.0 | ~1950 | 1843.6 | **3098.1** |
+| Swimmer | **144.1±6.5** | ~137 | N | N |
+| Humanoid | **177.3±77.6** | N | N | N |
+| Reacher | **-3.3±0.3** | N | -6.51 | -4.01 |
+| InvertedPendulum | **1000.0±0.0** | N | **1000.0** | **1000.0** |
+| InvertedDoublePendulum | 8364.3±2778.9 | N | **9355.5** | 8370.0 |
+
+\* details[[5]](#footnote5)[[6]](#footnote6)[[7]](#footnote7)
+
+### TD3
+
+| Environment | Tianshou | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [TD3 paper](https://arxiv.org/abs/1802.09477) |
+| :--------------------: | :---------------: | :-------------------: | :--------------: |
+| Ant | **5116.4±799.9** | ~3800 | 4372.4±1000.3 |
+| HalfCheetah | **10201.2±772.8** | ~9750 | 9637.0±859.1 |
+| Hopper | 3472.2±116.8 | ~2860 | **3564.1±114.7** |
+| Walker2d | 3982.4±274.5 | ~4000 | **4682.8±539.6** |
+| Swimmer | **104.2±34.2** | ~78 | N |
+| Humanoid | **5189.5±178.5** | N | N |
+| Reacher | **-2.7±0.2** | N | -3.6±0.6 |
+| InvertedPendulum | **1000.0±0.0** | N | **1000.0±0.0** |
+| InvertedDoublePendulum | **9349.2±14.3** | N | **9337.5±15.0** |
+
+\* details[[5]](#footnote5)[[6]](#footnote6)[[7]](#footnote7)
+
+### SAC
+
+| Environment | Tianshou | [SpinningUp (Pytorch)](https://spinningup.openai.com/en/latest/spinningup/bench.html) | [SAC paper](https://arxiv.org/abs/1801.01290) |
+| :--------------------: | :----------------: | :-------------------: | :---------: |
+| Ant | **5850.2±475.7** | ~3980 | ~3720 |
+| HalfCheetah | **12138.8±1049.3** | ~11520 | ~10400 |
+| Hopper | **3542.2±51.5** | ~3150 | ~3370 |
+| Walker2d | **5007.0±251.5** | ~4250 | ~3740 |
+| Swimmer | **44.4±0.5** | ~41.7 | N |
+| Humanoid | **5488.5±81.2** | N | ~5200 |
+| Reacher | **-2.6±0.2** | N | N |
+| InvertedPendulum | **1000.0±0.0** | N | N |
+| InvertedDoublePendulum | **9359.5±0.4** | N | N |
+
+\* details[[5]](#footnote5)[[6]](#footnote6)
+
+#### Hints for SAC
0. DO NOT share the same network with two critic networks.
-1. The sigma (of the Gaussian policy) MUST be conditioned on input.
+1. The sigma (of the Gaussian policy) should be conditioned on input.
2. The network size should not be less than 256.
3. The deterministic evaluation helps a lot :)
+## Onpolicy Algorithms
+
+TBD
+
+
+
+
+## Note
+
+[1] Supported environments include HalfCheetah-v3, Hopper-v3, Swimmer-v3, Walker2d-v3, Ant-v3, Humanoid-v3, Reacher-v2, InvertedPendulum-v2 and InvertedDoublePendulum-v2. Pusher, Thrower, Striker and HumanoidStandup are not supported because they are not commonly seen in literatures.
+
+[2] Pretrained agents, detailed graphs (single agent, single game) and log details can all be found [here](https://cloud.tsinghua.edu.cn/d/356e0f5d1e66426b9828/).
+
+[3] We used the latest version of all mujoco environments in gym (0.17.3 with mujoco==2.0.2.13), but it's not often the case with other benchmarks. Please check for details yourself in the original paper. (Different version's outcomes are usually similar, though)
+
+[4] We didn't compare offpolicy algorithms to OpenAI baselines [benchmark](https://github.com/openai/baselines/blob/master/benchmarks_mujoco1M.htm), because for now it seems that they haven't provided benchmark for offpolicy algorithms, but in [SpinningUp docs](https://spinningup.openai.com/en/latest/spinningup/bench.html) they stated that "SpinningUp implementations of DDPG, TD3, and SAC are roughly at-parity with the best-reported results for these algorithms", so we think lack of comparisons with OpenAI baselines is okay.
+
+[5] ~ means the number is approximated from the graph because accurate numbers is not provided in the paper. N means graphs not provided.
+
+[6] Reward metric: The meaning of the table value is the max average return over 10 trails (different seeds) ± a single standard deviation over trails. Each trial is averaged on another 10 test seeds. Only the first 1M steps data will be considered. The shaded region on the graph also represents a single standard deviation. It is the same as [TD3 evaluation method](https://github.com/sfujim/TD3/issues/34).
+
+[7] In TD3 paper, shaded region represents only half of standard deviation.
+
+[8] SAC's start-timesteps is set to 10000 by default while it is 25000 is DDPG/TD3. TD3's learning rate is set to 3e-4 while it is 1e-3 for DDPG/SAC. However, there is NO enough evidence to support our choice of such hyperparameters (we simply choose them because of SpinningUp) and you can try playing with those hyperparameters to see if you can improve performance. Do tell us if you can!
+
+[9] We use batchsize of 256 in DDPG/TD3/SAC while SpinningUp use 100. Minor difference also lies with `start-timesteps`, data loop method `step_per_collect`, method to deal with/bootstrap truncated steps because of timelimit and unfinished/collecting episodes (contribute to performance improvement), etc.
diff --git a/examples/mujoco/benchmark/Ant-v3/ddpg/figure.png b/examples/mujoco/benchmark/Ant-v3/ddpg/figure.png
new file mode 100644
index 0000000..a732378
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Ant-v3/figure.png b/examples/mujoco/benchmark/Ant-v3/figure.png
new file mode 100644
index 0000000..afb48c2
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Ant-v3/sac/figure.png b/examples/mujoco/benchmark/Ant-v3/sac/figure.png
new file mode 100644
index 0000000..5d4d452
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Ant-v3/td3/figure.png b/examples/mujoco/benchmark/Ant-v3/td3/figure.png
new file mode 100644
index 0000000..3bf9043
Binary files /dev/null and b/examples/mujoco/benchmark/Ant-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/ddpg/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/ddpg/figure.png
new file mode 100644
index 0000000..8df4fd8
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/figure.png
new file mode 100644
index 0000000..6459af8
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/sac/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/sac/figure.png
new file mode 100644
index 0000000..0e692f7
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/HalfCheetah-v3/td3/figure.png b/examples/mujoco/benchmark/HalfCheetah-v3/td3/figure.png
new file mode 100644
index 0000000..76cb4bf
Binary files /dev/null and b/examples/mujoco/benchmark/HalfCheetah-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/ddpg/figure.png b/examples/mujoco/benchmark/Hopper-v3/ddpg/figure.png
new file mode 100644
index 0000000..2ed9d3a
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/figure.png b/examples/mujoco/benchmark/Hopper-v3/figure.png
new file mode 100644
index 0000000..0f41d5c
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/sac/figure.png b/examples/mujoco/benchmark/Hopper-v3/sac/figure.png
new file mode 100644
index 0000000..da077e6
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Hopper-v3/td3/figure.png b/examples/mujoco/benchmark/Hopper-v3/td3/figure.png
new file mode 100644
index 0000000..62ccc22
Binary files /dev/null and b/examples/mujoco/benchmark/Hopper-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/ddpg/figure.png b/examples/mujoco/benchmark/Humanoid-v3/ddpg/figure.png
new file mode 100644
index 0000000..7a84f66
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/figure.png b/examples/mujoco/benchmark/Humanoid-v3/figure.png
new file mode 100644
index 0000000..3d788b8
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/sac/figure.png b/examples/mujoco/benchmark/Humanoid-v3/sac/figure.png
new file mode 100644
index 0000000..b585f59
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Humanoid-v3/td3/figure.png b/examples/mujoco/benchmark/Humanoid-v3/td3/figure.png
new file mode 100644
index 0000000..d919ceb
Binary files /dev/null and b/examples/mujoco/benchmark/Humanoid-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/ddpg/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/ddpg/figure.png
new file mode 100644
index 0000000..128d86e
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png
new file mode 100644
index 0000000..ded29d5
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/sac/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/sac/figure.png
new file mode 100644
index 0000000..fc23e5c
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedDoublePendulum-v2/td3/figure.png b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/td3/figure.png
new file mode 100644
index 0000000..71f0e92
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedDoublePendulum-v2/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/ddpg/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/ddpg/figure.png
new file mode 100644
index 0000000..f0e33a7
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/figure.png
new file mode 100644
index 0000000..f5f1e71
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/sac/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/sac/figure.png
new file mode 100644
index 0000000..11d8ada
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/InvertedPendulum-v2/td3/figure.png b/examples/mujoco/benchmark/InvertedPendulum-v2/td3/figure.png
new file mode 100644
index 0000000..bd8e525
Binary files /dev/null and b/examples/mujoco/benchmark/InvertedPendulum-v2/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/ddpg/figure.png b/examples/mujoco/benchmark/Reacher-v2/ddpg/figure.png
new file mode 100644
index 0000000..baf2b6f
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/figure.png b/examples/mujoco/benchmark/Reacher-v2/figure.png
new file mode 100644
index 0000000..8943139
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/sac/figure.png b/examples/mujoco/benchmark/Reacher-v2/sac/figure.png
new file mode 100644
index 0000000..be2debc
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Reacher-v2/td3/figure.png b/examples/mujoco/benchmark/Reacher-v2/td3/figure.png
new file mode 100644
index 0000000..f883ffe
Binary files /dev/null and b/examples/mujoco/benchmark/Reacher-v2/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/ddpg/figure.png b/examples/mujoco/benchmark/Swimmer-v3/ddpg/figure.png
new file mode 100644
index 0000000..6982db5
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/figure.png b/examples/mujoco/benchmark/Swimmer-v3/figure.png
new file mode 100644
index 0000000..c7345c1
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/sac/figure.png b/examples/mujoco/benchmark/Swimmer-v3/sac/figure.png
new file mode 100644
index 0000000..7a2ac16
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Swimmer-v3/td3/figure.png b/examples/mujoco/benchmark/Swimmer-v3/td3/figure.png
new file mode 100644
index 0000000..f9f8219
Binary files /dev/null and b/examples/mujoco/benchmark/Swimmer-v3/td3/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/ddpg/figure.png b/examples/mujoco/benchmark/Walker2d-v3/ddpg/figure.png
new file mode 100644
index 0000000..bbe52a3
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/ddpg/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/figure.png b/examples/mujoco/benchmark/Walker2d-v3/figure.png
new file mode 100644
index 0000000..5201bad
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/sac/figure.png b/examples/mujoco/benchmark/Walker2d-v3/sac/figure.png
new file mode 100644
index 0000000..44581d1
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/sac/figure.png differ
diff --git a/examples/mujoco/benchmark/Walker2d-v3/td3/figure.png b/examples/mujoco/benchmark/Walker2d-v3/td3/figure.png
new file mode 100644
index 0000000..389a9f2
Binary files /dev/null and b/examples/mujoco/benchmark/Walker2d-v3/td3/figure.png differ
diff --git a/examples/mujoco/mujoco_ddpg.py b/examples/mujoco/mujoco_ddpg.py
new file mode 100755
index 0000000..491d423
--- /dev/null
+++ b/examples/mujoco/mujoco_ddpg.py
@@ -0,0 +1,131 @@
+#!/usr/bin/env python3
+
+import os
+import gym
+import torch
+import datetime
+import argparse
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+
+from tianshou.policy import DDPGPolicy
+from tianshou.utils import BasicLogger
+from tianshou.env import SubprocVectorEnv
+from tianshou.utils.net.common import Net
+from tianshou.exploration import GaussianNoise
+from tianshou.trainer import offpolicy_trainer
+from tianshou.utils.net.continuous import Actor, Critic
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='Ant-v3')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--buffer-size', type=int, default=1000000)
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
+ parser.add_argument('--actor-lr', type=float, default=1e-3)
+ parser.add_argument('--critic-lr', type=float, default=1e-3)
+ parser.add_argument('--gamma', type=float, default=0.99)
+ parser.add_argument('--tau', type=float, default=0.005)
+ parser.add_argument('--exploration-noise', type=float, default=0.1)
+ parser.add_argument("--start-timesteps", type=int, default=25000)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--step-per-epoch', type=int, default=5000)
+ parser.add_argument('--step-per-collect', type=int, default=1)
+ parser.add_argument('--update-per-step', type=int, default=1)
+ parser.add_argument('--n-step', type=int, default=1)
+ parser.add_argument('--batch-size', type=int, default=256)
+ parser.add_argument('--training-num', type=int, default=1)
+ parser.add_argument('--test-num', type=int, default=10)
+ parser.add_argument('--logdir', type=str, default='log')
+ parser.add_argument('--render', type=float, default=0.)
+ parser.add_argument(
+ '--device', type=str,
+ default='cuda' if torch.cuda.is_available() else 'cpu')
+ parser.add_argument('--resume-path', type=str, default=None)
+ return parser.parse_args()
+
+
+def test_ddpg(args=get_args()):
+ env = gym.make(args.task)
+ args.state_shape = env.observation_space.shape or env.observation_space.n
+ args.action_shape = env.action_space.shape or env.action_space.n
+ args.max_action = env.action_space.high[0]
+ args.exploration_noise = args.exploration_noise * args.max_action
+ print("Observations shape:", args.state_shape)
+ print("Actions shape:", args.action_shape)
+ print("Action range:", np.min(env.action_space.low),
+ np.max(env.action_space.high))
+ # train_envs = gym.make(args.task)
+ if args.training_num > 1:
+ train_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ else:
+ train_envs = gym.make(args.task)
+ # test_envs = gym.make(args.task)
+ test_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ # seed
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ train_envs.seed(args.seed)
+ test_envs.seed(args.seed)
+ # model
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = Actor(
+ net_a, args.action_shape, max_action=args.max_action,
+ device=args.device).to(args.device)
+ actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
+ net_c = Net(args.state_shape, args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True, device=args.device)
+ critic = Critic(net_c, device=args.device).to(args.device)
+ critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
+ policy = DDPGPolicy(
+ actor, actor_optim, critic, critic_optim,
+ action_range=[env.action_space.low[0], env.action_space.high[0]],
+ tau=args.tau, gamma=args.gamma,
+ exploration_noise=GaussianNoise(sigma=args.exploration_noise),
+ estimation_step=args.n_step)
+ # load a previous policy
+ if args.resume_path:
+ policy.load_state_dict(torch.load(
+ args.resume_path, map_location=args.device
+ ))
+ print("Loaded agent from: ", args.resume_path)
+
+ # collector
+ if args.training_num > 1:
+ buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
+ else:
+ buffer = ReplayBuffer(args.buffer_size)
+ train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
+ test_collector = Collector(policy, test_envs)
+ train_collector.collect(n_step=args.start_timesteps, random=True)
+ # log
+ log_path = os.path.join(args.logdir, args.task, 'ddpg', 'seed_' + str(
+ args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
+ writer = SummaryWriter(log_path)
+ writer.add_text("args", str(args))
+ logger = BasicLogger(writer)
+
+ def save_fn(policy):
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
+ # trainer
+ result = offpolicy_trainer(
+ policy, train_collector, test_collector, args.epoch,
+ args.step_per_epoch, args.step_per_collect, args.test_num,
+ args.batch_size, save_fn=save_fn, logger=logger,
+ update_per_step=args.update_per_step, test_in_train=False)
+
+ # Let's watch its performance!
+ policy.eval()
+ test_envs.seed(args.seed)
+ test_collector.reset()
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
+ print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
+
+
+if __name__ == '__main__':
+ test_ddpg()
diff --git a/examples/mujoco/mujoco_sac.py b/examples/mujoco/mujoco_sac.py
old mode 100644
new mode 100755
index c078773..73ef4eb
--- a/examples/mujoco/mujoco_sac.py
+++ b/examples/mujoco/mujoco_sac.py
@@ -1,7 +1,8 @@
+#!/usr/bin/env python3
+
import os
import gym
import torch
-import pprint
import datetime
import argparse
import numpy as np
@@ -12,42 +13,38 @@ from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offpolicy_trainer
-from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.continuous import ActorProb, Critic
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='Ant-v3')
- parser.add_argument('--seed', type=int, default=1626)
+ parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=1000000)
- parser.add_argument('--actor-lr', type=float, default=3e-4)
- parser.add_argument('--critic-lr', type=float, default=3e-4)
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
+ parser.add_argument('--actor-lr', type=float, default=1e-3)
+ parser.add_argument('--critic-lr', type=float, default=1e-3)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--tau', type=float, default=0.005)
parser.add_argument('--alpha', type=float, default=0.2)
parser.add_argument('--auto-alpha', default=False, action='store_true')
parser.add_argument('--alpha-lr', type=float, default=3e-4)
- parser.add_argument('--n-step', type=int, default=2)
- parser.add_argument('--epoch', type=int, default=100)
- parser.add_argument('--step-per-epoch', type=int, default=40000)
- parser.add_argument('--step-per-collect', type=int, default=4)
- parser.add_argument('--update-per-step', type=float, default=0.25)
- parser.add_argument('--pre-collect-step', type=int, default=10000)
+ parser.add_argument("--start-timesteps", type=int, default=10000)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--step-per-epoch', type=int, default=5000)
+ parser.add_argument('--step-per-collect', type=int, default=1)
+ parser.add_argument('--update-per-step', type=int, default=1)
+ parser.add_argument('--n-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=256)
- parser.add_argument('--hidden-sizes', type=int,
- nargs='*', default=[128, 128])
- parser.add_argument('--training-num', type=int, default=4)
- parser.add_argument('--test-num', type=int, default=100)
+ parser.add_argument('--training-num', type=int, default=1)
+ parser.add_argument('--test-num', type=int, default=10)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
- parser.add_argument('--log-interval', type=int, default=1000)
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--resume-path', type=str, default=None)
- parser.add_argument('--watch', default=False, action='store_true',
- help='watch the play of pre-trained policy only')
return parser.parse_args()
@@ -61,8 +58,11 @@ def test_sac(args=get_args()):
print("Action range:", np.min(env.action_space.low),
np.max(env.action_space.high))
# train_envs = gym.make(args.task)
- train_envs = SubprocVectorEnv(
- [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ if args.training_num > 1:
+ train_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ else:
+ train_envs = gym.make(args.task)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
@@ -72,21 +72,20 @@ def test_sac(args=get_args()):
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
- net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
- device=args.device)
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
- net, args.action_shape, max_action=args.max_action,
+ net_a, args.action_shape, max_action=args.max_action,
device=args.device, unbounded=True, conditioned_sigma=True
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
- critic1 = Critic(net_c1, device=args.device).to(args.device)
- critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True, device=args.device)
+ critic1 = Critic(net_c1, device=args.device).to(args.device)
+ critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
@@ -109,46 +108,35 @@ def test_sac(args=get_args()):
print("Loaded agent from: ", args.resume_path)
# collector
- train_collector = Collector(
- policy, train_envs,
- VectorReplayBuffer(args.buffer_size, len(train_envs)),
- exploration_noise=True)
+ if args.training_num > 1:
+ buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
+ else:
+ buffer = ReplayBuffer(args.buffer_size)
+ train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
+ train_collector.collect(n_step=args.start_timesteps, random=True)
# log
log_path = os.path.join(args.logdir, args.task, 'sac', 'seed_' + str(
args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
- logger = BasicLogger(writer, train_interval=args.log_interval)
-
- def watch():
- # watch agent's performance
- print("Testing agent ...")
- policy.eval()
- test_envs.seed(args.seed)
- test_collector.reset()
- result = test_collector.collect(n_episode=args.test_num, render=args.render)
- pprint.pprint(result)
+ logger = BasicLogger(writer)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
-
- def stop_fn(mean_rewards):
- return False
-
- if args.watch:
- watch()
- exit(0)
-
# trainer
- train_collector.collect(n_step=args.pre_collect_step, random=True)
result = offpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.step_per_collect, args.test_num,
- args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger,
- update_per_step=args.update_per_step)
- pprint.pprint(result)
- watch()
+ args.batch_size, save_fn=save_fn, logger=logger,
+ update_per_step=args.update_per_step, test_in_train=False)
+
+ # Let's watch its performance!
+ policy.eval()
+ test_envs.seed(args.seed)
+ test_collector.reset()
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
+ print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
if __name__ == '__main__':
diff --git a/examples/mujoco/mujoco_td3.py b/examples/mujoco/mujoco_td3.py
new file mode 100755
index 0000000..9066cbe
--- /dev/null
+++ b/examples/mujoco/mujoco_td3.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+
+import os
+import gym
+import torch
+import datetime
+import argparse
+import numpy as np
+from torch.utils.tensorboard import SummaryWriter
+
+from tianshou.policy import TD3Policy
+from tianshou.utils import BasicLogger
+from tianshou.env import SubprocVectorEnv
+from tianshou.utils.net.common import Net
+from tianshou.exploration import GaussianNoise
+from tianshou.trainer import offpolicy_trainer
+from tianshou.utils.net.continuous import Actor, Critic
+from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, default='Ant-v3')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--buffer-size', type=int, default=1000000)
+ parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[256, 256])
+ parser.add_argument('--actor-lr', type=float, default=3e-4)
+ parser.add_argument('--critic-lr', type=float, default=3e-4)
+ parser.add_argument('--gamma', type=float, default=0.99)
+ parser.add_argument('--tau', type=float, default=0.005)
+ parser.add_argument('--exploration-noise', type=float, default=0.1)
+ parser.add_argument('--policy-noise', type=float, default=0.2)
+ parser.add_argument('--noise-clip', type=float, default=0.5)
+ parser.add_argument('--update-actor-freq', type=int, default=2)
+ parser.add_argument("--start-timesteps", type=int, default=25000)
+ parser.add_argument('--epoch', type=int, default=200)
+ parser.add_argument('--step-per-epoch', type=int, default=5000)
+ parser.add_argument('--step-per-collect', type=int, default=1)
+ parser.add_argument('--update-per-step', type=int, default=1)
+ parser.add_argument('--n-step', type=int, default=1)
+ parser.add_argument('--batch-size', type=int, default=256)
+ parser.add_argument('--training-num', type=int, default=1)
+ parser.add_argument('--test-num', type=int, default=10)
+ parser.add_argument('--logdir', type=str, default='log')
+ parser.add_argument('--render', type=float, default=0.)
+ parser.add_argument(
+ '--device', type=str,
+ default='cuda' if torch.cuda.is_available() else 'cpu')
+ parser.add_argument('--resume-path', type=str, default=None)
+ return parser.parse_args()
+
+
+def test_td3(args=get_args()):
+ env = gym.make(args.task)
+ args.state_shape = env.observation_space.shape or env.observation_space.n
+ args.action_shape = env.action_space.shape or env.action_space.n
+ args.max_action = env.action_space.high[0]
+ args.exploration_noise = args.exploration_noise * args.max_action
+ args.policy_noise = args.policy_noise * args.max_action
+ args.noise_clip = args.noise_clip * args.max_action
+ print("Observations shape:", args.state_shape)
+ print("Actions shape:", args.action_shape)
+ print("Action range:", np.min(env.action_space.low),
+ np.max(env.action_space.high))
+ # train_envs = gym.make(args.task)
+ if args.training_num > 1:
+ train_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.training_num)])
+ else:
+ train_envs = gym.make(args.task)
+ # test_envs = gym.make(args.task)
+ test_envs = SubprocVectorEnv(
+ [lambda: gym.make(args.task) for _ in range(args.test_num)])
+ # seed
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ train_envs.seed(args.seed)
+ test_envs.seed(args.seed)
+ # model
+ net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
+ actor = Actor(
+ net_a, args.action_shape, max_action=args.max_action,
+ device=args.device).to(args.device)
+ actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
+ net_c1 = Net(args.state_shape, args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True, device=args.device)
+ net_c2 = Net(args.state_shape, args.action_shape,
+ hidden_sizes=args.hidden_sizes,
+ concat=True, device=args.device)
+ critic1 = Critic(net_c1, device=args.device).to(args.device)
+ critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
+ critic2 = Critic(net_c2, device=args.device).to(args.device)
+ critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
+
+ policy = TD3Policy(
+ actor, actor_optim, critic1, critic1_optim, critic2, critic2_optim,
+ action_range=[env.action_space.low[0], env.action_space.high[0]],
+ tau=args.tau, gamma=args.gamma,
+ exploration_noise=GaussianNoise(sigma=args.exploration_noise),
+ policy_noise=args.policy_noise, update_actor_freq=args.update_actor_freq,
+ noise_clip=args.noise_clip, estimation_step=args.n_step)
+
+ # load a previous policy
+ if args.resume_path:
+ policy.load_state_dict(torch.load(
+ args.resume_path, map_location=args.device
+ ))
+ print("Loaded agent from: ", args.resume_path)
+
+ # collector
+ if args.training_num > 1:
+ buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
+ else:
+ buffer = ReplayBuffer(args.buffer_size)
+ train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
+ test_collector = Collector(policy, test_envs)
+ train_collector.collect(n_step=args.start_timesteps, random=True)
+ # log
+ log_path = os.path.join(args.logdir, args.task, 'td3', 'seed_' + str(
+ args.seed) + '_' + datetime.datetime.now().strftime('%m%d-%H%M%S'))
+ writer = SummaryWriter(log_path)
+ writer.add_text("args", str(args))
+ logger = BasicLogger(writer)
+
+ def save_fn(policy):
+ torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
+ # trainer
+ result = offpolicy_trainer(
+ policy, train_collector, test_collector, args.epoch,
+ args.step_per_epoch, args.step_per_collect, args.test_num,
+ args.batch_size, save_fn=save_fn, logger=logger,
+ update_per_step=args.update_per_step, test_in_train=False)
+
+ # Let's watch its performance!
+ policy.eval()
+ test_envs.seed(args.seed)
+ test_collector.reset()
+ result = test_collector.collect(n_episode=args.test_num, render=args.render)
+ print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}')
+
+
+if __name__ == '__main__':
+ test_td3()
diff --git a/examples/mujoco/results/sac/all.png b/examples/mujoco/results/sac/all.png
deleted file mode 100644
index 7f314f4..0000000
Binary files a/examples/mujoco/results/sac/all.png and /dev/null differ
diff --git a/examples/mujoco/run_experiments.sh b/examples/mujoco/run_experiments.sh
new file mode 100755
index 0000000..4de3263
--- /dev/null
+++ b/examples/mujoco/run_experiments.sh
@@ -0,0 +1,10 @@
+#!/bin/bash
+
+LOGDIR="results"
+TASK=$1
+
+echo "Experiments started."
+for seed in $(seq 1 10)
+do
+ python mujoco_sac.py --task $TASK --epoch 200 --seed $seed --logdir $LOGDIR > ${TASK}_`date '+%m-%d-%H-%M-%S'`_seed_$seed.txt 2>&1
+done
diff --git a/tianshou/policy/modelfree/ddpg.py b/tianshou/policy/modelfree/ddpg.py
index 9a4dad0..87a06c5 100644
--- a/tianshou/policy/modelfree/ddpg.py
+++ b/tianshou/policy/modelfree/ddpg.py
@@ -129,18 +129,31 @@ class DDPGPolicy(BasePolicy):
obs = batch[input]
actions, h = model(obs, state=state, info=batch.info)
actions += self._action_bias
+ actions = actions.clamp(self._range[0], self._range[1])
return Batch(act=actions, state=h)
- def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
- weight = batch.pop("weight", 1.0)
- current_q = self.critic(batch.obs, batch.act).flatten()
+ @staticmethod
+ def _mse_optimizer(
+ batch: Batch, critic: torch.nn.Module, optimizer: torch.optim.Optimizer
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """A simple wrapper script for updating critic network."""
+ weight = getattr(batch, "weight", 1.0)
+ current_q = critic(batch.obs, batch.act).flatten()
target_q = batch.returns.flatten()
td = current_q - target_q
+ # critic_loss = F.mse_loss(current_q1, target_q)
critic_loss = (td.pow(2) * weight).mean()
- batch.weight = td # prio-buffer
- self.critic_optim.zero_grad()
+ optimizer.zero_grad()
critic_loss.backward()
- self.critic_optim.step()
+ optimizer.step()
+ return td, critic_loss
+
+ def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
+ # critic
+ td, critic_loss = self._mse_optimizer(
+ batch, self.critic, self.critic_optim)
+ batch.weight = td # prio-buffer
+ # actor
action = self(batch).act
actor_loss = -self.critic(batch.obs, action).mean()
self.actor_optim.zero_grad()
diff --git a/tianshou/policy/modelfree/dqn.py b/tianshou/policy/modelfree/dqn.py
index bd1fea1..1edeea2 100644
--- a/tianshou/policy/modelfree/dqn.py
+++ b/tianshou/policy/modelfree/dqn.py
@@ -69,7 +69,7 @@ class DQNPolicy(BasePolicy):
def sync_weight(self) -> None:
"""Synchronize the weight for the target network."""
- self.model_old.load_state_dict(self.model.state_dict())
+ self.model_old.load_state_dict(self.model.state_dict()) # type: ignore
def _target_q(self, buffer: ReplayBuffer, indice: np.ndarray) -> torch.Tensor:
batch = buffer[indice] # batch.obs_next: s_{t+n}
diff --git a/tianshou/policy/modelfree/sac.py b/tianshou/policy/modelfree/sac.py
index 68bef39..ac81cbf 100644
--- a/tianshou/policy/modelfree/sac.py
+++ b/tianshou/policy/modelfree/sac.py
@@ -115,7 +115,11 @@ class SACPolicy(DDPGPolicy):
x = dist.rsample()
y = torch.tanh(x)
act = y * self._action_scale + self._action_bias
+ # __eps is used to avoid log of zero/negative number.
y = self._action_scale * (1 - y.pow(2)) + self.__eps
+ # Compute logprob from Gaussian, and then apply correction for Tanh squashing.
+ # You can check out the original SAC paper (arXiv 1801.01290): Eq 21.
+ # in appendix C to get some understanding of this equation.
log_prob = dist.log_prob(x).unsqueeze(-1)
log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
@@ -134,26 +138,11 @@ class SACPolicy(DDPGPolicy):
return target_q
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
- weight = batch.pop("weight", 1.0)
-
- # critic 1
- current_q1 = self.critic1(batch.obs, batch.act).flatten()
- target_q = batch.returns.flatten()
- td1 = current_q1 - target_q
- critic1_loss = (td1.pow(2) * weight).mean()
- # critic1_loss = F.mse_loss(current_q1, target_q)
- self.critic1_optim.zero_grad()
- critic1_loss.backward()
- self.critic1_optim.step()
-
- # critic 2
- current_q2 = self.critic2(batch.obs, batch.act).flatten()
- td2 = current_q2 - target_q
- critic2_loss = (td2.pow(2) * weight).mean()
- # critic2_loss = F.mse_loss(current_q2, target_q)
- self.critic2_optim.zero_grad()
- critic2_loss.backward()
- self.critic2_optim.step()
+ # critic 1&2
+ td1, critic1_loss = self._mse_optimizer(
+ batch, self.critic1, self.critic1_optim)
+ td2, critic2_loss = self._mse_optimizer(
+ batch, self.critic2, self.critic2_optim)
batch.weight = (td1 + td2) / 2.0 # prio-buffer
# actor
diff --git a/tianshou/policy/modelfree/td3.py b/tianshou/policy/modelfree/td3.py
index bd65722..09f288f 100644
--- a/tianshou/policy/modelfree/td3.py
+++ b/tianshou/policy/modelfree/td3.py
@@ -105,25 +105,14 @@ class TD3Policy(DDPGPolicy):
return target_q
def learn(self, batch: Batch, **kwargs: Any) -> Dict[str, float]:
- weight = batch.pop("weight", 1.0)
- # critic 1
- current_q1 = self.critic1(batch.obs, batch.act).flatten()
- target_q = batch.returns.flatten()
- td1 = current_q1 - target_q
- critic1_loss = (td1.pow(2) * weight).mean()
- # critic1_loss = F.mse_loss(current_q1, target_q)
- self.critic1_optim.zero_grad()
- critic1_loss.backward()
- self.critic1_optim.step()
- # critic 2
- current_q2 = self.critic2(batch.obs, batch.act).flatten()
- td2 = current_q2 - target_q
- critic2_loss = (td2.pow(2) * weight).mean()
- # critic2_loss = F.mse_loss(current_q2, target_q)
- self.critic2_optim.zero_grad()
- critic2_loss.backward()
- self.critic2_optim.step()
+ # critic 1&2
+ td1, critic1_loss = self._mse_optimizer(
+ batch, self.critic1, self.critic1_optim)
+ td2, critic2_loss = self._mse_optimizer(
+ batch, self.critic2, self.critic2_optim)
batch.weight = (td1 + td2) / 2.0 # prio-buffer
+
+ # actor
if self._cnt % self._freq == 0:
actor_loss = -self.critic1(batch.obs, self(batch, eps=0.0).act).mean()
self.actor_optim.zero_grad()
diff --git a/tianshou/trainer/offline.py b/tianshou/trainer/offline.py
index 13f96fa..802e349 100644
--- a/tianshou/trainer/offline.py
+++ b/tianshou/trainer/offline.py
@@ -93,8 +93,9 @@ def offline_trainer(
if save_fn:
save_fn(policy)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
+ f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, None, test_collector, best_reward, best_reward_std)
diff --git a/tianshou/trainer/offpolicy.py b/tianshou/trainer/offpolicy.py
index 72a243d..444dbc0 100644
--- a/tianshou/trainer/offpolicy.py
+++ b/tianshou/trainer/offpolicy.py
@@ -145,8 +145,9 @@ def offpolicy_trainer(
if save_fn:
save_fn(policy)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
+ f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,
diff --git a/tianshou/trainer/onpolicy.py b/tianshou/trainer/onpolicy.py
index dae20a7..ab73127 100644
--- a/tianshou/trainer/onpolicy.py
+++ b/tianshou/trainer/onpolicy.py
@@ -155,8 +155,9 @@ def onpolicy_trainer(
if save_fn:
save_fn(policy)
if verbose:
- print(f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_rew"
- f"ard: {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
+ print(
+ f"Epoch #{epoch}: test_reward: {rew:.6f} ± {rew_std:.6f}, best_reward:"
+ f" {best_reward:.6f} ± {best_reward_std:.6f} in #{best_epoch}")
if stop_fn and stop_fn(best_reward):
break
return gather_info(start_time, train_collector, test_collector,