Tianshou/examples/mujoco/fetch_her_ddpg.py
Daniel Plop eb0215cf76
Refactoring/mypy issues test (#1017)
Improves typing in examples and tests, towards mypy passing there.

Introduces the SpaceInfo utility
2024-02-06 14:24:30 +01:00

236 lines
8.5 KiB
Python

#!/usr/bin/env python3
# isort: skip_file
import argparse
import datetime
import os
import pprint
import gymnasium as gym
import numpy as np
import torch
import wandb
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import (
Collector,
HERReplayBuffer,
HERVectorReplayBuffer,
ReplayBuffer,
VectorReplayBuffer,
)
from tianshou.env import ShmemVectorEnv, TruncatedAsTerminated
from tianshou.exploration import GaussianNoise
from tianshou.policy import DDPGPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.net.common import Net, get_dict_state_decorator
from tianshou.utils.net.continuous import Actor, Critic
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="FetchReach-v3")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[256, 256])
parser.add_argument("--actor-lr", type=float, default=1e-3)
parser.add_argument("--critic-lr", type=float, default=3e-3)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--exploration-noise", type=float, default=0.1)
parser.add_argument("--start-timesteps", type=int, default=25000)
parser.add_argument("--epoch", type=int, default=10)
parser.add_argument("--step-per-epoch", type=int, default=5000)
parser.add_argument("--step-per-collect", type=int, default=1)
parser.add_argument("--update-per-step", type=int, default=1)
parser.add_argument("--n-step", type=int, default=1)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--replay-buffer", type=str, default="her", choices=["normal", "her"])
parser.add_argument("--her-horizon", type=int, default=50)
parser.add_argument("--her-future-k", type=int, default=8)
parser.add_argument("--training-num", type=int, default=1)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
parser.add_argument("--resume-id", type=str, default=None)
parser.add_argument(
"--logger",
type=str,
default="tensorboard",
choices=["tensorboard", "wandb"],
)
parser.add_argument("--wandb-project", type=str, default="HER-benchmark")
parser.add_argument(
"--watch",
default=False,
action="store_true",
help="watch the play of pre-trained policy only",
)
return parser.parse_args()
def make_fetch_env(task, training_num, test_num):
env = TruncatedAsTerminated(gym.make(task))
train_envs = ShmemVectorEnv(
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(training_num)],
)
test_envs = ShmemVectorEnv(
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)],
)
return env, train_envs, test_envs
def test_ddpg(args: argparse.Namespace = get_args()) -> None:
# log
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
args.algo_name = "ddpg"
log_name = os.path.join(args.task, args.algo_name, str(args.seed), now)
log_path = os.path.join(args.logdir, log_name)
# logger
if args.logger == "wandb":
logger = WandbLogger(
save_interval=1,
name=log_name.replace(os.path.sep, "__"),
run_id=args.resume_id,
config=args,
project=args.wandb_project,
)
logger.wandb_run.config.setdefaults(vars(args))
args = argparse.Namespace(**wandb.config)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
if args.logger == "tensorboard":
logger = TensorboardLogger(writer)
else: # wandb
logger.load(writer)
env, train_envs, test_envs = make_fetch_env(args.task, args.training_num, args.test_num)
args.state_shape = {
"observation": env.observation_space["observation"].shape,
"achieved_goal": env.observation_space["achieved_goal"].shape,
"desired_goal": env.observation_space["desired_goal"].shape,
}
args.action_shape = env.action_space.shape or env.action_space.n
args.max_action = env.action_space.high[0]
args.exploration_noise = args.exploration_noise * args.max_action
print("Observations shape:", args.state_shape)
print("Actions shape:", args.action_shape)
print("Action range:", np.min(env.action_space.low), np.max(env.action_space.high))
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# model
dict_state_dec, flat_state_shape = get_dict_state_decorator(
state_shape=args.state_shape,
keys=["observation", "achieved_goal", "desired_goal"],
)
net_a = dict_state_dec(Net)(
flat_state_shape,
hidden_sizes=args.hidden_sizes,
device=args.device,
)
actor = dict_state_dec(Actor)(
net_a,
args.action_shape,
max_action=args.max_action,
device=args.device,
).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c = dict_state_dec(Net)(
flat_state_shape,
action_shape=args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic = dict_state_dec(Critic)(net_c, device=args.device).to(args.device)
critic_optim = torch.optim.Adam(critic.parameters(), lr=args.critic_lr)
policy: DDPGPolicy = DDPGPolicy(
actor=actor,
actor_optim=actor_optim,
critic=critic,
critic_optim=critic_optim,
tau=args.tau,
gamma=args.gamma,
exploration_noise=GaussianNoise(sigma=args.exploration_noise),
estimation_step=args.n_step,
action_space=env.action_space,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location=args.device))
print("Loaded agent from: ", args.resume_path)
# collector
def compute_reward_fn(ag: np.ndarray, g: np.ndarray):
return env.compute_reward(ag, g, {})
buffer: VectorReplayBuffer | ReplayBuffer | HERReplayBuffer | HERVectorReplayBuffer
if args.replay_buffer == "normal":
if args.training_num > 1:
buffer = VectorReplayBuffer(args.buffer_size, len(train_envs))
else:
buffer = ReplayBuffer(args.buffer_size)
else:
if args.training_num > 1:
buffer = HERVectorReplayBuffer(
args.buffer_size,
len(train_envs),
compute_reward_fn=compute_reward_fn,
horizon=args.her_horizon,
future_k=args.her_future_k,
)
else:
buffer = HERReplayBuffer(
args.buffer_size,
compute_reward_fn=compute_reward_fn,
horizon=args.her_horizon,
future_k=args.her_future_k,
)
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs)
train_collector.collect(n_step=args.start_timesteps, random=True)
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
if not args.watch:
# trainer
result = OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
step_per_collect=args.step_per_collect,
episode_per_test=args.test_num,
batch_size=args.batch_size,
save_best_fn=save_best_fn,
logger=logger,
update_per_step=args.update_per_step,
test_in_train=False,
).run()
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)
if __name__ == "__main__":
test_ddpg()