Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer instead of passed factory
This commit is contained in:
parent
76cbd7efc2
commit
eeb2081ca6
@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
||||
class AutoAlphaFactoryDefault(AutoAlphaFactory):
|
||||
def __init__(self, lr: float = 3e-4):
|
||||
self.lr = lr
|
||||
|
||||
@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||
target_entropy = float(-np.prod(envs.get_action_shape()))
|
||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
||||
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
|
||||
return target_entropy, log_alpha, alpha_optim
|
||||
|
Loading…
x
Reference in New Issue
Block a user