parent
49f43e9f1f
commit
506cc97ba5
@ -74,7 +74,8 @@ def test_pdqn(args=get_args()):
|
||||
# collector
|
||||
if args.prioritized_replay > 0:
|
||||
buf = PrioritizedReplayBuffer(
|
||||
args.buffer_size, alpha=args.alpha, beta=args.alpha)
|
||||
args.buffer_size, alpha=args.alpha,
|
||||
beta=args.alpha, repeat_sample=True)
|
||||
else:
|
||||
buf = ReplayBuffer(args.buffer_size)
|
||||
train_collector = Collector(
|
||||
|
@ -340,6 +340,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
:param float alpha: the prioritization exponent.
|
||||
:param float beta: the importance sample soft coefficient.
|
||||
:param str mode: defaults to ``weight``.
|
||||
:param bool replace: whether to sample with replacement
|
||||
|
||||
.. seealso::
|
||||
|
||||
@ -348,7 +349,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
"""
|
||||
|
||||
def __init__(self, size: int, alpha: float, beta: float,
|
||||
mode: str = 'weight', **kwargs) -> None:
|
||||
mode: str = 'weight',
|
||||
replace: bool = False, **kwargs) -> None:
|
||||
if mode != 'weight':
|
||||
raise NotImplementedError
|
||||
super().__init__(size, **kwargs)
|
||||
@ -358,6 +360,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
self.weight = np.zeros(size, dtype=np.float64)
|
||||
self._amortization_freq = 50
|
||||
self._amortization_counter = 0
|
||||
self._replace = replace
|
||||
|
||||
def add(self,
|
||||
obs: Union[dict, np.ndarray],
|
||||
@ -377,6 +380,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||
self._check_weight_sum()
|
||||
|
||||
@property
|
||||
def replace(self):
|
||||
return self._replace
|
||||
|
||||
@replace.setter
|
||||
def replace(self, v: bool):
|
||||
self._replace = v
|
||||
|
||||
def sample(self, batch_size: int,
|
||||
importance_sample: bool = True
|
||||
) -> Tuple[Batch, np.ndarray]:
|
||||
@ -391,7 +402,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
indice = np.random.choice(
|
||||
self._size, batch_size,
|
||||
p=(self.weight / self.weight.sum())[:self._size],
|
||||
replace=False)
|
||||
replace=self._replace)
|
||||
# self._weight_sum is not work for the accuracy issue
|
||||
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
||||
elif batch_size == 0:
|
||||
@ -402,7 +413,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
else:
|
||||
# if batch_size larger than len(self),
|
||||
# it will lead to a bug in update weight
|
||||
raise ValueError("batch_size should be less than len(self)")
|
||||
raise ValueError(
|
||||
"batch_size should be less than len(self), \
|
||||
or set replace=False")
|
||||
batch = self[indice]
|
||||
if importance_sample:
|
||||
impt_weight = Batch(
|
||||
@ -424,6 +437,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
:param np.ndarray indice: indice you want to update weight
|
||||
:param np.ndarray new_weight: new priority weight you wangt to update
|
||||
"""
|
||||
if self._replace:
|
||||
if isinstance(indice, slice):
|
||||
# convert slice to ndarray
|
||||
indice = np.arange(indice.stop)[indice]
|
||||
# remove the same values in indice
|
||||
indice, unique_indice = np.unique(
|
||||
indice, return_index=True)
|
||||
new_weight = new_weight[unique_indice]
|
||||
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
||||
- self.weight[indice].sum()
|
||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||
|
Loading…
x
Reference in New Issue
Block a user