diff --git a/README.md b/README.md
index 777240e..41e1e41 100644
--- a/README.md
+++ b/README.md
@@ -103,27 +103,29 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
### Comprehensive Functionality
-| RL Platform | GitHub Stars | # of RL Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
-| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------- | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- |
-| [Baselines](https://github.com/openai/baselines) | [](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
-| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
-| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch |
-| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
-| [SpinningUp](https://github.com/openai/spinningup) | [](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch |
-| [Dopamine](https://github.com/google/dopamine) | [](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
-| [ACME](https://github.com/deepmind/acme) | [](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
-| [keras-rl](https://github.com/keras-rl/keras-rl) | [](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
-| [rlpyt](https://github.com/astooke/rlpyt) | [](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
-| [ChainerRL](https://github.com/chainer/chainerrl) | [](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
-| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1(3) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
-| | | | | | | | |
-| [Tianshou](https://github.com/thu-ml/tianshou) | [](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
+| RL Platform | GitHub Stars | # of Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
+| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------ | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- |
+| [Baselines](https://github.com/openai/baselines) | [](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
+| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 |
+| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 (3) | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch |
+| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
+| [SpinningUp](https://github.com/openai/spinningup) | [](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch |
+| [Dopamine](https://github.com/google/dopamine) | [](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
+| [ACME](https://github.com/deepmind/acme) | [](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
+| [keras-rl](https://github.com/keras-rl/keras-rl) | [](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
+| [rlpyt](https://github.com/astooke/rlpyt) | [](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
+| [ChainerRL](https://github.com/chainer/chainerrl) | [](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
+| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1 (4) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
+| | | | | | | | |
+| [Tianshou](https://github.com/thu-ml/tianshou) | [](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
(1): access date: 2021-08-08
(2): not all algorithms support this feature
-(3): super fast APPO!
+(3): TQC and QR-DQN in [sb3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) instead of main repo
+
+(4): super fast APPO!
### High quality software engineering standard
diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py
index 9fbd7d3..ee63b9b 100644
--- a/test/discrete/test_ppo.py
+++ b/test/discrete/test_ppo.py
@@ -20,12 +20,12 @@ def get_args():
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=20000)
- parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=50000)
- parser.add_argument('--episode-per-collect', type=int, default=20)
- parser.add_argument('--repeat-per-collect', type=int, default=2)
+ parser.add_argument('--step-per-collect', type=int, default=2000)
+ parser.add_argument('--repeat-per-collect', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
parser.add_argument('--training-num', type=int, default=20)
@@ -41,15 +41,16 @@ def get_args():
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--gae-lambda', type=float, default=0.95)
- parser.add_argument('--rew-norm', type=int, default=1)
+ parser.add_argument('--rew-norm', type=int, default=0)
+ parser.add_argument('--norm-adv', type=int, default=0)
+ parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--dual-clip', type=float, default=None)
- parser.add_argument('--value-clip', type=int, default=1)
+ parser.add_argument('--value-clip', type=int, default=0)
args = parser.parse_known_args()[0]
return args
def test_ppo(args=get_args()):
- torch.set_num_threads(1) # for poor CPU
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
@@ -90,7 +91,9 @@ def test_ppo(args=get_args()):
dual_clip=args.dual_clip,
value_clip=args.value_clip,
action_space=env.action_space,
- deterministic_eval=True)
+ deterministic_eval=True,
+ advantage_normalization=args.norm_adv,
+ recompute_advantage=args.recompute_adv)
# collector
train_collector = Collector(
policy, train_envs,
@@ -111,7 +114,7 @@ def test_ppo(args=get_args()):
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
- episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
+ step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
assert stop_fn(result['best_reward'])