Refactoring, dropping package config
This commit is contained in:
		
							parent
							
								
									316eb3c579
								
							
						
					
					
						commit
						997b520580
					
				| @ -2,9 +2,9 @@ import warnings | |||||||
| 
 | 
 | ||||||
| import gymnasium as gym | import gymnasium as gym | ||||||
| 
 | 
 | ||||||
| from tianshou.config import BasicExperimentConfig, RLSamplingConfig |  | ||||||
| from tianshou.env import ShmemVectorEnv, VectorEnvNormObs | from tianshou.env import ShmemVectorEnv, VectorEnvNormObs | ||||||
| from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory | from tianshou.highlevel.env import ContinuousEnvironments, EnvFactory | ||||||
|  | from tianshou.highlevel.experiment import RLSamplingConfig | ||||||
| 
 | 
 | ||||||
| try: | try: | ||||||
|     import envpool |     import envpool | ||||||
| @ -41,14 +41,15 @@ def make_mujoco_env(task: str, seed: int, num_train_envs: int, num_test_envs: in | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class MujocoEnvFactory(EnvFactory): | class MujocoEnvFactory(EnvFactory): | ||||||
|     def __init__(self, experiment_config: BasicExperimentConfig, sampling_config: RLSamplingConfig): |     def __init__(self, task: str, seed: int, sampling_config: RLSamplingConfig): | ||||||
|  |         self.task = task | ||||||
|         self.sampling_config = sampling_config |         self.sampling_config = sampling_config | ||||||
|         self.experiment_config = experiment_config |         self.seed = seed | ||||||
| 
 | 
 | ||||||
|     def create_envs(self) -> ContinuousEnvironments: |     def create_envs(self) -> ContinuousEnvironments: | ||||||
|         env, train_envs, test_envs = make_mujoco_env( |         env, train_envs, test_envs = make_mujoco_env( | ||||||
|             task=self.experiment_config.task, |             task=self.task, | ||||||
|             seed=self.experiment_config.seed, |             seed=self.seed, | ||||||
|             num_train_envs=self.sampling_config.num_train_envs, |             num_train_envs=self.sampling_config.num_train_envs, | ||||||
|             num_test_envs=self.sampling_config.num_test_envs, |             num_test_envs=self.sampling_config.num_test_envs, | ||||||
|             obs_norm=True, |             obs_norm=True, | ||||||
|  | |||||||
| @ -9,17 +9,13 @@ from jsonargparse import CLI | |||||||
| from torch.distributions import Independent, Normal | from torch.distributions import Independent, Normal | ||||||
| 
 | 
 | ||||||
| from examples.mujoco.mujoco_env import MujocoEnvFactory | from examples.mujoco.mujoco_env import MujocoEnvFactory | ||||||
| from tianshou.config import ( | from tianshou.highlevel.agent import PGConfig, PPOAgentFactory, PPOConfig, RLAgentConfig | ||||||
|     BasicExperimentConfig, | from tianshou.highlevel.experiment import ( | ||||||
|     LoggerConfig, |     RLExperiment, | ||||||
|     PGConfig, |     RLExperimentConfig, | ||||||
|     PPOConfig, |  | ||||||
|     RLAgentConfig, |  | ||||||
|     RLSamplingConfig, |     RLSamplingConfig, | ||||||
| ) | ) | ||||||
| from tianshou.highlevel.agent import PPOAgentFactory | from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig | ||||||
| from tianshou.highlevel.experiment import RLExperiment |  | ||||||
| from tianshou.highlevel.logger import DefaultLoggerFactory |  | ||||||
| from tianshou.highlevel.module import ( | from tianshou.highlevel.module import ( | ||||||
|     ContinuousActorProbFactory, |     ContinuousActorProbFactory, | ||||||
|     ContinuousNetCriticFactory, |     ContinuousNetCriticFactory, | ||||||
| @ -35,19 +31,20 @@ class NNConfig: | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main( | def main( | ||||||
|     experiment_config: BasicExperimentConfig, |     experiment_config: RLExperimentConfig, | ||||||
|     logger_config: LoggerConfig, |     logger_config: LoggerConfig, | ||||||
|     sampling_config: RLSamplingConfig, |     sampling_config: RLSamplingConfig, | ||||||
|     general_config: RLAgentConfig, |     general_config: RLAgentConfig, | ||||||
|     pg_config: PGConfig, |     pg_config: PGConfig, | ||||||
|     ppo_config: PPOConfig, |     ppo_config: PPOConfig, | ||||||
|     nn_config: NNConfig, |     nn_config: NNConfig, | ||||||
|  |     task: str = "Ant-v4", | ||||||
| ): | ): | ||||||
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") |     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||||||
|     log_name = os.path.join(experiment_config.task, "ppo", str(experiment_config.seed), now) |     log_name = os.path.join(task, "ppo", str(experiment_config.seed), now) | ||||||
|     logger_factory = DefaultLoggerFactory(logger_config) |     logger_factory = DefaultLoggerFactory(logger_config) | ||||||
| 
 | 
 | ||||||
|     env_factory = MujocoEnvFactory(experiment_config, sampling_config) |     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||||
| 
 | 
 | ||||||
|     def dist_fn(*logits): |     def dist_fn(*logits): | ||||||
|         return Independent(Normal(*logits), 1) |         return Independent(Normal(*logits), 1) | ||||||
|  | |||||||
| @ -7,14 +7,13 @@ from collections.abc import Sequence | |||||||
| from jsonargparse import CLI | from jsonargparse import CLI | ||||||
| 
 | 
 | ||||||
| from examples.mujoco.mujoco_env import MujocoEnvFactory | from examples.mujoco.mujoco_env import MujocoEnvFactory | ||||||
| from tianshou.config import ( | from tianshou.highlevel.agent import SACAgentFactory, SACConfig | ||||||
|     BasicExperimentConfig, | from tianshou.highlevel.experiment import ( | ||||||
|     LoggerConfig, |     RLExperiment, | ||||||
|  |     RLExperimentConfig, | ||||||
|     RLSamplingConfig, |     RLSamplingConfig, | ||||||
| ) | ) | ||||||
| from tianshou.highlevel.agent import SACAgentFactory | from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig | ||||||
| from tianshou.highlevel.experiment import RLExperiment |  | ||||||
| from tianshou.highlevel.logger import DefaultLoggerFactory |  | ||||||
| from tianshou.highlevel.module import ( | from tianshou.highlevel.module import ( | ||||||
|     ContinuousActorProbFactory, |     ContinuousActorProbFactory, | ||||||
|     ContinuousNetCriticFactory, |     ContinuousNetCriticFactory, | ||||||
| @ -23,17 +22,18 @@ from tianshou.highlevel.optim import AdamOptimizerFactory | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main( | def main( | ||||||
|     experiment_config: BasicExperimentConfig, |     experiment_config: RLExperimentConfig, | ||||||
|     logger_config: LoggerConfig, |     logger_config: LoggerConfig, | ||||||
|     sampling_config: RLSamplingConfig, |     sampling_config: RLSamplingConfig, | ||||||
|     sac_config: SACAgentFactory.Config, |     sac_config: SACConfig, | ||||||
|     hidden_sizes: Sequence[int] = (256, 256), |     hidden_sizes: Sequence[int] = (256, 256), | ||||||
|  |     task: str = "Ant-v4", | ||||||
| ): | ): | ||||||
|     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") |     now = datetime.datetime.now().strftime("%y%m%d-%H%M%S") | ||||||
|     log_name = os.path.join(experiment_config.task, "sac", str(experiment_config.seed), now) |     log_name = os.path.join(task, "sac", str(experiment_config.seed), now) | ||||||
|     logger_factory = DefaultLoggerFactory(logger_config) |     logger_factory = DefaultLoggerFactory(logger_config) | ||||||
| 
 | 
 | ||||||
|     env_factory = MujocoEnvFactory(experiment_config, sampling_config) |     env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config) | ||||||
| 
 | 
 | ||||||
|     actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) |     actor_factory = ContinuousActorProbFactory(hidden_sizes, conditioned_sigma=True) | ||||||
|     critic_factory = ContinuousNetCriticFactory(hidden_sizes) |     critic_factory = ContinuousNetCriticFactory(hidden_sizes) | ||||||
|  | |||||||
| @ -1,10 +0,0 @@ | |||||||
| __all__ = ["PGConfig", "PPOConfig", "RLAgentConfig", "RLSamplingConfig", "BasicExperimentConfig", "LoggerConfig"] |  | ||||||
| 
 |  | ||||||
| from .config import ( |  | ||||||
|     BasicExperimentConfig, |  | ||||||
|     PGConfig, |  | ||||||
|     PPOConfig, |  | ||||||
|     RLAgentConfig, |  | ||||||
|     RLSamplingConfig, |  | ||||||
|     LoggerConfig, |  | ||||||
| ) |  | ||||||
| @ -1,86 +0,0 @@ | |||||||
| from dataclasses import dataclass |  | ||||||
| from typing import Literal |  | ||||||
| 
 |  | ||||||
| import torch |  | ||||||
| from jsonargparse import set_docstring_parse_options |  | ||||||
| 
 |  | ||||||
| set_docstring_parse_options(attribute_docstrings=True) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass |  | ||||||
| class BasicExperimentConfig: |  | ||||||
|     """Generic config for setting up the experiment, not RL or training specific.""" |  | ||||||
| 
 |  | ||||||
|     seed: int = 42 |  | ||||||
|     task: str = "Ant-v4" |  | ||||||
|     """Mujoco specific""" |  | ||||||
|     render: float | None = 0.0 |  | ||||||
|     """Milliseconds between rendered frames; if None, no rendering""" |  | ||||||
|     device: str = "cuda" if torch.cuda.is_available() else "cpu" |  | ||||||
|     resume_id: str | None = None |  | ||||||
|     """For restoring a model and running means of env-specifics from a checkpoint""" |  | ||||||
|     resume_path: str | None = None |  | ||||||
|     """For restoring a model and running means of env-specifics from a checkpoint""" |  | ||||||
|     watch: bool = False |  | ||||||
|     """If True, will not perform training and only watch the restored policy""" |  | ||||||
|     watch_num_episodes = 10 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass |  | ||||||
| class LoggerConfig: |  | ||||||
|     """Logging config.""" |  | ||||||
| 
 |  | ||||||
|     logdir: str = "log" |  | ||||||
|     logger: Literal["tensorboard", "wandb"] = "tensorboard" |  | ||||||
|     wandb_project: str = "mujoco.benchmark" |  | ||||||
|     """Only used if logger is wandb.""" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass |  | ||||||
| class RLSamplingConfig: |  | ||||||
|     """Sampling, epochs, parallelization, buffers, collectors, and batching.""" |  | ||||||
| 
 |  | ||||||
|     num_epochs: int = 100 |  | ||||||
|     step_per_epoch: int = 30000 |  | ||||||
|     batch_size: int = 64 |  | ||||||
|     num_train_envs: int = 64 |  | ||||||
|     num_test_envs: int = 10 |  | ||||||
|     buffer_size: int = 4096 |  | ||||||
|     step_per_collect: int = 2048 |  | ||||||
|     repeat_per_collect: int = 10 |  | ||||||
|     update_per_step: int = 1 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass |  | ||||||
| class RLAgentConfig: |  | ||||||
|     """Config common to most RL algorithms.""" |  | ||||||
| 
 |  | ||||||
|     gamma: float = 0.99 |  | ||||||
|     """Discount factor""" |  | ||||||
|     gae_lambda: float = 0.95 |  | ||||||
|     """For Generalized Advantage Estimate (equivalent to TD(lambda))""" |  | ||||||
|     action_bound_method: Literal["clip", "tanh"] | None = "clip" |  | ||||||
|     """How to map original actions in range (-inf, inf) to [-1, 1]""" |  | ||||||
|     rew_norm: bool = True |  | ||||||
|     """Whether to normalize rewards""" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass |  | ||||||
| class PGConfig: |  | ||||||
|     """Config of general policy-gradient algorithms.""" |  | ||||||
| 
 |  | ||||||
|     ent_coef: float = 0.0 |  | ||||||
|     vf_coef: float = 0.25 |  | ||||||
|     max_grad_norm: float = 0.5 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @dataclass |  | ||||||
| class PPOConfig: |  | ||||||
|     """PPO specific config.""" |  | ||||||
| 
 |  | ||||||
|     value_clip: bool = False |  | ||||||
|     norm_adv: bool = False |  | ||||||
|     """Whether to normalize advantages""" |  | ||||||
|     eps_clip: float = 0.2 |  | ||||||
|     dual_clip: float | None = None |  | ||||||
|     recompute_adv: bool = True |  | ||||||
| @ -0,0 +1,3 @@ | |||||||
|  | from jsonargparse import set_docstring_parse_options | ||||||
|  | 
 | ||||||
|  | set_docstring_parse_options(attribute_docstrings=True) | ||||||
| @ -2,13 +2,14 @@ import os | |||||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||||
| from collections.abc import Callable | from collections.abc import Callable | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  | from typing import Literal | ||||||
| 
 | 
 | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| from tianshou.config import PGConfig, PPOConfig, RLAgentConfig, RLSamplingConfig |  | ||||||
| from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer | from tianshou.data import Collector, ReplayBuffer, VectorReplayBuffer | ||||||
| from tianshou.exploration import BaseNoise | from tianshou.exploration import BaseNoise | ||||||
| from tianshou.highlevel.env import Environments | from tianshou.highlevel.env import Environments | ||||||
|  | from tianshou.highlevel.experiment import RLSamplingConfig | ||||||
| from tianshou.highlevel.logger import Logger | from tianshou.highlevel.logger import Logger | ||||||
| from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice | from tianshou.highlevel.module import ActorFactory, CriticFactory, TDevice | ||||||
| from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory | from tianshou.highlevel.optim import LRSchedulerFactory, OptimizerFactory | ||||||
| @ -124,6 +125,41 @@ class OffpolicyAgentFactory(AgentFactory, ABC): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class RLAgentConfig: | ||||||
|  |     """Config common to most RL algorithms.""" | ||||||
|  | 
 | ||||||
|  |     gamma: float = 0.99 | ||||||
|  |     """Discount factor""" | ||||||
|  |     gae_lambda: float = 0.95 | ||||||
|  |     """For Generalized Advantage Estimate (equivalent to TD(lambda))""" | ||||||
|  |     action_bound_method: Literal["clip", "tanh"] | None = "clip" | ||||||
|  |     """How to map original actions in range (-inf, inf) to [-1, 1]""" | ||||||
|  |     rew_norm: bool = True | ||||||
|  |     """Whether to normalize rewards""" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class PGConfig: | ||||||
|  |     """Config of general policy-gradient algorithms.""" | ||||||
|  | 
 | ||||||
|  |     ent_coef: float = 0.0 | ||||||
|  |     vf_coef: float = 0.25 | ||||||
|  |     max_grad_norm: float = 0.5 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class PPOConfig: | ||||||
|  |     """PPO specific config.""" | ||||||
|  | 
 | ||||||
|  |     value_clip: bool = False | ||||||
|  |     norm_adv: bool = False | ||||||
|  |     """Whether to normalize advantages""" | ||||||
|  |     eps_clip: float = 0.2 | ||||||
|  |     dual_clip: float | None = None | ||||||
|  |     recompute_adv: bool = True | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class PPOAgentFactory(OnpolicyAgentFactory): | class PPOAgentFactory(OnpolicyAgentFactory): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
| @ -186,10 +222,22 @@ class PPOAgentFactory(OnpolicyAgentFactory): | |||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class SACConfig: | ||||||
|  |     tau: float = 0.005 | ||||||
|  |     gamma: float = 0.99 | ||||||
|  |     alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2 | ||||||
|  |     reward_normalization: bool = False | ||||||
|  |     estimation_step: int = 1 | ||||||
|  |     deterministic_eval: bool = True | ||||||
|  |     actor_lr: float = 1e-3 | ||||||
|  |     critic1_lr: float = 1e-3 | ||||||
|  |     critic2_lr: float = 1e-3 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class SACAgentFactory(OffpolicyAgentFactory): | class SACAgentFactory(OffpolicyAgentFactory): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         config: "SACAgentFactory.Config", |         config: SACConfig, | ||||||
|         sampling_config: RLSamplingConfig, |         sampling_config: RLSamplingConfig, | ||||||
|         actor_factory: ActorFactory, |         actor_factory: ActorFactory, | ||||||
|         critic1_factory: CriticFactory, |         critic1_factory: CriticFactory, | ||||||
| @ -227,17 +275,3 @@ class SACAgentFactory(OffpolicyAgentFactory): | |||||||
|             deterministic_eval=self.config.deterministic_eval, |             deterministic_eval=self.config.deterministic_eval, | ||||||
|             exploration_noise=self.exploration_noise, |             exploration_noise=self.exploration_noise, | ||||||
|         ) |         ) | ||||||
| 
 |  | ||||||
|     @dataclass |  | ||||||
|     class Config: |  | ||||||
|         """SAC configuration.""" |  | ||||||
| 
 |  | ||||||
|         tau: float = 0.005 |  | ||||||
|         gamma: float = 0.99 |  | ||||||
|         alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] = 0.2 |  | ||||||
|         reward_normalization: bool = False |  | ||||||
|         estimation_step: int = 1 |  | ||||||
|         deterministic_eval: bool = True |  | ||||||
|         actor_lr: float = 1e-3 |  | ||||||
|         critic1_lr: float = 1e-3 |  | ||||||
|         critic2_lr: float = 1e-3 |  | ||||||
|  | |||||||
| @ -1,12 +1,10 @@ | |||||||
|  | from dataclasses import dataclass | ||||||
| from pprint import pprint | from pprint import pprint | ||||||
| from typing import Generic, TypeVar | from typing import Generic, TypeVar | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| 
 | 
 | ||||||
| from tianshou.config import ( |  | ||||||
|     BasicExperimentConfig, |  | ||||||
| ) |  | ||||||
| from tianshou.data import Collector | from tianshou.data import Collector | ||||||
| from tianshou.highlevel.agent import AgentFactory | from tianshou.highlevel.agent import AgentFactory | ||||||
| from tianshou.highlevel.env import EnvFactory | from tianshou.highlevel.env import EnvFactory | ||||||
| @ -18,10 +16,42 @@ TPolicy = TypeVar("TPolicy", bound=BasePolicy) | |||||||
| TTrainer = TypeVar("TTrainer", bound=BaseTrainer) | TTrainer = TypeVar("TTrainer", bound=BaseTrainer) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class RLExperimentConfig: | ||||||
|  |     """Generic config for setting up the experiment, not RL or training specific.""" | ||||||
|  | 
 | ||||||
|  |     seed: int = 42 | ||||||
|  |     render: float | None = 0.0 | ||||||
|  |     """Milliseconds between rendered frames; if None, no rendering""" | ||||||
|  |     device: str = "cuda" if torch.cuda.is_available() else "cpu" | ||||||
|  |     resume_id: str | None = None | ||||||
|  |     """For restoring a model and running means of env-specifics from a checkpoint""" | ||||||
|  |     resume_path: str | None = None | ||||||
|  |     """For restoring a model and running means of env-specifics from a checkpoint""" | ||||||
|  |     watch: bool = False | ||||||
|  |     """If True, will not perform training and only watch the restored policy""" | ||||||
|  |     watch_num_episodes = 10 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @dataclass | ||||||
|  | class RLSamplingConfig: | ||||||
|  |     """Sampling, epochs, parallelization, buffers, collectors, and batching.""" | ||||||
|  | 
 | ||||||
|  |     num_epochs: int = 100 | ||||||
|  |     step_per_epoch: int = 30000 | ||||||
|  |     batch_size: int = 64 | ||||||
|  |     num_train_envs: int = 64 | ||||||
|  |     num_test_envs: int = 10 | ||||||
|  |     buffer_size: int = 4096 | ||||||
|  |     step_per_collect: int = 2048 | ||||||
|  |     repeat_per_collect: int = 10 | ||||||
|  |     update_per_step: int = 1 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class RLExperiment(Generic[TPolicy, TTrainer]): | class RLExperiment(Generic[TPolicy, TTrainer]): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         config: BasicExperimentConfig, |         config: RLExperimentConfig, | ||||||
|         env_factory: EnvFactory, |         env_factory: EnvFactory, | ||||||
|         logger_factory: LoggerFactory, |         logger_factory: LoggerFactory, | ||||||
|         agent_factory: AgentFactory, |         agent_factory: AgentFactory, | ||||||
|  | |||||||
| @ -1,10 +1,10 @@ | |||||||
| import os | import os | ||||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  | from typing import Literal | ||||||
| 
 | 
 | ||||||
| from torch.utils.tensorboard import SummaryWriter | from torch.utils.tensorboard import SummaryWriter | ||||||
| 
 | 
 | ||||||
| from tianshou.config import LoggerConfig |  | ||||||
| from tianshou.utils import TensorboardLogger, WandbLogger | from tianshou.utils import TensorboardLogger, WandbLogger | ||||||
| 
 | 
 | ||||||
| TLogger = TensorboardLogger | WandbLogger | TLogger = TensorboardLogger | WandbLogger | ||||||
| @ -22,11 +22,21 @@ class LoggerFactory(ABC): | |||||||
|         pass |         pass | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class LoggerConfig: | ||||||
|  |     """Logging config.""" | ||||||
|  | 
 | ||||||
|  |     logdir: str = "log" | ||||||
|  |     logger: Literal["tensorboard", "wandb"] = "tensorboard" | ||||||
|  |     wandb_project: str = "mujoco.benchmark" | ||||||
|  |     """Only used if logger is wandb.""" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| class DefaultLoggerFactory(LoggerFactory): | class DefaultLoggerFactory(LoggerFactory): | ||||||
|     def __init__(self, config: LoggerConfig): |     def __init__(self, config: LoggerConfig): | ||||||
|         self.config = config |         self.config = config | ||||||
| 
 | 
 | ||||||
|     def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: |     def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger: | ||||||
|         writer = SummaryWriter(self.config.logdir) |         writer = SummaryWriter(self.config.logdir) | ||||||
|         writer.add_text("args", str(self.config)) |         writer.add_text("args", str(self.config)) | ||||||
|         if self.config.logger == "wandb": |         if self.config.logger == "wandb": | ||||||
|  | |||||||
| @ -66,7 +66,7 @@ class ContinuousActorProbFactory(ContinuousActorFactory): | |||||||
|         actor = ActorProb( |         actor = ActorProb( | ||||||
|             net_a, |             net_a, | ||||||
|             envs.get_action_shape(), |             envs.get_action_shape(), | ||||||
|             unbounded=True, |             unbounded=self.unbounded, | ||||||
|             device=device, |             device=device, | ||||||
|             conditioned_sigma=self.conditioned_sigma, |             conditioned_sigma=self.conditioned_sigma, | ||||||
|         ).to(device) |         ).to(device) | ||||||
|  | |||||||
| @ -1,6 +1,6 @@ | |||||||
| from abc import ABC, abstractmethod | from abc import ABC, abstractmethod | ||||||
| from collections.abc import Iterable | from collections.abc import Iterable | ||||||
| from typing import Any, Type | from typing import Any | ||||||
| 
 | 
 | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| @ -8,7 +8,7 @@ from torch import Tensor | |||||||
| from torch.optim import Adam | from torch.optim import Adam | ||||||
| from torch.optim.lr_scheduler import LambdaLR, LRScheduler | from torch.optim.lr_scheduler import LambdaLR, LRScheduler | ||||||
| 
 | 
 | ||||||
| from tianshou.config import RLSamplingConfig | from tianshou.highlevel.experiment import RLSamplingConfig | ||||||
| 
 | 
 | ||||||
| TParams = Iterable[Tensor] | Iterable[dict[str, Any]] | TParams = Iterable[Tensor] | Iterable[dict[str, Any]] | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user