From e63d8d41471d561e82c90fcb36d377dd64d0a190 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Tue, 17 Oct 2023 12:05:36 +0200 Subject: [PATCH] Use ToStringMixin in dataclasses to detect recurring objects in larger object trees --- tianshou/highlevel/config.py | 6 ++++-- tianshou/highlevel/env.py | 2 +- tianshou/highlevel/params/policy_params.py | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tianshou/highlevel/config.py b/tianshou/highlevel/config.py index e3f74b2..bca77e0 100644 --- a/tianshou/highlevel/config.py +++ b/tianshou/highlevel/config.py @@ -1,11 +1,13 @@ from dataclasses import dataclass +from tianshou.utils.string import ToStringMixin + @dataclass -class SamplingConfig: +class SamplingConfig(ToStringMixin): """Sampling, epochs, parallelization, buffers, collectors, and batching.""" - # TODO: What are reasonable defaults? + # TODO: What are the most reasonable defaults? num_epochs: int = 100 step_per_epoch: int = 30000 batch_size: int = 64 diff --git a/tianshou/highlevel/env.py b/tianshou/highlevel/env.py index 6fdc970..d11eddc 100644 --- a/tianshou/highlevel/env.py +++ b/tianshou/highlevel/env.py @@ -138,7 +138,7 @@ class DiscreteEnvironments(Environments): return EnvType.DISCRETE -class EnvFactory(ABC): +class EnvFactory(ToStringMixin, ABC): @abstractmethod def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: pass diff --git a/tianshou/highlevel/params/policy_params.py b/tianshou/highlevel/params/policy_params.py index 9d8cd53..8d64537 100644 --- a/tianshou/highlevel/params/policy_params.py +++ b/tianshou/highlevel/params/policy_params.py @@ -21,6 +21,7 @@ from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory from tianshou.highlevel.params.noise import NoiseFactory from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.utils import MultipleLRSchedulers +from tianshou.utils.string import ToStringMixin @dataclass(kw_only=True) @@ -227,7 +228,7 @@ class GetParamTransformersProtocol(Protocol): @dataclass -class Params(GetParamTransformersProtocol): +class Params(GetParamTransformersProtocol, ToStringMixin): def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: params = asdict(self) for transformer in self._get_param_transformers():