Log full experiment configuration, adding string representations to relevant classes
This commit is contained in:
		
							parent
							
								
									58bd20f882
								
							
						
					
					
						commit
						9f0a410bb1
					
				@ -21,6 +21,7 @@ from tianshou.highlevel.params.policy_params import PPOParams
 | 
			
		||||
from tianshou.highlevel.params.policy_wrapper import (
 | 
			
		||||
    PolicyWrapperFactoryIntrinsicCuriosity,
 | 
			
		||||
)
 | 
			
		||||
from tianshou.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(
 | 
			
		||||
@ -114,4 +115,4 @@ def main(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    CLI(main)
 | 
			
		||||
    logging.run_main(lambda: CLI(main))
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
 | 
			
		||||
from tianshou.highlevel.optim import OptimizerFactoryRMSprop
 | 
			
		||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
 | 
			
		||||
from tianshou.highlevel.params.policy_params import A2CParams
 | 
			
		||||
from tianshou.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(
 | 
			
		||||
@ -82,4 +83,4 @@ def main(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    CLI(main)
 | 
			
		||||
    logging.run_main(lambda: CLI(main))
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
 | 
			
		||||
)
 | 
			
		||||
from tianshou.highlevel.params.noise import MaxActionScaledGaussian
 | 
			
		||||
from tianshou.highlevel.params.policy_params import DDPGParams
 | 
			
		||||
from tianshou.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(
 | 
			
		||||
@ -75,4 +76,4 @@ def main(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    CLI(main)
 | 
			
		||||
    logging.run_main(lambda: CLI(main))
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ from tianshou.highlevel.experiment import (
 | 
			
		||||
)
 | 
			
		||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactoryLinear
 | 
			
		||||
from tianshou.highlevel.params.policy_params import PPOParams
 | 
			
		||||
from tianshou.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(
 | 
			
		||||
@ -95,4 +96,4 @@ def main(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    CLI(main)
 | 
			
		||||
    logging.run_main(lambda: CLI(main))
 | 
			
		||||
 | 
			
		||||
@ -14,6 +14,7 @@ from tianshou.highlevel.experiment import (
 | 
			
		||||
)
 | 
			
		||||
from tianshou.highlevel.params.alpha import AutoAlphaFactoryDefault
 | 
			
		||||
from tianshou.highlevel.params.policy_params import SACParams
 | 
			
		||||
from tianshou.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(
 | 
			
		||||
@ -81,4 +82,4 @@ def main(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    CLI(main)
 | 
			
		||||
    logging.run_main(lambda: CLI(main))
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,7 @@ from tianshou.highlevel.params.noise import (
 | 
			
		||||
    MaxActionScaledGaussian,
 | 
			
		||||
)
 | 
			
		||||
from tianshou.highlevel.params.policy_params import TD3Params
 | 
			
		||||
from tianshou.utils import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(
 | 
			
		||||
@ -84,4 +85,4 @@ def main(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    CLI(main)
 | 
			
		||||
    logging.run_main(lambda: CLI(main))
 | 
			
		||||
 | 
			
		||||
@ -41,6 +41,7 @@ from tianshou.policy import (
 | 
			
		||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
 | 
			
		||||
from tianshou.utils.net import continuous, discrete
 | 
			
		||||
from tianshou.utils.net.common import ActorCritic
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
CHECKPOINT_DICT_KEY_MODEL = "model"
 | 
			
		||||
CHECKPOINT_DICT_KEY_OBS_RMS = "obs_rms"
 | 
			
		||||
@ -48,7 +49,7 @@ TParams = TypeVar("TParams", bound=Params)
 | 
			
		||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AgentFactory(ABC):
 | 
			
		||||
class AgentFactory(ABC, ToStringMixin):
 | 
			
		||||
    def __init__(self, sampling_config: RLSamplingConfig, optim_factory: OptimizerFactory):
 | 
			
		||||
        self.sampling_config = sampling_config
 | 
			
		||||
        self.optim_factory = optim_factory
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,4 @@
 | 
			
		||||
import logging
 | 
			
		||||
from abc import abstractmethod
 | 
			
		||||
from collections.abc import Callable, Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
@ -37,7 +38,9 @@ from tianshou.highlevel.params.policy_wrapper import PolicyWrapperFactory
 | 
			
		||||
from tianshou.highlevel.persistence import PersistableConfigProtocol
 | 
			
		||||
from tianshou.policy import BasePolicy
 | 
			
		||||
from tianshou.trainer import BaseTrainer
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
log = logging.getLogger(__name__)
 | 
			
		||||
TPolicy = TypeVar("TPolicy", bound=BasePolicy)
 | 
			
		||||
TTrainer = TypeVar("TTrainer", bound=BaseTrainer)
 | 
			
		||||
 | 
			
		||||
@ -59,7 +62,7 @@ class RLExperimentConfig:
 | 
			
		||||
    watch_num_episodes = 10
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RLExperiment(Generic[TPolicy, TTrainer]):
 | 
			
		||||
class RLExperiment(Generic[TPolicy, TTrainer], ToStringMixin):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        config: RLExperimentConfig,
 | 
			
		||||
@ -204,13 +207,15 @@ class RLExperimentBuilder:
 | 
			
		||||
        agent_factory = self._create_agent_factory()
 | 
			
		||||
        if self._policy_wrapper_factory:
 | 
			
		||||
            agent_factory.set_policy_wrapper_factory(self._policy_wrapper_factory)
 | 
			
		||||
        return RLExperiment(
 | 
			
		||||
        experiment = RLExperiment(
 | 
			
		||||
            self._config,
 | 
			
		||||
            self._env_factory,
 | 
			
		||||
            agent_factory,
 | 
			
		||||
            self._logger_factory,
 | 
			
		||||
            env_config=self._env_config,
 | 
			
		||||
        )
 | 
			
		||||
        log.info(f"Created experiment:\n{experiment.pprints()}")
 | 
			
		||||
        return experiment
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _BuilderMixinActorFactory:
 | 
			
		||||
 | 
			
		||||
@ -6,6 +6,7 @@ from typing import Literal, TypeAlias
 | 
			
		||||
from torch.utils.tensorboard import SummaryWriter
 | 
			
		||||
 | 
			
		||||
from tianshou.utils import TensorboardLogger, WandbLogger
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
TLogger: TypeAlias = TensorboardLogger | WandbLogger
 | 
			
		||||
 | 
			
		||||
@ -16,7 +17,7 @@ class Logger:
 | 
			
		||||
    log_path: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoggerFactory(ABC):
 | 
			
		||||
class LoggerFactory(ToStringMixin, ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def create_logger(self, log_name: str, run_id: int | None, config_dict: dict) -> Logger:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ from tianshou.highlevel.env import Environments, EnvType
 | 
			
		||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
 | 
			
		||||
from tianshou.utils.net import continuous, discrete
 | 
			
		||||
from tianshou.utils.net.common import BaseActor, Net
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ContinuousActorType:
 | 
			
		||||
@ -15,7 +16,7 @@ class ContinuousActorType:
 | 
			
		||||
    DETERMINISTIC = "deterministic"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActorFactory(ABC):
 | 
			
		||||
class ActorFactory(ToStringMixin, ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def create_module(self, envs: Environments, device: TDevice) -> BaseActor:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
@ -7,9 +7,10 @@ from tianshou.highlevel.env import Environments, EnvType
 | 
			
		||||
from tianshou.highlevel.module.core import TDevice, init_linear_orthogonal
 | 
			
		||||
from tianshou.utils.net import continuous
 | 
			
		||||
from tianshou.utils.net.common import Net
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CriticFactory(ABC):
 | 
			
		||||
class CriticFactory(ToStringMixin, ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def create_module(self, envs: Environments, device: TDevice, use_action: bool) -> nn.Module:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ from tianshou.highlevel.module.core import TDevice
 | 
			
		||||
from tianshou.highlevel.module.critic import CriticFactory
 | 
			
		||||
from tianshou.highlevel.optim import OptimizerFactory
 | 
			
		||||
from tianshou.utils.net.common import ActorCritic
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -30,7 +31,7 @@ class ActorCriticModuleOpt:
 | 
			
		||||
        return self.actor_critic_module.critic
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ActorModuleOptFactory:
 | 
			
		||||
class ActorModuleOptFactory(ToStringMixin):
 | 
			
		||||
    def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
 | 
			
		||||
        self.actor_factory = actor_factory
 | 
			
		||||
        self.optim_factory = optim_factory
 | 
			
		||||
@ -41,7 +42,7 @@ class ActorModuleOptFactory:
 | 
			
		||||
        return ModuleOpt(actor, opt)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CriticModuleOptFactory:
 | 
			
		||||
class CriticModuleOptFactory(ToStringMixin):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        critic_factory: CriticFactory,
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,10 @@ from typing import Any
 | 
			
		||||
import torch
 | 
			
		||||
from torch.optim import Adam, RMSprop
 | 
			
		||||
 | 
			
		||||
from tianshou.utils.string import ToStringMixin
 | 
			
		||||
 | 
			
		||||
class OptimizerFactory(ABC):
 | 
			
		||||
 | 
			
		||||
class OptimizerFactory(ABC, ToStringMixin):
 | 
			
		||||
    # TODO: Is it OK to assume that all optimizers have a learning rate argument?
 | 
			
		||||
    # Right now, the learning rate is typically a configuration parameter.
 | 
			
		||||
    # If we drop the assumption, we can't have that and will need to move the parameter
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user