Log full experiment configuration, adding string representations to relevant classes

This commit is contained in:
Dominik Jain 2023-10-03 21:14:22 +02:00
parent 58bd20f882
commit 9f0a410bb1
13 changed files with 33 additions and 15 deletions

View File

@ -21,6 +21,7 @@ from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity,
)
from tianshou.utils import logging
def main(
@ -114,4 +115,4 @@ def main(
if __name__ == "__main__":
CLI(main)
logging.run_main(lambda: CLI(main))

View File

@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import A2CParams
from tianshou.utils import logging
def main(
@ -82,4 +83,4 @@ def main(
if __name__ == "__main__":
CLI(main)
logging.run_main(lambda: CLI(main))

View File

@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
)
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
from tianshou.highlevel.params.policy_params import DDPGParams
from tianshou.utils import logging
def main(
@ -75,4 +76,4 @@ def main(
if __name__ == "__main__":
CLI(main)
logging.run_main(lambda: CLI(main))

View File

@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
)
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
def main(
@ -95,4 +96,4 @@ def main(
if __name__ == "__main__":
CLI(main)
logging.run_main(lambda: CLI(main))

View File

@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
)
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils import logging
def main(
@ -81,4 +82,4 @@ def main(
if __name__ == "__main__":
CLI(main)
logging.run_main(lambda: CLI(main))

View File

@ -17,6 +17,7 @@ from tianshou.highlevel.params.noise import (
MaxActionScaledGaussian,
)
from tianshou.highlevel.params.policy_params import TD3Params
from tianshou.utils import logging
def main(
@ -84,4 +85,4 @@ def main(
if __name__ == "__main__":
CLI(main)
logging.run_main(lambda: CLI(main))

View File

@ -41,6 +41,7 @@ from tianshou.policy import (
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
@ -48,7 +49,7 @@ TParams = TypeVar("TParams", bound=Params)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class AgentFactory(ABC):
class AgentFactory(ABC, ToStringMixin):
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config
self.optim_factory = optim_factory

View File

@ -1,3 +1,4 @@
import logging
from abc import abstractmethod
from collections.abc import Callable, Sequence
from dataclasses import dataclass
@ -37,7 +38,9 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer
from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__)
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
@ -59,7 +62,7 @@ class RLExperimentConfig:
watch_num_episodes = 10
class RLExperiment(Generic[TPolicy, TTrainer]):
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
def __init__(
self,
config: RLExperimentConfig,
@ -204,13 +207,15 @@ class RLExperimentBuilder:
agent_factory = self._create_agent_factory()
if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
return RLExperiment(
experiment = RLExperiment(
self._config,
self._env_factory,
agent_factory,
self._logger_factory,
env_config=self._env_config,
)
log.info(f"Created experiment:\n{experiment.pprints()}")
return experiment
class _BuilderMixinActorFactory:

View File

@ -6,6 +6,7 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = TensorboardLogger | WandbLogger
@ -16,7 +17,7 @@ class Logger:
log_path: str
class LoggerFactory(ABC):
class LoggerFactory(ToStringMixin, ABC):
@abstractmethod
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
pass

View File

@ -8,6 +8,7 @@ from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, Net
from tianshou.utils.string import ToStringMixin
class ContinuousActorType:
@ -15,7 +16,7 @@ class ContinuousActorType:
DETERMINISTIC = "deterministic"
class ActorFactory(ABC):
class ActorFactory(ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
pass

View File

@ -7,9 +7,10 @@ from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous
from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin
class CriticFactory(ABC):
class CriticFactory(ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass

View File

@ -8,6 +8,7 @@ from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin
@dataclass
@ -30,7 +31,7 @@ class ActorCriticModuleOpt:
return self.actor_critic_module.critic
class ActorModuleOptFactory:
class ActorModuleOptFactory(ToStringMixin):
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory
self.optim_factory = optim_factory
@ -41,7 +42,7 @@ class ActorModuleOptFactory:
return ModuleOpt(actor, opt)
class CriticModuleOptFactory:
class CriticModuleOptFactory(ToStringMixin):
def __init__(
self,
critic_factory: CriticFactory,

View File

@ -4,8 +4,10 @@ from typing import Any
import torch
from torch.optim import Adam, RMSprop
from tianshou.utils.string import ToStringMixin
class OptimizerFactory(ABC):
class OptimizerFactory(ABC, ToStringMixin):
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter.
# If we drop the assumption, we can't have that and will need to move the parameter