Add ToStringMixin to further high-level parameter classes

This commit is contained in:
Dominik Jain 2023-10-05 13:15:24 +02:00
parent 8f67c2e9d9
commit 358978c65d
7 changed files with 14 additions and 7 deletions

View File

@ -8,6 +8,7 @@ import torch
from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | int | torch.device
@ -29,7 +30,7 @@ class Module:
output_dim: int
class ModuleFactory(ABC):
class ModuleFactory(ToStringMixin, ABC):
@abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> Module:
pass

View File

@ -6,9 +6,10 @@ import torch
from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.string import ToStringMixin
class AutoAlphaFactory(ABC):
class AutoAlphaFactory(ToStringMixin, ABC):
@abstractmethod
def create_auto_alpha(
self,

View File

@ -6,11 +6,12 @@ import torch
from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistParams
from tianshou.utils.string import ToStringMixin
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
class DistributionFunctionFactory(ABC):
class DistributionFunctionFactory(ToStringMixin, ABC):
@abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
pass

View File

@ -3,12 +3,13 @@ from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from tianshou.highlevel.env import ContinuousEnvironments, Environments
from tianshou.utils.string import ToStringMixin
TValue = TypeVar("TValue")
TEnvs = TypeVar("TEnvs", bound=Environments)
class EnvValueFactory(Generic[TValue, TEnvs], ABC):
class EnvValueFactory(Generic[TValue, TEnvs], ToStringMixin, ABC):
@abstractmethod
def create_value(self, envs: TEnvs) -> TValue:
pass

View File

@ -5,9 +5,10 @@ import torch
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig
from tianshou.utils.string import ToStringMixin
class LRSchedulerFactory(ABC):
class LRSchedulerFactory(ToStringMixin, ABC):
@abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass

View File

@ -2,9 +2,10 @@ from abc import ABC, abstractmethod
from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.highlevel.env import ContinuousEnvironments, Environments
from tianshou.utils.string import ToStringMixin
class NoiseFactory(ABC):
class NoiseFactory(ToStringMixin, ABC):
@abstractmethod
def create_noise(self, envs: Environments) -> BaseNoise:
pass

View File

@ -7,12 +7,13 @@ from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.utils.string import ToStringMixin
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
@abstractmethod
def create_wrapped_policy(
self,