diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 878ae4b..4e8490d 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC): pass -class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? +class AutoAlphaFactoryDefault(AutoAlphaFactory): def __init__(self, lr: float = 3e-4): self.lr = lr @@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: target_entropy = float(-np.prod(envs.get_action_shape())) log_alpha = torch.zeros(1, requires_grad=True, device=device) - alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) + alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) return target_entropy, log_alpha, alpha_optim