Remove LoggerConfig
This commit is contained in:
parent
997b520580
commit
adc324038a
@ -15,7 +15,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
RLSamplingConfig,
|
RLSamplingConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
ContinuousNetCriticFactory,
|
ContinuousNetCriticFactory,
|
||||||
@ -32,7 +32,6 @@ class NNConfig:
|
|||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
logger_config: LoggerConfig,
|
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
general_config: RLAgentConfig,
|
general_config: RLAgentConfig,
|
||||||
pg_config: PGConfig,
|
pg_config: PGConfig,
|
||||||
@ -42,7 +41,7 @@ def main(
|
|||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "ppo", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory(logger_config)
|
logger_factory = DefaultLoggerFactory()
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
RLExperimentConfig,
|
RLExperimentConfig,
|
||||||
RLSamplingConfig,
|
RLSamplingConfig,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.logger import DefaultLoggerFactory, LoggerConfig
|
from tianshou.highlevel.logger import DefaultLoggerFactory
|
||||||
from tianshou.highlevel.module import (
|
from tianshou.highlevel.module import (
|
||||||
ContinuousActorProbFactory,
|
ContinuousActorProbFactory,
|
||||||
ContinuousNetCriticFactory,
|
ContinuousNetCriticFactory,
|
||||||
@ -23,7 +23,6 @@ from tianshou.highlevel.optim import AdamOptimizerFactory
|
|||||||
|
|
||||||
def main(
|
def main(
|
||||||
experiment_config: RLExperimentConfig,
|
experiment_config: RLExperimentConfig,
|
||||||
logger_config: LoggerConfig,
|
|
||||||
sampling_config: RLSamplingConfig,
|
sampling_config: RLSamplingConfig,
|
||||||
sac_config: SACConfig,
|
sac_config: SACConfig,
|
||||||
hidden_sizes: Sequence[int] = (256, 256),
|
hidden_sizes: Sequence[int] = (256, 256),
|
||||||
@ -31,7 +30,7 @@ def main(
|
|||||||
):
|
):
|
||||||
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
now = datetime.datetime.now().strftime("%y%m%d-%H%M%S")
|
||||||
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
log_name = os.path.join(task, "sac", str(experiment_config.seed), now)
|
||||||
logger_factory = DefaultLoggerFactory(logger_config)
|
logger_factory = DefaultLoggerFactory()
|
||||||
|
|
||||||
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
env_factory = MujocoEnvFactory(task, experiment_config.seed, sampling_config)
|
||||||
|
|
||||||
|
@ -22,36 +22,30 @@ class LoggerFactory(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoggerConfig:
|
|
||||||
"""Logging config."""
|
|
||||||
|
|
||||||
logdir: str = "log"
|
|
||||||
logger: Literal["tensorboard", "wandb"] = "tensorboard"
|
|
||||||
wandb_project: str = "mujoco.benchmark"
|
|
||||||
"""Only used if logger is wandb."""
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultLoggerFactory(LoggerFactory):
|
class DefaultLoggerFactory(LoggerFactory):
|
||||||
def __init__(self, config: LoggerConfig):
|
def __init__(self, log_dir: str = "log", logger_type: Literal["tensorboard", "wandb"] = "tensorboard", wandb_project: str | None = None):
|
||||||
self.config = config
|
if logger_type == "wandb" and wandb_project is None:
|
||||||
|
raise ValueError("Must provide 'wand_project'")
|
||||||
|
self.log_dir = log_dir
|
||||||
|
self.logger_type = logger_type
|
||||||
|
self.wandb_project = wandb_project
|
||||||
|
|
||||||
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
def create_logger(self, log_name: str, run_id: str | None, config_dict: dict) -> Logger:
|
||||||
writer = SummaryWriter(self.config.logdir)
|
writer = SummaryWriter(self.log_dir)
|
||||||
writer.add_text("args", str(self.config))
|
writer.add_text("args", str(dict(log_dir=self.log_dir, logger_type=self.logger_type, wandb_project=self.wandb_project)))
|
||||||
if self.config.logger == "wandb":
|
if self.logger_type == "wandb":
|
||||||
logger = WandbLogger(
|
logger = WandbLogger(
|
||||||
save_interval=1,
|
save_interval=1,
|
||||||
name=log_name.replace(os.path.sep, "__"),
|
name=log_name.replace(os.path.sep, "__"),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
config=config_dict,
|
config=config_dict,
|
||||||
project=self.config.wandb_project,
|
project=self.wandb_project,
|
||||||
)
|
)
|
||||||
logger.load(writer)
|
logger.load(writer)
|
||||||
elif self.config.logger == "tensorboard":
|
elif self.logger_type == "tensorboard":
|
||||||
logger = TensorboardLogger(writer)
|
logger = TensorboardLogger(writer)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown logger: {self.config.logger}")
|
raise ValueError(f"Unknown logger type '{self.logger_type}'")
|
||||||
log_path = os.path.join(self.config.logdir, log_name)
|
log_path = os.path.join(self.log_dir, log_name)
|
||||||
os.makedirs(log_path, exist_ok=True)
|
os.makedirs(log_path, exist_ok=True)
|
||||||
return Logger(logger=logger, log_path=log_path)
|
return Logger(logger=logger, log_path=log_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user