diff --git a/examples/mujoco/mujoco_env.py b/examples/mujoco/mujoco_env.py index 7e3c38c..15d9393 100644 --- a/examples/mujoco/mujoco_env.py +++ b/examples/mujoco/mujoco_env.py @@ -2,9 +2,9 @@ import warnings import gymnasium as gym -from tianshou.config import RLSamplingConfig, BasicExperimentConfig +from tianshou.config import BasicExperimentConfig, RLSamplingConfig from tianshou.env import ShmemVectorEnv, VectorEnvNormObs -from tianshou.highlevel.env import EnvFactory, Environments, ContinuousEnvironments +from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory try: import envpool @@ -12,9 +12,7 @@ except ImportError: envpool = None -def make_mujoco_env( - task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool -): +def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: int, obs_norm: bool): """Wrapper function for Mujoco env. If EnvPool is installed, it will automatically switch to EnvPool's Mujoco env. diff --git a/examples/mujoco/mujoco_ppo_cfg.py b/examples/mujoco/mujoco_ppo_cfg.py deleted file mode 100644 index bd4628a..0000000 --- a/examples/mujoco/mujoco_ppo_cfg.py +++ /dev/null @@ -1,359 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import datetime -import os -import pprint -from collections.abc import Sequence -from typing import Literal, Optional, Tuple, Union - -import gymnasium as gym -import numpy as np -import torch -from jsonargparse import CLI -from torch import nn -from torch.distributions import Independent, Normal -from torch.optim.lr_scheduler import LambdaLR -from torch.utils.tensorboard import SummaryWriter - -from mujoco_env import make_mujoco_env -from tianshou.config import ( - BasicExperimentConfig, - LoggerConfig, - NNConfig, - PGConfig, - PPOConfig, - RLAgentConfig, - RLSamplingConfig, -) -from tianshou.config.utils import collect_configs -from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer -from tianshou.env import VectorEnvNormObs -from tianshou.policy import BasePolicy, PPOPolicy -from tianshou.trainer import OnpolicyTrainer -from tianshou.utils import TensorboardLogger, WandbLogger -from tianshou.utils.net.common import ActorCritic, Net -from tianshou.utils.net.continuous import ActorProb, Critic - - -def set_seed(seed=42): - np.random.seed(seed) - torch.manual_seed(seed) - - -def get_logger_for_run( - algo_name: str, - task: str, - logger_config: LoggerConfig, - config: dict, - seed: int, - resume_id: Optional[Union[str, int]], -) -> Tuple[str, Union[WandbLogger, TensorboardLogger]]: - """ - - :param algo_name: - :param task: - :param logger_config: - :param config: the experiment config - :param seed: - :param resume_id: used as run_id by wandb, unused for tensorboard - :return: - """ - """Returns the log_path and logger.""" - now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") - log_name = os.path.join(task, algo_name, str(seed), now) - log_path = os.path.join(logger_config.logdir, log_name) - - logger = get_logger( - logger_config.logger, - log_path, - log_name=log_name, - run_id=resume_id, - config=config, - wandb_project=logger_config.wandb_project, - ) - return log_path, logger - - -def get_continuous_env_info( - env: gym.Env, -) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]: - if not isinstance(env.action_space, gym.spaces.Box): - raise ValueError( - "Only environments with continuous action space are supported here. " - f"But got env with action space: {env.action_space.__class__}." - ) - state_shape = env.observation_space.shape or env.observation_space.n - if not state_shape: - raise ValueError("Observation space shape is not defined") - action_shape = env.action_space.shape - max_action = env.action_space.high[0] - return state_shape, action_shape, max_action - - -def resume_from_checkpoint( - path: str, - policy: BasePolicy, - train_envs: VectorEnvNormObs | None = None, - test_envs: VectorEnvNormObs | None = None, - device: str | int | torch.device | None = None, -): - ckpt = torch.load(path, map_location=device) - policy.load_state_dict(ckpt["model"]) - if train_envs: - train_envs.set_obs_rms(ckpt["obs_rms"]) - if test_envs: - test_envs.set_obs_rms(ckpt["obs_rms"]) - print("Loaded agent and obs. running means from: ", path) - - -def watch_agent(n_episode, policy: BasePolicy, test_collector: Collector, render=0.0): - policy.eval() - test_collector.reset() - result = test_collector.collect(n_episode=n_episode, render=render) - print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') - - -def get_train_test_collector( - buffer_size: int, - policy: BasePolicy, - train_envs: VectorEnvNormObs, - test_envs: VectorEnvNormObs, -): - if len(train_envs) > 1: - buffer = VectorReplayBuffer(buffer_size, len(train_envs)) - else: - buffer = ReplayBuffer(buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, test_envs) - return test_collector, train_collector - - -TShape = Union[int, Sequence[int]] - - -def get_actor_critic( - state_shape: TShape, - hidden_sizes: Sequence[int], - action_shape: TShape, - device: str | int | torch.device = "cpu", -): - net_a = Net( - state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device - ) - actor = ActorProb(net_a, action_shape, unbounded=True, device=device).to(device) - net_c = Net( - state_shape, hidden_sizes=hidden_sizes, activation=nn.Tanh, device=device - ) - # TODO: twice device? - critic = Critic(net_c, device=device).to(device) - return actor, critic - - -def get_logger( - kind: Literal["wandb", "tensorboard"], - log_path: str, - log_name="", - run_id: Optional[Union[str, int]] = None, - config: Optional[Union[dict, argparse.Namespace]] = None, - wandb_project: Optional[str] = None, -): - writer = SummaryWriter(log_path) - writer.add_text("args", str(config)) - if kind == "wandb": - logger = WandbLogger( - save_interval=1, - name=log_name.replace(os.path.sep, "__"), - run_id=run_id, - config=config, - project=wandb_project, - ) - logger.load(writer) - elif kind == "tensorboard": - logger = TensorboardLogger(writer) - else: - raise ValueError(f"Unknown logger: {kind}") - return logger - - -def get_lr_scheduler(optim, step_per_epoch: int, step_per_collect: int, epochs: int): - """Decay learning rate to 0 linearly.""" - max_update_num = np.ceil(step_per_epoch / step_per_collect) * epochs - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - return lr_scheduler - - -def init_and_get_optim(actor: nn.Module, critic: nn.Module, lr: float): - """Initializes layers of actor and critic. - - :param actor: - :param critic: - :param lr: - :return: - """ - actor_critic = ActorCritic(actor, critic) - torch.nn.init.constant_(actor.sigma_param, -0.5) - for m in actor_critic.modules(): - if isinstance(m, torch.nn.Linear): - # orthogonal initialization - torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) - torch.nn.init.zeros_(m.bias) - if hasattr(actor, "mu"): - # For continuous action spaces with Gaussian policies - # do last policy layer scaling, this will make initial actions have (close to) - # 0 mean and std, and will help boost performances, - # see https://arxiv.org/abs/2006.05990, Fig.24 for details - for m in actor.mu.modules(): - # TODO: seems like biases are initialized twice for the actor - if isinstance(m, torch.nn.Linear): - torch.nn.init.zeros_(m.bias) - m.weight.data.copy_(0.01 * m.weight.data) - optim = torch.optim.Adam(actor_critic.parameters(), lr=lr) - return optim - - -def main( - experiment_config: BasicExperimentConfig, - logger_config: LoggerConfig, - sampling_config: RLSamplingConfig, - general_config: RLAgentConfig, - pg_config: PGConfig, - ppo_config: PPOConfig, - nn_config: NNConfig, -): - """ - Run the PPO test on the provided parameters. - - :param experiment_config: BasicExperimentConfig - not ML or RL specific - :param logger_config: LoggerConfig - :param sampling_config: SamplingConfig - - sampling, epochs, parallelization, buffers, collectors, and batching. - :param general_config: RLAgentConfig - general RL agent config - :param pg_config: PGConfig: common to most policy gradient algorithms - :param ppo_config: PPOConfig - PPO specific config - :param nn_config: NNConfig - NN-training specific config - - :return: None - """ - full_config = collect_configs(*locals().values()) - set_seed(experiment_config.seed) - - # create test and train envs, add env info to config - env, train_envs, test_envs = make_mujoco_env( - task=experiment_config.task, - seed=experiment_config.seed, - num_train_envs=sampling_config.num_train_envs, - num_test_envs=sampling_config.num_test_envs, - obs_norm=True, - ) - - # adding env_info to logged config - state_shape, action_shape, max_action = get_continuous_env_info(env) - full_config["env_info"] = { - "state_shape": state_shape, - "action_shape": action_shape, - "max_action": max_action, - } - log_path, logger = get_logger_for_run( - "ppo", - experiment_config.task, - logger_config, - full_config, - experiment_config.seed, - experiment_config.resume_id, - ) - - # Setup NNs - actor, critic = get_actor_critic( - state_shape, nn_config.hidden_sizes, action_shape, experiment_config.device - ) - optim = init_and_get_optim(actor, critic, nn_config.lr) - - lr_scheduler = None - if nn_config.lr_decay: - lr_scheduler = get_lr_scheduler( - optim, - sampling_config.step_per_epoch, - sampling_config.step_per_collect, - sampling_config.num_epochs, - ) - - # Create policy - def dist_fn(*logits): - return Independent(Normal(*logits), 1) - - policy = PPOPolicy( - # nn-stuff - actor, - critic, - optim, - dist_fn=dist_fn, - lr_scheduler=lr_scheduler, - # env-stuff - action_space=train_envs.action_space, - action_scaling=True, - # general_config - discount_factor=general_config.gamma, - gae_lambda=general_config.gae_lambda, - reward_normalization=general_config.rew_norm, - action_bound_method=general_config.action_bound_method, - # pg_config - max_grad_norm=pg_config.max_grad_norm, - vf_coef=pg_config.vf_coef, - ent_coef=pg_config.ent_coef, - # ppo_config - eps_clip=ppo_config.eps_clip, - value_clip=ppo_config.value_clip, - dual_clip=ppo_config.dual_clip, - advantage_normalization=ppo_config.norm_adv, - recompute_advantage=ppo_config.recompute_adv, - ) - - if experiment_config.resume_path: - resume_from_checkpoint( - experiment_config.resume_path, - policy, - train_envs=train_envs, - test_envs=test_envs, - device=experiment_config.device, - ) - - test_collector, train_collector = get_train_test_collector( - sampling_config.buffer_size, policy, test_envs, train_envs - ) - - # TODO: test num is the number of test envs but used as episode_per_test - # here and in watch_agent - if not experiment_config.watch: - # RL training - def save_best_fn(pol: nn.Module): - state = {"model": pol.state_dict(), "obs_rms": train_envs.get_obs_rms()} - torch.save(state, os.path.join(log_path, "policy.pth")) - - trainer = OnpolicyTrainer( - policy=policy, - train_collector=train_collector, - test_collector=test_collector, - max_epoch=sampling_config.num_epochs, - step_per_epoch=sampling_config.step_per_epoch, - repeat_per_collect=sampling_config.repeat_per_collect, - episode_per_test=sampling_config.num_test_envs, - batch_size=sampling_config.batch_size, - step_per_collect=sampling_config.step_per_collect, - save_best_fn=save_best_fn, - logger=logger, - test_in_train=False, - ) - result = trainer.run() - pprint.pprint(result) - - watch_agent( - sampling_config.num_test_envs, - policy, - test_collector, - render=experiment_config.render, - ) - - -if __name__ == "__main__": - CLI(main) diff --git a/examples/mujoco/mujoco_ppo_hl.py b/examples/mujoco/mujoco_ppo_hl.py index 7356afe..7cba814 100644 --- a/examples/mujoco/mujoco_ppo_hl.py +++ b/examples/mujoco/mujoco_ppo_hl.py @@ -2,6 +2,8 @@ import datetime import os +from collections.abc import Sequence +from dataclasses import dataclass from jsonargparse import CLI from torch.distributions import Independent, Normal @@ -10,17 +12,26 @@ from examples.mujoco.mujoco_env import MujocoEnvFactory from tianshou.config import ( BasicExperimentConfig, LoggerConfig, - NNConfig, PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig, ) from tianshou.highlevel.agent import PPOAgentFactory -from tianshou.highlevel.logger import DefaultLoggerFactory -from tianshou.highlevel.module import ContinuousActorProbFactory, ContinuousNetCriticFactory -from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory from tianshou.highlevel.experiment import RLExperiment +from tianshou.highlevel.logger import DefaultLoggerFactory +from tianshou.highlevel.module import ( + ContinuousActorProbFactory, + ContinuousNetCriticFactory, +) +from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory + + +@dataclass +class NNConfig: + hidden_sizes: Sequence[int] = (64, 64) + lr: float = 3e-4 + lr_decay: bool = True def main( @@ -43,15 +54,22 @@ def main( actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes) critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes) - optim_factory = AdamOptimizerFactory(lr=nn_config.lr) - lr_scheduler_factory = LinearLRSchedulerFactory(nn_config, sampling_config) - agent_factory = PPOAgentFactory(general_config, pg_config, ppo_config, sampling_config, nn_config, - actor_factory, critic_factory, optim_factory, dist_fn, lr_scheduler_factory) + optim_factory = AdamOptimizerFactory() + lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None + agent_factory = PPOAgentFactory( + general_config, + pg_config, + ppo_config, + sampling_config, + actor_factory, + critic_factory, + optim_factory, + dist_fn, + nn_config.lr, + lr_scheduler_factory, + ) - experiment = RLExperiment(experiment_config, logger_config, general_config, sampling_config, - env_factory, - logger_factory, - agent_factory) + experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory) experiment.run(log_name) diff --git a/examples/mujoco/mujoco_sac_hl.py b/examples/mujoco/mujoco_sac_hl.py new file mode 100644 index 0000000..24a8f20 --- /dev/null +++ b/examples/mujoco/mujoco_sac_hl.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +import datetime +import os +from collections.abc import Sequence + +from jsonargparse import CLI + +from examples.mujoco.mujoco_env import MujocoEnvFactory +from tianshou.config import ( + BasicExperimentConfig, + LoggerConfig, + RLSamplingConfig, +) +from tianshou.highlevel.agent import SACAgentFactory +from tianshou.highlevel.experiment import RLExperiment +from tianshou.highlevel.logger import DefaultLoggerFactory +from tianshou.highlevel.module import ( + ContinuousActorProbFactory, + ContinuousNetCriticFactory, +) +from tianshou.highlevel.optim import AdamOptimizerFactory + + +def main( + experiment_config: BasicExperimentConfig, + logger_config: LoggerConfig, + sampling_config: RLSamplingConfig, + sac_config: SACAgentFactory.Config, + hidden_sizes: Sequence[int] = (256, 256), +): + now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") + log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now) + logger_factory = DefaultLoggerFactory(logger_config) + + env_factory = MujocoEnvFactory(experiment_config, sampling_config) + + actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) + critic_factory = ContinuousNetCriticFactory(hidden_sizes) + optim_factory = AdamOptimizerFactory() + agent_factory = SACAgentFactory( + sac_config, + sampling_config, + actor_factory, + critic_factory, + critic_factory, + optim_factory, + ) + + experiment = RLExperiment(experiment_config, env_factory, logger_factory, agent_factory) + + experiment.run(log_name) + + +if __name__ == "__main__": + CLI(main) diff --git a/pyproject.toml b/pyproject.toml index 243e0e6..4601db5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,7 +63,9 @@ envpool = ["envpool"] optional = true [tool.poetry.group.dev.dependencies] black = "^23.7.0" +docstring-parser = "^0.15" jinja2 = "*" +jsonargparse = "^4.24.1" mypy = "^1.4.1" # networkx is used in a test networkx = "*" @@ -83,8 +85,6 @@ sphinx_rtd_theme = "*" sphinxcontrib-bibtex = "*" sphinxcontrib-spelling = "^8.0.0" wandb = "^0.12.0" -jsonargparse = "^4.24.1" -docstring-parser = "^0.15" [tool.mypy] allow_redefinition = true diff --git a/tianshou/config/__init__.py b/tianshou/config/__init__.py index 27c9ec6..9f9107c 100644 --- a/tianshou/config/__init__.py +++ b/tianshou/config/__init__.py @@ -1 +1,10 @@ -from .config import * +__all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"] + +from .config import ( + BasicExperimentConfig, + PGConfig, + PPOConfig, + RLAgentConfig, + RLSamplingConfig, + LoggerConfig, +) diff --git a/tianshou/config/config.py b/tianshou/config/config.py index c7f9a6c..25ea63a 100644 --- a/tianshou/config/config.py +++ b/tianshou/config/config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Literal, Optional, Sequence +from typing import Literal import torch from jsonargparse import set_docstring_parse_options @@ -14,12 +14,12 @@ class BasicExperimentConfig: seed: int = 42 task: str = "Ant-v4" """Mujoco specific""" - render: Optional[float] = 0.0 + render: float | None = 0.0 """Milliseconds between rendered frames; if None, no rendering""" device: str = "cuda" if torch.cuda.is_available() else "cpu" - resume_id: Optional[int] = None + resume_id: str | None = None """For restoring a model and running means of env-specifics from a checkpoint""" - resume_path: str = None + resume_path: str | None = None """For restoring a model and running means of env-specifics from a checkpoint""" watch: bool = False """If True, will not perform training and only watch the restored policy""" @@ -28,7 +28,7 @@ class BasicExperimentConfig: @dataclass class LoggerConfig: - """Logging config""" + """Logging config.""" logdir: str = "log" logger: Literal["tensorboard", "wandb"] = "tensorboard" @@ -48,17 +48,18 @@ class RLSamplingConfig: buffer_size: int = 4096 step_per_collect: int = 2048 repeat_per_collect: int = 10 + update_per_step: int = 1 @dataclass class RLAgentConfig: - """Config common to most RL algorithms""" + """Config common to most RL algorithms.""" gamma: float = 0.99 """Discount factor""" gae_lambda: float = 0.95 """For Generalized Advantage Estimate (equivalent to TD(lambda))""" - action_bound_method: Optional[Literal["clip", "tanh"]] = "clip" + action_bound_method: Literal["clip", "tanh"] | None = "clip" """How to map original actions in range (-inf, inf) to [-1, 1]""" rew_norm: bool = True """Whether to normalize rewards""" @@ -66,7 +67,7 @@ class RLAgentConfig: @dataclass class PGConfig: - """Config of general policy-gradient algorithms""" + """Config of general policy-gradient algorithms.""" ent_coef: float = 0.0 vf_coef: float = 0.25 @@ -75,18 +76,11 @@ class PGConfig: @dataclass class PPOConfig: - """PPO specific config""" + """PPO specific config.""" value_clip: bool = False norm_adv: bool = False """Whether to normalize advantages""" eps_clip: float = 0.2 - dual_clip: Optional[float] = None + dual_clip: float | None = None recompute_adv: bool = True - - -@dataclass -class NNConfig: - hidden_sizes: Sequence[int] = (64, 64) - lr: float = 3e-4 - lr_decay: bool = True diff --git a/tianshou/config/utils.py b/tianshou/config/utils.py index e3a0998..0b130a6 100644 --- a/tianshou/config/utils.py +++ b/tianshou/config/utils.py @@ -2,11 +2,10 @@ from dataclasses import asdict, is_dataclass def collect_configs(*confs): - """ - Collect instances of dataclasses to a single dict mapping the - classname to the values. If any of the passed objects is not a - dataclass or if two instances of the same config class are passed, - an error will be raised. + """Collect instances of dataclasses to a single dict mapping the classname to the values. + + If any of the passed objects is not a ddataclass or if two instances + of the same config class are passed, an error will be raised. :param confs: dataclasses :return: Dictionary mapping class names to their instances. diff --git a/tianshou/highlevel/agent.py b/tianshou/highlevel/agent.py index 8314aeb..e95c799 100644 --- a/tianshou/highlevel/agent.py +++ b/tianshou/highlevel/agent.py @@ -1,33 +1,51 @@ import os -from abc import abstractmethod, ABC -from typing import Callable +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass import torch -from tianshou.config import RLSamplingConfig, PGConfig, PPOConfig, RLAgentConfig, NNConfig -from tianshou.data import VectorReplayBuffer, ReplayBuffer, Collector +from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig +from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer +from tianshou.exploration import BaseNoise from tianshou.highlevel.env import Environments from tianshou.highlevel.logger import Logger from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice -from tianshou.highlevel.optim import OptimizerFactory, LRSchedulerFactory -from tianshou.policy import BasePolicy, PPOPolicy -from tianshou.trainer import BaseTrainer, OnpolicyTrainer +from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory +from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy +from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.utils.net.common import ActorCritic - CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" class AgentFactory(ABC): + def __init__(self, sampling_config: RLSamplingConfig): + self.sampling_config = sampling_config + + def create_train_test_collector(self, policy: BasePolicy, envs: Environments): + buffer_size = self.sampling_config.buffer_size + train_envs = envs.train_envs + if len(train_envs) > 1: + buffer = VectorReplayBuffer(buffer_size, len(train_envs)) + else: + buffer = ReplayBuffer(buffer_size) + train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) + test_collector = Collector(policy, envs.test_envs) + return train_collector, test_collector + @abstractmethod def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: pass @staticmethod def _create_save_best_fn(envs: Environments, log_path: str) -> Callable: - def save_best_fn(pol: torch.nn.Module): - state = {"model": pol.state_dict(), "obs_rms": envs.train_envs.get_obs_rms()} + def save_best_fn(pol: torch.nn.Module) -> None: + state = { + CHECKPOINT_DICT_KEY_MODEL: pol.state_dict(), + CHECKPOINT_DICT_KEY_OBS_RMS: envs.train_envs.get_obs_rms(), + } torch.save(state, os.path.join(log_path, "policy.pth")) return save_best_fn @@ -43,36 +61,26 @@ class AgentFactory(ABC): print("Loaded agent and obs. running means from: ", path) # TODO logging @abstractmethod - def create_train_test_collector(self, - policy: BasePolicy, - envs: Environments): - pass - - @abstractmethod - def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector, - envs: Environments, logger: Logger) -> BaseTrainer: + def create_trainer( + self, + policy: BasePolicy, + train_collector: Collector, + test_collector: Collector, + envs: Environments, + logger: Logger, + ) -> BaseTrainer: pass class OnpolicyAgentFactory(AgentFactory, ABC): - def __init__(self, sampling_config: RLSamplingConfig): - self.sampling_config = sampling_config - - def create_train_test_collector(self, - policy: BasePolicy, - envs: Environments): - buffer_size = self.sampling_config.buffer_size - train_envs = envs.train_envs - if len(train_envs) > 1: - buffer = VectorReplayBuffer(buffer_size, len(train_envs)) - else: - buffer = ReplayBuffer(buffer_size) - train_collector = Collector(policy, train_envs, buffer, exploration_noise=True) - test_collector = Collector(policy, envs.test_envs) - return train_collector, test_collector - - def create_trainer(self, policy: BasePolicy, train_collector: Collector, test_collector: Collector, - envs: Environments, logger: Logger) -> OnpolicyTrainer: + def create_trainer( + self, + policy: BasePolicy, + train_collector: Collector, + test_collector: Collector, + envs: Environments, + logger: Logger, + ) -> OnpolicyTrainer: sampling_config = self.sampling_config return OnpolicyTrainer( policy=policy, @@ -90,17 +98,46 @@ class OnpolicyAgentFactory(AgentFactory, ABC): ) +class OffpolicyAgentFactory(AgentFactory, ABC): + def create_trainer( + self, + policy: BasePolicy, + train_collector: Collector, + test_collector: Collector, + envs: Environments, + logger: Logger, + ) -> OffpolicyTrainer: + sampling_config = self.sampling_config + return OffpolicyTrainer( + policy=policy, + train_collector=train_collector, + test_collector=test_collector, + max_epoch=sampling_config.num_epochs, + step_per_epoch=sampling_config.step_per_epoch, + step_per_collect=sampling_config.step_per_collect, + episode_per_test=sampling_config.num_test_envs, + batch_size=sampling_config.batch_size, + save_best_fn=self._create_save_best_fn(envs, logger.log_path), + logger=logger.logger, + update_per_step=sampling_config.update_per_step, + test_in_train=False, + ) + + class PPOAgentFactory(OnpolicyAgentFactory): - def __init__(self, general_config: RLAgentConfig, - pg_config: PGConfig, - ppo_config: PPOConfig, - sampling_config: RLSamplingConfig, - nn_config: NNConfig, - actor_factory: ActorFactory, - critic_factory: CriticFactory, - optimizer_factory: OptimizerFactory, - dist_fn, - lr_scheduler_factory: LRSchedulerFactory): + def __init__( + self, + general_config: RLAgentConfig, + pg_config: PGConfig, + ppo_config: PPOConfig, + sampling_config: RLSamplingConfig, + actor_factory: ActorFactory, + critic_factory: CriticFactory, + optimizer_factory: OptimizerFactory, + dist_fn, + lr: float, + lr_scheduler_factory: LRSchedulerFactory | None = None, + ): super().__init__(sampling_config) self.optimizer_factory = optimizer_factory self.critic_factory = critic_factory @@ -108,16 +145,19 @@ class PPOAgentFactory(OnpolicyAgentFactory): self.ppo_config = ppo_config self.pg_config = pg_config self.general_config = general_config + self.lr = lr self.lr_scheduler_factory = lr_scheduler_factory self.dist_fn = dist_fn - self.nn_config = nn_config def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy: actor = self.actor_factory.create_module(envs, device) - critic = self.critic_factory.create_module(envs, device) + critic = self.critic_factory.create_module(envs, device, use_action=False) actor_critic = ActorCritic(actor, critic) - optim = self.optimizer_factory.create_optimizer(actor_critic) - lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim) + optim = self.optimizer_factory.create_optimizer(actor_critic, self.lr) + if self.lr_scheduler_factory is not None: + lr_scheduler = self.lr_scheduler_factory.create_scheduler(optim) + else: + lr_scheduler = None return PPOPolicy( # nn-stuff actor, @@ -144,3 +184,60 @@ class PPOAgentFactory(OnpolicyAgentFactory): advantage_normalization=self.ppo_config.norm_adv, recompute_advantage=self.ppo_config.recompute_adv, ) + + +class SACAgentFactory(OffpolicyAgentFactory): + def __init__( + self, + config: "SACAgentFactory.Config", + sampling_config: RLSamplingConfig, + actor_factory: ActorFactory, + critic1_factory: CriticFactory, + critic2_factory: CriticFactory, + optim_factory: OptimizerFactory, + exploration_noise: BaseNoise | None = None, + ): + super().__init__(sampling_config) + self.critic2_factory = critic2_factory + self.critic1_factory = critic1_factory + self.actor_factory = actor_factory + self.exploration_noise = exploration_noise + self.optim_factory = optim_factory + self.config = config + + def create_policy(self, envs: Environments, device: TDevice) -> BasePolicy: + actor = self.actor_factory.create_module(envs, device) + critic1 = self.critic1_factory.create_module(envs, device, use_action=True) + critic2 = self.critic2_factory.create_module(envs, device, use_action=True) + actor_optim = self.optim_factory.create_optimizer(actor, lr=self.config.actor_lr) + critic1_optim = self.optim_factory.create_optimizer(critic1, lr=self.config.critic1_lr) + critic2_optim = self.optim_factory.create_optimizer(critic2, lr=self.config.critic2_lr) + return SACPolicy( + actor, + actor_optim, + critic1, + critic1_optim, + critic2, + critic2_optim, + tau=self.config.tau, + gamma=self.config.gamma, + alpha=self.config.alpha, + estimation_step=self.config.estimation_step, + action_space=envs.get_action_space(), + deterministic_eval=self.config.deterministic_eval, + exploration_noise=self.exploration_noise, + ) + + @dataclass + class Config: + """SAC configuration.""" + + tau: float = 0.005 + gamma: float = 0.99 + alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2 + reward_normalization: bool = False + estimation_step: int = 1 + deterministic_eval: bool = True + actor_lr: float = 1e-3 + critic1_lr: float = 1e-3 + critic2_lr: float = 1e-3 diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 3869476..90a8044 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -1,24 +1,22 @@ from abc import ABC, abstractmethod -from typing import Tuple, Optional, Dict, Any, Union, Sequence +from collections.abc import Sequence +from typing import Any import gymnasium as gym from tianshou.env import BaseVectorEnv -TShape = Union[int, Sequence[int]] +TShape = int | Sequence[int] class Environments(ABC): - def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): self.env = env self.train_envs = train_envs self.test_envs = test_envs - def info(self) -> Dict[str, Any]: - return { - "action_shape": self.get_action_shape(), - "state_shape": self.get_state_shape() - } + def info(self) -> dict[str, Any]: + return {"action_shape": self.get_action_shape(), "state_shape": self.get_state_shape()} @abstractmethod def get_action_shape(self) -> TShape: @@ -33,7 +31,7 @@ class Environments(ABC): class ContinuousEnvironments(Environments): - def __init__(self, env: Optional[gym.Env], train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): + def __init__(self, env: gym.Env | None, train_envs: BaseVectorEnv, test_envs: BaseVectorEnv): super().__init__(env, train_envs, test_envs) self.state_shape, self.action_shape, self.max_action = self._get_continuous_env_info(env) @@ -44,12 +42,12 @@ class ContinuousEnvironments(Environments): @staticmethod def _get_continuous_env_info( - env: gym.Env, - ) -> Tuple[Tuple[int, ...], Tuple[int, ...], float]: + env: gym.Env, + ) -> tuple[tuple[int, ...], tuple[int, ...], float]: if not isinstance(env.action_space, gym.spaces.Box): raise ValueError( "Only environments with continuous action space are supported here. " - f"But got env with action space: {env.action_space.__class__}." + f"But got env with action space: {env.action_space.__class__}.", ) state_shape = env.observation_space.shape or env.observation_space.n if not state_shape: @@ -68,4 +66,4 @@ class ContinuousEnvironments(Environments): class EnvFactory(ABC): @abstractmethod def create_envs(self) -> Environments: - pass \ No newline at end of file + pass diff --git a/tianshou/highlevel/experiment.py b/tianshou/highlevel/experiment.py index b175e42..19f0e2c 100644 --- a/tianshou/highlevel/experiment.py +++ b/tianshou/highlevel/experiment.py @@ -4,7 +4,9 @@ from typing import Generic, TypeVar import numpy as np import torch -from tianshou.config import BasicExperimentConfig, LoggerConfig, RLAgentConfig, RLSamplingConfig +from tianshou.config import ( + BasicExperimentConfig, +) from tianshou.data import Collector from tianshou.highlevel.agent import AgentFactory from tianshou.highlevel.env import EnvFactory @@ -17,23 +19,19 @@ TTrainer = TypeVar("TTrainer", bound=BaseTrainer) class RLExperiment(Generic[TPolicy, TTrainer]): - def __init__(self, - config: BasicExperimentConfig, - logger_config: LoggerConfig, - general_config: RLAgentConfig, - sampling_config: RLSamplingConfig, - env_factory: EnvFactory, - logger_factory: LoggerFactory, - agent_factory: AgentFactory): + def __init__( + self, + config: BasicExperimentConfig, + env_factory: EnvFactory, + logger_factory: LoggerFactory, + agent_factory: AgentFactory, + ): self.config = config - self.logger_config = logger_config - self.general_config = general_config - self.sampling_config = sampling_config self.env_factory = env_factory self.logger_factory = logger_factory self.agent_factory = agent_factory - def _set_seed(self): + def _set_seed(self) -> None: seed = self.config.seed np.random.seed(seed) torch.manual_seed(seed) @@ -43,7 +41,7 @@ class RLExperiment(Generic[TPolicy, TTrainer]): # TODO } - def run(self, log_name: str): + def run(self, log_name: str) -> None: self._set_seed() envs = self.env_factory.create_envs() @@ -52,25 +50,47 @@ class RLExperiment(Generic[TPolicy, TTrainer]): full_config.update(envs.info()) run_id = self.config.resume_id - logger = self.logger_factory.create_logger(log_name=log_name, run_id=run_id, config_dict=full_config) + logger = self.logger_factory.create_logger( + log_name=log_name, + run_id=run_id, + config_dict=full_config, + ) policy = self.agent_factory.create_policy(envs, self.config.device) if self.config.resume_path: - self.agent_factory.load_checkpoint(policy, self.config.resume_path, envs, self.config.device) + self.agent_factory.load_checkpoint( + policy, + self.config.resume_path, + envs, + self.config.device, + ) - train_collector, test_collector = self.agent_factory.create_train_test_collector(policy, envs) + train_collector, test_collector = self.agent_factory.create_train_test_collector( + policy, + envs, + ) if not self.config.watch: - trainer = self.agent_factory.create_trainer(policy, train_collector, test_collector, envs, logger) + trainer = self.agent_factory.create_trainer( + policy, + train_collector, + test_collector, + envs, + logger, + ) result = trainer.run() pprint(result) # TODO logging - self._watch_agent(self.config.watch_num_episodes, policy, test_collector, self.config.render) + self._watch_agent( + self.config.watch_num_episodes, + policy, + test_collector, + self.config.render, + ) @staticmethod - def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render): + def _watch_agent(num_episodes, policy: BasePolicy, test_collector: Collector, render) -> None: policy.eval() test_collector.reset() result = test_collector.collect(n_episode=num_episodes, render=render) print(f'Final reward: {result["rews"].mean()}, length: {result["lens"].mean()}') - diff --git a/tianshou/highlevel/logger.py b/tianshou/highlevel/logger.py index c1c0c31..098072b 100644 --- a/tianshou/highlevel/logger.py +++ b/tianshou/highlevel/logger.py @@ -1,15 +1,13 @@ -from abc import ABC, abstractmethod import os +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Union, Optional from torch.utils.tensorboard import SummaryWriter from tianshou.config import LoggerConfig from tianshou.utils import TensorboardLogger, WandbLogger - -TLogger = Union[TensorboardLogger, WandbLogger] +TLogger = TensorboardLogger | WandbLogger @dataclass @@ -20,7 +18,7 @@ class Logger: class LoggerFactory(ABC): @abstractmethod - def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger: + def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: pass @@ -28,7 +26,7 @@ class DefaultLoggerFactory(LoggerFactory): def __init__(self, config: LoggerConfig): self.config = config - def create_logger(self, log_name: str, run_id: Optional[int], config_dict: dict) -> Logger: + def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: writer = SummaryWriter(self.config.logdir) writer.add_text("args", str(self.config)) if self.config.logger == "wandb": diff --git a/tianshou/highlevel/module.py b/tianshou/highlevel/module.py index bf11934..e85a4db 100644 --- a/tianshou/highlevel/module.py +++ b/tianshou/highlevel/module.py @@ -1,24 +1,24 @@ -from abc import abstractmethod, ABC -from typing import Sequence +from abc import ABC, abstractmethod +from collections.abc import Sequence +import numpy as np import torch from torch import nn -import numpy as np from tianshou.highlevel.env import Environments from tianshou.utils.net.common import Net -from tianshou.utils.net.continuous import ActorProb, Critic as ContinuousCritic +from tianshou.utils.net.continuous import ActorProb +from tianshou.utils.net.continuous import Critic as ContinuousCritic TDevice = str | int | torch.device -def init_linear_orthogonal(m: torch.nn.Module): - """ - Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0 +def init_linear_orthogonal(module: torch.nn.Module): + """Applies orthogonal initialization to linear layers of the given module and sets bias weights to 0. - :param m: the module whose submodules are to be processed + :param module: the module whose submodules are to be processed """ - for m in m.modules(): + for m in module.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2)) torch.nn.init.zeros_(m.bias) @@ -31,9 +31,9 @@ class ActorFactory(ABC): @staticmethod def _init_linear(actor: torch.nn.Module): - """ - Initializes linear layers of an actor module using default mechanisms - :param module: the actor module + """Initializes linear layers of an actor module using default mechanisms. + + :param module: the actor module. """ init_linear_orthogonal(actor) if hasattr(actor, "mu"): @@ -51,17 +51,29 @@ class ContinuousActorFactory(ActorFactory, ABC): class ContinuousActorProbFactory(ContinuousActorFactory): - def __init__(self, hidden_sizes: Sequence[int]): + def __init__(self, hidden_sizes: Sequence[int], unbounded=True, conditioned_sigma=False): self.hidden_sizes = hidden_sizes + self.unbounded = unbounded + self.conditioned_sigma = conditioned_sigma def create_module(self, envs: Environments, device: TDevice) -> nn.Module: net_a = Net( - envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device + envs.get_state_shape(), + hidden_sizes=self.hidden_sizes, + activation=nn.Tanh, + device=device, ) - actor = ActorProb(net_a, envs.get_action_shape(), unbounded=True, device=device).to(device) + actor = ActorProb( + net_a, + envs.get_action_shape(), + unbounded=True, + device=device, + conditioned_sigma=self.conditioned_sigma, + ).to(device) # init params - torch.nn.init.constant_(actor.sigma_param, -0.5) + if not self.conditioned_sigma: + torch.nn.init.constant_(actor.sigma_param, -0.5) self._init_linear(actor) return actor @@ -69,7 +81,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory): class CriticFactory(ABC): @abstractmethod - def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: pass @@ -78,12 +90,19 @@ class ContinuousCriticFactory(CriticFactory, ABC): class ContinuousNetCriticFactory(ContinuousCriticFactory): - def __init__(self, hidden_sizes: Sequence[int]): + def __init__(self, hidden_sizes: Sequence[int], action_shape=0): + self.action_shape = action_shape self.hidden_sizes = hidden_sizes - def create_module(self, envs: Environments, device: TDevice) -> nn.Module: + def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: + action_shape = envs.get_action_shape() if use_action else 0 net_c = Net( - envs.get_state_shape(), hidden_sizes=self.hidden_sizes, activation=nn.Tanh, device=device + envs.get_state_shape(), + action_shape=action_shape, + hidden_sizes=self.hidden_sizes, + concat=use_action, + activation=nn.Tanh, + device=device, ) critic = ContinuousCritic(net_c, device=device).to(device) init_linear_orthogonal(critic) diff --git a/tianshou/highlevel/optim.py b/tianshou/highlevel/optim.py index 43cbb53..5ac660a 100644 --- a/tianshou/highlevel/optim.py +++ b/tianshou/highlevel/optim.py @@ -1,54 +1,51 @@ from abc import ABC, abstractmethod -from typing import Union, Iterable, Dict, Any, Optional +from collections.abc import Iterable +from typing import Any, Type import numpy as np import torch from torch import Tensor from torch.optim import Adam -from torch.optim.lr_scheduler import LRScheduler, LambdaLR +from torch.optim.lr_scheduler import LambdaLR, LRScheduler -from tianshou.config import RLSamplingConfig, NNConfig +from tianshou.config import RLSamplingConfig -TParams = Union[Iterable[Tensor], Iterable[Dict[str, Any]]] +TParams = Iterable[Tensor] | Iterable[dict[str, Any]] class OptimizerFactory(ABC): @abstractmethod - def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer: + def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: pass class TorchOptimizerFactory(OptimizerFactory): - def __init__(self, optim_class, **kwargs): + def __init__(self, optim_class: Any, **kwargs): self.optim_class = optim_class self.kwargs = kwargs - def create_optimizer(self, module: torch.nn.Module) -> torch.optim.Optimizer: - return self.optim_class(module.parameters(), **self.kwargs) + def create_optimizer(self, module: torch.nn.Module, lr: float) -> torch.optim.Optimizer: + return self.optim_class(module.parameters(), lr=lr, **self.kwargs) class AdamOptimizerFactory(OptimizerFactory): - def __init__(self, lr): - self.lr = lr - - def create_optimizer(self, module: torch.nn.Module) -> Adam: - return Adam(module.parameters(), lr=self.lr) + def create_optimizer(self, module: torch.nn.Module, lr: float) -> Adam: + return Adam(module.parameters(), lr=lr) class LRSchedulerFactory(ABC): @abstractmethod - def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]: + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: pass class LinearLRSchedulerFactory(LRSchedulerFactory): - def __init__(self, nn_config: NNConfig, sampling_config: RLSamplingConfig): - self.nn_config = nn_config + def __init__(self, sampling_config: RLSamplingConfig): self.sampling_config = sampling_config - def create_scheduler(self, optim: torch.optim.Optimizer) -> Optional[LRScheduler]: - lr_scheduler = None - if self.nn_config.lr_decay: - max_update_num = np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) * self.sampling_config.num_epochs - lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num) - return lr_scheduler + def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: + max_update_num = ( + np.ceil(self.sampling_config.step_per_epoch / self.sampling_config.step_per_collect) + * self.sampling_config.num_epochs + ) + return LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)