"""Factories for the generation of environment-dependent parameters.""" from abc import ABC, abstractmethod from typing import Generic, TypeVar from tianshou.highlevel.env import ContinuousEnvironments, Environments from tianshou.utils.string import ToStringMixin TValue = TypeVar("TValue") TEnvs = TypeVar("TEnvs", bound=Environments) class EnvValueFactory(Generic[TValue, TEnvs], ToStringMixin, ABC): @abstractmethod def create_value(self, envs: TEnvs) -> TValue: pass class FloatEnvValueFactory(EnvValueFactory[float, TEnvs], Generic[TEnvs], ABC): """Serves as a type bound for float value factories.""" class FloatEnvValueFactoryMaxActionScaled(FloatEnvValueFactory[ContinuousEnvironments]): def __init__(self, value: float): """:param value: value with which to scale the max action value""" self.value = value def create_value(self, envs: ContinuousEnvironments) -> float: envs.get_type().assert_continuous(self) return envs.max_action * self.value class MaxActionScaled(FloatEnvValueFactoryMaxActionScaled): pass