Add ToStringMixin to further high-level parameter classes
This commit is contained in:
parent
8f67c2e9d9
commit
358978c65d
@ -8,6 +8,7 @@ import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.utils.net.common import Net
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TDevice: TypeAlias = str | int | torch.device
|
||||
|
||||
@ -29,7 +30,7 @@ class Module:
|
||||
output_dim: int
|
||||
|
||||
|
||||
class ModuleFactory(ABC):
|
||||
class ModuleFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||
pass
|
||||
|
@ -6,9 +6,10 @@ import torch
|
||||
from tianshou.highlevel.env import Environments
|
||||
from tianshou.highlevel.module.core import TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class AutoAlphaFactory(ABC):
|
||||
class AutoAlphaFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_auto_alpha(
|
||||
self,
|
||||
|
@ -6,11 +6,12 @@ import torch
|
||||
|
||||
from tianshou.highlevel.env import Environments, EnvType
|
||||
from tianshou.policy.modelfree.pg import TDistParams
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
|
||||
|
||||
|
||||
class DistributionFunctionFactory(ABC):
|
||||
class DistributionFunctionFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||
pass
|
||||
|
@ -3,12 +3,13 @@ from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TValue = TypeVar("TValue")
|
||||
TEnvs = TypeVar("TEnvs", bound=Environments)
|
||||
|
||||
|
||||
class EnvValueFactory(Generic[TValue, TEnvs], ABC):
|
||||
class EnvValueFactory(Generic[TValue, TEnvs], ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_value(self, envs: TEnvs) -> TValue:
|
||||
pass
|
||||
|
@ -5,9 +5,10 @@ import torch
|
||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||
|
||||
from tianshou.highlevel.config import RLSamplingConfig
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class LRSchedulerFactory(ABC):
|
||||
class LRSchedulerFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||
pass
|
||||
|
@ -2,9 +2,10 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
|
||||
class NoiseFactory(ABC):
|
||||
class NoiseFactory(ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_noise(self, envs: Environments) -> BaseNoise:
|
||||
pass
|
||||
|
@ -7,12 +7,13 @@ from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
||||
from tianshou.highlevel.optim import OptimizerFactory
|
||||
from tianshou.policy import BasePolicy, ICMPolicy
|
||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||
from tianshou.utils.string import ToStringMixin
|
||||
|
||||
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
|
||||
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
||||
|
||||
|
||||
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
|
||||
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
|
||||
@abstractmethod
|
||||
def create_wrapped_policy(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user