Unify PPO configuration objects, use experiment-specific configuration
in mujoco_ppo_hl
This commit is contained in:
		
							parent
							
								
									8ec42009cb
								
							
						
					
					
						commit
						3fd60f9e70
					
				| @ -3,8 +3,8 @@ import warnings | ||||
| import gymnasium as gym | ||||
| 
 | ||||
| from tianshou.env import ShmemVectorEnv, VectorEnvNormObs | ||||
| from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory | ||||
| from tianshou.highlevel.config import RLSamplingConfig | ||||
| from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory | ||||
| 
 | ||||
| try: | ||||
|     import envpool | ||||
|  | ||||
| @ -3,18 +3,18 @@ | ||||
| import datetime | ||||
| import os | ||||
| from collections.abc import Sequence | ||||
| from dataclasses import dataclass | ||||
| from typing import Literal | ||||
| 
 | ||||
| from jsonargparse import CLI | ||||
| from torch.distributions import Independent, Normal | ||||
| 
 | ||||
| from examples.mujoco.mujoco_env import MujocoEnvFactory | ||||
| from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig | ||||
| from tianshou.highlevel.agent import PPOAgentFactory, PPOConfig | ||||
| from tianshou.highlevel.config import RLSamplingConfig | ||||
| from tianshou.highlevel.experiment import ( | ||||
|     RLExperiment, | ||||
|     RLExperimentConfig, | ||||
| ) | ||||
| from tianshou.highlevel.config import RLSamplingConfig | ||||
| from tianshou.highlevel.logger import DefaultLoggerFactory | ||||
| from tianshou.highlevel.module import ( | ||||
|     ContinuousActorProbFactory, | ||||
| @ -23,45 +23,79 @@ from tianshou.highlevel.module import ( | ||||
| from tianshou.highlevel.optim import AdamOptimizerFactory, LinearLRSchedulerFactory | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class NNConfig: | ||||
|     hidden_sizes: Sequence[int] = (64, 64) | ||||
|     lr: float = 3e-4 | ||||
|     lr_decay: bool = True | ||||
| 
 | ||||
| 
 | ||||
| def main( | ||||
|     experiment_config: RLExperimentConfig, | ||||
|     sampling_config: RLSamplingConfig, | ||||
|     general_config: RLAgentConfig, | ||||
|     pg_config: PGConfig, | ||||
|     ppo_config: PPOConfig, | ||||
|     nn_config: NNConfig, | ||||
|     task: str = "Ant-v4", | ||||
|     buffer_size: int = 4096, | ||||
|     hidden_sizes: Sequence[int] = (64, 64), | ||||
|     lr: float = 3e-4, | ||||
|     gamma: float = 0.99, | ||||
|     epoch: int = 100, | ||||
|     step_per_epoch: int = 30000, | ||||
|     step_per_collect: int = 2048, | ||||
|     repeat_per_collect: int = 10, | ||||
|     batch_size: int = 64, | ||||
|     training_num: int = 64, | ||||
|     test_num: int = 10, | ||||
|     rew_norm: bool = True, | ||||
|     vf_coef: float = 0.25, | ||||
|     ent_coef: float = 0.0, | ||||
|     gae_lambda: float = 0.95, | ||||
|     bound_action_method: Literal["clip", "tanh"] | None = "clip", | ||||
|     lr_decay: bool = True, | ||||
|     max_grad_norm: float = 0.5, | ||||
|     eps_clip: float = 0.2, | ||||
|     dual_clip: float | None = None, | ||||
|     value_clip: bool = False, | ||||
|     norm_adv: bool = False, | ||||
|     recompute_adv: bool = True, | ||||
| ): | ||||
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||||
|     log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) | ||||
|     logger_factory = DefaultLoggerFactory() | ||||
| 
 | ||||
|     sampling_config = RLSamplingConfig( | ||||
|         num_epochs=epoch, | ||||
|         step_per_epoch=step_per_epoch, | ||||
|         batch_size=batch_size, | ||||
|         num_train_envs=training_num, | ||||
|         num_test_envs=test_num, | ||||
|         buffer_size=buffer_size, | ||||
|         step_per_collect=step_per_collect, | ||||
|         repeat_per_collect=repeat_per_collect, | ||||
|     ) | ||||
| 
 | ||||
|     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||
| 
 | ||||
|     def dist_fn(*logits): | ||||
|         return Independent(Normal(*logits), 1) | ||||
| 
 | ||||
|     actor_factory = ContinuousActorProbFactory(nn_config.hidden_sizes) | ||||
|     critic_factory = ContinuousNetCriticFactory(nn_config.hidden_sizes) | ||||
|     ppo_config = PPOConfig( | ||||
|         gamma=gamma, | ||||
|         gae_lambda=gae_lambda, | ||||
|         action_bound_method=bound_action_method, | ||||
|         rew_norm=rew_norm, | ||||
|         ent_coef=ent_coef, | ||||
|         vf_coef=vf_coef, | ||||
|         max_grad_norm=max_grad_norm, | ||||
|         value_clip=value_clip, | ||||
|         norm_adv=norm_adv, | ||||
|         eps_clip=eps_clip, | ||||
|         dual_clip=dual_clip, | ||||
|         recompute_adv=recompute_adv, | ||||
|     ) | ||||
|     actor_factory = ContinuousActorProbFactory(hidden_sizes) | ||||
|     critic_factory = ContinuousNetCriticFactory(hidden_sizes) | ||||
|     optim_factory = AdamOptimizerFactory() | ||||
|     lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if nn_config.lr_decay else None | ||||
|     lr_scheduler_factory = LinearLRSchedulerFactory(sampling_config) if lr_decay else None | ||||
|     agent_factory = PPOAgentFactory( | ||||
|         general_config, | ||||
|         pg_config, | ||||
|         ppo_config, | ||||
|         sampling_config, | ||||
|         actor_factory, | ||||
|         critic_factory, | ||||
|         optim_factory, | ||||
|         dist_fn, | ||||
|         nn_config.lr, | ||||
|         lr, | ||||
|         lr_scheduler_factory, | ||||
|     ) | ||||
| 
 | ||||
|  | ||||
| @ -8,11 +8,11 @@ from jsonargparse import CLI | ||||
| 
 | ||||
| from examples.mujoco.mujoco_env import MujocoEnvFactory | ||||
| from tianshou.highlevel.agent import DefaultAutoAlphaFactory, SACAgentFactory, SACConfig | ||||
| from tianshou.highlevel.config import RLSamplingConfig | ||||
| from tianshou.highlevel.experiment import ( | ||||
|     RLExperiment, | ||||
|     RLExperimentConfig, | ||||
| ) | ||||
| from tianshou.highlevel.config import RLSamplingConfig | ||||
| from tianshou.highlevel.logger import DefaultLoggerFactory | ||||
| from tianshou.highlevel.module import ( | ||||
|     ContinuousActorProbFactory, | ||||
|  | ||||
| @ -9,8 +9,8 @@ import torch | ||||
| 
 | ||||
| from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer | ||||
| from tianshou.exploration import BaseNoise | ||||
| from tianshou.highlevel.env import Environments | ||||
| from tianshou.highlevel.config import RLSamplingConfig | ||||
| from tianshou.highlevel.env import Environments | ||||
| from tianshou.highlevel.logger import Logger | ||||
| from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice | ||||
| from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory | ||||
| @ -143,7 +143,7 @@ class RLAgentConfig: | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class PGConfig: | ||||
| class PGConfig(RLAgentConfig): | ||||
|     """Config of general policy-gradient algorithms.""" | ||||
| 
 | ||||
|     ent_coef: float = 0.0 | ||||
| @ -152,7 +152,7 @@ class PGConfig: | ||||
| 
 | ||||
| 
 | ||||
| @dataclass | ||||
| class PPOConfig: | ||||
| class PPOConfig(PGConfig): | ||||
|     """PPO specific config.""" | ||||
| 
 | ||||
|     value_clip: bool = False | ||||
| @ -166,9 +166,7 @@ class PPOConfig: | ||||
| class PPOAgentFactory(OnpolicyAgentFactory): | ||||
|     def __init__( | ||||
|         self, | ||||
|         general_config: RLAgentConfig, | ||||
|         pg_config: PGConfig, | ||||
|         ppo_config: PPOConfig, | ||||
|         config: PPOConfig, | ||||
|         sampling_config: RLSamplingConfig, | ||||
|         actor_factory: ActorFactory, | ||||
|         critic_factory: CriticFactory, | ||||
| @ -181,9 +179,7 @@ class PPOAgentFactory(OnpolicyAgentFactory): | ||||
|         self.optimizer_factory = optimizer_factory | ||||
|         self.critic_factory = critic_factory | ||||
|         self.actor_factory = actor_factory | ||||
|         self.ppo_config = ppo_config | ||||
|         self.pg_config = pg_config | ||||
|         self.general_config = general_config | ||||
|         self.config = config | ||||
|         self.lr = lr | ||||
|         self.lr_scheduler_factory = lr_scheduler_factory | ||||
|         self.dist_fn = dist_fn | ||||
| @ -208,27 +204,30 @@ class PPOAgentFactory(OnpolicyAgentFactory): | ||||
|             action_space=envs.get_action_space(), | ||||
|             action_scaling=True, | ||||
|             # general_config | ||||
|             discount_factor=self.general_config.gamma, | ||||
|             gae_lambda=self.general_config.gae_lambda, | ||||
|             reward_normalization=self.general_config.rew_norm, | ||||
|             action_bound_method=self.general_config.action_bound_method, | ||||
|             discount_factor=self.config.gamma, | ||||
|             gae_lambda=self.config.gae_lambda, | ||||
|             reward_normalization=self.config.rew_norm, | ||||
|             action_bound_method=self.config.action_bound_method, | ||||
|             # pg_config | ||||
|             max_grad_norm=self.pg_config.max_grad_norm, | ||||
|             vf_coef=self.pg_config.vf_coef, | ||||
|             ent_coef=self.pg_config.ent_coef, | ||||
|             max_grad_norm=self.config.max_grad_norm, | ||||
|             vf_coef=self.config.vf_coef, | ||||
|             ent_coef=self.config.ent_coef, | ||||
|             # ppo_config | ||||
|             eps_clip=self.ppo_config.eps_clip, | ||||
|             value_clip=self.ppo_config.value_clip, | ||||
|             dual_clip=self.ppo_config.dual_clip, | ||||
|             advantage_normalization=self.ppo_config.norm_adv, | ||||
|             recompute_advantage=self.ppo_config.recompute_adv, | ||||
|             eps_clip=self.config.eps_clip, | ||||
|             value_clip=self.config.value_clip, | ||||
|             dual_clip=self.config.dual_clip, | ||||
|             advantage_normalization=self.config.norm_adv, | ||||
|             recompute_advantage=self.config.recompute_adv, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class AutoAlphaFactory(ABC): | ||||
|     @abstractmethod | ||||
|     def create_auto_alpha( | ||||
|         self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, | ||||
|         self, | ||||
|         envs: Environments, | ||||
|         optim_factory: OptimizerFactory, | ||||
|         device: TDevice, | ||||
|     ): | ||||
|         pass | ||||
| 
 | ||||
| @ -238,7 +237,10 @@ class DefaultAutoAlphaFactory(AutoAlphaFactory):  # TODO better name? | ||||
|         self.lr = lr | ||||
| 
 | ||||
|     def create_auto_alpha( | ||||
|         self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice, | ||||
|         self, | ||||
|         envs: Environments, | ||||
|         optim_factory: OptimizerFactory, | ||||
|         device: TDevice, | ||||
|     ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: | ||||
|         target_entropy = -np.prod(envs.get_action_shape()) | ||||
|         log_alpha = torch.zeros(1, requires_grad=True, device=device) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user