Add ToStringMixin to further high-level parameter classes

This commit is contained in:
Dominik Jain 2023-10-05 13:15:24 +02:00
parent 8f67c2e9d9
commit 358978c65d
7 changed files with 14 additions and 7 deletions

View File

@ -8,6 +8,7 @@ import torch
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.utils.net.common import Net from tianshou.utils.net.common import Net
from tianshou.utils.string import ToStringMixin
TDevice: TypeAlias = str | int | torch.device TDevice: TypeAlias = str | int | torch.device
@ -29,7 +30,7 @@ class Module:
output_dim: int output_dim: int
class ModuleFactory(ABC): class ModuleFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_module(self, envs: Environments, device: TDevice) -> Module: def create_module(self, envs: Environments, device: TDevice) -> Module:
pass pass

View File

@ -6,9 +6,10 @@ import torch
from tianshou.highlevel.env import Environments from tianshou.highlevel.env import Environments
from tianshou.highlevel.module.core import TDevice from tianshou.highlevel.module.core import TDevice
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.utils.string import ToStringMixin
class AutoAlphaFactory(ABC): class AutoAlphaFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_auto_alpha( def create_auto_alpha(
self, self,

View File

@ -6,11 +6,12 @@ import torch
from tianshou.highlevel.env import Environments, EnvType from tianshou.highlevel.env import Environments, EnvType
from tianshou.policy.modelfree.pg import TDistParams from tianshou.policy.modelfree.pg import TDistParams
from tianshou.utils.string import ToStringMixin
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution] TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
class DistributionFunctionFactory(ABC): class DistributionFunctionFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_dist_fn(self, envs: Environments) -> TDistributionFunction: def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
pass pass

View File

@ -3,12 +3,13 @@ from abc import ABC, abstractmethod
from typing import Generic, TypeVar from typing import Generic, TypeVar
from tianshou.highlevel.env import ContinuousEnvironments, Environments from tianshou.highlevel.env import ContinuousEnvironments, Environments
from tianshou.utils.string import ToStringMixin
TValue = TypeVar("TValue") TValue = TypeVar("TValue")
TEnvs = TypeVar("TEnvs", bound=Environments) TEnvs = TypeVar("TEnvs", bound=Environments)
class EnvValueFactory(Generic[TValue, TEnvs], ABC): class EnvValueFactory(Generic[TValue, TEnvs], ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_value(self, envs: TEnvs) -> TValue: def create_value(self, envs: TEnvs) -> TValue:
pass pass

View File

@ -5,9 +5,10 @@ import torch
from torch.optim.lr_scheduler import LambdaLR, LRScheduler from torch.optim.lr_scheduler import LambdaLR, LRScheduler
from tianshou.highlevel.config import RLSamplingConfig from tianshou.highlevel.config import RLSamplingConfig
from tianshou.utils.string import ToStringMixin
class LRSchedulerFactory(ABC): class LRSchedulerFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler: def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
pass pass

View File

@ -2,9 +2,10 @@ from abc import ABC, abstractmethod
from tianshou.exploration import BaseNoise, GaussianNoise from tianshou.exploration import BaseNoise, GaussianNoise
from tianshou.highlevel.env import ContinuousEnvironments, Environments from tianshou.highlevel.env import ContinuousEnvironments, Environments
from tianshou.utils.string import ToStringMixin
class NoiseFactory(ABC): class NoiseFactory(ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_noise(self, envs: Environments) -> BaseNoise: def create_noise(self, envs: Environments) -> BaseNoise:
pass pass

View File

@ -7,12 +7,13 @@ from tianshou.highlevel.module.core import ModuleFactory, TDevice
from tianshou.highlevel.optim import OptimizerFactory from tianshou.highlevel.optim import OptimizerFactory
from tianshou.policy import BasePolicy, ICMPolicy from tianshou.policy import BasePolicy, ICMPolicy
from tianshou.utils.net.discrete import IntrinsicCuriosityModule from tianshou.utils.net.discrete import IntrinsicCuriosityModule
from tianshou.utils.string import ToStringMixin
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy) TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy) TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC): class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
@abstractmethod @abstractmethod
def create_wrapped_policy( def create_wrapped_policy(
self, self,