Fix AutoAlphaFactoryDefault using hard-coded Adam optimizer instead of passed factory
This commit is contained in:
		
							parent
							
								
									76cbd7efc2
								
							
						
					
					
						commit
						eeb2081ca6
					
				@ -20,7 +20,7 @@ class AutoAlphaFactory(ToStringMixin, ABC):
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AutoAlphaFactoryDefault(AutoAlphaFactory):  # TODO better name?
 | 
			
		||||
class AutoAlphaFactoryDefault(AutoAlphaFactory):
 | 
			
		||||
    def __init__(self, lr: float = 3e-4):
 | 
			
		||||
        self.lr = lr
 | 
			
		||||
 | 
			
		||||
@ -32,5 +32,5 @@ class AutoAlphaFactoryDefault(AutoAlphaFactory):  # TODO better name?
 | 
			
		||||
    ) -> tuple[float, torch.Tensor, torch.optim.Optimizer]:
 | 
			
		||||
        target_entropy = float(-np.prod(envs.get_action_shape()))
 | 
			
		||||
        log_alpha = torch.zeros(1, requires_grad=True, device=device)
 | 
			
		||||
        alpha_optim = torch.optim.Adam([log_alpha], lr=self.lr)
 | 
			
		||||
        alpha_optim = optim_factory.create_optimizer_for_params([log_alpha], self.lr)
 | 
			
		||||
        return target_entropy, log_alpha, alpha_optim
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user