This PR aims to provide the script of Atari DQN setting: - A speedrun of PongNoFrameskip-v4 (finished, about half an hour in i7-8750 + GTX1060 with 1M environment steps) - A general script for all atari game Since we use multiple env for simulation, the result is slightly different from the original paper, but consider to be acceptable. It also adds another parameter save_only_last_obs for replay buffer in order to save the memory. Co-authored-by: Trinkle23897 <463003665@qq.com>
148 lines
5.6 KiB
Python
148 lines
5.6 KiB
Python
import os
|
|
import torch
|
|
import pprint
|
|
import argparse
|
|
import numpy as np
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
from tianshou.policy import DQNPolicy
|
|
from tianshou.env import SubprocVectorEnv
|
|
from tianshou.utils.net.discrete import DQN
|
|
from tianshou.trainer import offpolicy_trainer
|
|
from tianshou.data import Collector, ReplayBuffer
|
|
|
|
from atari_wrapper import wrap_deepmind
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--task', type=str, default='PongNoFrameskip-v4')
|
|
parser.add_argument('--seed', type=int, default=0)
|
|
parser.add_argument('--eps_test', type=float, default=0.005)
|
|
parser.add_argument('--eps_train', type=float, default=1.)
|
|
parser.add_argument('--eps_train_final', type=float, default=0.05)
|
|
parser.add_argument('--buffer-size', type=int, default=100000)
|
|
parser.add_argument('--lr', type=float, default=0.0001)
|
|
parser.add_argument('--gamma', type=float, default=0.99)
|
|
parser.add_argument('--n_step', type=int, default=3)
|
|
parser.add_argument('--target_update_freq', type=int, default=500)
|
|
parser.add_argument('--epoch', type=int, default=100)
|
|
parser.add_argument('--step_per_epoch', type=int, default=10000)
|
|
parser.add_argument('--collect_per_step', type=int, default=10)
|
|
parser.add_argument('--batch_size', type=int, default=32)
|
|
parser.add_argument('--training_num', type=int, default=16)
|
|
parser.add_argument('--test_num', type=int, default=10)
|
|
parser.add_argument('--logdir', type=str, default='log')
|
|
parser.add_argument('--render', type=float, default=0.)
|
|
parser.add_argument(
|
|
'--device', type=str,
|
|
default='cuda' if torch.cuda.is_available() else 'cpu')
|
|
parser.add_argument('--frames_stack', type=int, default=4)
|
|
parser.add_argument('--resume_path', type=str, default=None)
|
|
parser.add_argument('--watch', default=False, action='store_true',
|
|
help='watch the play of pre-trained policy only')
|
|
return parser.parse_args()
|
|
|
|
|
|
def make_atari_env(args):
|
|
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
|
|
|
|
|
|
def make_atari_env_watch(args):
|
|
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
|
|
episode_life=False, clip_rewards=False)
|
|
|
|
|
|
def test_dqn(args=get_args()):
|
|
env = make_atari_env(args)
|
|
args.state_shape = env.observation_space.shape or env.observation_space.n
|
|
args.action_shape = env.env.action_space.shape or env.env.action_space.n
|
|
# should be N_FRAMES x H x W
|
|
print("Observations shape: ", args.state_shape)
|
|
print("Actions shape: ", args.action_shape)
|
|
# make environments
|
|
train_envs = SubprocVectorEnv([lambda: make_atari_env(args)
|
|
for _ in range(args.training_num)])
|
|
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
|
|
for _ in range(args.test_num)])
|
|
# seed
|
|
np.random.seed(args.seed)
|
|
torch.manual_seed(args.seed)
|
|
train_envs.seed(args.seed)
|
|
test_envs.seed(args.seed)
|
|
# define model
|
|
net = DQN(*args.state_shape,
|
|
args.action_shape, args.device).to(args.device)
|
|
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
|
|
# define policy
|
|
policy = DQNPolicy(net, optim, args.gamma, args.n_step,
|
|
target_update_freq=args.target_update_freq)
|
|
# load a previous policy
|
|
if args.resume_path:
|
|
policy.load_state_dict(torch.load(args.resume_path))
|
|
print("Loaded agent from: ", args.resume_path)
|
|
# replay buffer: `save_last_obs` and `stack_num` can be removed together
|
|
# when you have enough RAM
|
|
buffer = ReplayBuffer(args.buffer_size, ignore_obs_next=True,
|
|
save_last_obs=True, stack_num=args.frames_stack)
|
|
# collector
|
|
train_collector = Collector(policy, train_envs, buffer)
|
|
test_collector = Collector(policy, test_envs)
|
|
# log
|
|
log_path = os.path.join(args.logdir, args.task, 'dqn')
|
|
writer = SummaryWriter(log_path)
|
|
|
|
def save_fn(policy):
|
|
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
|
|
|
|
def stop_fn(x):
|
|
if env.env.spec.reward_threshold:
|
|
return x >= env.spec.reward_threshold
|
|
elif 'Pong' in args.task:
|
|
return x >= 20
|
|
|
|
def train_fn(x):
|
|
# nature DQN setting, linear decay in the first 1M steps
|
|
now = x * args.collect_per_step * args.step_per_epoch
|
|
if now <= 1e6:
|
|
eps = args.eps_train - now / 1e6 * \
|
|
(args.eps_train - args.eps_train_final)
|
|
policy.set_eps(eps)
|
|
else:
|
|
policy.set_eps(args.eps_train_final)
|
|
print("set eps =", policy.eps)
|
|
|
|
def test_fn(x):
|
|
policy.set_eps(args.eps_test)
|
|
|
|
# watch agent's performance
|
|
def watch():
|
|
print("Testing agent ...")
|
|
policy.eval()
|
|
policy.set_eps(args.eps_test)
|
|
test_envs.seed(args.seed)
|
|
test_collector.reset()
|
|
result = test_collector.collect(n_episode=[1] * args.test_num,
|
|
render=args.render)
|
|
pprint.pprint(result)
|
|
|
|
if args.watch:
|
|
watch()
|
|
exit(0)
|
|
|
|
# test train_collector and start filling replay buffer
|
|
train_collector.collect(n_step=args.batch_size * 4)
|
|
# trainer
|
|
result = offpolicy_trainer(
|
|
policy, train_collector, test_collector, args.epoch,
|
|
args.step_per_epoch, args.collect_per_step, args.test_num,
|
|
args.batch_size, train_fn=train_fn, test_fn=test_fn,
|
|
stop_fn=stop_fn, save_fn=save_fn, writer=writer, test_in_train=False)
|
|
|
|
pprint.pprint(result)
|
|
watch()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_dqn(get_args())
|