164 lines
6.4 KiB
Python
Raw Normal View History

2020-03-17 20:22:37 +08:00
import gym
import time
import tqdm
import torch
import argparse
import numpy as np
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tianshou.policy import A2CPolicy
from tianshou.env import SubprocVectorEnv
from tianshou.utils import tqdm_config, MovAvg
from tianshou.data import Collector, ReplayBuffer
class Net(nn.Module):
def __init__(self, layer_num, state_shape, action_shape, device='cpu'):
super().__init__()
self.device = device
self.model = [
nn.Linear(np.prod(state_shape), 128),
nn.ReLU(inplace=True)]
for i in range(layer_num):
self.model += [nn.Linear(128, 128), nn.ReLU(inplace=True)]
self.actor = self.model + [nn.Linear(128, np.prod(action_shape))]
self.critic = self.model + [nn.Linear(128, 1)]
self.actor = nn.Sequential(*self.actor)
self.critic = nn.Sequential(*self.critic)
def forward(self, s, **kwargs):
s = torch.tensor(s, device=self.device, dtype=torch.float)
batch = s.shape[0]
s = s.view(batch, -1)
logits = self.actor(s)
value = self.critic(s)
return logits, value, None
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='CartPole-v0')
parser.add_argument('--seed', type=int, default=1626)
parser.add_argument('--buffer-size', type=int, default=20000)
2020-03-18 21:45:41 +08:00
parser.add_argument('--lr', type=float, default=3e-4)
2020-03-17 20:22:37 +08:00
parser.add_argument('--gamma', type=float, default=0.9)
parser.add_argument('--epoch', type=int, default=100)
parser.add_argument('--step-per-epoch', type=int, default=320)
parser.add_argument('--collect-per-step', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--layer-num', type=int, default=2)
2020-03-18 21:45:41 +08:00
parser.add_argument('--training-num', type=int, default=32)
2020-03-17 20:22:37 +08:00
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument(
'--device', type=str,
default='cuda' if torch.cuda.is_available() else 'cpu')
# a2c special
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--entropy-coef', type=float, default=0.001)
2020-03-18 21:45:41 +08:00
parser.add_argument('--max-grad-norm', type=float, default=None)
2020-03-17 20:22:37 +08:00
args = parser.parse_known_args()[0]
return args
def test_a2c(args=get_args()):
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# train_envs = gym.make(args.task)
train_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.training_num)],
reset_after_done=True)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)],
reset_after_done=False)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(args.layer_num, args.state_shape, args.action_shape, args.device)
net = net.to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
dist = torch.distributions.Categorical
policy = A2CPolicy(
net, optim, dist, args.gamma,
vf_coef=args.vf_coef,
2020-03-18 21:45:41 +08:00
entropy_coef=args.entropy_coef,
max_grad_norm=args.max_grad_norm)
2020-03-17 20:22:37 +08:00
# collector
training_collector = Collector(
policy, train_envs, ReplayBuffer(args.buffer_size))
2020-03-18 21:45:41 +08:00
test_collector = Collector(policy, test_envs, stat_size=args.test_num)
2020-03-17 20:22:37 +08:00
# log
stat_loss = MovAvg()
global_step = 0
writer = SummaryWriter(args.logdir)
best_epoch = -1
best_reward = -1e10
start_time = time.time()
for epoch in range(1, 1 + args.epoch):
desc = f'Epoch #{epoch}'
# train
policy.train()
with tqdm.tqdm(
total=args.step_per_epoch, desc=desc, **tqdm_config) as t:
while t.n < t.total:
result = training_collector.collect(
n_episode=args.collect_per_step)
losses = policy.learn(
training_collector.sample(0), args.batch_size)
training_collector.reset_buffer()
global_step += len(losses)
t.update(len(losses))
stat_loss.add(losses)
writer.add_scalar(
'reward', result['reward'], global_step=global_step)
writer.add_scalar(
'length', result['length'], global_step=global_step)
writer.add_scalar(
'loss', stat_loss.get(), global_step=global_step)
writer.add_scalar(
'speed', result['speed'], global_step=global_step)
t.set_postfix(loss=f'{stat_loss.get():.6f}',
reward=f'{result["reward"]:.6f}',
length=f'{result["length"]:.2f}',
speed=f'{result["speed"]:.2f}')
2020-03-18 21:45:41 +08:00
if t.n <= t.total:
t.update()
2020-03-17 20:22:37 +08:00
# eval
test_collector.reset_env()
test_collector.reset_buffer()
policy.eval()
result = test_collector.collect(n_episode=args.test_num)
if best_reward < result['reward']:
best_reward = result['reward']
best_epoch = epoch
print(f'Epoch #{epoch}: test_reward: {result["reward"]:.6f}, '
f'best_reward: {best_reward:.6f} in #{best_epoch}')
if best_reward >= env.spec.reward_threshold:
break
assert best_reward >= env.spec.reward_threshold
training_collector.close()
test_collector.close()
if __name__ == '__main__':
train_cnt = training_collector.collect_step
test_cnt = test_collector.collect_step
duration = time.time() - start_time
print(f'Collect {train_cnt} training frame and {test_cnt} test frame '
f'in {duration:.2f}s, '
f'speed: {(train_cnt + test_cnt) / duration:.2f}it/s')
# Let's watch its performance!
env = gym.make(args.task)
test_collector = Collector(policy, env)
result = test_collector.collect(n_episode=1, render=1 / 35)
print(f'Final reward: {result["reward"]}, length: {result["length"]}')
test_collector.close()
if __name__ == '__main__':
test_a2c()