Tianshou/examples/mujoco/mujoco_redq.py
maxhuettenrauch 60d1ba1c8f
Fix/reset before collect in procedural examples, tests and hl experiment (#1100)
Needed due to a breaking change in the Collector which was overlooked in some of the examples
2024-04-16 10:30:21 +02:00

208 lines
7.4 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import numpy as np
import torch
from mujoco_env import make_mujoco_env
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.highlevel.logger import LoggerFactoryDefault
from tianshou.policy import REDQPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils.net.common import EnsembleLinear, Net
from tianshou.utils.net.continuous import ActorProb, Critic
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--ensemble-size", type=int, default=10)
parser.add_argument("--subset-size", type=int, default=2)
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=1e-3)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--alpha", type=float, default=0.2)
parser.add_argument("--auto-alpha", default=False, action="store_true")
parser.add_argument("--alpha-lr", type=float, default=3e-4)
parser.add_argument("--start-timesteps", type=int, default=10000)
parser.add_argument("--epoch", type=int, default=200)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--step-per-collect", type=int, default=1)
parser.add_argument("--update-per-step", type=int, default=20)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--target-mode", type=str, choices=("min", "mean"), default="min")
parser.add_argument("--training-num", type=int, default=1)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def test_redq(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_mujoco_env(
args.task,
args.seed,
args.training_num,
args.test_num,
obs_norm=False,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# model
net_a = Net(state_shape=args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(
net_a,
args.action_shape,
device=args.device,
unbounded=True,
conditioned_sigma=True,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
def linear(x: int, y: int) -> EnsembleLinear:
return EnsembleLinear(args.ensemble_size, x, y)
net_c = Net(
state_shape=args.state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
linear_layer=linear,
)
critics = Critic(
net_c,
device=args.device,
linear_layer=linear,
flatten_input=False,
).to(args.device)
critics_optim = torch.optim.Adam(critics.parameters(), lr=args.critic_lr)
if args.auto_alpha:
target_entropy = -np.prod(env.action_space.shape)
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy: REDQPolicy = REDQPolicy(
actor=actor,
actor_optim=actor_optim,
critic=critics,
critic_optim=critics_optim,
ensemble_size=args.ensemble_size,
subset_size=args.subset_size,
tau=args.tau,
gamma=args.gamma,
alpha=args.alpha,
estimation_step=args.n_step,
actor_delay=args.update_per_step,
target_mode=args.target_mode,
action_space=env.action_space,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
buffer: VectorReplayBuffer | ReplayBuffer
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.reset()
train_collector.collect(n_step=args.start_timesteps, random=True)
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "redq"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
logger_factory = LoggerFactoryDefault()
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"
logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
if not args.watch:
# trainer
result = OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
step_per_collect=args.step_per_collect,
episode_per_test=args.test_num,
batch_size=args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
).run()
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)
if __name__ == "__main__":
test_redq()