Tianshou/test/discrete/test_a2c_with_il.py

153 lines
6.4 KiB
Python
Raw Normal View History

2020-04-11 16:54:27 +08:00
import os
2020-03-17 20:22:37 +08:00
import gym
import torch
2020-03-20 19:52:29 +08:00
import pprint
2020-03-17 20:22:37 +08:00
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.utils.net.discrete import Actor, Critic
2020-04-20 11:25:20 +08:00
from tianshou.policy import A2CPolicy, ImitationPolicy
from tianshou.trainer import onpolicy_trainer, offpolicy_trainer
2020-03-17 20:22:37 +08:00
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
2020-05-16 20:08:32 +08:00
parser.add_argument('--seed', type=int, default=1)
2020-03-17 20:22:37 +08:00
parser.add_argument('--buffer-size', type=int, default=20000)
parser.add_argument('--lr', type=float, default=1e-3)
2020-04-20 11:25:20 +08:00
parser.add_argument('--il-lr', type=float, default=1e-3)
2020-03-17 20:22:37 +08:00
parser.add_argument('--gamma', type=float, default=0.9)
2020-04-14 21:11:06 +08:00
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--step-per-epoch', type=int, default=50000)
parser.add_argument('--il-step-per-epoch', type=int, default=1000)
parser.add_argument('--episode-per-collect', type=int, default=16)
parser.add_argument('--step-per-collect', type=int, default=16)
parser.add_argument('--update-per-step', type=float, default=1 / 16)
2020-03-20 19:52:29 +08:00
parser.add_argument('--repeat-per-collect', type=int, default=1)
2020-03-17 20:22:37 +08:00
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[64, 64])
parser.add_argument('--imitation-hidden-sizes', type=int,
nargs='*', default=[128])
parser.add_argument('--training-num', type=int, default=16)
2020-03-17 20:22:37 +08:00
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
2020-03-17 20:22:37 +08:00
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
# a2c special
parser.add_argument('--vf-coef', type=float, default=0.5)
2020-06-03 13:59:47 +08:00
parser.add_argument('--ent-coef', type=float, default=0.0)
2020-03-18 21:45:41 +08:00
parser.add_argument('--max-grad-norm', type=float, default=None)
2020-04-14 21:11:06 +08:00
parser.add_argument('--gae-lambda', type=float, default=1.)
parser.add_argument('--rew-norm', action="store_true", default=False)
2020-03-17 20:22:37 +08:00
args = parser.parse_known_args()[0]
return args
2020-05-16 20:08:32 +08:00
def test_a2c_with_il(args=get_args()):
2020-04-03 21:28:12 +08:00
torch.set_num_threads(1) # for poor CPU
2020-03-17 20:22:37 +08:00
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
2020-04-03 21:28:12 +08:00
# you can also use tianshou.env.SubprocVectorEnv
2020-03-17 20:22:37 +08:00
# train_envs = gym.make(args.task)
train_envs = DummyVectorEnv(
2020-03-25 14:08:28 +08:00
[lambda: gym.make(args.task) for _ in range(args.training_num)])
2020-03-17 20:22:37 +08:00
# test_envs = gym.make(args.task)
test_envs = DummyVectorEnv(
2020-03-25 14:08:28 +08:00
[lambda: gym.make(args.task) for _ in range(args.test_num)])
2020-03-17 20:22:37 +08:00
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
actor = Actor(net, args.action_shape, device=args.device).to(args.device)
critic = Critic(net, device=args.device).to(args.device)
optim = torch.optim.Adam(
set(actor.parameters()).union(critic.parameters()), lr=args.lr)
2020-03-17 20:22:37 +08:00
dist = torch.distributions.Categorical
policy = A2CPolicy(
actor, critic, optim, dist,
discount_factor=args.gamma, gae_lambda=args.gae_lambda,
2020-04-14 21:11:06 +08:00
vf_coef=args.vf_coef, ent_coef=args.ent_coef,
max_grad_norm=args.max_grad_norm, reward_normalization=args.rew_norm,
action_space=env.action_space)
2020-03-17 20:22:37 +08:00
# collector
2020-03-19 17:23:46 +08:00
train_collector = Collector(
policy, train_envs,
2021-04-16 20:37:12 +08:00
VectorReplayBuffer(args.buffer_size, len(train_envs)))
2020-03-23 11:34:52 +08:00
test_collector = Collector(policy, test_envs)
2020-03-17 20:22:37 +08:00
# log
2020-04-11 16:54:27 +08:00
log_path = os.path.join(args.logdir, args.task, 'a2c')
writer = SummaryWriter(log_path)
logger = BasicLogger(writer)
2020-04-11 16:54:27 +08:00
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
2020-03-19 17:23:46 +08:00
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
2020-03-19 17:23:46 +08:00
# trainer
2020-03-20 19:52:29 +08:00
result = onpolicy_trainer(
policy, train_collector, test_collector, args.epoch,
args.step_per_epoch, args.repeat_per_collect, args.test_num, args.batch_size,
episode_per_collect=args.episode_per_collect, stop_fn=stop_fn, save_fn=save_fn,
logger=logger)
2020-03-20 19:52:29 +08:00
assert stop_fn(result['best_reward'])
2020-03-17 20:22:37 +08:00
if __name__ == '__main__':
2020-03-20 19:52:29 +08:00
pprint.pprint(result)
2020-03-17 20:22:37 +08:00
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
2020-03-19 17:23:46 +08:00
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
2020-03-17 20:22:37 +08:00
policy.eval()
2020-04-20 11:25:20 +08:00
# here we define an imitation collector with a trivial policy
2020-05-16 20:08:32 +08:00
if args.task == 'CartPole-v0':
env.spec.reward_threshold = 190 # lower the goal
net = Net(args.state_shape, hidden_sizes=args.hidden_sizes,
device=args.device)
net = Actor(net, args.action_shape, device=args.device).to(args.device)
2020-04-20 11:25:20 +08:00
optim = torch.optim.Adam(net.parameters(), lr=args.il_lr)
il_policy = ImitationPolicy(net, optim, action_space=env.action_space)
il_test_collector = Collector(
il_policy,
DummyVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
)
2020-04-20 11:25:20 +08:00
train_collector.reset()
result = offpolicy_trainer(
il_policy, train_collector, il_test_collector, args.epoch,
args.il_step_per_epoch, args.step_per_collect, args.test_num,
args.batch_size, stop_fn=stop_fn, save_fn=save_fn, logger=logger)
2020-04-20 11:25:20 +08:00
assert stop_fn(result['best_reward'])
2020-04-20 11:25:20 +08:00
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
il_policy.eval()
2020-04-20 11:25:20 +08:00
collector = Collector(il_policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
2020-04-20 11:25:20 +08:00
2020-03-17 20:22:37 +08:00
if __name__ == '__main__':
2020-05-16 20:08:32 +08:00
test_a2c_with_il()