Use ToStringMixin in dataclasses to detect recurring objects in larger object trees

This commit is contained in:
Dominik Jain 2023-10-17 12:05:36 +02:00
parent d84e936430
commit e63d8d4147
3 changed files with 7 additions and 4 deletions

View File

@ -1,11 +1,13 @@
from dataclasses import dataclass from dataclasses import dataclass
from tianshou.utils.string import ToStringMixin
@dataclass @dataclass
class SamplingConfig: class SamplingConfig(ToStringMixin):
"""Sampling, epochs, parallelization, buffers, collectors, and batching.""" """Sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are reasonable defaults? # TODO: What are the most reasonable defaults?
num_epochs: int = 100 num_epochs: int = 100
step_per_epoch: int = 30000 step_per_epoch: int = 30000
batch_size: int = 64 batch_size: int = 64

View File

@ -138,7 +138,7 @@ class DiscreteEnvironments(Environments):
return EnvType.DISCRETE return EnvType.DISCRETE
class EnvFactory(ABC): class EnvFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments: def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
pass pass

View File

@ -21,6 +21,7 @@ from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils import MultipleLRSchedulers from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin
@dataclass(kw_only=True) @dataclass(kw_only=True)
@ -227,7 +228,7 @@ class GetParamTransformersProtocol(Protocol):
@dataclass @dataclass
class Params(GetParamTransformersProtocol): class Params(GetParamTransformersProtocol, ToStringMixin):
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]: def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self) params = asdict(self)
for transformer in self._get_param_transformers(): for transformer in self._get_param_transformers():