Tianshou/examples/mujoco/mujoco_a2c.py

230 lines
8.0 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import numpy as np
import torch
from mujoco_env import make_mujoco_env
from torch import nn
from torch.distributions import Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.policy import A2CPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--lr", type=float, default=7e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=30000)
parser.add_argument("--step-per-collect", type=int, default=80)
parser.add_argument("--repeat-per-collect", type=int, default=1)
# batch-size >> step-per-collect means calculating all data in one singe forward.
parser.add_argument("--batch-size", type=int, default=None)
parser.add_argument("--training-num", type=int, default=16)
parser.add_argument("--test-num", type=int, default=10)
# a2c special
parser.add_argument("--rew-norm", type=int, default=True)
parser.add_argument("--vf-coef", type=float, default=0.5)
parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--bound-action-method", type=str, default="clip")
parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def test_a2c(args=get_args()):
env, train_envs, test_envs = make_mujoco_env(
args.task,
args.seed,
args.training_num,
args.test_num,
obs_norm=True,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# model
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
actor = ActorProb(
net_a,
args.action_shape,
unbounded=True,
device=args.device,
).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
critic = Critic(net_c, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.RMSprop(
actor_critic.parameters(),
lr=args.lr,
eps=1e-5,
alpha=0.99,
)
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits):
return Independent(Normal(*logits), 1)
policy = A2CPolicy(
actor=actor,
critic=critic,
optim=optim,
dist_fn=dist,
discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler,
action_space=env.action_space,
)
# load a previous policy
if args.resume_path:
ckpt = torch.load(args.resume_path, map_location=args.device)
policy.load_state_dict(ckpt["model"])
train_envs.set_obs_rms(ckpt["obs_rms"])
test_envs.set_obs_rms(ckpt["obs_rms"])
print("Loaded agent from: ", args.resume_path)
# collector
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "a2c"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
def save_best_fn(policy):
state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()}
torch.save(state, os.path.join(log_path, "policy.pth"))
if not args.watch:
# trainer
result = OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
repeat_per_collect=args.repeat_per_collect,
episode_per_test=args.test_num,
batch_size=args.batch_size,
step_per_collect=args.step_per_collect,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False,
).run()
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
result = test_collector.collect(n_episode=args.test_num, render=args.render)
print(f"Final reward: {result.returns_stat.mean}, length: {result.lens_stat.mean}")
if __name__ == "__main__":
test_a2c()