Tianshou/test/discrete/test_ppo.py

130 lines
5.1 KiB
Python
Raw Normal View History

2020-04-11 16:54:27 +08:00
import os
2020-03-20 19:52:29 +08:00
import gym
import torch
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import PPOPolicy
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
2020-03-20 19:52:29 +08:00
from tianshou.trainer import onpolicy_trainer
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
2020-03-20 19:52:29 +08:00
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
2020-04-19 14:30:42 +08:00
parser.add_argument('--seed', type=int, default=0)
2020-03-20 19:52:29 +08:00
parser.add_argument('--buffer-size', type=int, default=20000)
2020-03-25 14:08:28 +08:00
parser.add_argument('--lr', type=float, default=1e-3)
2020-04-19 14:30:42 +08:00
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--episode-per-collect', type=int, default=20)
2020-03-20 19:52:29 +08:00
parser.add_argument('--repeat-per-collect', type=int, default=2)
parser.add_argument('--batch-size', type=int, default=64)
2021-03-30 11:50:35 +08:00
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[64, 64])
2020-04-19 14:30:42 +08:00
parser.add_argument('--training-num', type=int, default=20)
2020-03-20 19:52:29 +08:00
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
2020-03-20 19:52:29 +08:00
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
# ppo special
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.0)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
2021-03-30 11:50:35 +08:00
parser.add_argument('--gae-lambda', type=float, default=0.95)
2020-06-03 13:59:47 +08:00
parser.add_argument('--rew-norm', type=int, default=1)
2020-04-19 14:30:42 +08:00
parser.add_argument('--dual-clip', type=float, default=None)
2020-06-03 13:59:47 +08:00
parser.add_argument('--value-clip', type=int, default=1)
2020-03-20 19:52:29 +08:00
args = parser.parse_known_args()[0]
return args
def test_ppo(args=get_args()):
2020-04-03 21:28:12 +08:00
torch.set_num_threads(1) # for poor CPU
2020-03-20 19:52:29 +08:00
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
2020-04-03 21:28:12 +08:00
# you can also use tianshou.env.SubprocVectorEnv
train_envs = DummyVectorEnv(
2020-03-25 14:08:28 +08:00
[lambda: gym.make(args.task) for _ in range(args.training_num)])
2020-03-20 19:52:29 +08:00
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
2020-03-25 14:08:28 +08:00
[lambda: gym.make(args.task) for _ in range(args.test_num)])
2020-03-20 19:52:29 +08:00
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
2020-05-16 20:27:01 +08:00
# orthogonal initialization
for m in list(actor.modules()) + list(critic.modules()):
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
2020-05-17 17:06:20 +08:00
torch.nn.init.zeros_(m.bias)
optim = torch.optim.Adam(set(
actor.parameters()).union(critic.parameters()), lr=args.lr)
2020-03-20 19:52:29 +08:00
dist = torch.distributions.Categorical
policy = PPOPolicy(
actor, critic, optim, dist,
discount_factor=args.gamma,
2020-03-20 19:52:29 +08:00
max_grad_norm=args.max_grad_norm,
eps_clip=args.eps_clip,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
2020-04-19 14:30:42 +08:00
gae_lambda=args.gae_lambda,
reward_normalization=args.rew_norm,
dual_clip=args.dual_clip,
value_clip=args.value_clip,
action_space=env.action_space)
2020-03-20 19:52:29 +08:00
# collector
train_collector = Collector(
policy, train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True)
2020-03-23 11:34:52 +08:00
test_collector = Collector(policy, test_envs)
2020-03-20 19:52:29 +08:00
# log
2020-04-11 16:54:27 +08:00
log_path = os.path.join(args.logdir, args.task, 'ppo')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
2020-04-11 16:54:27 +08:00
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
2020-03-20 19:52:29 +08:00
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
2020-03-20 19:52:29 +08:00
# trainer
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
2020-03-20 19:52:29 +08:00
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
2020-03-20 19:52:29 +08:00
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
2020-03-20 19:52:29 +08:00
if __name__ == '__main__':
test_ppo()