import numpy import random import sum_tree from buffer import ReplayBuffer class PropotionalExperience(ReplayBuffer): """ The class represents prioritized experience replay buffer. The class has functions: store samples, pick samples with probability in proportion to sample's priority, update each sample's priority, reset alpha. see https://arxiv.org/pdf/1511.05952.pdf . """ def __init__(self, conf): """ Prioritized experience replay buffer initialization. Parameters ---------- memory_size : int sample size to be stored batch_size : int batch size to be selected by `select` method alpha: float exponent determine how much prioritization. Prob_i \sim priority_i**alpha/sum(priority**alpha) """ memory_size = conf['size'] batch_size = conf['batch_size'] alpha = conf['alpha'] self.tree = sum_tree.SumTree(memory_size) self.memory_size = memory_size self.batch_size = batch_size self.alpha = alpha def add(self, data, priority): """ Add new sample. Parameters ---------- data : object new sample priority : float sample's priority """ self.tree.add(data, priority**self.alpha) def sample(self, conf): """ The method return samples randomly. Parameters ---------- beta : float Returns ------- out : list of samples weights: list of weight indices: list of sample indices The indices indicate sample positions in a sum tree. """ beta = conf['beta'] if self.tree.filled_size() < self.batch_size: return None, None, None out = [] indices = [] weights = [] priorities = [] for _ in range(self.batch_size): r = random.random() data, priority, index = self.tree.find(r) priorities.append(priority) weights.append((1./self.memory_size/priority)**beta if priority > 1e-16 else 0) indices.append(index) out.append(data) self.update_priority([index], [0]) # To avoid duplicating self.update_priority(indices, priorities) # Revert priorities max_weights = max(weights) weights[:] = [x / max_weights for x in weights] # Normalize for stability return out, weights, indices def update_priority(self, indices, priorities): """ The methods update samples's priority. Parameters ---------- indices : list of sample indices """ for i, p in zip(indices, priorities): self.tree.val_update(i, p**self.alpha) def reset_alpha(self, alpha): """ Reset a exponent alpha. Parameters ---------- alpha : float """ self.alpha, old_alpha = alpha, self.alpha priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())] self.update_priority(range(self.tree.filled_size()), priorities)