Tianshou/examples/mujoco/mujoco_reinforce_hl.py
maxhuettenrauch 5fe9aea798
Update and fix dependencies related to mac install (#1044)
Addresses part of #1015 

### Dependencies

- move jsonargparse and docstring-parser to dependencies to run hl
examples without dev
- create mujoco-py extra for legacy mujoco envs
- updated atari extra
    - removed atari-py and gym dependencies
    - added ALE-py, autorom, and shimmy
- created robotics extra for HER-DDPG

### Mac specific

- only install envpool when not on mac
- mujoco-py not working on macOS newer than Monterey
(https://github.com/openai/mujoco-py/issues/777)
- D4RL also fails due to dependency on mujoco-py
(https://github.com/Farama-Foundation/D4RL/issues/232)

### Other

- reduced training-num/test-num in example files to a number ≤ 20
(examples with 100 led to too many open files)
- rendering for Mujoco envs needs to be fixed on gymnasium side
(https://github.com/Farama-Foundation/Gymnasium/issues/749)

---------

Co-authored-by: Maximilian Huettenrauch <m.huettenrauch@appliedai.de>
Co-authored-by: Michael Panchenko <35432522+MischaPanch@users.noreply.github.com>
2024-02-06 17:06:38 +01:00

77 lines
2.3 KiB
Python

#!/usr/bin/env python3
import functools
import os
from collections.abc import Sequence
from typing import Literal
import torch
from examples.mujoco.mujoco_env import MujocoEnvFactory
from tianshou.highlevel.config import SamplingConfig
from tianshou.highlevel.experiment import (
ExperimentConfig,
PGExperimentBuilder,
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PGParams
from tianshou.utils import logging
from tianshou.utils.logging import datetime_tag
def main(
experiment_config: ExperimentConfig,
task: str = "Ant-v4",
buffer_size: int = 4096,
hidden_sizes: Sequence[int] = (64, 64),
lr: float = 1e-3,
gamma: float = 0.99,
epoch: int = 100,
step_per_epoch: int = 30000,
step_per_collect: int = 2048,
repeat_per_collect: int = 1,
batch_size: int | None = None,
training_num: int = 10,
test_num: int = 10,
rew_norm: bool = True,
action_bound_method: Literal["clip", "tanh"] = "tanh",
lr_decay: bool = True,
) -> None:
log_name = os.path.join(task, "reinforce", str(experiment_config.seed), datetime_tag())
sampling_config = SamplingConfig(
num_epochs=epoch,
step_per_epoch=step_per_epoch,
batch_size=batch_size,
num_train_envs=training_num,
num_test_envs=test_num,
buffer_size=buffer_size,
step_per_collect=step_per_collect,
repeat_per_collect=repeat_per_collect,
)
env_factory = MujocoEnvFactory(task, experiment_config.seed, obs_norm=True)
experiment = (
PGExperimentBuilder(env_factory, experiment_config, sampling_config)
.with_pg_params(
PGParams(
discount_factor=gamma,
action_bound_method=action_bound_method,
reward_normalization=rew_norm,
lr=lr,
lr_scheduler_factory=LRSchedulerFactoryLinear(sampling_config)
if lr_decay
else None,
),
)
.with_actor_factory_default(hidden_sizes, torch.nn.Tanh, continuous_unbounded=True)
.build()
)
experiment.run(log_name)
if __name__ == "__main__":
run_with_default_config = functools.partial(main, experiment_config=ExperimentConfig())
logging.run_cli(run_with_default_config)