parent
49f43e9f1f
commit
506cc97ba5
@ -74,7 +74,8 @@ def test_pdqn(args=get_args()):
|
|||||||
# collector
|
# collector
|
||||||
if args.prioritized_replay > 0:
|
if args.prioritized_replay > 0:
|
||||||
buf = PrioritizedReplayBuffer(
|
buf = PrioritizedReplayBuffer(
|
||||||
args.buffer_size, alpha=args.alpha, beta=args.alpha)
|
args.buffer_size, alpha=args.alpha,
|
||||||
|
beta=args.alpha, repeat_sample=True)
|
||||||
else:
|
else:
|
||||||
buf = ReplayBuffer(args.buffer_size)
|
buf = ReplayBuffer(args.buffer_size)
|
||||||
train_collector = Collector(
|
train_collector = Collector(
|
||||||
|
@ -340,6 +340,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
:param float alpha: the prioritization exponent.
|
:param float alpha: the prioritization exponent.
|
||||||
:param float beta: the importance sample soft coefficient.
|
:param float beta: the importance sample soft coefficient.
|
||||||
:param str mode: defaults to ``weight``.
|
:param str mode: defaults to ``weight``.
|
||||||
|
:param bool replace: whether to sample with replacement
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@ -348,7 +349,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, size: int, alpha: float, beta: float,
|
def __init__(self, size: int, alpha: float, beta: float,
|
||||||
mode: str = 'weight', **kwargs) -> None:
|
mode: str = 'weight',
|
||||||
|
replace: bool = False, **kwargs) -> None:
|
||||||
if mode != 'weight':
|
if mode != 'weight':
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
super().__init__(size, **kwargs)
|
super().__init__(size, **kwargs)
|
||||||
@ -358,6 +360,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
self.weight = np.zeros(size, dtype=np.float64)
|
self.weight = np.zeros(size, dtype=np.float64)
|
||||||
self._amortization_freq = 50
|
self._amortization_freq = 50
|
||||||
self._amortization_counter = 0
|
self._amortization_counter = 0
|
||||||
|
self._replace = replace
|
||||||
|
|
||||||
def add(self,
|
def add(self,
|
||||||
obs: Union[dict, np.ndarray],
|
obs: Union[dict, np.ndarray],
|
||||||
@ -377,6 +380,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
super().add(obs, act, rew, done, obs_next, info, policy)
|
super().add(obs, act, rew, done, obs_next, info, policy)
|
||||||
self._check_weight_sum()
|
self._check_weight_sum()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def replace(self):
|
||||||
|
return self._replace
|
||||||
|
|
||||||
|
@replace.setter
|
||||||
|
def replace(self, v: bool):
|
||||||
|
self._replace = v
|
||||||
|
|
||||||
def sample(self, batch_size: int,
|
def sample(self, batch_size: int,
|
||||||
importance_sample: bool = True
|
importance_sample: bool = True
|
||||||
) -> Tuple[Batch, np.ndarray]:
|
) -> Tuple[Batch, np.ndarray]:
|
||||||
@ -391,7 +402,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
indice = np.random.choice(
|
indice = np.random.choice(
|
||||||
self._size, batch_size,
|
self._size, batch_size,
|
||||||
p=(self.weight / self.weight.sum())[:self._size],
|
p=(self.weight / self.weight.sum())[:self._size],
|
||||||
replace=False)
|
replace=self._replace)
|
||||||
# self._weight_sum is not work for the accuracy issue
|
# self._weight_sum is not work for the accuracy issue
|
||||||
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
# p=(self.weight/self._weight_sum)[:self._size], replace=False)
|
||||||
elif batch_size == 0:
|
elif batch_size == 0:
|
||||||
@ -402,7 +413,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
else:
|
else:
|
||||||
# if batch_size larger than len(self),
|
# if batch_size larger than len(self),
|
||||||
# it will lead to a bug in update weight
|
# it will lead to a bug in update weight
|
||||||
raise ValueError("batch_size should be less than len(self)")
|
raise ValueError(
|
||||||
|
"batch_size should be less than len(self), \
|
||||||
|
or set replace=False")
|
||||||
batch = self[indice]
|
batch = self[indice]
|
||||||
if importance_sample:
|
if importance_sample:
|
||||||
impt_weight = Batch(
|
impt_weight = Batch(
|
||||||
@ -424,6 +437,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
|
|||||||
:param np.ndarray indice: indice you want to update weight
|
:param np.ndarray indice: indice you want to update weight
|
||||||
:param np.ndarray new_weight: new priority weight you wangt to update
|
:param np.ndarray new_weight: new priority weight you wangt to update
|
||||||
"""
|
"""
|
||||||
|
if self._replace:
|
||||||
|
if isinstance(indice, slice):
|
||||||
|
# convert slice to ndarray
|
||||||
|
indice = np.arange(indice.stop)[indice]
|
||||||
|
# remove the same values in indice
|
||||||
|
indice, unique_indice = np.unique(
|
||||||
|
indice, return_index=True)
|
||||||
|
new_weight = new_weight[unique_indice]
|
||||||
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
|
||||||
- self.weight[indice].sum()
|
- self.weight[indice].sum()
|
||||||
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user