From eeb2081ca68f7894ba5a8ab31de32e360ab6ded5 Mon Sep 17 00:00:00 2001 From: Dominik Jain Date: Wed, 14 Feb 2024 20:43:38 +0100 Subject: [PATCH] Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer instead of passed factory --- tianshou/highlevel/params/alpha.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tianshou/highlevel/params/alpha.py b/tianshou/highlevel/params/alpha.py index 878ae4b..4e8490d 100644 --- a/tianshou/highlevel/params/alpha.py +++ b/tianshou/highlevel/params/alpha.py @@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC): pass -class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? +class AutoAlphaFactoryDefault(AutoAlphaFactory): def __init__(self, lr: float = 3e-4): self.lr = lr @@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name? ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]: target_entropy = float(-np.prod(envs.get_action_shape())) log_alpha = torch.zeros(1, requires_grad=True, device=device) - alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr) + alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr) return target_entropy, log_alpha, alpha_optim