Tianshou/examples/box2d/bipedal_bdq.py
maxhuettenrauch 5fe9aea798
Update and fix dependencies related to mac install (#1044)
Addresses part of #1015 

### Dependencies

- move jsonargparse and docstring-parser to dependencies to run hl
examples without dev
- create mujoco-py extra for legacy mujoco envs
- updated atari extra
    - removed atari-py and gym dependencies
    - added ALE-py, autorom, and shimmy
- created robotics extra for HER-DDPG

### Mac specific

- only install envpool when not on mac
- mujoco-py not working on macOS newer than Monterey
(https://github.com/openai/mujoco-py/issues/777)
- D4RL also fails due to dependency on mujoco-py
(https://github.com/Farama-Foundation/D4RL/issues/232)

### Other

- reduced training-num/test-num in example files to a number ≤ 20
(examples with 100 led to too many open files)
- rendering for Mujoco envs needs to be fixed on gymnasium side
(https://github.com/Farama-Foundation/Gymnasium/issues/749)

---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
2024-02-06 17:06:38 +01:00

169 lines
6.1 KiB
Python

import argparse
import datetime
import os
import pprint
import gymnasium as gym
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ContinuousToDiscrete, SubprocVectorEnv
from tianshou.policy import BranchingDQNPolicy
from tianshou.policy.base import BasePolicy
from tianshou.trainer import OffpolicyTrainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import BranchingNet
def get_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
# task
parser.add_argument("--task", type=str, default="BipedalWalker-v3")
# network architecture
parser.add_argument("--common-hidden-sizes", type=int, nargs="*", default=[512, 256])
parser.add_argument("--action-hidden-sizes", type=int, nargs="*", default=[128])
parser.add_argument("--value-hidden-sizes", type=int, nargs="*", default=[128])
parser.add_argument("--action-per-branch", type=int, default=25)
# training hyperparameters
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--eps-test", type=float, default=0.0)
parser.add_argument("--eps-train", type=float, default=0.73)
parser.add_argument("--eps-decay", type=float, default=5e-6)
parser.add_argument("--buffer-size", type=int, default=100000)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--target-update-freq", type=int, default=1000)
parser.add_argument("--epoch", type=int, default=1000)
parser.add_argument("--step-per-epoch", type=int, default=80000)
parser.add_argument("--step-per-collect", type=int, default=16)
parser.add_argument("--update-per-step", type=float, default=0.0625)
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--training-num", type=int, default=20)
parser.add_argument("--test-num", type=int, default=10)
# other
parser.add_argument("--logdir", type=str, default="log")
parser.add_argument("--render", type=float, default=0.0)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
)
return parser.parse_args()
def test_bdq(args: argparse.Namespace = get_args()) -> None:
env = gym.make(args.task)
env = ContinuousToDiscrete(env, args.action_per_branch)
args.state_shape = env.observation_space.shape or env.observation_space.n
args.action_shape = env.action_space.shape or env.action_space.n
args.num_branches = (
args.action_shape if isinstance(args.action_shape, int) else args.action_shape[0]
)
print("Observations shape:", args.state_shape)
print("Num branches:", args.num_branches)
print("Actions per branch:", args.action_per_branch)
# train_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
# you can also use tianshou.env.SubprocVectorEnv
train_envs = SubprocVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.training_num)
],
)
# test_envs = ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
test_envs = SubprocVectorEnv(
[
lambda: ContinuousToDiscrete(gym.make(args.task), args.action_per_branch)
for _ in range(args.test_num)
],
)
# seed
np.random.seed(args.seed)
torch.manual_seed(args.seed)
train_envs.seed(args.seed)
test_envs.seed(args.seed)
# model
net = BranchingNet(
args.state_shape,
args.num_branches,
args.action_per_branch,
args.common_hidden_sizes,
args.value_hidden_sizes,
args.action_hidden_sizes,
device=args.device,
).to(args.device)
optim = torch.optim.Adam(net.parameters(), lr=args.lr)
policy = BranchingDQNPolicy(
model=net,
optim=optim,
discount_factor=args.gamma,
action_space=env.action_space,
target_update_freq=args.target_update_freq,
)
# collector
train_collector = Collector(
policy,
train_envs,
VectorReplayBuffer(args.buffer_size, len(train_envs)),
exploration_noise=True,
)
test_collector = Collector(policy, test_envs, exploration_noise=False)
# policy.set_eps(1)
train_collector.collect(n_step=args.batch_size * args.training_num)
# log
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_path = os.path.join(args.logdir, "bdq", args.task, current_time)
writer = SummaryWriter(log_path)
logger = TensorboardLogger(writer)
def save_best_fn(policy: BasePolicy) -> None:
torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth"))
def stop_fn(mean_rewards: float) -> bool:
return mean_rewards >= getattr(env.spec.reward_threshold)
def train_fn(epoch: int, env_step: int) -> None: # exp decay
eps = max(args.eps_train * (1 - args.eps_decay) ** env_step, args.eps_test)
policy.set_eps(eps)
def test_fn(epoch: int, env_step: int | None) -> None:
policy.set_eps(args.eps_test)
# trainer
result = OffpolicyTrainer(
policy=policy,
train_collector=train_collector,
test_collector=test_collector,
max_epoch=args.epoch,
step_per_epoch=args.step_per_epoch,
step_per_collect=args.step_per_collect,
episode_per_test=args.test_num,
batch_size=args.batch_size,
update_per_step=args.update_per_step,
# stop_fn=stop_fn,
train_fn=train_fn,
test_fn=test_fn,
save_best_fn=save_best_fn,
logger=logger,
).run()
# assert stop_fn(result.best_reward)
if __name__ == "__main__":
pprint.pprint(result)
# Let's watch its performance!
policy.eval()
policy.set_eps(args.eps_test)
test_envs.seed(args.seed)
test_collector.reset()
collector_stats = test_collector.collect(n_episode=args.test_num, render=args.render)
print(collector_stats)
if __name__ == "__main__":
test_bdq(get_args())