120 lines
3.3 KiB
Python
Raw Normal View History

2017-12-10 14:53:57 +08:00
import numpy
import random
import sum_tree
from buffer import ReplayBuffer
class PropotionalExperience(ReplayBuffer):
""" The class represents prioritized experience replay buffer.
The class has functions: store samples, pick samples with
probability in proportion to sample's priority, update
each sample's priority, reset alpha.
see https://arxiv.org/pdf/1511.05952.pdf .
"""
def __init__(self, conf):
""" Prioritized experience replay buffer initialization.
Parameters
----------
memory_size : int
sample size to be stored
batch_size : int
batch size to be selected by `select` method
alpha: float
exponent determine how much prioritization.
Prob_i \sim priority_i**alpha/sum(priority**alpha)
"""
memory_size = conf['size']
batch_size = conf['batch_size']
alpha = conf['alpha']
self.tree = sum_tree.SumTree(memory_size)
self.memory_size = memory_size
self.batch_size = batch_size
self.alpha = alpha
def add(self, data, priority):
""" Add new sample.
Parameters
----------
data : object
new sample
priority : float
sample's priority
"""
self.tree.add(data, priority**self.alpha)
def sample(self, conf):
""" The method return samples randomly.
Parameters
----------
beta : float
Returns
-------
out :
list of samples
weights:
list of weight
indices:
list of sample indices
The indices indicate sample positions in a sum tree.
"""
beta = conf['beta']
if self.tree.filled_size() < self.batch_size:
return None, None, None
out = []
indices = []
weights = []
priorities = []
for _ in range(self.batch_size):
r = random.random()
data, priority, index = self.tree.find(r)
priorities.append(priority)
weights.append((1./self.memory_size/priority)**beta if priority > 1e-16 else 0)
indices.append(index)
out.append(data)
self.update_priority([index], [0]) # To avoid duplicating
self.update_priority(indices, priorities) # Revert priorities
max_weights = max(weights)
weights[:] = [x / max_weights for x in weights] # Normalize for stability
return out, weights, indices
def update_priority(self, indices, priorities):
""" The methods update samples's priority.
Parameters
----------
indices :
list of sample indices
"""
for i, p in zip(indices, priorities):
self.tree.val_update(i, p**self.alpha)
def reset_alpha(self, alpha):
""" Reset a exponent alpha.
Parameters
----------
alpha : float
"""
self.alpha, old_alpha = alpha, self.alpha
priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())]
self.update_priority(range(self.tree.filled_size()), priorities)