Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer instead of passed factory

This commit is contained in:
Dominik Jain 2024-02-14 20:43:38 +01:00
parent 76cbd7efc2
commit eeb2081ca6

View File

@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC):
pass
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
class AutoAlphaFactoryDefault(AutoAlphaFactory):
def __init__(self, lr: float = 3e-4):
self.lr = lr
@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
target_entropy = float(-np.prod(envs.get_action_shape()))
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
return target_entropy, log_alpha, alpha_optim