Add ToStringMixin to further high-level parameter classes
This commit is contained in:
parent
8f67c2e9d9
commit
358978c65d
@ -8,6 +8,7 @@ import torch
|
|||||||
|
|
||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.utils.net.common import Net
|
from tianshou.utils.net.common import Net
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TDevice: TypeAlias = str | int | torch.device
|
TDevice: TypeAlias = str | int | torch.device
|
||||||
|
|
||||||
@ -29,7 +30,7 @@ class Module:
|
|||||||
output_dim: int
|
output_dim: int
|
||||||
|
|
||||||
|
|
||||||
class ModuleFactory(ABC):
|
class ModuleFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
def create_module(self, envs: Environments, device: TDevice) -> Module:
|
||||||
pass
|
pass
|
||||||
|
@ -6,9 +6,10 @@ import torch
|
|||||||
from tianshou.highlevel.env import Environments
|
from tianshou.highlevel.env import Environments
|
||||||
from tianshou.highlevel.module.core import TDevice
|
from tianshou.highlevel.module.core import TDevice
|
||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
class AutoAlphaFactory(ABC):
|
class AutoAlphaFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_auto_alpha(
|
def create_auto_alpha(
|
||||||
self,
|
self,
|
||||||
|
@ -6,11 +6,12 @@ import torch
|
|||||||
|
|
||||||
from tianshou.highlevel.env import Environments, EnvType
|
from tianshou.highlevel.env import Environments, EnvType
|
||||||
from tianshou.policy.modelfree.pg import TDistParams
|
from tianshou.policy.modelfree.pg import TDistParams
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
|
TDistributionFunction: TypeAlias = Callable[[TDistParams], torch.distributions.Distribution]
|
||||||
|
|
||||||
|
|
||||||
class DistributionFunctionFactory(ABC):
|
class DistributionFunctionFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
def create_dist_fn(self, envs: Environments) -> TDistributionFunction:
|
||||||
pass
|
pass
|
||||||
|
@ -3,12 +3,13 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import Generic, TypeVar
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TValue = TypeVar("TValue")
|
TValue = TypeVar("TValue")
|
||||||
TEnvs = TypeVar("TEnvs", bound=Environments)
|
TEnvs = TypeVar("TEnvs", bound=Environments)
|
||||||
|
|
||||||
|
|
||||||
class EnvValueFactory(Generic[TValue, TEnvs], ABC):
|
class EnvValueFactory(Generic[TValue, TEnvs], ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_value(self, envs: TEnvs) -> TValue:
|
def create_value(self, envs: TEnvs) -> TValue:
|
||||||
pass
|
pass
|
||||||
|
@ -5,9 +5,10 @@ import torch
|
|||||||
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
|
||||||
|
|
||||||
from tianshou.highlevel.config import RLSamplingConfig
|
from tianshou.highlevel.config import RLSamplingConfig
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
class LRSchedulerFactory(ABC):
|
class LRSchedulerFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
def create_scheduler(self, optim: torch.optim.Optimizer) -> LRScheduler:
|
||||||
pass
|
pass
|
||||||
|
@ -2,9 +2,10 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
from tianshou.exploration import BaseNoise, GaussianNoise
|
from tianshou.exploration import BaseNoise, GaussianNoise
|
||||||
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
from tianshou.highlevel.env import ContinuousEnvironments, Environments
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
|
|
||||||
class NoiseFactory(ABC):
|
class NoiseFactory(ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_noise(self, envs: Environments) -> BaseNoise:
|
def create_noise(self, envs: Environments) -> BaseNoise:
|
||||||
pass
|
pass
|
||||||
|
@ -7,12 +7,13 @@ from tianshou.highlevel.module.core import ModuleFactory, TDevice
|
|||||||
from tianshou.highlevel.optim import OptimizerFactory
|
from tianshou.highlevel.optim import OptimizerFactory
|
||||||
from tianshou.policy import BasePolicy, ICMPolicy
|
from tianshou.policy import BasePolicy, ICMPolicy
|
||||||
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
from tianshou.utils.net.discrete import IntrinsicCuriosityModule
|
||||||
|
from tianshou.utils.string import ToStringMixin
|
||||||
|
|
||||||
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
|
TPolicyIn = TypeVar("TPolicyIn", bound=BasePolicy)
|
||||||
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
TPolicyOut = TypeVar("TPolicyOut", bound=BasePolicy)
|
||||||
|
|
||||||
|
|
||||||
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ABC):
|
class PolicyWrapperFactory(Generic[TPolicyIn, TPolicyOut], ToStringMixin, ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_wrapped_policy(
|
def create_wrapped_policy(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user