Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer instead of passed factory
This commit is contained in:
parent
76cbd7efc2
commit
eeb2081ca6
@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
class AutoAlphaFactoryDefault(AutoAlphaFactory):
|
||||||
def __init__(self, lr: float = 3e-4):
|
def __init__(self, lr: float = 3e-4):
|
||||||
self.lr = lr
|
self.lr = lr
|
||||||
|
|
||||||
@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory): # TODO better name?
|
|||||||
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
|
||||||
target_entropy = float(-np.prod(envs.get_action_shape()))
|
target_entropy = float(-np.prod(envs.get_action_shape()))
|
||||||
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
log_alpha = torch.zeros(1, requires_grad=True, device=device)
|
||||||
alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
|
alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
|
||||||
return target_entropy, log_alpha, alpha_optim
|
return target_entropy, log_alpha, alpha_optim
|
||||||
|
Loading…
x
Reference in New Issue
Block a user