From eec0826fd3a7c3066f68a5cc8b0b7ac145a6f1ac Mon Sep 17 00:00:00 2001 From: n+e <463003665@qq.com> Date: Wed, 16 Sep 2020 17:43:19 +0800 Subject: [PATCH] change PER update interface in BasePolicy (#217) * fix #215 --- tianshou/policy/base.py | 10 +++------- tianshou/policy/modelfree/discrete_sac.py | 4 ++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 8bd0fcb..7785b8f 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -6,8 +6,7 @@ from numba import njit from abc import ABC, abstractmethod from typing import Any, List, Union, Mapping, Optional, Callable -from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ - to_torch_as, to_numpy +from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy class BasePolicy(ABC, nn.Module): @@ -138,9 +137,7 @@ class BasePolicy(ABC, nn.Module): Typical usage is to update the sampling weight in prioritized experience replay. Used in :meth:`update`. """ - if isinstance(buffer, PrioritizedReplayBuffer) and hasattr( - batch, "weight" - ): + if hasattr(buffer, "update_weight") and hasattr(batch, "weight"): buffer.update_weight(indice, batch.weight) def update( @@ -253,8 +250,7 @@ class BasePolicy(ABC, nn.Module): gamma, n_step, len(buffer), mean, std) batch.returns = to_torch_as(target_q, target_q_torch) - # prio buffer update - if isinstance(buffer, PrioritizedReplayBuffer): + if hasattr(batch, "weight"): # prio buffer update batch.weight = to_torch_as(batch.weight, target_q_torch) return batch diff --git a/tianshou/policy/modelfree/discrete_sac.py b/tianshou/policy/modelfree/discrete_sac.py index d9fa35e..f71b18b 100644 --- a/tianshou/policy/modelfree/discrete_sac.py +++ b/tianshou/policy/modelfree/discrete_sac.py @@ -127,8 +127,8 @@ class DiscreteSACPolicy(SACPolicy): self.actor_optim.step() if self._is_auto_alpha: - log_prob = entropy.detach() - self._target_entropy - alpha_loss = (self._log_alpha * log_prob).mean() + log_prob = -entropy.detach() + self._target_entropy + alpha_loss = -(self._log_alpha * log_prob).mean() self._alpha_optim.zero_grad() alpha_loss.backward() self._alpha_optim.step()