Add documentation to parameters, improve factorisation
This commit is contained in:
parent
e63d8d4147
commit
ff451f8373
@ -256,9 +256,13 @@ class ParamsMixinLearningRateWithScheduler(GetParamTransformersProtocol):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
||||||
actor_lr: float = 1e-3
|
actor_lr: float = 1e-3
|
||||||
|
"""the learning rate to use for the actor network"""
|
||||||
critic_lr: float = 1e-3
|
critic_lr: float = 1e-3
|
||||||
|
"""the learning rate to use for the critic network"""
|
||||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
"""factory for the creation of a learning rate scheduler to use for the actor network (if any)"""
|
||||||
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
"""factory for the creation of a learning rate scheduler to use for the critic network (if any)"""
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
return [
|
return [
|
||||||
@ -272,17 +276,60 @@ class ParamsMixinActorAndCritic(GetParamTransformersProtocol):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PGParams(Params, ParamsMixinLearningRateWithScheduler):
|
class ParamsMixinActionScaling(GetParamTransformersProtocol):
|
||||||
discount_factor: float = 0.99
|
|
||||||
reward_normalization: bool = False
|
|
||||||
deterministic_eval: bool = False
|
|
||||||
action_scaling: bool | Literal["default"] = "default"
|
action_scaling: bool | Literal["default"] = "default"
|
||||||
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
"""whether to apply action scaling; when set to "default", it will be enabled for continuous action spaces"""
|
||||||
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
action_bound_method: Literal["clip", "tanh"] | None = "clip"
|
||||||
|
"""
|
||||||
|
method to bound action to range [-1, 1]. Only used if the action_space is continuous.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParamsMixinExplorationNoise(GetParamTransformersProtocol):
|
||||||
|
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
|
||||||
|
"""
|
||||||
|
If not None, add noise to actions for exploration.
|
||||||
|
This is useful when solving "hard exploration" problems.
|
||||||
|
It can either be a distribution, a factory for the creation of a distribution or "default".
|
||||||
|
When set to "default", use Gaussian noise with standard deviation 0.1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
return [ParamTransformerNoiseFactory("exploration_noise")]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PGParams(Params, ParamsMixinActionScaling, ParamsMixinLearningRateWithScheduler):
|
||||||
|
discount_factor: float = 0.99
|
||||||
|
"""
|
||||||
|
discount factor (gamma) for future rewards; must be in [0, 1]
|
||||||
|
"""
|
||||||
|
reward_normalization: bool = False
|
||||||
|
"""
|
||||||
|
if True, will normalize the returns by subtracting the running mean and dividing by the running
|
||||||
|
standard deviation.
|
||||||
|
"""
|
||||||
|
deterministic_eval: bool = False
|
||||||
|
"""
|
||||||
|
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
|
||||||
|
Does not affect training.
|
||||||
|
"""
|
||||||
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
dist_fn: TDistributionFunction | DistributionFunctionFactory | Literal["default"] = "default"
|
||||||
|
"""
|
||||||
|
This can either be a function which maps the model output to a torch distribution or a
|
||||||
|
factory for the creation of such a function.
|
||||||
|
When set to "default", a factory which creates Gaussian distributions from mean and standard
|
||||||
|
deviation will be used for the continuous case and which creates categorical distributions
|
||||||
|
for the discrete case (see :class:`DistributionFunctionFactoryDefault`)
|
||||||
|
"""
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
||||||
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
transformers.extend(ParamsMixinLearningRateWithScheduler._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
transformers.append(ParamTransformerActionScaling("action_scaling"))
|
||||||
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
|
transformers.append(ParamTransformerDistributionFunction("dist_fn"))
|
||||||
@ -290,47 +337,125 @@ class PGParams(Params, ParamsMixinLearningRateWithScheduler):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class A2CParams(PGParams):
|
class ParamsMixinGeneralAdvantageEstimation(GetParamTransformersProtocol):
|
||||||
vf_coef: float = 0.5
|
|
||||||
ent_coef: float = 0.01
|
|
||||||
max_grad_norm: float | None = None
|
|
||||||
gae_lambda: float = 0.95
|
gae_lambda: float = 0.95
|
||||||
|
"""
|
||||||
|
determines the blend between Monte Carlo and one-step temporal difference (TD) estimates of the advantage
|
||||||
|
function in general advantage estimation (GAE).
|
||||||
|
A value of 0 gives a fully TD-based estimate; lambda=1 gives a fully Monte Carlo estimate.
|
||||||
|
"""
|
||||||
max_batchsize: int = 256
|
max_batchsize: int = 256
|
||||||
|
"""the maximum size of the batch when computing general advantage estimation (GAE)"""
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class A2CParams(PGParams, ParamsMixinGeneralAdvantageEstimation):
|
||||||
|
vf_coef: float = 0.5
|
||||||
|
"""weight (coefficient) of the value loss in the loss function"""
|
||||||
|
ent_coef: float = 0.01
|
||||||
|
"""weight (coefficient) of the entropy loss in the loss function"""
|
||||||
|
max_grad_norm: float | None = None
|
||||||
|
"""maximum norm for clipping gradients in backpropagation"""
|
||||||
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PPOParams(A2CParams):
|
class PPOParams(A2CParams):
|
||||||
eps_clip: float = 0.2
|
eps_clip: float = 0.2
|
||||||
|
"""
|
||||||
|
determines the range of allowed change in the policy during a policy update:
|
||||||
|
The ratio between the probabilities indicated by the new and old policy is
|
||||||
|
constrained to stay in the interval [1 - eps_clip, 1 + eps_clip].
|
||||||
|
Small values thus force the new policy to stay close to the old policy.
|
||||||
|
Typical values range between 0.1 and 0.3.
|
||||||
|
The optimal epsilon depends on the environment; more stochastic environments may need larger epsilons.
|
||||||
|
"""
|
||||||
dual_clip: float | None = None
|
dual_clip: float | None = None
|
||||||
|
"""
|
||||||
|
determines the lower bound clipping for the probability ratio
|
||||||
|
(corresponds to parameter c in arXiv:1912.09729, Equation 5).
|
||||||
|
If set to None, dual clipping is not used and the bounds described in parameter eps_clip apply.
|
||||||
|
If set to a float value c, the lower bound is changed from 1 - eps_clip to c,
|
||||||
|
where c < 1 - eps_clip.
|
||||||
|
Setting c > 0 reduces policy oscillation and further stabilizes training.
|
||||||
|
Typical values are between 0 and 0.5. Smaller values provide more stability.
|
||||||
|
Setting c = 0 yields PPO with only the upper bound.
|
||||||
|
"""
|
||||||
value_clip: bool = False
|
value_clip: bool = False
|
||||||
|
"""
|
||||||
|
whether to apply clipping of the predicted value function during policy learning.
|
||||||
|
Value clipping discourages large changes in value predictions between updates.
|
||||||
|
Inaccurate value predictions can lead to bad policy updates, which can cause training instability.
|
||||||
|
Clipping values prevents sporadic large errors from skewing policy updates too much.
|
||||||
|
"""
|
||||||
advantage_normalization: bool = True
|
advantage_normalization: bool = True
|
||||||
|
"""whether to apply per mini-batch advantage normalization."""
|
||||||
recompute_advantage: bool = False
|
recompute_advantage: bool = False
|
||||||
|
"""
|
||||||
|
whether to recompute advantage every update repeat as described in
|
||||||
|
https://arxiv.org/pdf/2006.05990.pdf, Sec. 3.5.
|
||||||
|
The original PPO implementation splits the data in each policy iteration
|
||||||
|
step into individual transitions and then randomly assigns them to minibatches.
|
||||||
|
This makes it impossible to compute advantages as the temporal structure is broken.
|
||||||
|
Therefore, the advantages are computed once at the beginning of each policy iteration step and
|
||||||
|
then used in minibatch policy and value function optimization.
|
||||||
|
This results in higher diversity of data in each minibatch at the cost of
|
||||||
|
using slightly stale advantage estimations.
|
||||||
|
Enabling this option will, as a remedy to this problem, recompute the advantages at the beginning
|
||||||
|
of each pass over the data instead of just once per iteration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class NPGParams(PGParams):
|
class NPGParams(PGParams, ParamsMixinGeneralAdvantageEstimation):
|
||||||
optim_critic_iters: int = 5
|
optim_critic_iters: int = 5
|
||||||
|
"""number of times to optimize critic network per update."""
|
||||||
actor_step_size: float = 0.5
|
actor_step_size: float = 0.5
|
||||||
|
"""step size for actor update in natural gradient direction"""
|
||||||
advantage_normalization: bool = True
|
advantage_normalization: bool = True
|
||||||
gae_lambda: float = 0.95
|
"""whether to do per mini-batch advantage normalization."""
|
||||||
max_batchsize: int = 256
|
|
||||||
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
|
transformers = super()._get_param_transformers()
|
||||||
|
transformers.extend(ParamsMixinGeneralAdvantageEstimation._get_param_transformers(self))
|
||||||
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TRPOParams(NPGParams):
|
class TRPOParams(NPGParams):
|
||||||
max_kl: float = 0.01
|
max_kl: float = 0.01
|
||||||
|
"""
|
||||||
|
maximum KL divergence, used to constrain each actor network update.
|
||||||
|
"""
|
||||||
backtrack_coeff: float = 0.8
|
backtrack_coeff: float = 0.8
|
||||||
|
"""
|
||||||
|
coefficient with which to reduce the step size when constraints are not met.
|
||||||
|
"""
|
||||||
max_backtracks: int = 10
|
max_backtracks: int = 10
|
||||||
|
"""maximum number of times to backtrack in line search when the constraints are not met."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
||||||
actor_lr: float = 1e-3
|
actor_lr: float = 1e-3
|
||||||
|
"""the learning rate to use for the actor network"""
|
||||||
critic1_lr: float = 1e-3
|
critic1_lr: float = 1e-3
|
||||||
|
"""the learning rate to use for the first critic network"""
|
||||||
critic2_lr: float = 1e-3
|
critic2_lr: float = 1e-3
|
||||||
|
"""the learning rate to use for the second critic network"""
|
||||||
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
actor_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
"""factory for the creation of a learning rate scheduler to use for the actor network (if any)"""
|
||||||
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic1_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
"""factory for the creation of a learning rate scheduler to use for the first critic network (if any)"""
|
||||||
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
critic2_lr_scheduler_factory: LRSchedulerFactory | None = None
|
||||||
|
"""factory for the creation of a learning rate scheduler to use for the second critic network (if any)"""
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
return [
|
return [
|
||||||
@ -345,46 +470,69 @@ class ParamsMixinActorAndDualCritics(GetParamTransformersProtocol):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SACParams(Params, ParamsMixinActorAndDualCritics):
|
class _SACParams(Params, ParamsMixinActorAndDualCritics):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
|
"""controls the contribution of the entropy term in the overall optimization objective,
|
||||||
|
i.e. the desired amount of randomness in the optimal policy.
|
||||||
|
Higher values mean greater target entropy and therefore more randomness in the policy.
|
||||||
|
Lower values mean lower target entropy and therefore a more deterministic policy.
|
||||||
|
"""
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
|
||||||
|
alpha: float | AutoAlphaFactory = 0.2
|
||||||
|
"""
|
||||||
|
controls the relative importance (coefficient) of the entropy term in the loss function.
|
||||||
|
This can be a constant or a factory for the creation of a representation that allows the
|
||||||
|
parameter to be automatically tuned;
|
||||||
|
use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard
|
||||||
|
auto-adjusted alpha.
|
||||||
|
"""
|
||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = None
|
"""the number of steps to look ahead"""
|
||||||
deterministic_eval: bool = True
|
|
||||||
action_scaling: bool = True
|
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
||||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
|
||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DiscreteSACParams(Params, ParamsMixinActorAndDualCritics):
|
class SACParams(_SACParams, ParamsMixinExplorationNoise, ParamsMixinActionScaling):
|
||||||
tau: float = 0.005
|
deterministic_eval: bool = True
|
||||||
gamma: float = 0.99
|
"""
|
||||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
whether to use deterministic action (mean of Gaussian policy) in evaluation mode instead of stochastic
|
||||||
estimation_step: int = 1
|
action sampled by the policy. Does not affect training."""
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerAutoAlpha("alpha"))
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DiscreteSACParams(_SACParams):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
||||||
discount_factor: float = 0.99
|
discount_factor: float = 0.99
|
||||||
|
"""
|
||||||
|
discount factor (gamma) for future rewards; must be in [0, 1]
|
||||||
|
"""
|
||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
|
"""the number of steps to look ahead"""
|
||||||
target_update_freq: int = 0
|
target_update_freq: int = 0
|
||||||
|
"""the target network update frequency (0 if no target network is to be used)"""
|
||||||
reward_normalization: bool = False
|
reward_normalization: bool = False
|
||||||
|
"""whether to normalize the returns to Normal(0, 1)"""
|
||||||
is_double: bool = True
|
is_double: bool = True
|
||||||
|
"""whether to use double Q learning"""
|
||||||
clip_loss_grad: bool = False
|
clip_loss_grad: bool = False
|
||||||
|
"""whether to clip the gradient of the loss in accordance with nature14236; this amounts to using the Huber
|
||||||
|
loss instead of the MSE loss."""
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
@ -395,9 +543,13 @@ class DQNParams(Params, ParamsMixinLearningRateWithScheduler):
|
|||||||
@dataclass
|
@dataclass
|
||||||
class IQNParams(DQNParams):
|
class IQNParams(DQNParams):
|
||||||
sample_size: int = 32
|
sample_size: int = 32
|
||||||
|
"""the number of samples for policy evaluation"""
|
||||||
online_sample_size: int = 8
|
online_sample_size: int = 8
|
||||||
|
"""the number of samples for online model in training"""
|
||||||
target_sample_size: int = 8
|
target_sample_size: int = 8
|
||||||
|
"""the number of samples for target model in training."""
|
||||||
num_quantiles: int = 200
|
num_quantiles: int = 200
|
||||||
|
"""the number of quantile midpoints in the inverse cumulative distribution function of the value"""
|
||||||
hidden_sizes: Sequence[int] = ()
|
hidden_sizes: Sequence[int] = ()
|
||||||
"""hidden dimensions to use in the IQN network"""
|
"""hidden dimensions to use in the IQN network"""
|
||||||
num_cosines: int = 64
|
num_cosines: int = 64
|
||||||
@ -410,29 +562,54 @@ class IQNParams(DQNParams):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DDPGParams(Params, ParamsMixinActorAndCritic):
|
class DDPGParams(
|
||||||
|
Params,
|
||||||
|
ParamsMixinActorAndCritic,
|
||||||
|
ParamsMixinExplorationNoise,
|
||||||
|
ParamsMixinActionScaling,
|
||||||
|
):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
|
"""
|
||||||
|
controls the soft update of the target network.
|
||||||
|
It determines how slowly the target networks track the main networks.
|
||||||
|
Smaller tau means slower tracking and more stable learning.
|
||||||
|
"""
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
|
||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
action_scaling: bool = True
|
"""the number of steps to look ahead."""
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
transformers.extend(ParamsMixinActorAndCritic._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
|
||||||
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
||||||
return transformers
|
return transformers
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class REDQParams(DDPGParams):
|
class REDQParams(DDPGParams):
|
||||||
ensemble_size: int = 10
|
ensemble_size: int = 10
|
||||||
|
"""the number of sub-networks in the critic ensemble"""
|
||||||
subset_size: int = 2
|
subset_size: int = 2
|
||||||
alpha: float | tuple[float, torch.Tensor, torch.optim.Optimizer] | AutoAlphaFactory = 0.2
|
"""the number of networks in the subset"""
|
||||||
|
alpha: float | AutoAlphaFactory = 0.2
|
||||||
|
"""
|
||||||
|
controls the relative importance (coefficient) of the entropy term in the loss function.
|
||||||
|
This can be a constant or a factory for the creation of a representation that allows the
|
||||||
|
parameter to be automatically tuned;
|
||||||
|
use :class:`tianshou.highlevel.params.alpha.AutoAlphaFactoryDefault` for the standard
|
||||||
|
auto-adjusted alpha.
|
||||||
|
"""
|
||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
|
"""the number of steps to look ahead"""
|
||||||
actor_delay: int = 20
|
actor_delay: int = 20
|
||||||
|
"""the number of critic updates before an actor update"""
|
||||||
deterministic_eval: bool = True
|
deterministic_eval: bool = True
|
||||||
|
"""
|
||||||
|
whether to use deterministic action (the dist's mode) instead of stochastic one during evaluation.
|
||||||
|
Does not affect training.
|
||||||
|
"""
|
||||||
target_mode: Literal["mean", "min"] = "min"
|
target_mode: Literal["mean", "min"] = "min"
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
@ -442,21 +619,34 @@ class REDQParams(DDPGParams):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TD3Params(Params, ParamsMixinActorAndDualCritics):
|
class TD3Params(
|
||||||
|
Params,
|
||||||
|
ParamsMixinActorAndDualCritics,
|
||||||
|
ParamsMixinExplorationNoise,
|
||||||
|
ParamsMixinActionScaling,
|
||||||
|
):
|
||||||
tau: float = 0.005
|
tau: float = 0.005
|
||||||
|
"""
|
||||||
|
controls the soft update of the target network.
|
||||||
|
It determines how slowly the target networks track the main networks.
|
||||||
|
Smaller tau means slower tracking and more stable learning.
|
||||||
|
"""
|
||||||
gamma: float = 0.99
|
gamma: float = 0.99
|
||||||
exploration_noise: BaseNoise | Literal["default"] | NoiseFactory | None = "default"
|
"""discount factor (gamma) for future rewards; must be in [0, 1]"""
|
||||||
policy_noise: float | FloatEnvValueFactory = 0.2
|
policy_noise: float | FloatEnvValueFactory = 0.2
|
||||||
|
"""the scale of the the noise used in updating policy network"""
|
||||||
noise_clip: float | FloatEnvValueFactory = 0.5
|
noise_clip: float | FloatEnvValueFactory = 0.5
|
||||||
|
"""determines the clipping range of the noise used in updating the policy network as [-noise_clip, noise_clip]"""
|
||||||
update_actor_freq: int = 2
|
update_actor_freq: int = 2
|
||||||
|
"""the update frequency of actor network"""
|
||||||
estimation_step: int = 1
|
estimation_step: int = 1
|
||||||
action_scaling: bool = True
|
"""the number of steps to look ahead."""
|
||||||
action_bound_method: Literal["clip"] | None = "clip"
|
|
||||||
|
|
||||||
def _get_param_transformers(self) -> list[ParamTransformer]:
|
def _get_param_transformers(self) -> list[ParamTransformer]:
|
||||||
transformers = super()._get_param_transformers()
|
transformers = super()._get_param_transformers()
|
||||||
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
transformers.extend(ParamsMixinActorAndDualCritics._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerNoiseFactory("exploration_noise"))
|
transformers.extend(ParamsMixinExplorationNoise._get_param_transformers(self))
|
||||||
|
transformers.extend(ParamsMixinActionScaling._get_param_transformers(self))
|
||||||
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
transformers.append(ParamTransformerFloatEnvParamFactory("policy_noise"))
|
||||||
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
transformers.append(ParamTransformerFloatEnvParamFactory("noise_clip"))
|
||||||
return transformers
|
return transformers
|
||||||
|
Loading…
x
Reference in New Issue
Block a user