2024-04-26 17:39:31 +02:00

179 lines
5.8 KiB
Python

#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from examples.offline.utils import load_buffer_d4rl
from tianshou.data import Collector
from tianshou.env import SubprocVectorEnv
from tianshou.policy import ImitationPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OfflineTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import Actor
from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="HalfCheetah-v2")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--expert-data-task", type=str, default="halfcheetah-expert-v2")
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=1 / 35)
parser.add_argument("--gamma", default=0.99)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="offline_d4rl.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def test_il() -> None:
args = get_args()
env = gym.make(args.task)
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
args.max_action = space_info.action_info.max_action
print("device:", args.device)
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", args.min_action, args.max_action)
args.state_dim = space_info.observation_info.obs_dim
args.action_dim = space_info.action_info.action_dim
print("Max_action", args.max_action)
test_envs = SubprocVectorEnv([lambda: gym.make(args.task) for _ in range(args.test_num)])
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
test_envs.seed(args.seed)
# model
net = Net(
state_shape=args.state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = Actor(
net,
action_shape=args.action_shape,
max_action=args.max_action,
device=args.device,
).to(args.device)
optim = torch.optim.Adam(actor.parameters(), lr=args.lr)
policy: ImitationPolicy = ImitationPolicy(
actor=actor,
optim=optim,
action_space=env.action_space,
action_scaling=True,
action_bound_method="clip",
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
test_collector = Collector(policy, test_envs)
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "cql"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger: WandbLogger | TensorboardLogger
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else:
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
logger.load(writer)
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def watch() -> None:
if args.resume_path is None:
args.resume_path = os.path.join(log_path, "policy.pth")
policy.load_state_dict(torch.load(args.resume_path, map_location=torch.device("cpu")))
collector = Collector(policy, env)
collector.collect(n_episode=1, render=1 / 35, eval_mode=True)
if not args.watch:
replay_buffer = load_buffer_d4rl(args.expert_data_task)
# trainer
result = OfflineTrainer(
policy=policy,
buffer=replay_buffer,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
episode_per_test=args.test_num,
batch_size=args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
).run()
pprint.pprint(result)
else:
watch()
# Let's watch its performance!
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(
n_episode=args.test_num,
render=args.render,
eval_mode=True,
)
print(collector_stats)
if __name__ == "__main__":
test_il()