Tianshou/examples/box2d/bipedal_hardcore_sac.py
maxhuettenrauch 5fe9aea798
Update and fix dependencies related to mac install (#1044)
Addresses part of #1015 

### Dependencies

- move jsonargparse and docstring-parser to dependencies to run hl
examples without dev
- create mujoco-py extra for legacy mujoco envs
- updated atari extra
    - removed atari-py and gym dependencies
    - added ALE-py, autorom, and shimmy
- created robotics extra for HER-DDPG

### Mac specific

- only install envpool when not on mac
- mujoco-py not working on macOS newer than Monterey
(https://github.com/openai/mujoco-py/issues/777)
- D4RL also fails due to dependency on mujoco-py
(https://github.com/Farama-Foundation/D4RL/issues/232)

### Other

- reduced training-num/test-num in example files to a number ≤ 20
(examples with 100 led to too many open files)
- rendering for Mujoco envs needs to be fixed on gymnasium side
(https://github.com/Farama-Foundation/Gymnasium/issues/749)

---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
2024-02-06 17:06:38 +01:00

214 lines
7.4 KiB
Python

import argparse
import os
import pprint
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from gymnasium.core import WrapperActType, WrapperObsType
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import SubprocVectorEnv
from tianshou.policy import SACPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import Net
from tianshou.utils.net.continuous import ActorProb, Critic
from tianshou.utils.space_info import SpaceInfo
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="BipedalWalkerHardcore-v3")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--buffer-size", type=int, default=1000000)
parser.add_argument("--actor-lr", type=float, default=3e-4)
parser.add_argument("--critic-lr", type=float, default=1e-3)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--tau", type=float, default=0.005)
parser.add_argument("--alpha", type=float, default=0.1)
parser.add_argument("--auto-alpha", type=int, default=1)
parser.add_argument("--alpha-lr", type=float, default=3e-4)
parser.add_argument("--epoch", type=int, default=100)
parser.add_argument("--step-per-epoch", type=int, default=100000)
parser.add_argument("--step-per-collect", type=int, default=10)
parser.add_argument("--update-per-step", type=float, default=0.1)
parser.add_argument("--batch-size", type=int, default=128)
parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[128, 128])
parser.add_argument("--training-num", type=int, default=10)
parser.add_argument("--test-num", type=int, default=10)
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument("--n-step", type=int, default=4)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
parser.add_argument("--resume-path", type=str, default=None)
return parser.parse_args()
class Wrapper(gym.Wrapper):
"""Env wrapper for reward scale, action repeat and removing done penalty."""
def __init__(
self,
env: gym.Env,
action_repeat: int = 3,
reward_scale: int = 5,
rm_done: bool = True,
) -> None:
super().__init__(env)
self.action_repeat = action_repeat
self.reward_scale = reward_scale
self.rm_done = rm_done
def step(
self,
action: WrapperActType,
) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]:
rew_sum = 0.0
for _ in range(self.action_repeat):
obs, rew, terminated, truncated, info = self.env.step(action)
done = terminated | truncated
# remove done reward penalty
if not done or not self.rm_done:
rew_sum = rew_sum + float(rew)
if done:
break
# scale reward
return obs, self.reward_scale * rew_sum, terminated, truncated, info
def test_sac_bipedal(args: argparse.Namespace = get_args()) -> None:
env = Wrapper(gym.make(args.task))
space_info = SpaceInfo.from_env(env)
args.state_shape = space_info.observation_info.obs_shape
args.action_shape = space_info.action_info.action_shape
args.max_action = space_info.action_info.max_action
train_envs = SubprocVectorEnv(
[lambda: Wrapper(gym.make(args.task)) for _ in range(args.training_num)],
)
# test_envs = gym.make(args.task)
test_envs = SubprocVectorEnv(
[
lambda: Wrapper(gym.make(args.task), reward_scale=1, rm_done=False)
for _ in range(args.test_num)
],
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net_a = Net(args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device)
actor = ActorProb(net_a, args.action_shape, device=args.device, unbounded=True).to(args.device)
actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
net_c1 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic1 = Critic(net_c1, device=args.device).to(args.device)
critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
net_c2 = Net(
args.state_shape,
args.action_shape,
hidden_sizes=args.hidden_sizes,
concat=True,
device=args.device,
)
critic2 = Critic(net_c2, device=args.device).to(args.device)
critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)
action_dim = space_info.action_info.action_dim
if args.auto_alpha:
target_entropy = -action_dim
log_alpha = torch.zeros(1, requires_grad=True, device=args.device)
alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
args.alpha = (target_entropy, log_alpha, alpha_optim)
policy: SACPolicy = SACPolicy(
actor=actor,
actor_optim=actor_optim,
critic=critic1,
critic_optim=critic1_optim,
critic2=critic2,
critic2_optim=critic2_optim,
tau=args.tau,
gamma=args.gamma,
alpha=args.alpha,
estimation_step=args.n_step,
action_space=env.action_space,
)
# load a previous policy
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path))
print("Loaded agent from: ", args.resume_path)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True,
)
test_collector = Collector(policy, test_envs)
# train_collector.collect(n_step=args.buffer_size)
# log
log_path = os.path.join(args.logdir, args.task, "sac")
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards: float) -> bool:
if env.spec:
if not env.spec.reward_threshold:
return False
else:
return mean_rewards >= env.spec.reward_threshold
return False
# trainer
result = OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
step_per_collect=args.step_per_collect,
episode_per_test=args.test_num,
batch_size=args.batch_size,
update_per_step=args.update_per_step,
test_in_train=False,
stop_fn=stop_fn,
save_best_fn=save_best_fn,
logger=logger,
).run()
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)
if __name__ == "__main__":
test_sac_bipedal()