parent
623bf24f0c
commit
eec0826fd3
@ -6,8 +6,7 @@ from numba import njit
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, List, Union, Mapping, Optional, Callable
|
from typing import Any, List, Union, Mapping, Optional, Callable
|
||||||
|
|
||||||
from tianshou.data import Batch, ReplayBuffer, PrioritizedReplayBuffer, \
|
from tianshou.data import Batch, ReplayBuffer, to_torch_as, to_numpy
|
||||||
to_torch_as, to_numpy
|
|
||||||
|
|
||||||
|
|
||||||
class BasePolicy(ABC, nn.Module):
|
class BasePolicy(ABC, nn.Module):
|
||||||
@ -138,9 +137,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
Typical usage is to update the sampling weight in prioritized
|
Typical usage is to update the sampling weight in prioritized
|
||||||
experience replay. Used in :meth:`update`.
|
experience replay. Used in :meth:`update`.
|
||||||
"""
|
"""
|
||||||
if isinstance(buffer, PrioritizedReplayBuffer) and hasattr(
|
if hasattr(buffer, "update_weight") and hasattr(batch, "weight"):
|
||||||
batch, "weight"
|
|
||||||
):
|
|
||||||
buffer.update_weight(indice, batch.weight)
|
buffer.update_weight(indice, batch.weight)
|
||||||
|
|
||||||
def update(
|
def update(
|
||||||
@ -253,8 +250,7 @@ class BasePolicy(ABC, nn.Module):
|
|||||||
gamma, n_step, len(buffer), mean, std)
|
gamma, n_step, len(buffer), mean, std)
|
||||||
|
|
||||||
batch.returns = to_torch_as(target_q, target_q_torch)
|
batch.returns = to_torch_as(target_q, target_q_torch)
|
||||||
# prio buffer update
|
if hasattr(batch, "weight"): # prio buffer update
|
||||||
if isinstance(buffer, PrioritizedReplayBuffer):
|
|
||||||
batch.weight = to_torch_as(batch.weight, target_q_torch)
|
batch.weight = to_torch_as(batch.weight, target_q_torch)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
|
@ -127,8 +127,8 @@ class DiscreteSACPolicy(SACPolicy):
|
|||||||
self.actor_optim.step()
|
self.actor_optim.step()
|
||||||
|
|
||||||
if self._is_auto_alpha:
|
if self._is_auto_alpha:
|
||||||
log_prob = entropy.detach() - self._target_entropy
|
log_prob = -entropy.detach() + self._target_entropy
|
||||||
alpha_loss = (self._log_alpha * log_prob).mean()
|
alpha_loss = -(self._log_alpha * log_prob).mean()
|
||||||
self._alpha_optim.zero_grad()
|
self._alpha_optim.zero_grad()
|
||||||
alpha_loss.backward()
|
alpha_loss.backward()
|
||||||
self._alpha_optim.step()
|
self._alpha_optim.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user