parent
							
								
									49f43e9f1f
								
							
						
					
					
						commit
						506cc97ba5
					
				@ -74,7 +74,8 @@ def test_pdqn(args=get_args()):
 | 
			
		||||
    # collector
 | 
			
		||||
    if args.prioritized_replay > 0:
 | 
			
		||||
        buf = PrioritizedReplayBuffer(
 | 
			
		||||
            args.buffer_size, alpha=args.alpha, beta=args.alpha)
 | 
			
		||||
            args.buffer_size, alpha=args.alpha,
 | 
			
		||||
            beta=args.alpha, repeat_sample=True)
 | 
			
		||||
    else:
 | 
			
		||||
        buf = ReplayBuffer(args.buffer_size)
 | 
			
		||||
    train_collector = Collector(
 | 
			
		||||
 | 
			
		||||
@ -340,6 +340,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
    :param float alpha: the prioritization exponent.
 | 
			
		||||
    :param float beta: the importance sample soft coefficient.
 | 
			
		||||
    :param str mode: defaults to ``weight``.
 | 
			
		||||
    :param bool replace: whether to sample with replacement
 | 
			
		||||
 | 
			
		||||
    .. seealso::
 | 
			
		||||
 | 
			
		||||
@ -348,7 +349,8 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, size: int, alpha: float, beta: float,
 | 
			
		||||
                 mode: str = 'weight', **kwargs) -> None:
 | 
			
		||||
                 mode: str = 'weight',
 | 
			
		||||
                 replace: bool = False, **kwargs) -> None:
 | 
			
		||||
        if mode != 'weight':
 | 
			
		||||
            raise NotImplementedError
 | 
			
		||||
        super().__init__(size, **kwargs)
 | 
			
		||||
@ -358,6 +360,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
        self.weight = np.zeros(size, dtype=np.float64)
 | 
			
		||||
        self._amortization_freq = 50
 | 
			
		||||
        self._amortization_counter = 0
 | 
			
		||||
        self._replace = replace
 | 
			
		||||
 | 
			
		||||
    def add(self,
 | 
			
		||||
            obs: Union[dict, np.ndarray],
 | 
			
		||||
@ -377,6 +380,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
        super().add(obs, act, rew, done, obs_next, info, policy)
 | 
			
		||||
        self._check_weight_sum()
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def replace(self):
 | 
			
		||||
        return self._replace
 | 
			
		||||
 | 
			
		||||
    @replace.setter
 | 
			
		||||
    def replace(self, v: bool):
 | 
			
		||||
        self._replace = v
 | 
			
		||||
 | 
			
		||||
    def sample(self, batch_size: int,
 | 
			
		||||
               importance_sample: bool = True
 | 
			
		||||
               ) -> Tuple[Batch, np.ndarray]:
 | 
			
		||||
@ -391,7 +402,7 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
            indice = np.random.choice(
 | 
			
		||||
                self._size, batch_size,
 | 
			
		||||
                p=(self.weight / self.weight.sum())[:self._size],
 | 
			
		||||
                replace=False)
 | 
			
		||||
                replace=self._replace)
 | 
			
		||||
            # self._weight_sum is not work for the accuracy issue
 | 
			
		||||
            # p=(self.weight/self._weight_sum)[:self._size], replace=False)
 | 
			
		||||
        elif batch_size == 0:
 | 
			
		||||
@ -402,7 +413,9 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
        else:
 | 
			
		||||
            # if batch_size larger than len(self),
 | 
			
		||||
            # it will lead to a bug in update weight
 | 
			
		||||
            raise ValueError("batch_size should be less than len(self)")
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "batch_size should be less than len(self), \
 | 
			
		||||
                    or set replace=False")
 | 
			
		||||
        batch = self[indice]
 | 
			
		||||
        if importance_sample:
 | 
			
		||||
            impt_weight = Batch(
 | 
			
		||||
@ -424,6 +437,14 @@ class PrioritizedReplayBuffer(ReplayBuffer):
 | 
			
		||||
        :param np.ndarray indice: indice you want to update weight
 | 
			
		||||
        :param np.ndarray new_weight: new priority weight you wangt to update
 | 
			
		||||
        """
 | 
			
		||||
        if self._replace:
 | 
			
		||||
            if isinstance(indice, slice):
 | 
			
		||||
                # convert slice to ndarray
 | 
			
		||||
                indice = np.arange(indice.stop)[indice]
 | 
			
		||||
            # remove the same values in indice
 | 
			
		||||
            indice, unique_indice = np.unique(
 | 
			
		||||
                indice, return_index=True)
 | 
			
		||||
            new_weight = new_weight[unique_indice]
 | 
			
		||||
        self._weight_sum += np.power(np.abs(new_weight), self._alpha).sum() \
 | 
			
		||||
            - self.weight[indice].sum()
 | 
			
		||||
        self.weight[indice] = np.power(np.abs(new_weight), self._alpha)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user