Use ToStringMixin in dataclasses to detect recurring objects in larger object trees

This commit is contained in:
Dominik Jain 2023-10-17 12:05:36 +02:00
parent d84e936430
commit e63d8d4147
3 changed files with 7 additions and 4 deletions

View File

@ -1,11 +1,13 @@
from dataclasses import dataclass
from tianshou.utils.string import ToStringMixin
@dataclass
class SamplingConfig:
class SamplingConfig(ToStringMixin):
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
# TODO: What are reasonable defaults?
# TODO: What are the most reasonable defaults?
num_epochs: int = 100
step_per_epoch: int = 30000
batch_size: int = 64

View File

@ -138,7 +138,7 @@ class DiscreteEnvironments(Environments):
return EnvType.DISCRETE
class EnvFactory(ABC):
class EnvFactory(ToStringMixin, ABC):
@abstractmethod
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
pass

View File

@ -21,6 +21,7 @@ from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
from tianshou.highlevel.params.noise import NoiseFactory
from tianshou.policy.modelfree.pg import TDistributionFunction
from tianshou.utils import MultipleLRSchedulers
from tianshou.utils.string import ToStringMixin
@dataclass(kw_only=True)
@ -227,7 +228,7 @@ class GetParamTransformersProtocol(Protocol):
@dataclass
class Params(GetParamTransformersProtocol):
class Params(GetParamTransformersProtocol, ToStringMixin):
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
params = asdict(self)
for transformer in self._get_param_transformers():