Log full experiment configuration, adding string representations to relevant classes
This commit is contained in:
parent
58bd20f882
commit
9f0a410bb1
@ -21,6 +21,7 @@ from tianshou.highlevel.params.policy_params import PPOParams
|
|||||||
from tianshou.highlevel.params.policy_wrapper import (
|
from tianshou.highlevel.params.policy_wrapper import (
|
||||||
PolicyWrapperFactoryIntrinsicCuriosity,
|
PolicyWrapperFactoryIntrinsicCuriosity,
|
||||||
)
|
)
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -114,4 +115,4 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CLI(main)
|
logging.run_main(lambda: CLI(main))
|
||||||
|
@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import A2CParams
|
from tianshou.highlevel.params.policy_params import A2CParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -82,4 +83,4 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CLI(main)
|
logging.run_main(lambda: CLI(main))
|
||||||
|
@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
|
||||||
from tianshou.highlevel.params.policy_params import DDPGParams
|
from tianshou.highlevel.params.policy_params import DDPGParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -75,4 +76,4 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CLI(main)
|
logging.run_main(lambda: CLI(main))
|
||||||
|
@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
|
||||||
from tianshou.highlevel.params.policy_params import PPOParams
|
from tianshou.highlevel.params.policy_params import PPOParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -95,4 +96,4 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CLI(main)
|
logging.run_main(lambda: CLI(main))
|
||||||
|
@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
|
|||||||
)
|
)
|
||||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
|
||||||
from tianshou.highlevel.params.policy_params import SACParams
|
from tianshou.highlevel.params.policy_params import SACParams
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -81,4 +82,4 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CLI(main)
|
logging.run_main(lambda: CLI(main))
|
||||||
|
@ -17,6 +17,7 @@ from tianshou.highlevel.params.noise import (
|
|||||||
MaxActionScaledGaussian,
|
MaxActionScaledGaussian,
|
||||||
)
|
)
|
||||||
from tianshou.highlevel.params.policy_params import TD3Params
|
from tianshou.highlevel.params.policy_params import TD3Params
|
||||||
|
from tianshou.utils import logging
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@ -84,4 +85,4 @@ def main(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
CLI(main)
|
logging.run_main(lambda: CLI(main))
|
||||||
|
@ -41,6 +41,7 @@ from tianshou.policy import (
|
|||||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
CHECKPOINT_DICT_KEY_MODEL = "model"
|
CHECKPOINT_DICT_KEY_MODEL = "model"
|
||||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
|
||||||
@ -48,7 +49,7 @@ TParams = TypeVar("TParams", bound=Params)
|
|||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
class AgentFactory(ABC):
|
class AgentFactory(ABC, ToStringMixin):
|
||||||
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
|
||||||
self.sampling_config = sampling_config
|
self.sampling_config = sampling_config
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -37,7 +38,9 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
|
|||||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
from tianshou.highlevel.persistence import PersistableConfigProtocol
|
||||||
from tianshou.policy import BasePolicy
|
from tianshou.policy import BasePolicy
|
||||||
from tianshou.trainer import BaseTrainer
|
from tianshou.trainer import BaseTrainer
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
|
||||||
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
|
||||||
|
|
||||||
@ -59,7 +62,7 @@ class RLExperimentConfig:
|
|||||||
watch_num_episodes = 10
|
watch_num_episodes = 10
|
||||||
|
|
||||||
|
|
||||||
class RLExperiment(Generic[TPolicy, TTrainer]):
|
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: RLExperimentConfig,
|
config: RLExperimentConfig,
|
||||||
@ -204,13 +207,15 @@ class RLExperimentBuilder:
|
|||||||
agent_factory = self._create_agent_factory()
|
agent_factory = self._create_agent_factory()
|
||||||
if self._policy_wrapper_factory:
|
if self._policy_wrapper_factory:
|
||||||
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
|
||||||
return RLExperiment(
|
experiment = RLExperiment(
|
||||||
self._config,
|
self._config,
|
||||||
self._env_factory,
|
self._env_factory,
|
||||||
agent_factory,
|
agent_factory,
|
||||||
self._logger_factory,
|
self._logger_factory,
|
||||||
env_config=self._env_config,
|
env_config=self._env_config,
|
||||||
)
|
)
|
||||||
|
log.info(f"Created experiment:\n{experiment.pprints()}")
|
||||||
|
return experiment
|
||||||
|
|
||||||
|
|
||||||
class _BuilderMixinActorFactory:
|
class _BuilderMixinActorFactory:
|
||||||
|
@ -6,6 +6,7 @@ from typing import Literal, TypeAlias
|
|||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
from tianshou.utils import TensorboardLogger, WandbLogger
|
from tianshou.utils import TensorboardLogger, WandbLogger
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
TLogger: TypeAlias = TensorboardLogger | WandbLogger
|
||||||
|
|
||||||
@ -16,7 +17,7 @@ class Logger:
|
|||||||
log_path: str
|
log_path: str
|
||||||
|
|
||||||
|
|
||||||
class LoggerFactory(ABC):
|
class LoggerFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
|
||||||
pass
|
pass
|
||||||
|
@ -8,6 +8,7 @@ from tianshou.highlevel.env import Environments, EnvType
|
|||||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
from tianshou.utils.net import continuous, discrete
|
from tianshou.utils.net import continuous, discrete
|
||||||
from tianshou.utils.net.common import BaseActor, Net
|
from tianshou.utils.net.common import BaseActor, Net
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
class ContinuousActorType:
|
class ContinuousActorType:
|
||||||
@ -15,7 +16,7 @@ class ContinuousActorType:
|
|||||||
DETERMINISTIC = "deterministic"
|
DETERMINISTIC = "deterministic"
|
||||||
|
|
||||||
|
|
||||||
class ActorFactory(ABC):
|
class ActorFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
|
||||||
pass
|
pass
|
||||||
|
@ -7,9 +7,10 @@ from tianshou.highlevel.env import Environments, EnvType
|
|||||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
|
||||||
from tianshou.utils.net import continuous
|
from tianshou.utils.net import continuous
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
class CriticFactory(ABC):
|
class CriticFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
|
||||||
pass
|
pass
|
||||||
|
@ -8,6 +8,7 @@ from tianshou.highlevel.module.core import TDevice
|
|||||||
from tianshou.highlevel.module.critic import CriticFactory
|
from tianshou.highlevel.module.critic import CriticFactory
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.utils.net.common import ActorCritic
|
from tianshou.utils.net.common import ActorCritic
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -30,7 +31,7 @@ class ActorCriticModuleOpt:
|
|||||||
return self.actor_critic_module.critic
|
return self.actor_critic_module.critic
|
||||||
|
|
||||||
|
|
||||||
class ActorModuleOptFactory:
|
class ActorModuleOptFactory(ToStringMixin):
|
||||||
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
|
||||||
self.actor_factory = actor_factory
|
self.actor_factory = actor_factory
|
||||||
self.optim_factory = optim_factory
|
self.optim_factory = optim_factory
|
||||||
@ -41,7 +42,7 @@ class ActorModuleOptFactory:
|
|||||||
return ModuleOpt(actor, opt)
|
return ModuleOpt(actor, opt)
|
||||||
|
|
||||||
|
|
||||||
class CriticModuleOptFactory:
|
class CriticModuleOptFactory(ToStringMixin):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
critic_factory: CriticFactory,
|
critic_factory: CriticFactory,
|
||||||
|
@ -4,8 +4,10 @@ from typing import Any
|
|||||||
import torch
|
import torch
|
||||||
from torch.optim import Adam, RMSprop
|
from torch.optim import Adam, RMSprop
|
||||||
|
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
class OptimizerFactory(ABC):
|
|
||||||
|
class OptimizerFactory(ABC, ToStringMixin):
|
||||||
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
# TODO: Is it OK to assume that all optimizers have a learning rate argument?
|
||||||
# Right now, the learning rate is typically a configuration parameter.
|
# Right now, the learning rate is typically a configuration parameter.
|
||||||
# If we drop the assumption, we can't have that and will need to move the parameter
|
# If we drop the assumption, we can't have that and will need to move the parameter
|
||||||
|
Loading…
x
Reference in New Issue
Block a user