diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index e3f74b2..bca77e0 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,11 +1,13 @@ from dataclasses import dataclass +from tianshou.utils.string import ToStringMixin + @dataclass -class SamplingConfig: +class SamplingConfig(ToStringMixin): """Sampling, epochs, parallelization, buffers, collectors, and batching.""" - # TODO: What are reasonable defaults? + # TODO: What are the most reasonable defaults? num_epochs: int = 100 step_per_epoch: int = 30000 batch_size: int = 64 diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 6fdc970..d11eddc 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -138,7 +138,7 @@ class DiscreteEnvironments(Environments): return EnvType.DISCRETE -class EnvFactory(ABC): +class EnvFactory(ToStringMixin, ABC): @abstractmethod def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: pass diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 9d8cd53..8d64537 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -21,6 +21,7 @@ from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.string import ToStringMixin @dataclass(kw_only=True) @@ -227,7 +228,7 @@ class GetParamTransformersProtocol(Protocol): @dataclass -class Params(GetParamTransformersProtocol): +class Params(GetParamTransformersProtocol, ToStringMixin): def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: params = asdict(self) for transformer in self._get_param_transformers():