diff --git a/README.md b/README.md index 777240e..41e1e41 100644 --- a/README.md +++ b/README.md @@ -103,27 +103,29 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma ### Comprehensive Functionality -| RL Platform | GitHub Stars | # of RL Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend | -| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------- | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- | -| [Baselines](https://github.com/openai/baselines) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | -| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | -| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [![GitHub stars](https://img.shields.io/github/stars/DLR-RM/stable-baselines3)](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch | -| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch | -| [SpinningUp](https://github.com/openai/spinningup) | [![GitHub stars](https://img.shields.io/github/stars/openai/spinningup)](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch | -| [Dopamine](https://github.com/google/dopamine) | [![GitHub stars](https://img.shields.io/github/stars/google/dopamine)](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX | -| [ACME](https://github.com/deepmind/acme) | [![GitHub stars](https://img.shields.io/github/stars/deepmind/acme)](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX | -| [keras-rl](https://github.com/keras-rl/keras-rl) | [![GitHub stars](https://img.shields.io/github/stars/keras-rl/keras-rl)](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras | -| [rlpyt](https://github.com/astooke/rlpyt) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | -| [ChainerRL](https://github.com/chainer/chainerrl) | [![GitHub stars](https://img.shields.io/github/stars/chainer/chainerrl)](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer | -| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [![GitHub stars](https://img.shields.io/github/stars/alex-petrenko/sample-factory)](https://github.com/alex-petrenko/sample-factory/stargazers) | 1(3) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | -| | | | | | | | | -| [Tianshou](https://github.com/thu-ml/tianshou) | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | +| RL Platform | GitHub Stars | # of Alg. (1) | Custom Env | Batch Training | RNN Support | Nested Observation | Backend | +| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------ | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- | +| [Baselines](https://github.com/openai/baselines) | [![GitHub stars](https://img.shields.io/github/stars/openai/baselines)](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | +| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [![GitHub stars](https://img.shields.io/github/stars/hill-a/stable-baselines)](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :heavy_check_mark: | :x: | TF1 | +| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [![GitHub stars](https://img.shields.io/github/stars/DLR-RM/stable-baselines3)](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 (3) | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :heavy_check_mark: | PyTorch | +| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [![GitHub stars](https://img.shields.io/github/stars/ray-project/ray)](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch | +| [SpinningUp](https://github.com/openai/spinningup) | [![GitHub stars](https://img.shields.io/github/stars/openai/spinningup)](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: (2) | :x: | :x: | PyTorch | +| [Dopamine](https://github.com/google/dopamine) | [![GitHub stars](https://img.shields.io/github/stars/google/dopamine)](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX | +| [ACME](https://github.com/deepmind/acme) | [![GitHub stars](https://img.shields.io/github/stars/deepmind/acme)](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX | +| [keras-rl](https://github.com/keras-rl/keras-rl) | [![GitHub stars](https://img.shields.io/github/stars/keras-rl/keras-rl)](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras | +| [rlpyt](https://github.com/astooke/rlpyt) | [![GitHub stars](https://img.shields.io/github/stars/astooke/rlpyt)](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | +| [ChainerRL](https://github.com/chainer/chainerrl) | [![GitHub stars](https://img.shields.io/github/stars/chainer/chainerrl)](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer | +| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [![GitHub stars](https://img.shields.io/github/stars/alex-petrenko/sample-factory)](https://github.com/alex-petrenko/sample-factory/stargazers) | 1 (4) | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | +| | | | | | | | | +| [Tianshou](https://github.com/thu-ml/tianshou) | [![GitHub stars](https://img.shields.io/github/stars/thu-ml/tianshou)](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch | (1): access date: 2021-08-08 (2): not all algorithms support this feature -(3): super fast APPO! +(3): TQC and QR-DQN in [sb3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) instead of main repo + +(4): super fast APPO! ### High quality software engineering standard diff --git a/test/discrete/test_ppo.py b/test/discrete/test_ppo.py index 9fbd7d3..ee63b9b 100644 --- a/test/discrete/test_ppo.py +++ b/test/discrete/test_ppo.py @@ -20,12 +20,12 @@ def get_args(): parser.add_argument('--task', type=str, default='CartPole-v0') parser.add_argument('--seed', type=int, default=0) parser.add_argument('--buffer-size', type=int, default=20000) - parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--gamma', type=float, default=0.99) parser.add_argument('--epoch', type=int, default=10) parser.add_argument('--step-per-epoch', type=int, default=50000) - parser.add_argument('--episode-per-collect', type=int, default=20) - parser.add_argument('--repeat-per-collect', type=int, default=2) + parser.add_argument('--step-per-collect', type=int, default=2000) + parser.add_argument('--repeat-per-collect', type=int, default=10) parser.add_argument('--batch-size', type=int, default=64) parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64]) parser.add_argument('--training-num', type=int, default=20) @@ -41,15 +41,16 @@ def get_args(): parser.add_argument('--eps-clip', type=float, default=0.2) parser.add_argument('--max-grad-norm', type=float, default=0.5) parser.add_argument('--gae-lambda', type=float, default=0.95) - parser.add_argument('--rew-norm', type=int, default=1) + parser.add_argument('--rew-norm', type=int, default=0) + parser.add_argument('--norm-adv', type=int, default=0) + parser.add_argument('--recompute-adv', type=int, default=0) parser.add_argument('--dual-clip', type=float, default=None) - parser.add_argument('--value-clip', type=int, default=1) + parser.add_argument('--value-clip', type=int, default=0) args = parser.parse_known_args()[0] return args def test_ppo(args=get_args()): - torch.set_num_threads(1) # for poor CPU env = gym.make(args.task) args.state_shape = env.observation_space.shape or env.observation_space.n args.action_shape = env.action_space.shape or env.action_space.n @@ -90,7 +91,9 @@ def test_ppo(args=get_args()): dual_clip=args.dual_clip, value_clip=args.value_clip, action_space=env.action_space, - deterministic_eval=True) + deterministic_eval=True, + advantage_normalization=args.norm_adv, + recompute_advantage=args.recompute_adv) # collector train_collector = Collector( policy, train_envs, @@ -111,7 +114,7 @@ def test_ppo(args=get_args()): result = onpolicy_trainer( policy, train_collector, test_collector, args.epoch, args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size, - episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn, + step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn, logger=logger) assert stop_fn(result['best_reward'])