Tianshou/test/discrete/test_il_bcq.py
ChenDRAG 7036073649
Trainer refactor : some definition change (#293)
This PR focus on some definition change of trainer to make it more friendly to use and be consistent with typical usage in research papers, typically change `collect-per-step` to `step-per-collect`, add `update-per-step` / `episode-per-collect` accordingly, and modify the documentation.
2021-02-21 13:06:02 +08:00

113 lines
4.1 KiB
Python

import os
import gym
import torch
import pickle
import pprint
import argparse
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector
from tianshou.env import DummyVectorEnv
from tianshou.utils.net.common import Net
from tianshou.trainer import offline_trainer
from tianshou.policy import DiscreteBCQPolicy
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="CartPole-v0")
parser.add_argument("--seed", type=int, default=1626)
parser.add_argument("--eps-test", type=float, default=0.001)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--gamma", type=float, default=0.9)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--target-update-freq", type=int, default=320)
parser.add_argument("--unlikely-action-threshold", type=float, default=0.3)
parser.add_argument("--imitation-logits-penalty", type=float, default=0.01)
parser.add_argument("--epoch", type=int, default=5)
parser.add_argument("--update-per-epoch", type=int, default=1000)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument('--hidden-sizes', type=int,
nargs='*', default=[128, 128, 128])
parser.add_argument("--test-num", type=int, default=100)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.)
parser.add_argument(
"--load-buffer-name", type=str,
default="./expert_DQN_CartPole-v0.pkl",
)
parser.add_argument(
"--device", type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
args = parser.parse_known_args()[0]
return args
def test_discrete_bcq(args=get_args()):
# envs
env = gym.make(args.task)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
test_envs = DummyVectorEnv(
[lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
policy_net = Net(
args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
imitation_net = Net(
args.state_shape, args.action_shape,
hidden_sizes=args.hidden_sizes, device=args.device).to(args.device)
optim = torch.optim.Adam(
set(policy_net.parameters()).union(imitation_net.parameters()),
lr=args.lr)
policy = DiscreteBCQPolicy(
policy_net, imitation_net, optim, args.gamma, args.n_step,
args.target_update_freq, args.eps_test,
args.unlikely_action_threshold, args.imitation_logits_penalty,
)
# buffer
assert os.path.exists(args.load_buffer_name), \
"Please run test_dqn.py first to get expert's data buffer."
buffer = pickle.load(open(args.load_buffer_name, "rb"))
# collector
test_collector = Collector(policy, test_envs, exploration_noise=True)
log_path = os.path.join(args.logdir, args.task, 'discrete_bcq')
writer = SummaryWriter(log_path)
def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
def stop_fn(mean_rewards):
return mean_rewards >= env.spec.reward_threshold
result = offline_trainer(
policy, buffer, test_collector,
args.epoch, args.update_per_epoch, args.test_num, args.batch_size,
stop_fn=stop_fn, save_fn=save_fn, writer=writer)
assert stop_fn(result['best_reward'])
if __name__ == '__main__':
pprint.pprint(result)
# Let's watch its performance!
env = gym.make(args.task)
policy.eval()
policy.set_eps(args.eps_test)
collector = Collector(policy, env)
result = collector.collect(n_episode=1, render=args.render)
rews, lens = result["rews"], result["lens"]
print(f"Final reward: {rews.mean()}, length: {lens.mean()}")
if __name__ == "__main__":
test_discrete_bcq(get_args())