Tianshou/examples/mujoco/mujoco_ppo.py
maxhuettenrauch 5fe9aea798
Update and fix dependencies related to mac install (#1044)
Addresses part of #1015 

### Dependencies

- move jsonargparse and docstring-parser to dependencies to run hl
examples without dev
- create mujoco-py extra for legacy mujoco envs
- updated atari extra
    - removed atari-py and gym dependencies
    - added ALE-py, autorom, and shimmy
- created robotics extra for HER-DDPG

### Mac specific

- only install envpool when not on mac
- mujoco-py not working on macOS newer than Monterey
(https://github.com/openai/mujoco-py/issues/777)
- D4RL also fails due to dependency on mujoco-py
(https://github.com/Farama-Foundation/D4RL/issues/232)

### Other

- reduced training-num/test-num in example files to a number ≤ 20
(examples with 100 led to too many open files)
- rendering for Mujoco envs needs to be fixed on gymnasium side
(https://github.com/Farama-Foundation/Gymnasium/issues/749)

---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
2024-02-06 17:06:38 +01:00

234 lines
8.4 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import datetime
import os
import pprint
import numpy as np
import torch
from mujoco_env import make_mujoco_env
from torch import nn
from torch.distributions import Distribution, Independent, Normal
from torch.optim.lr_scheduler import LambdaLR
from examples.common import logger_factory
from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer
from tianshou.policy import PPOPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.continuous import ActorProb, Critic
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="Ant-v4")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=4096)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64])
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=30000)
parser.add_argument("--step-per-collect", type=int, default=2048)
parser.add_argument("--repeat-per-collect", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--training-num", type=int, default=8)
parser.add_argument("--test-num", type=int, default=10)
# ppo special
parser.add_argument("--rew-norm", type=int, default=True)
# In theory, `vf-coef` will not make any difference if using Adam optimizer.
parser.add_argument("--vf-coef", type=float, default=0.25)
parser.add_argument("--ent-coef", type=float, default=0.0)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--bound-action-method", type=str, default="clip")
parser.add_argument("--lr-decay", type=int, default=True)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--eps-clip", type=float, default=0.2)
parser.add_argument("--dual-clip", type=float, default=None)
parser.add_argument("--value-clip", type=int, default=0)
parser.add_argument("--norm-adv", type=int, default=0)
parser.add_argument("--recompute-adv", type=int, default=1)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="mujoco.benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def test_ppo(args: argparse.Namespace = get_args()) -> None:
env, train_envs, test_envs = make_mujoco_env(
args.task,
args.seed,
args.training_num,
args.test_num,
obs_norm=True,
)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# model
net_a = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
actor = ActorProb(
net_a,
args.action_shape,
unbounded=True,
device=args.device,
).to(args.device)
net_c = Net(
args.state_shape,
hidden_sizes=args.hidden_sizes,
activation=nn.Tanh,
device=args.device,
)
critic = Critic(net_c, device=args.device).to(args.device)
actor_critic = ActorCritic(actor, critic)
torch.nn.init.constant_(actor.sigma_param, -0.5)
for m in actor_critic.modules():
if isinstance(m, torch.nn.Linear):
# orthogonal initialization
torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
torch.nn.init.zeros_(m.bias)
# do last policy layer scaling, this will make initial actions have (close to)
# 0 mean and std, and will help boost performances,
# see https://arxiv.org/abs/2006.05990, Fig.24 for details
for m in actor.mu.modules():
if isinstance(m, torch.nn.Linear):
torch.nn.init.zeros_(m.bias)
m.weight.data.copy_(0.01 * m.weight.data)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(args.step_per_epoch / args.step_per_collect) * args.epoch
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
def dist(*logits: torch.Tensor) -> Distribution:
return Independent(Normal(*logits), 1)
policy: PPOPolicy = PPOPolicy(
actor=actor,
critic=critic,
optim=optim,
dist_fn=dist,
discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
action_scaling=True,
action_bound_method=args.bound_action_method,
lr_scheduler=lr_scheduler,
action_space=env.action_space,
eps_clip=args.eps_clip,
value_clip=args.value_clip,
dual_clip=args.dual_clip,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv,
)
# load a previous policy
if args.resume_path:
ckpt = torch.load(args.resume_path, map_location=args.device)
policy.load_state_dict(ckpt["model"])
train_envs.set_obs_rms(ckpt["obs_rms"])
test_envs.set_obs_rms(ckpt["obs_rms"])
print("Loaded agent from: ", args.resume_path)
# collector
buffer: VectorReplayBuffer | ReplayBuffer
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "ppo"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger_factory.logger_type = "wandb"
logger_factory.wandb_project = args.wandb_project
else:
logger_factory.logger_type = "tensorboard"
logger = logger_factory.create_logger(
log_dir=log_path,
experiment_name=log_name,
run_id=args.resume_id,
config_dict=vars(args),
)
def save_best_fn(policy: BasePolicy) -> None:
state = {"model": policy.state_dict(), "obs_rms": train_envs.get_obs_rms()}
torch.save(state, os.path.join(log_path, "policy.pth"))
if not args.watch:
# trainer
result = OnpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
repeat_per_collect=args.repeat_per_collect,
episode_per_test=args.test_num,
batch_size=args.batch_size,
step_per_collect=args.step_per_collect,
save_best_fn=save_best_fn,
logger=logger,
test_in_train=False,
).run()
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)
if __name__ == "__main__":
test_ppo()