This commit is contained in:
rocknamx 2020-06-25 07:02:59 +08:00 committed by GitHub
parent 49f43e9f1f
commit 506cc97ba5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 4 deletions

View File

@ -74,7 +74,8 @@ def test_pdqn(args=get_args()):
# collector # collector
if args.prioritized_replay > 0: if args.prioritized_replay > 0:
buf = PrioritizedReplayBuffer( buf = PrioritizedReplayBuffer(
args.buffer_size, alpha=args.alpha, beta=args.alpha) args.buffer_size, alpha=args.alpha,
beta=args.alpha, repeat_sample=True)
else: else:
buf = ReplayBuffer(args.buffer_size) buf = ReplayBuffer(args.buffer_size)
train_collector = Collector( train_collector = Collector(

View File

@ -340,6 +340,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
:param float alpha: the prioritization exponent. :param float alpha: the prioritization exponent.
:param float beta: the importance sample soft coefficient. :param float beta: the importance sample soft coefficient.
:param str mode: defaults to ``weight``. :param str mode: defaults to ``weight``.
:param bool replace: whether to sample with replacement
.. seealso:: .. seealso::
@ -348,7 +349,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
""" """
def __init__(self, size: int, alpha: float, beta: float, def __init__(self, size: int, alpha: float, beta: float,
mode: str = 'weight', **kwargs) -> None: mode: str = 'weight',
replace: bool = False, **kwargs) -> None:
if mode != 'weight': if mode != 'weight':
raise NotImplementedError raise NotImplementedError
super().__init__(size, **kwargs) super().__init__(size, **kwargs)
@ -358,6 +360,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
self.weight = np.zeros(size, dtype=np.float64) self.weight = np.zeros(size, dtype=np.float64)
self._amortization_freq = 50 self._amortization_freq = 50
self._amortization_counter = 0 self._amortization_counter = 0
self._replace = replace
def add(self, def add(self,
obs: Union[dict, np.ndarray], obs: Union[dict, np.ndarray],
@ -377,6 +380,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
super().add(obs, act, rew, done, obs_next, info, policy) super().add(obs, act, rew, done, obs_next, info, policy)
self._check_weight_sum() self._check_weight_sum()
@property
def replace(self):
return self._replace
@replace.setter
def replace(self, v: bool):
self._replace = v
def sample(self, batch_size: int, def sample(self, batch_size: int,
importance_sample: bool = True importance_sample: bool = True
) -> Tuple[Batch, np.ndarray]: ) -> Tuple[Batch, np.ndarray]:
@ -391,7 +402,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
indice = np.random.choice( indice = np.random.choice(
self._size, batch_size, self._size, batch_size,
p=(self.weight / self.weight.sum())[:self._size], p=(self.weight / self.weight.sum())[:self._size],
replace=False) replace=self._replace)
# self._weight_sum is not work for the accuracy issue # self._weight_sum is not work for the accuracy issue
# p=(self.weight/self._weight_sum)[:self._size], replace=False) # p=(self.weight/self._weight_sum)[:self._size], replace=False)
elif batch_size == 0: elif batch_size == 0:
@ -402,7 +413,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
else: else:
# if batch_size larger than len(self), # if batch_size larger than len(self),
# it will lead to a bug in update weight # it will lead to a bug in update weight
raise ValueError("batch_size should be less than len(self)") raise ValueError(
"batch_size should be less than len(self), \
or set replace=False")
batch = self[indice] batch = self[indice]
if importance_sample: if importance_sample:
impt_weight = Batch( impt_weight = Batch(
@ -424,6 +437,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
:param np.ndarray indice: indice you want to update weight :param np.ndarray indice: indice you want to update weight
:param np.ndarray new_weight: new priority weight you wangt to update :param np.ndarray new_weight: new priority weight you wangt to update
""" """
if self._replace:
if isinstance(indice, slice):
# convert slice to ndarray
indice = np.arange(indice.stop)[indice]
# remove the same values in indice
indice, unique_indice = np.unique(
indice, return_index=True)
new_weight = new_weight[unique_indice]
self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \ self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
- self.weight[indice].sum() - self.weight[indice].sum()
self.weight[indice] = np.power(np.abs(new_weight), self._alpha) self.weight[indice] = np.power(np.abs(new_weight), self._alpha)