Use ToStringMixin in dataclasses to detect recurring objects in larger object trees
This commit is contained in:
parent
d84e936430
commit
e63d8d4147
@ -1,11 +1,13 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SamplingConfig:
|
class SamplingConfig(ToStringMixin):
|
||||||
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
"""Sampling, epochs, parallelization, buffers, collectors, and batching."""
|
||||||
|
|
||||||
# TODO: What are reasonable defaults?
|
# TODO: What are the most reasonable defaults?
|
||||||
num_epochs: int = 100
|
num_epochs: int = 100
|
||||||
step_per_epoch: int = 30000
|
step_per_epoch: int = 30000
|
||||||
batch_size: int = 64
|
batch_size: int = 64
|
||||||
|
@ -138,7 +138,7 @@ class DiscreteEnvironments(Environments):
|
|||||||
return EnvType.DISCRETE
|
return EnvType.DISCRETE
|
||||||
|
|
||||||
|
|
||||||
class EnvFactory(ABC):
|
class EnvFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
def create_envs(self, config: PersistableConfigProtocol | None = None) -> Environments:
|
||||||
pass
|
pass
|
||||||
|
@ -21,6 +21,7 @@ from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
|
|||||||
from tianshou.highlevel.params.noise import NoiseFactory
|
from tianshou.highlevel.params.noise import NoiseFactory
|
||||||
from tianshou.policy.modelfree.pg import TDistributionFunction
|
from tianshou.policy.modelfree.pg import TDistributionFunction
|
||||||
from tianshou.utils import MultipleLRSchedulers
|
from tianshou.utils import MultipleLRSchedulers
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
@dataclass(kw_only=True)
|
@dataclass(kw_only=True)
|
||||||
@ -227,7 +228,7 @@ class GetParamTransformersProtocol(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Params(GetParamTransformersProtocol):
|
class Params(GetParamTransformersProtocol, ToStringMixin):
|
||||||
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
|
||||||
params = asdict(self)
|
params = asdict(self)
|
||||||
for transformer in self._get_param_transformers():
|
for transformer in self._get_param_transformers():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user