change PER update interface in BasePolicy (#217)

* fix #215
This commit is contained in:
n+e 2020-09-16 17:43:19 +08:00 committed by GitHub
parent 623bf24f0c
commit eec0826fd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 9 deletions

View File

@ -6,8 +6,7 @@ from numba import njit
from abc import ABC, abstractmethod
from typing import Any, List, Union, Mapping, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
to_torch_as, to_numpy
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module):
@ -138,9 +137,7 @@ class BasePolicy(ABC, nn.Module):
Typical usage is to update the sampling weight in prioritized
experience replay. Used in :meth:`update`.
"""
if isinstance(buffer, PrioritizedReplayBuffer) and hasattr(
batch, "weight"
):
if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
buffer.update_weight(indice, batch.weight)
def update(
@ -253,8 +250,7 @@ class BasePolicy(ABC, nn.Module):
gamma, n_step, len(buffer), mean, std)
batch.returns = to_torch_as(target_q, target_q_torch)
# prio buffer update
if isinstance(buffer, PrioritizedReplayBuffer):
if hasattr(batch, "weight"): # prio buffer update
batch.weight = to_torch_as(batch.weight, target_q_torch)
return batch

View File

@ -127,8 +127,8 @@ class DiscreteSACPolicy(SACPolicy):
self.actor_optim.step()
if self._is_auto_alpha:
log_prob = entropy.detach() - self._target_entropy
alpha_loss = (self._log_alpha * log_prob).mean()
log_prob = -entropy.detach() + self._target_entropy
alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad()
alpha_loss.backward()
self._alpha_optim.step()