parent
623bf24f0c
commit
eec0826fd3
@ -6,8 +6,7 @@ from numba import njit
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Union, Mapping, Optional, Callable
|
||||
|
||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
||||
to_torch_as, to_numpy
|
||||
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||
|
||||
|
||||
class BasePolicy(ABC, nn.Module):
|
||||
@ -138,9 +137,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
Typical usage is to update the sampling weight in prioritized
|
||||
experience replay. Used in :meth:`update`.
|
||||
"""
|
||||
if isinstance(buffer, PrioritizedReplayBuffer) and hasattr(
|
||||
batch, "weight"
|
||||
):
|
||||
if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
|
||||
buffer.update_weight(indice, batch.weight)
|
||||
|
||||
def update(
|
||||
@ -253,8 +250,7 @@ class BasePolicy(ABC, nn.Module):
|
||||
gamma, n_step, len(buffer), mean, std)
|
||||
|
||||
batch.returns = to_torch_as(target_q, target_q_torch)
|
||||
# prio buffer update
|
||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
||||
if hasattr(batch, "weight"): # prio buffer update
|
||||
batch.weight = to_torch_as(batch.weight, target_q_torch)
|
||||
return batch
|
||||
|
||||
|
@ -127,8 +127,8 @@ class DiscreteSACPolicy(SACPolicy):
|
||||
self.actor_optim.step()
|
||||
|
||||
if self._is_auto_alpha:
|
||||
log_prob = entropy.detach() - self._target_entropy
|
||||
alpha_loss = (self._log_alpha * log_prob).mean()
|
||||
log_prob = -entropy.detach() + self._target_entropy
|
||||
alpha_loss = -(self._log_alpha * log_prob).mean()
|
||||
self._alpha_optim.zero_grad()
|
||||
alpha_loss.backward()
|
||||
self._alpha_optim.step()
|
||||
|
Loading…
x
Reference in New Issue
Block a user