Tianshou/examples/offline/d4rl_bcq.py
2024-04-26 17:39:31 +02:00

243 lines
8.1 KiB
Python

#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import BCQPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import MLP, Net
from tianshou.utils.net.continuous import VAE, Critic, Perturbation
from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2")
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=1e-3)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--n-step", type=int, default=3)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=1 / 35)
parser.add_argument("--vae-hidden-sizes", type=int, nargs="*", default=[512, 512])
# default to 2 * action_dim
parser.add_argument("--latent-dim", type=int)
parser.add_argument("--gamma", default=0.99)
parser.add_argument("--tau", default=0.005)
# Weighting for Clipped Double Q-learning in BCQ
parser.add_argument("--lmbda", default=0.75)
# Max perturbation hyper-parameter for BCQ
parser.add_argument("--phi", default=0.05)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def test_bcq() -> None:
args = get_args()
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
args.max_action = space_info.action_info.max_action
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", args.min_action, args.max_action)
args.state_dim = space_info.observation_info.obs_dim
args.action_dim = space_info.action_info.action_dim
print("Max_action", args.max_action)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
# perturbation network
net_a = MLP(
input_dim=args.state_dim + args.action_dim,
output_dim=args.action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Perturbation(net_a, max_action=args.max_action, device=args.device, phi=args.phi).to(
args.device,
)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
state_shape=args.state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
net_c2 = Net(
state_shape=args.state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
# vae
# output_dim = 0, so the last Module in the encoder is ReLU
vae_encoder = MLP(
input_dim=args.state_dim + args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
if not args.latent_dim:
args.latent_dim = args.action_dim * 2
vae_decoder = MLP(
input_dim=args.state_dim + args.latent_dim,
output_dim=args.action_dim,
hidden_sizes=args.vae_hidden_sizes,
device=args.device,
)
vae = VAE(
vae_encoder,
vae_decoder,
hidden_dim=args.vae_hidden_sizes[-1],
latent_dim=args.latent_dim,
max_action=args.max_action,
device=args.device,
).to(args.device)
vae_optim = torch.optim.Adam(vae.parameters())
policy: BCQPolicy = BCQPolicy(
actor_perturbation=actor,
actor_perturbation_optim=actor_optim,
critic=critic1,
critic_optim=critic1_optim,
action_space=env.action_space,
critic2=critic2,
critic2_optim=critic2_optim,
vae=vae,
vae_optim=vae_optim,
device=args.device,
gamma=args.gamma,
tau=args.tau,
lmbda=args.lmbda,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
test_collector = Collector(policy, test_envs)
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "bcq"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger: WandbLogger | TensorboardLogger
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else:
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
logger.load(writer)
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def watch() -> None:
if args.resume_path is None:
args.resume_path = os.path.join(log_path, "policy.pth")
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
# trainer
result = OfflineTrainer(
policy=policy,
buffer=replay_buffer,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
episode_per_test=args.test_num,
batch_size=args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
).run()
pprint.pprint(result)
else:
watch()
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats)
if __name__ == "__main__":
test_bcq()