parent
							
								
									623bf24f0c
								
							
						
					
					
						commit
						eec0826fd3
					
				| @ -6,8 +6,7 @@ from numba import njit | ||||
| from abc import ABC, abstractmethod | ||||
| from typing import Any, List, Union, Mapping, Optional, Callable | ||||
| 
 | ||||
| from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ | ||||
|     to_torch_as, to_numpy | ||||
| from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy | ||||
| 
 | ||||
| 
 | ||||
| class BasePolicy(ABC, nn.Module): | ||||
| @ -138,9 +137,7 @@ class BasePolicy(ABC, nn.Module): | ||||
|         Typical usage is to update the sampling weight in prioritized | ||||
|         experience replay. Used in :meth:`update`. | ||||
|         """ | ||||
|         if isinstance(buffer, PrioritizedReplayBuffer) and hasattr( | ||||
|             batch, "weight" | ||||
|         ): | ||||
|         if hasattr(buffer, "update_weight") and hasattr(batch, "weight"): | ||||
|             buffer.update_weight(indice, batch.weight) | ||||
| 
 | ||||
|     def update( | ||||
| @ -253,8 +250,7 @@ class BasePolicy(ABC, nn.Module): | ||||
|                                  gamma, n_step, len(buffer), mean, std) | ||||
| 
 | ||||
|         batch.returns = to_torch_as(target_q, target_q_torch) | ||||
|         # prio buffer update | ||||
|         if isinstance(buffer, PrioritizedReplayBuffer): | ||||
|         if hasattr(batch, "weight"):  # prio buffer update | ||||
|             batch.weight = to_torch_as(batch.weight, target_q_torch) | ||||
|         return batch | ||||
| 
 | ||||
|  | ||||
| @ -127,8 +127,8 @@ class DiscreteSACPolicy(SACPolicy): | ||||
|         self.actor_optim.step() | ||||
| 
 | ||||
|         if self._is_auto_alpha: | ||||
|             log_prob = entropy.detach() - self._target_entropy | ||||
|             alpha_loss = (self._log_alpha * log_prob).mean() | ||||
|             log_prob = -entropy.detach() + self._target_entropy | ||||
|             alpha_loss = -(self._log_alpha * log_prob).mean() | ||||
|             self._alpha_optim.zero_grad() | ||||
|             alpha_loss.backward() | ||||
|             self._alpha_optim.step() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user