make ppo discrete test script more general (#418)
This commit is contained in:
parent
bba30f83d1
commit
5b7732a29b
34
README.md
34
README.md
@ -103,27 +103,29 @@ The example scripts are under [test/](https://github.com/thu-ml/tianshou/blob/ma
|
||||
|
||||
### Comprehensive Functionality
|
||||
|
||||
| RL Platform | GitHub Stars | # of RL Alg. <sup>(1)</sup> | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
|
||||
| ------------------------------------------------------------ | ------------------------------------------------------------ | --------------------------- | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- |
|
||||
| [Baselines](https://github.com/openai/baselines) | [](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
|
||||
| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
|
||||
| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :heavy_check_mark: | PyTorch |
|
||||
| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
|
||||
| [SpinningUp](https://github.com/openai/spinningup) | [](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :x: | PyTorch |
|
||||
| [Dopamine](https://github.com/google/dopamine) | [](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
|
||||
| [ACME](https://github.com/deepmind/acme) | [](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
|
||||
| [keras-rl](https://github.com/keras-rl/keras-rl) | [](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
|
||||
| [rlpyt](https://github.com/astooke/rlpyt) | [](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
|
||||
| [ChainerRL](https://github.com/chainer/chainerrl) | [](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
|
||||
| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1<sup>(3)</sup> | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
|
||||
| | | | | | | | |
|
||||
| [Tianshou](https://github.com/thu-ml/tianshou) | [](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
|
||||
| RL Platform | GitHub Stars | # of Alg. <sup>(1)</sup> | Custom Env | Batch Training | RNN Support | Nested Observation | Backend |
|
||||
| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------ | --------------------------- | --------------------------------- | ------------------ | ------------------ | ---------- |
|
||||
| [Baselines](https://github.com/openai/baselines) | [](https://github.com/openai/baselines/stargazers) | 9 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
|
||||
| [Stable-Baselines](https://github.com/hill-a/stable-baselines) | [](https://github.com/hill-a/stable-baselines/stargazers) | 11 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :heavy_check_mark: | :x: | TF1 |
|
||||
| [Stable-Baselines3](https://github.com/DLR-RM/stable-baselines3) | [](https://github.com/DLR-RM/stable-baselines3/stargazers) | 7<sup> (3)</sup> | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :heavy_check_mark: | PyTorch |
|
||||
| [Ray/RLlib](https://github.com/ray-project/ray/tree/master/rllib/) | [](https://github.com/ray-project/ray/stargazers) | 16 | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/PyTorch |
|
||||
| [SpinningUp](https://github.com/openai/spinningup) | [](https://github.com/openai/spinningupstargazers) | 6 | :heavy_check_mark: (gym) | :heavy_minus_sign: <sup>(2)</sup> | :x: | :x: | PyTorch |
|
||||
| [Dopamine](https://github.com/google/dopamine) | [](https://github.com/google/dopamine/stargazers) | 7 | :x: | :x: | :x: | :x: | TF/JAX |
|
||||
| [ACME](https://github.com/deepmind/acme) | [](https://github.com/deepmind/acme/stargazers) | 14 | :heavy_check_mark: (dm_env) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | TF/JAX |
|
||||
| [keras-rl](https://github.com/keras-rl/keras-rl) | [](https://github.com/keras-rl/keras-rlstargazers) | 7 | :heavy_check_mark: (gym) | :x: | :x: | :x: | Keras |
|
||||
| [rlpyt](https://github.com/astooke/rlpyt) | [](https://github.com/astooke/rlpyt/stargazers) | 11 | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
|
||||
| [ChainerRL](https://github.com/chainer/chainerrl) | [](https://github.com/chainer/chainerrl/stargazers) | 18 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :x: | Chainer |
|
||||
| [Sample Factory](https://github.com/alex-petrenko/sample-factory) | [](https://github.com/alex-petrenko/sample-factory/stargazers) | 1<sup> (4)</sup> | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
|
||||
| | | | | | | | |
|
||||
| [Tianshou](https://github.com/thu-ml/tianshou) | [](https://github.com/thu-ml/tianshou/stargazers) | 20 | :heavy_check_mark: (gym) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | PyTorch |
|
||||
|
||||
<sup>(1): access date: 2021-08-08</sup>
|
||||
|
||||
<sup>(2): not all algorithms support this feature</sup>
|
||||
|
||||
<sup>(3): super fast APPO!</sup>
|
||||
<sup>(3): TQC and QR-DQN in [sb3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) instead of main repo</sup>
|
||||
|
||||
<sup>(4): super fast APPO!</sup>
|
||||
|
||||
|
||||
### High quality software engineering standard
|
||||
|
@ -20,12 +20,12 @@ def get_args():
|
||||
parser.add_argument('--task', type=str, default='CartPole-v0')
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
parser.add_argument('--buffer-size', type=int, default=20000)
|
||||
parser.add_argument('--lr', type=float, default=1e-3)
|
||||
parser.add_argument('--lr', type=float, default=3e-4)
|
||||
parser.add_argument('--gamma', type=float, default=0.99)
|
||||
parser.add_argument('--epoch', type=int, default=10)
|
||||
parser.add_argument('--step-per-epoch', type=int, default=50000)
|
||||
parser.add_argument('--episode-per-collect', type=int, default=20)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=2)
|
||||
parser.add_argument('--step-per-collect', type=int, default=2000)
|
||||
parser.add_argument('--repeat-per-collect', type=int, default=10)
|
||||
parser.add_argument('--batch-size', type=int, default=64)
|
||||
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
|
||||
parser.add_argument('--training-num', type=int, default=20)
|
||||
@ -41,15 +41,16 @@ def get_args():
|
||||
parser.add_argument('--eps-clip', type=float, default=0.2)
|
||||
parser.add_argument('--max-grad-norm', type=float, default=0.5)
|
||||
parser.add_argument('--gae-lambda', type=float, default=0.95)
|
||||
parser.add_argument('--rew-norm', type=int, default=1)
|
||||
parser.add_argument('--rew-norm', type=int, default=0)
|
||||
parser.add_argument('--norm-adv', type=int, default=0)
|
||||
parser.add_argument('--recompute-adv', type=int, default=0)
|
||||
parser.add_argument('--dual-clip', type=float, default=None)
|
||||
parser.add_argument('--value-clip', type=int, default=1)
|
||||
parser.add_argument('--value-clip', type=int, default=0)
|
||||
args = parser.parse_known_args()[0]
|
||||
return args
|
||||
|
||||
|
||||
def test_ppo(args=get_args()):
|
||||
torch.set_num_threads(1) # for poor CPU
|
||||
env = gym.make(args.task)
|
||||
args.state_shape = env.observation_space.shape or env.observation_space.n
|
||||
args.action_shape = env.action_space.shape or env.action_space.n
|
||||
@ -90,7 +91,9 @@ def test_ppo(args=get_args()):
|
||||
dual_clip=args.dual_clip,
|
||||
value_clip=args.value_clip,
|
||||
action_space=env.action_space,
|
||||
deterministic_eval=True)
|
||||
deterministic_eval=True,
|
||||
advantage_normalization=args.norm_adv,
|
||||
recompute_advantage=args.recompute_adv)
|
||||
# collector
|
||||
train_collector = Collector(
|
||||
policy, train_envs,
|
||||
@ -111,7 +114,7 @@ def test_ppo(args=get_args()):
|
||||
result = onpolicy_trainer(
|
||||
policy, train_collector, test_collector, args.epoch,
|
||||
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
|
||||
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
step_per_collect=args.step_per_collect, stop_fn=stop_fn, save_fn=save_fn,
|
||||
logger=logger)
|
||||
assert stop_fn(result['best_reward'])
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user