Tianshou/examples/atari/atari_cql.py
Yi Su b5c3ddabfa
Add discrete Conservative Q-Learning for offline RL (#359)
Co-authored-by: Yi Su <yi.su@antgroup.com>
Co-authored-by: Yi Su <yi.su@antfin.com>
2021-05-12 09:24:48 +08:00

148 lines
5.4 KiB
Python

import os
import torch
import pickle
import pprint
import datetime
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger
from tianshou.env import SubprocVectorEnv
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteCQLPolicy
from tianshou.data import Collector, VectorReplayBuffer
from atari_network import QRDQN
from atari_wrapper import wrap_deepmind
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="PongNoFrameskip-v4")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument('--num-quantiles', type=int, default=200)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--target-update-freq", type=int, default=500)
parser.add_argument("--min-q-weight", type=float, default=10.)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--update-per-epoch", type=int, default=10000)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument('--frames-stack', type=int, default=4)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--watch", default=False, action="store_true",
help="watch the play of pre-trained policy only")
parser.add_argument("--log-interval", type=int, default=100)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_PongNoFrameskip-v4.hdf5")
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_known_args()[0]
return args
def make_atari_env(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack)
def make_atari_env_watch(args):
return wrap_deepmind(args.task, frame_stack=args.frames_stack,
episode_life=False, clip_rewards=False)
def test_discrete_cql(args=get_args()):
# envs
env = make_atari_env(args)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
# should be N_FRAMES x H x W
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
# make environments
test_envs = SubprocVectorEnv([lambda: make_atari_env_watch(args)
for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
net = QRDQN(*args.state_shape, args.action_shape,
args.num_quantiles, args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
# define policy
policy = DiscreteCQLPolicy(
net, optim, args.gamma, args.num_quantiles, args.n_step,
args.target_update_freq, min_q_weight=args.min_q_weight
).to(args.device)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(
args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run atari_qrdqn.py first to get expert's data buffer."
if args.load_buffer_name.endswith('.pkl'):
buffer = pickle.load(open(args.load_buffer_name, "rb"))
elif args.load_buffer_name.endswith('.hdf5'):
buffer = VectorReplayBuffer.load_hdf5(args.load_buffer_name)
else:
print(f"Unknown buffer format: {args.load_buffer_name}")
exit(0)
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(
args.logdir, args.task, 'cql',
f'seed_{args.seed}_{datetime.datetime.now().strftime("%m%d-%H%M%S")}')
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = BasicLogger(writer, update_interval=args.log_interval)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return False
# watch agent's performance
def watch():
print("Setup test envs ...")
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
print("Testing agent ...")
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num,
render=args.render)
pprint.pprint(result)
rew = result["rews"].mean()
print(f'Mean reward (over {result["n/ep"]} episodes): {rew}')
if args.watch:
watch()
exit(0)
result = offline_trainer(
policy, buffer, test_collector, args.epoch,
args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, logger=logger)
pprint.pprint(result)
watch()
if __name__ == "__main__":
test_discrete_cql(get_args())