change PER update interface in BasePolicy (#217)

* fix #215
This commit is contained in:
n+e 2020-09-16 17:43:19 +08:00 committed by GitHub
parent 623bf24f0c
commit eec0826fd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 9 deletions

View File

@ -6,8 +6,7 @@ from numba import njit
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, List, Union, Mapping, Optional, Callable from typing import Any, List, Union, Mapping, Optional, Callable
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \ from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
to_torch_as, to_numpy
class BasePolicy(ABC, nn.Module): class BasePolicy(ABC, nn.Module):
@ -138,9 +137,7 @@ class BasePolicy(ABC, nn.Module):
Typical usage is to update the sampling weight in prioritized Typical usage is to update the sampling weight in prioritized
experience replay. Used in :meth:`update`. experience replay. Used in :meth:`update`.
""" """
if isinstance(buffer, PrioritizedReplayBuffer) and hasattr( if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
batch, "weight"
):
buffer.update_weight(indice, batch.weight) buffer.update_weight(indice, batch.weight)
def update( def update(
@ -253,8 +250,7 @@ class BasePolicy(ABC, nn.Module):
gamma, n_step, len(buffer), mean, std) gamma, n_step, len(buffer), mean, std)
batch.returns = to_torch_as(target_q, target_q_torch) batch.returns = to_torch_as(target_q, target_q_torch)
# prio buffer update if hasattr(batch, "weight"): # prio buffer update
if isinstance(buffer, PrioritizedReplayBuffer):
batch.weight = to_torch_as(batch.weight, target_q_torch) batch.weight = to_torch_as(batch.weight, target_q_torch)
return batch return batch

View File

@ -127,8 +127,8 @@ class DiscreteSACPolicy(SACPolicy):
self.actor_optim.step() self.actor_optim.step()
if self._is_auto_alpha: if self._is_auto_alpha:
log_prob = entropy.detach() - self._target_entropy log_prob = -entropy.detach() + self._target_entropy
alpha_loss = (self._log_alpha * log_prob).mean() alpha_loss = -(self._log_alpha * log_prob).mean()
self._alpha_optim.zero_grad() self._alpha_optim.zero_grad()
alpha_loss.backward() alpha_loss.backward()
self._alpha_optim.step() self._alpha_optim.step()