Move parameter transformation directly into parameter objects,
achieving greater separation of concerns and improved maintainability
This commit is contained in:
		
							parent
							
								
									38cf982034
								
							
						
					
					
						commit
						d4e604b46e
					
				@ -1,7 +1,6 @@
 | 
			
		||||
import os
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from collections.abc import Callable
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -19,12 +18,8 @@ from tianshou.highlevel.module import (
 | 
			
		||||
    TDevice,
 | 
			
		||||
)
 | 
			
		||||
from tianshou.highlevel.optim import OptimizerFactory
 | 
			
		||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
 | 
			
		||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
 | 
			
		||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
 | 
			
		||||
from tianshou.highlevel.params.noise import NoiseFactory
 | 
			
		||||
from tianshou.highlevel.params.policy_params import (
 | 
			
		||||
    ParamTransformer,
 | 
			
		||||
    ParamTransformerData,
 | 
			
		||||
    PPOParams,
 | 
			
		||||
    SACParams,
 | 
			
		||||
    TD3Params,
 | 
			
		||||
@ -32,7 +27,6 @@ from tianshou.highlevel.params.policy_params import (
 | 
			
		||||
from tianshou.policy import BasePolicy, PPOPolicy, SACPolicy, TD3Policy
 | 
			
		||||
from tianshou.policy.modelfree.pg import TDistParams
 | 
			
		||||
from tianshou.trainer import BaseTrainer, OffpolicyTrainer, OnpolicyTrainer
 | 
			
		||||
from tianshou.utils import MultipleLRSchedulers
 | 
			
		||||
from tianshou.utils.net.common import ActorCritic
 | 
			
		||||
 | 
			
		||||
CHECKPOINT_DICT_KEY_MODEL = "model"
 | 
			
		||||
@ -145,26 +139,6 @@ class OffpolicyAgentFactory(AgentFactory, ABC):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerDrop(ParamTransformer):
 | 
			
		||||
    def __init__(self, *keys: str):
 | 
			
		||||
        self.keys = keys
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any]) -> None:
 | 
			
		||||
        for k in self.keys:
 | 
			
		||||
            del kwargs[k]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerLRScheduler(ParamTransformer):
 | 
			
		||||
    def __init__(self, optim: torch.optim.Optimizer):
 | 
			
		||||
        self.optim = optim
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any]) -> None:
 | 
			
		||||
        factory: LRSchedulerFactory | None = self.get(kwargs, "lr_scheduler_factory", drop=True)
 | 
			
		||||
        kwargs["lr_scheduler"] = (
 | 
			
		||||
            factory.create_scheduler(self.optim) if factory is not None else None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _ActorMixin:
 | 
			
		||||
    def __init__(self, actor_factory: ActorFactory, optim_factory: OptimizerFactory):
 | 
			
		||||
        self.actor_module_opt_factory = ActorModuleOptFactory(actor_factory, optim_factory)
 | 
			
		||||
@ -269,8 +243,12 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
 | 
			
		||||
    def create_policy(self, envs: Environments, device: TDevice) -> PPOPolicy:
 | 
			
		||||
        actor_critic = self.create_actor_critic_module_opt(envs, device, self.params.lr)
 | 
			
		||||
        kwargs = self.params.create_kwargs(
 | 
			
		||||
            ParamTransformerDrop("lr"),
 | 
			
		||||
            ParamTransformerLRScheduler(actor_critic.optim),
 | 
			
		||||
            ParamTransformerData(
 | 
			
		||||
                envs=envs,
 | 
			
		||||
                device=device,
 | 
			
		||||
                optim_factory=self.optim_factory,
 | 
			
		||||
                optim=actor_critic.optim,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        return PPOPolicy(
 | 
			
		||||
            actor=actor_critic.actor,
 | 
			
		||||
@ -282,43 +260,6 @@ class PPOAgentFactory(OnpolicyAgentFactory, _ActorCriticMixin):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerAlpha(ParamTransformer):
 | 
			
		||||
    def __init__(self, envs: Environments, optim_factory: OptimizerFactory, device: TDevice):
 | 
			
		||||
        self.envs = envs
 | 
			
		||||
        self.optim_factory = optim_factory
 | 
			
		||||
        self.device = device
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any]) -> None:
 | 
			
		||||
        key = "alpha"
 | 
			
		||||
        alpha = self.get(kwargs, key)
 | 
			
		||||
        if isinstance(alpha, AutoAlphaFactory):
 | 
			
		||||
            kwargs[key] = alpha.create_auto_alpha(self.envs, self.optim_factory, self.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
 | 
			
		||||
    def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]]):
 | 
			
		||||
        self.optim_key_list = optim_key_list
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any]) -> None:
 | 
			
		||||
        lr_schedulers = []
 | 
			
		||||
        for optim, lr_scheduler_factory_key in self.optim_key_list:
 | 
			
		||||
            lr_scheduler_factory: LRSchedulerFactory | None = self.get(
 | 
			
		||||
                kwargs,
 | 
			
		||||
                lr_scheduler_factory_key,
 | 
			
		||||
                drop=True,
 | 
			
		||||
            )
 | 
			
		||||
            if lr_scheduler_factory is not None:
 | 
			
		||||
                lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
 | 
			
		||||
        match len(lr_schedulers):
 | 
			
		||||
            case 0:
 | 
			
		||||
                lr_scheduler = None
 | 
			
		||||
            case 1:
 | 
			
		||||
                lr_scheduler = lr_schedulers[0]
 | 
			
		||||
            case _:
 | 
			
		||||
                lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
 | 
			
		||||
        kwargs["lr_scheduler"] = lr_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
