Add a comment before SAC alpha loss (#565)
Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
This commit is contained in:
		
							parent
							
								
									ad2e1eaea0
								
							
						
					
					
						commit
						74f430ea36
					
				@ -364,6 +364,7 @@ class AsyncCollector(Collector):
 | 
			
		||||
        exploration_noise: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        # assert env.is_async
 | 
			
		||||
        warnings.warn("Using async setting may collect extra transitions into buffer.")
 | 
			
		||||
        super().__init__(policy, env, buffer, preprocess_fn, exploration_noise)
 | 
			
		||||
 | 
			
		||||
    def reset_env(self) -> None:
 | 
			
		||||
@ -424,7 +425,6 @@ class AsyncCollector(Collector):
 | 
			
		||||
                "Please specify at least one (either n_step or n_episode) "
 | 
			
		||||
                "in AsyncCollector.collect()."
 | 
			
		||||
            )
 | 
			
		||||
        warnings.warn("Using async setting may collect extra transitions into buffer.")
 | 
			
		||||
 | 
			
		||||
        ready_env_ids = self._ready_env_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -174,6 +174,7 @@ class SACPolicy(DDPGPolicy):
 | 
			
		||||
 | 
			
		||||
        if self._is_auto_alpha:
 | 
			
		||||
            log_prob = obs_result.log_prob.detach() + self._target_entropy
 | 
			
		||||
            # please take a look at issue #258 if you'd like to change this line
 | 
			
		||||
            alpha_loss = -(self._log_alpha * log_prob).mean()
 | 
			
		||||
            self._alpha_optim.zero_grad()
 | 
			
		||||
            alpha_loss.backward()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user