Log full experiment configuration, adding string representations to relevant classes

This commit is contained in:
Dominik Jain 2023-10-03 21:14:22 +02:00
parent 58bd20f882
commit 9f0a410bb1
13 changed files with 33 additions and 15 deletions

View File

@ -21,6 +21,7 @@ from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.highlevel.params.policy_wrapper import ( from tianshou.highlevel.params.policy_wrapper import (
PolicyWrapperFactoryIntrinsicCuriosity, PolicyWrapperFactoryIntrinsicCuriosity,
) )
from tianshou.utils import logging
def main( def main(
@ -114,4 +115,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
CLI(main) logging.run_main(lambda: CLI(main))

View File

@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
from tianshou.highlevel.optim import OptimizerFactoryRMSprop from tianshou.highlevel.optim import OptimizerFactoryRMSprop
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import A2CParams from tianshou.highlevel.params.policy_params import A2CParams
from tianshou.utils import logging
def main( def main(
@ -82,4 +83,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
CLI(main) logging.run_main(lambda: CLI(main))

View File

@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
) )
from tianshou.highlevel.params.noise import MaxActionScaledGaussian from tianshou.highlevel.params.noise import MaxActionScaledGaussian
from tianshou.highlevel.params.policy_params import DDPGParams from tianshou.highlevel.params.policy_params import DDPGParams
from tianshou.utils import logging
def main( def main(
@ -75,4 +76,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
CLI(main) logging.run_main(lambda: CLI(main))

View File

@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
) )
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
from tianshou.highlevel.params.policy_params import PPOParams from tianshou.highlevel.params.policy_params import PPOParams
from tianshou.utils import logging
def main( def main(
@ -95,4 +96,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
CLI(main) logging.run_main(lambda: CLI(main))

View File

@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
) )
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
from tianshou.highlevel.params.policy_params import SACParams from tianshou.highlevel.params.policy_params import SACParams
from tianshou.utils import logging
def main( def main(
@ -81,4 +82,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
CLI(main) logging.run_main(lambda: CLI(main))

View File

@ -17,6 +17,7 @@ from tianshou.highlevel.params.noise import (
MaxActionScaledGaussian, MaxActionScaledGaussian,
) )
from tianshou.highlevel.params.policy_params import TD3Params from tianshou.highlevel.params.policy_params import TD3Params
from tianshou.utils import logging
def main( def main(
@ -84,4 +85,4 @@ def main(
if __name__ == "__main__": if __name__ == "__main__":
CLI(main) logging.run_main(lambda: CLI(main))

View File

@ -41,6 +41,7 @@ from tianshou.policy import (
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin
CHECKPOINT_DICT_KEY_MODEL = "model" CHECKPOINT_DICT_KEY_MODEL = "model"
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms" CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
@ -48,7 +49,7 @@ TParams = TypeVar("TParams", bound=Params)
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
class AgentFactory(ABC): class AgentFactory(ABC, ToStringMixin):
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory): def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
self.sampling_config = sampling_config self.sampling_config = sampling_config
self.optim_factory = optim_factory self.optim_factory = optim_factory

View File

@ -1,3 +1,4 @@
import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from dataclasses import dataclass from dataclasses import dataclass
@ -37,7 +38,9 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
from tianshou.highlevel.persistence import PersistableConfigProtocol from tianshou.highlevel.persistence import PersistableConfigProtocol
from tianshou.policy import BasePolicy from tianshou.policy import BasePolicy
from tianshou.trainer import BaseTrainer from tianshou.trainer import BaseTrainer
from tianshou.utils.string import ToStringMixin
log = logging.getLogger(__name__)
TPolicy = TypeVar("TPolicy", bound=BasePolicy) TPolicy = TypeVar("TPolicy", bound=BasePolicy)
TTrainer = TypeVar("TTrainer", bound=BaseTrainer) TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
@ -59,7 +62,7 @@ class RLExperimentConfig:
watch_num_episodes = 10 watch_num_episodes = 10
class RLExperiment(Generic[TPolicy, TTrainer]): class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
def __init__( def __init__(
self, self,
config: RLExperimentConfig, config: RLExperimentConfig,
@ -204,13 +207,15 @@ class RLExperimentBuilder:
agent_factory = self._create_agent_factory() agent_factory = self._create_agent_factory()
if self._policy_wrapper_factory: if self._policy_wrapper_factory:
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory) agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
return RLExperiment( experiment = RLExperiment(
self._config, self._config,
self._env_factory, self._env_factory,
agent_factory, agent_factory,
self._logger_factory, self._logger_factory,
env_config=self._env_config, env_config=self._env_config,
) )
log.info(f"Created experiment:\n{experiment.pprints()}")
return experiment
class _BuilderMixinActorFactory: class _BuilderMixinActorFactory:

View File

@ -6,6 +6,7 @@ from typing import Literal, TypeAlias
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import TensorboardLogger, WandbLogger from tianshou.utils import TensorboardLogger, WandbLogger
from tianshou.utils.string import ToStringMixin
TLogger: TypeAlias = TensorboardLogger | WandbLogger TLogger: TypeAlias = TensorboardLogger | WandbLogger
@ -16,7 +17,7 @@ class Logger:
log_path: str log_path: str
class LoggerFactory(ABC): class LoggerFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger: def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
pass pass

View File

@ -8,6 +8,7 @@ from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous, discrete from tianshou.utils.net import continuous, discrete
from tianshou.utils.net.common import BaseActor, Net from tianshou.utils.net.common import BaseActor, Net
from tianshou.utils.string import ToStringMixin
class ContinuousActorType: class ContinuousActorType:
@ -15,7 +16,7 @@ class ContinuousActorType:
DETERMINISTIC = "deterministic" DETERMINISTIC = "deterministic"
class ActorFactory(ABC): class ActorFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> BaseActor: def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
pass pass

View File

@ -7,9 +7,10 @@ from tianshou.highlevel.env import Environments, EnvType
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
from tianshou.utils.net import continuous from tianshou.utils.net import continuous
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin
class CriticFactory(ABC): class CriticFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module: def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
pass pass

View File

@ -8,6 +8,7 @@ from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.module.critic import CriticFactory from tianshou.highlevel.module.critic import CriticFactory
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.net.common import ActorCritic from tianshou.utils.net.common import ActorCritic
from tianshou.utils.string import ToStringMixin
@dataclass @dataclass
@ -30,7 +31,7 @@ class ActorCriticModuleOpt:
return self.actor_critic_module.critic return self.actor_critic_module.critic
class ActorModuleOptFactory: class ActorModuleOptFactory(ToStringMixin):
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory): def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
self.actor_factory = actor_factory self.actor_factory = actor_factory
self.optim_factory = optim_factory self.optim_factory = optim_factory
@ -41,7 +42,7 @@ class ActorModuleOptFactory:
return ModuleOpt(actor, opt) return ModuleOpt(actor, opt)
class CriticModuleOptFactory: class CriticModuleOptFactory(ToStringMixin):
def __init__( def __init__(
self, self,
critic_factory: CriticFactory, critic_factory: CriticFactory,

View File

@ -4,8 +4,10 @@ from typing import Any
import torch import torch
from torch.optim import Adam, RMSprop from torch.optim import Adam, RMSprop
from tianshou.utils.string import ToStringMixin
class OptimizerFactory(ABC):
class OptimizerFactory(ABC, ToStringMixin):
# TODO: Is it OK to assume that all optimizers have a learning rate argument? # TODO: Is it OK to assume that all optimizers have a learning rate argument?
# Right now, the learning rate is typically a configuration parameter. # Right now, the learning rate is typically a configuration parameter.
# If we drop the assumption, we can't have that and will need to move the parameter # If we drop the assumption, we can't have that and will need to move the parameter