@ -346,15 +287,14 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
 | 
			
		||||
        critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
 | 
			
		||||
        critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
 | 
			
		||||
        kwargs = self.params.create_kwargs(
 | 
			
		||||
            ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
 | 
			
		||||
            ParamTransformerMultiLRScheduler(
 | 
			
		||||
                [
 | 
			
		||||
                    (actor.optim, "actor_lr_scheduler_factory"),
 | 
			
		||||
                    (critic1.optim, "critic1_lr_scheduler_factory"),
 | 
			
		||||
                    (critic2.optim, "critic2_lr_scheduler_factory"),
 | 
			
		||||
                ],
 | 
			
		||||
            ParamTransformerData(
 | 
			
		||||
                envs=envs,
 | 
			
		||||
                device=device,
 | 
			
		||||
                optim_factory=self.optim_factory,
 | 
			
		||||
                actor=actor,
 | 
			
		||||
                critic1=critic1,
 | 
			
		||||
                critic2=critic2,
 | 
			
		||||
            ),
 | 
			
		||||
            ParamTransformerAlpha(envs, optim_factory=self.optim_factory, device=device),
 | 
			
		||||
        )
 | 
			
		||||
        return SACPolicy(
 | 
			
		||||
            actor=actor.module,
 | 
			
		||||
@ -369,28 +309,6 @@ class SACAgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerNoiseFactory(ParamTransformer):
 | 
			
		||||
    def __init__(self, key: str, envs: Environments):
 | 
			
		||||
        self.key = key
 | 
			
		||||
        self.envs = envs
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any]) -> None:
 | 
			
		||||
        value = kwargs[self.key]
 | 
			
		||||
        if isinstance(value, NoiseFactory):
 | 
			
		||||
            kwargs[self.key] = value.create_noise(self.envs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
 | 
			
		||||
    def __init__(self, key: str, envs: Environments):
 | 
			
		||||
        self.key = key
 | 
			
		||||
        self.envs = envs
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any]) -> None:
 | 
			
		||||
        value = kwargs[self.key]
 | 
			
		||||
        if isinstance(value, FloatEnvParamFactory):
 | 
			
		||||
            kwargs[self.key] = value.create_param(self.envs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
@ -418,17 +336,14 @@ class TD3AgentFactory(OffpolicyAgentFactory, _ActorAndDualCriticsMixin):
 | 
			
		||||
        critic1 = self.create_critic_module_opt(envs, device, self.params.critic1_lr)
 | 
			
		||||
        critic2 = self.create_critic2_module_opt(envs, device, self.params.critic2_lr)
 | 
			
		||||
        kwargs = self.params.create_kwargs(
 | 
			
		||||
            ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"),
 | 
			
		||||
            ParamTransformerMultiLRScheduler(
 | 
			
		||||
                [
 | 
			
		||||
                    (actor.optim, "actor_lr_scheduler_factory"),
 | 
			
		||||
                    (critic1.optim, "critic1_lr_scheduler_factory"),
 | 
			
		||||
                    (critic2.optim, "critic2_lr_scheduler_factory"),
 | 
			
		||||
                ],
 | 
			
		||||
            ParamTransformerData(
 | 
			
		||||
                envs=envs,
 | 
			
		||||
                device=device,
 | 
			
		||||
                optim_factory=self.optim_factory,
 | 
			
		||||
                actor=actor,
 | 
			
		||||
                critic1=critic1,
 | 
			
		||||
                critic2=critic2,
 | 
			
		||||
            ),
 | 
			
		||||
            ParamTransformerNoiseFactory("exploration_noise", envs),
 | 
			
		||||
            ParamTransformerFloatEnvParamFactory("policy_noise", envs),
 | 
			
		||||
            ParamTransformerFloatEnvParamFactory("noise_clip", envs),
 | 
			
		||||
        )
 | 
			
		||||
        return TD3Policy(
 | 
			
		||||
            actor=actor.module,
 | 
			
		||||
 | 
			
		||||
@ -1,19 +1,41 @@
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from dataclasses import asdict, dataclass
 | 
			
		||||
from dataclasses import asdict, dataclass, field
 | 
			
		||||
from typing import Any, Literal
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from tianshou.exploration import BaseNoise
 | 
			
		||||
from tianshou.highlevel.env import Environments
 | 
			
		||||
from tianshou.highlevel.module import ModuleOpt, TDevice
 | 
			
		||||
from tianshou.highlevel.optim import OptimizerFactory
 | 
			
		||||
from tianshou.highlevel.params.alpha import AutoAlphaFactory
 | 
			
		||||
from tianshou.highlevel.params.env_param import FloatEnvParamFactory
 | 
			
		||||
from tianshou.highlevel.params.lr_scheduler import LRSchedulerFactory
 | 
			
		||||
from tianshou.highlevel.params.noise import NoiseFactory
 | 
			
		||||
from tianshou.utils import MultipleLRSchedulers
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(kw_only=True)
 | 
			
		||||
class ParamTransformerData:
 | 
			
		||||
    """Holds data that can be used by `ParamTransformer` instances to perform their transformation.
 | 
			
		||||
 | 
			
		||||
    The representation contains the superset of all data items that are required by different types of agent factories.
 | 
			
		||||
    An agent factory is expected to set only the attributes that are relevant to its parameters.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    envs: Environments
 | 
			
		||||
    device: TDevice
 | 
			
		||||
    optim_factory: OptimizerFactory
 | 
			
		||||
    optim: torch.optim.Optimizer | None = None
 | 
			
		||||
    """the single optimizer for the case where there is just one"""
 | 
			
		||||
    actor: ModuleOpt | None = None
 | 
			
		||||
    critic1: ModuleOpt | None = None
 | 
			
		||||
    critic2: ModuleOpt | None = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformer(ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any]) -> None:
 | 
			
		||||
    def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -24,13 +46,152 @@ class ParamTransformer(ABC):
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerDrop(ParamTransformer):
 | 
			
		||||
    def __init__(self, *keys: str):
 | 
			
		||||
        self.keys = keys
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        for k in self.keys:
 | 
			
		||||
            del kwargs[k]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerLRScheduler(ParamTransformer):
 | 
			
		||||
    """Transforms a key containing a learning rate scheduler factory (removed) into a key containing
 | 
			
		||||
    a learning rate scheduler (added) for the data member `optim`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, key_scheduler_factory: str, key_scheduler: str):
 | 
			
		||||
        self.key_scheduler_factory = key_scheduler_factory
 | 
			
		||||
        self.key_scheduler = key_scheduler
 | 
			
		||||
 | 
			
		||||
    def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        assert data.optim is not None
 | 
			
		||||
        factory: LRSchedulerFactory | None = self.get(params, self.key_scheduler_factory, drop=True)
 | 
			
		||||
        params[self.key_scheduler] = (
 | 
			
		||||
            factory.create_scheduler(data.optim) if factory is not None else None
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerMultiLRScheduler(ParamTransformer):
 | 
			
		||||
    """Transforms several scheduler factories into a single scheduler, which may be a MultipleLRSchedulers instance
 | 
			
		||||
    if more than one factory is indeed given.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, optim_key_list: list[tuple[torch.optim.Optimizer, str]], key_scheduler: str):
 | 
			
		||||
        """:param optim_key_list: a list of tuples (optimizer, key of learning rate factory)
 | 
			
		||||
        :param key_scheduler: the key under which to store the resulting learning rate scheduler
 | 
			
		||||
        """
 | 
			
		||||
        self.optim_key_list = optim_key_list
 | 
			
		||||
        self.key_scheduler = key_scheduler
 | 
			
		||||
 | 
			
		||||
    def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        lr_schedulers = []
 | 
			
		||||
        for optim, lr_scheduler_factory_key in self.optim_key_list:
 | 
			
		||||
            lr_scheduler_factory: LRSchedulerFactory | None = self.get(
 | 
			
		||||
                params,
 | 
			
		||||
                lr_scheduler_factory_key,
 | 
			
		||||
                drop=True,
 | 
			
		||||
            )
 | 
			
		||||
            if lr_scheduler_factory is not None:
 | 
			
		||||
                lr_schedulers.append(lr_scheduler_factory.create_scheduler(optim))
 | 
			
		||||
        match len(lr_schedulers):
 | 
			
		||||
            case 0:
 | 
			
		||||
                lr_scheduler = None
 | 
			
		||||
            case 1:
 | 
			
		||||
                lr_scheduler = lr_schedulers[0]
 | 
			
		||||
            case _:
 | 
			
		||||
                lr_scheduler = MultipleLRSchedulers(*lr_schedulers)
 | 
			
		||||
        params[self.key_scheduler] = lr_scheduler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerActorDualCriticsLRScheduler(ParamTransformer):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        key_scheduler_factory_actor: str,
 | 
			
		||||
        key_scheduler_factory_critic1: str,
 | 
			
		||||
        key_scheduler_factory_critic2: str,
 | 
			
		||||
        key_scheduler: str,
 | 
			
		||||
    ):
 | 
			
		||||
        self.key_factory_actor = key_scheduler_factory_actor
 | 
			
		||||
        self.key_factory_critic1 = key_scheduler_factory_critic1
 | 
			
		||||
        self.key_factory_critic2 = key_scheduler_factory_critic2
 | 
			
		||||
        self.key_scheduler = key_scheduler
 | 
			
		||||
 | 
			
		||||
    def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        transformer = ParamTransformerMultiLRScheduler(
 | 
			
		||||
            [
 | 
			
		||||
                (data.actor.optim, self.key_factory_actor),
 | 
			
		||||
                (data.critic1.optim, self.key_factory_critic1),
 | 
			
		||||
                (data.critic2.optim, self.key_factory_critic2),
 | 
			
		||||
            ],
 | 
			
		||||
            self.key_scheduler,
 | 
			
		||||
        )
 | 
			
		||||
        transformer.transform(params, data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerAutoAlpha(ParamTransformer):
 | 
			
		||||
    def __init__(self, key: str):
 | 
			
		||||
        self.key = key
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        alpha = self.get(kwargs, self.key)
 | 
			
		||||
        if isinstance(alpha, AutoAlphaFactory):
 | 
			
		||||
            kwargs[self.key] = alpha.create_auto_alpha(data.envs, data.optim_factory, data.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerNoiseFactory(ParamTransformer):
 | 
			
		||||
    def __init__(self, key: str):
 | 
			
		||||
        self.key = key
 | 
			
		||||
 | 
			
		||||
    def transform(self, params: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        value = params[self.key]
 | 
			
		||||
        if isinstance(value, NoiseFactory):
 | 
			
		||||
            params[self.key] = value.create_noise(data.envs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ParamTransformerFloatEnvParamFactory(ParamTransformer):
 | 
			
		||||
    def __init__(self, key: str):
 | 
			
		||||
        self.key = key
 | 
			
		||||
 | 
			
		||||
    def transform(self, kwargs: dict[str, Any], data: ParamTransformerData) -> None:
 | 
			
		||||
        value = kwargs[self.key]
 | 
			
		||||
        if isinstance(value, FloatEnvParamFactory):
 | 
			
		||||
            kwargs[self.key] = value.create_param(data.envs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ITransformableParams(ABC):
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def _add_transformer(self, transformer: ParamTransformer):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Params:
 | 
			
		||||
    def create_kwargs(self, *transformers: ParamTransformer) -> dict[str, Any]:
 | 
			
		||||
        d = asdict(self)
 | 
			
		||||
        for transformer in transformers:
 | 
			
		||||
            transformer.transform(d)
 | 
			
		||||
        return d
 | 
			
		||||
class Params(ITransformableParams):
 | 
			
		||||
    _param_transformers: list[ParamTransformer] = field(
 | 
			
		||||
        init=False,
 | 
			
		||||
        default_factory=list,
 | 
			
		||||
        repr=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def _add_transformer(self, transformer: ParamTransformer):
 | 
			
		||||
        self._param_transformers.append(transformer)
 | 
			
		||||
 | 
			
		||||
    def create_kwargs(self, data: ParamTransformerData) -> dict[str, Any]:
 | 
			
		||||
        params = asdict(self)
 | 
			
		||||
        for transformer in self._param_transformers:
 | 
			
		||||
            transformer.transform(params, data)
 | 
			
		||||
        del params["_param_transformers"]
 | 
			
		||||
        return params
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ParamsMixinLearningRateWithScheduler(ITransformableParams, ABC):
 | 
			
		||||
    lr: float = 1e-3
 | 
			
		||||
    lr_scheduler_factory: LRSchedulerFactory | None = None
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self._add_transformer(ParamTransformerDrop("lr"))
 | 
			
		||||
        self._add_transformer(ParamTransformerLRScheduler("lr_scheduler_factory", "lr_scheduler"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
@ -54,7 +215,7 @@ class A2CParams(PGParams):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class PPOParams(A2CParams):
 | 
			
		||||
class PPOParams(A2CParams, ParamsMixinLearningRateWithScheduler):
 | 
			
		||||
    """PPO specific config."""
 | 
			
		||||
 | 
			
		||||
    eps_clip: float = 0.2
 | 
			
		||||
@ -62,12 +223,10 @@ class PPOParams(A2CParams):
 | 
			
		||||
    value_clip: bool = False
 | 
			
		||||
    advantage_normalization: bool = True
 | 
			
		||||
    recompute_advantage: bool = False
 | 
			
		||||
    lr: float = 1e-3
 | 
			
		||||
    lr_scheduler_factory: LRSchedulerFactory | None = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ActorAndDualCriticsParams(Params):
 | 
			
		||||
class ParamsMixinActorAndDualCritics(ITransformableParams, ABC):
 | 
			
		||||
    actor_lr: float = 1e-3
 | 
			
		||||
    critic1_lr: float = 1e-3
 | 
			
		||||
    critic2_lr: float = 1e-3
 | 
			
		||||
@ -75,21 +234,37 @@ class ActorAndDualCriticsParams(Params):
 | 
			
		||||
    critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
 | 
			
		||||
    critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self._add_transformer(ParamTransformerDrop("actor_lr", "critic1_lr", "critic2_lr"))
 | 
			
		||||
        self._add_transformer(
 | 
			
		||||
            ParamTransformerActorDualCriticsLRScheduler(
 | 
			
		||||
                "actor_lr_scheduler_factory",
 | 
			
		||||
                "critic1_lr_scheduler_factory",
 | 
			
		||||
                "critic2_lr_scheduler_factory",
 | 
			
		||||
                "lr_scheduler",
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class SACParams(ActorAndDualCriticsParams):
 | 
			
		||||
class SACParams(Params, ParamsMixinActorAndDualCritics):
 | 
			
		||||
    tau: float = 0.005
 | 
			
		||||
    gamma: float = 0.99
 | 
			
		||||
    alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
 | 
			
		||||
    estimation_step: int = 1
 | 
			
		||||
    exploration_noise: BaseNoise | Literal["default"] | None = None
 | 
			
		||||
    exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
 | 
			
		||||
    deterministic_eval: bool = True
 | 
			
		||||
    action_scaling: bool = True
 | 
			
		||||
    action_bound_method: Literal["clip"] | None = "clip"
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        ParamsMixinActorAndDualCritics.__post_init__(self)
 | 
			
		||||
        self._add_transformer(ParamTransformerAutoAlpha("alpha"))
 | 
			
		||||
        self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class TD3Params(ActorAndDualCriticsParams):
 | 
			
		||||
class TD3Params(Params, ParamsMixinActorAndDualCritics):
 | 
			
		||||
    tau: float = 0.005
 | 
			
		||||
    gamma: float = 0.99
 | 
			
		||||
    exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
 | 
			
		||||
@ -99,3 +274,9 @@ class TD3Params(ActorAndDualCriticsParams):
 | 
			
		||||
    estimation_step: int = 1
 | 
			
		||||
    action_scaling: bool = True
 | 
			
		||||
    action_bound_method: Literal["clip"] | None = "clip"
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        ParamsMixinActorAndDualCritics.__post_init__(self)
 | 
			
		||||
        self._add_transformer(ParamTransformerNoiseFactory("exploration_noise"))
 | 
			
		||||
        self._add_transformer(ParamTransformerFloatEnvParamFactory("policy_noise"))
 | 
			
		||||
        self._add_transformer(ParamTransformerFloatEnvParamFactory("noise_clip"))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user