Log full experiment configuration, adding string representations to relevant classes
This commit is contained in:
parent
58bd20f882
commit
9f0a410bb1
@ -21,6 +21,7 @@ from tianshou.highlevel.params.policy_params import PPOParams
|
||||
from tianshou.highlevel.params.policy_wrapper import (
|
||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||
)
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
@ -114,4 +115,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
logging.run_main(lambda: CLI(main))
|
||||
|
@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
|
||||
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import A2CParams
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
@ -82,4 +83,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
logging.run_main(lambda: CLI(main))
|
||||
|
@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
||||
from tianshou.highlevel.params.policy_params import DDPGParams
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
@ -75,4 +76,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
logging.run_main(lambda: CLI(main))
|
||||
|
@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||
from tianshou.highlevel.params.policy_params import PPOParams
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
@ -95,4 +96,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
logging.run_main(lambda: CLI(main))
|
||||
|
@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
|
||||
)
|
||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||
from tianshou.highlevel.params.policy_params import SACParams
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
@ -81,4 +82,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
logging.run_main(lambda: CLI(main))
|
||||
|
@ -17,6 +17,7 @@ from tianshou.highlevel.params.noise import (
|
||||
MaxActionScaledGaussian,
|
||||
)
|
||||
from tianshou.highlevel.params.policy_params import TD3Params
|
||||
from tianshou.utils import logging
|
||||
|
||||
|
||||
def main(
|
||||
@ -84,4 +85,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
CLI(main)
|
||||
logging.run_main(lambda: CLI(main))
|
||||
|
@ -41,6 +41,7 @@ from tianshou.policy import (
|
||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||
@ -48,7 +49,7 @@ TParams = TypeVar("TParams", bound=Params)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
|
||||
|
||||
class AgentFactory(ABC):
|
||||
class AgentFactory(ABC, ToStringMixin):
|
||||
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
||||
self.sampling_config = sampling_config
|
||||
self.optim_factory = optim_factory
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
@ -37,7 +38,9 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||
from tianshou.policy import BasePolicy
|
||||
from tianshou.trainer import BaseTrainer
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||||
|
||||
@ -59,7 +62,7 @@ class RLExperimentConfig:
|
||||
watch_num_episodes = 10
|
||||
|
||||
|
||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
||||
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
config: RLExperimentConfig,
|
||||
@ -204,13 +207,15 @@ class RLExperimentBuilder:
|
||||
agent_factory = self._create_agent_factory()
|
||||
if self._policy_wrapper_factory:
|
||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||
return RLExperiment(
|
||||
experiment = RLExperiment(
|
||||
self._config,
|
||||
self._env_factory,
|
||||
agent_factory,
|
||||
self._logger_factory,
|
||||
env_config=self._env_config,
|
||||
)
|
||||
log.info(f"Created experiment:\n{experiment.pprints()}")
|
||||
return experiment
|
||||
|
||||
|
||||
class _BuilderMixinActorFactory:
|
||||
|
@ -6,6 +6,7 @@ from typing import Literal, TypeAlias
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
||||
|
||||
@ -16,7 +17,7 @@ class Logger:
|
||||
log_path: str
|
||||
|
||||
|
||||
class LoggerFactory(ABC):
|
||||
class LoggerFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
||||
pass
|
||||
|
@ -8,6 +8,7 @@ from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.utils.net import continuous, discrete
|
||||
from tianshou.utils.net.common import BaseActor, Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class ContinuousActorType:
|
||||
@ -15,7 +16,7 @@ class ContinuousActorType:
|
||||
DETERMINISTIC = "deterministic"
|
||||
|
||||
|
||||
class ActorFactory(ABC):
|
||||
class ActorFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||
pass
|
||||
|
@ -7,9 +7,10 @@ from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||
from tianshou.utils.net import continuous
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class CriticFactory(ABC):
|
||||
class CriticFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||
pass
|
||||
|
@ -8,6 +8,7 @@ from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.module.critic import CriticFactory
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.net.common import ActorCritic
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -30,7 +31,7 @@ class ActorCriticModuleOpt:
|
||||
return self.actor_critic_module.critic
|
||||
|
||||
|
||||
class ActorModuleOptFactory:
|
||||
class ActorModuleOptFactory(ToStringMixin):
|
||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||
self.actor_factory = actor_factory
|
||||
self.optim_factory = optim_factory
|
||||
@ -41,7 +42,7 @@ class ActorModuleOptFactory:
|
||||
return ModuleOpt(actor, opt)
|
||||
|
||||
|
||||
class CriticModuleOptFactory:
|
||||
class CriticModuleOptFactory(ToStringMixin):
|
||||
def __init__(
|
||||
self,
|
||||
critic_factory: CriticFactory,
|
||||
|
@ -4,8 +4,10 @@ from typing import Any
|
||||
import torch
|
||||
from torch.optim import Adam, RMSprop
|
||||
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
class OptimizerFactory(ABC):
|
||||
|
||||
class OptimizerFactory(ABC, ToStringMixin):
|
||||
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
||||
# Right now, the learning rate is typically a configuration parameter.
|
||||
# If we drop the assumption, we can't have that and will need to move the parameter
|
||||
|
Loading…
x
Reference in New Issue
Block a user