replay buffer initial commit

This commit is contained in:
songshshshsh 2017-12-10 14:53:57 +08:00
parent a8a12f1083
commit f1a7fd9ee1
10 changed files with 819 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from os.path import dirname, basename, isfile
import glob
modules = glob.glob(dirname(__file__)+"/*.py")
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]

View File

@ -0,0 +1,221 @@
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: stmayue@gmail.com
# description:
import sys
import math
import utility
class BinaryHeap(object):
def __init__(self, priority_size=100, priority_init=None, replace=True):
self.e2p = {}
self.p2e = {}
self.replace = replace
if priority_init is None:
self.priority_queue = {}
self.size = 0
self.max_size = priority_size
else:
# not yet test
self.priority_queue = priority_init
self.size = len(self.priority_queue)
self.max_size = None or self.size
experience_list = list(map(lambda x: self.priority_queue[x], self.priority_queue))
self.p2e = utility.list_to_dict(experience_list)
self.e2p = utility.exchange_key_value(self.p2e)
for i in range(int(self.size / 2), -1, -1):
self.down_heap(i)
def __repr__(self):
"""
:return: string of the priority queue, with level info
"""
if self.size == 0:
return 'No element in heap!'
to_string = ''
level = -1
max_level = int(math.floor(math.log(self.size, 2)))
for i in range(1, self.size + 1):
now_level = int(math.floor(math.log(i, 2)))
if level != now_level:
to_string = to_string + ('\n' if level != -1 else '') \
+ ' ' * (max_level - now_level)
level = now_level
to_string = to_string + '%.2f ' % self.priority_queue[i][1] + ' ' * (max_level - now_level)
return to_string
def check_full(self):
return self.size > self.max_size
def _insert(self, priority, e_id):
"""
insert new experience id with priority
(maybe don't need get_max_priority and implement it in this function)
:param priority: priority value
:param e_id: experience id
:return: bool
"""
self.size += 1
if self.check_full() and not self.replace:
sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority))
return False
else:
self.size = min(self.size, self.max_size)
self.priority_queue[self.size] = (priority, e_id)
self.p2e[self.size] = e_id
self.e2p[e_id] = self.size
self.up_heap(self.size)
return True
def update(self, priority, e_id):
"""
update priority value according its experience id
:param priority: new priority value
:param e_id: experience id
:return: bool
"""
if e_id in self.e2p:
p_id = self.e2p[e_id]
self.priority_queue[p_id] = (priority, e_id)
self.p2e[p_id] = e_id
self.down_heap(p_id)
self.up_heap(p_id)
return True
else:
# this e id is new, do insert
return self._insert(priority, e_id)
def get_max_priority(self):
"""
get max priority, if no experience, return 1
:return: max priority if size > 0 else 1
"""
if self.size > 0:
return self.priority_queue[1][0]
else:
return 1
def pop(self):
"""
pop out the max priority value with its experience id
:return: priority value & experience id
"""
if self.size == 0:
sys.stderr.write('Error: no value in heap, pop failed\n')
return False, False
pop_priority, pop_e_id = self.priority_queue[1]
self.e2p[pop_e_id] = -1
# replace first
last_priority, last_e_id = self.priority_queue[self.size]
self.priority_queue[1] = (last_priority, last_e_id)
self.size -= 1
self.e2p[last_e_id] = 1
self.p2e[1] = last_e_id
self.down_heap(1)
return pop_priority, pop_e_id
def up_heap(self, i):
"""
upward balance
:param i: tree node i
:return: None
"""
if i > 1:
parent = math.floor(i / 2)
if self.priority_queue[parent][0] < self.priority_queue[i][0]:
tmp = self.priority_queue[i]
self.priority_queue[i] = self.priority_queue[parent]
self.priority_queue[parent] = tmp
# change e2p & p2e
self.e2p[self.priority_queue[i][1]] = i
self.e2p[self.priority_queue[parent][1]] = parent
self.p2e[i] = self.priority_queue[i][1]
self.p2e[parent] = self.priority_queue[parent][1]
# up heap parent
self.up_heap(parent)
def down_heap(self, i):
"""
downward balance
:param i: tree node i
:return: None
"""
if i < self.size:
greatest = i
left, right = i * 2, i * 2 + 1
if left < self.size and self.priority_queue[left][0] > self.priority_queue[greatest][0]:
greatest = left
if right < self.size and self.priority_queue[right][0] > self.priority_queue[greatest][0]:
greatest = right
if greatest != i:
tmp = self.priority_queue[i]
self.priority_queue[i] = self.priority_queue[greatest]
self.priority_queue[greatest] = tmp
# change e2p & p2e
self.e2p[self.priority_queue[i][1]] = i
self.e2p[self.priority_queue[greatest][1]] = greatest
self.p2e[i] = self.priority_queue[i][1]
self.p2e[greatest] = self.priority_queue[greatest][1]
# down heap greatest
self.down_heap(greatest)
def get_priority(self):
"""
get all priority value
:return: list of priority
"""
return list(map(lambda x: x[0], self.priority_queue.values()))[0:self.size]
def get_e_id(self):
"""
get all experience id in priority queue
:return: list of experience ids order by their priority
"""
return list(map(lambda x: x[1], self.priority_queue.values()))[0:self.size]
def balance_tree(self):
"""
rebalance priority queue
:return: None
"""
sort_array = sorted(self.priority_queue.values(), key=lambda x: x[0], reverse=True)
# reconstruct priority_queue
self.priority_queue.clear()
self.p2e.clear()
self.e2p.clear()
cnt = 1
while cnt <= self.size:
priority, e_id = sort_array[cnt - 1]
self.priority_queue[cnt] = (priority, e_id)
self.p2e[cnt] = e_id
self.e2p[e_id] = cnt
cnt += 1
# sort the heap
for i in range(int(math.floor(self.size / 2)), 1, -1):
self.down_heap(i)
def priority_to_experience(self, priority_ids):
"""
retrieve experience ids by priority ids
:param priority_ids: list of priority id
:return: list of experience id
"""
return [self.p2e[i] for i in priority_ids]

View File

@ -0,0 +1,39 @@
class ReplayBuffer(object):
def __init__(self, conf):
'''
Initialize a replay buffer with parameters in conf.
'''
pass
def add(self, data, priority):
'''
Add a data with priority = priority to replay buffer.
'''
pass
def update_priority(self, indices, priorities):
'''
Update the data's priority whose indices = indices.
For proportional replay buffer, the priority is the priority.
For rank based replay buffer, the priorities parameter will be the delta used to update the priority.
'''
pass
def reset_alpha(self, alpha):
'''
This function only works for proportional replay buffer.
This function resets alpha.
'''
pass
def sample(self, conf):
'''
Sample from replay buffer with parameters in conf.
'''
pass
def rebalance(self):
'''
This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue.
'''
pass

View File

@ -0,0 +1,29 @@
from buffer import ReplayBuffer
import numpy as np
from collections import deque
class NaiveExperience(ReplayBuffer):
def __init__(self, conf):
self.max_size = conf['size']
self.n_entries = 0
self.memory = deque(maxlen = self.max_size)
def add(self, data, priority = 0):
self.memory.append(data)
if self.n_entries < self.max_size:
self.n_entries += 1
def update_priority(self, indices, priorities = 0):
pass
def reset_alpha(self, alpha):
pass
def sample(self, conf):
batch_size = conf['batch_size']
batch_size = min(len(self.memory), batch_size)
idxs = np.random.choice(len(self.memory), batch_size)
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
def rebalance(self):
pass

View File

@ -0,0 +1,119 @@
import numpy
import random
import sum_tree
from buffer import ReplayBuffer
class PropotionalExperience(ReplayBuffer):
""" The class represents prioritized experience replay buffer.
The class has functions: store samples, pick samples with
probability in proportion to sample's priority, update
each sample's priority, reset alpha.
see https://arxiv.org/pdf/1511.05952.pdf .
"""
def __init__(self, conf):
""" Prioritized experience replay buffer initialization.
Parameters
----------
memory_size : int
sample size to be stored
batch_size : int
batch size to be selected by `select` method
alpha: float
exponent determine how much prioritization.
Prob_i \sim priority_i**alpha/sum(priority**alpha)
"""
memory_size = conf['size']
batch_size = conf['batch_size']
alpha = conf['alpha']
self.tree = sum_tree.SumTree(memory_size)
self.memory_size = memory_size
self.batch_size = batch_size
self.alpha = alpha
def add(self, data, priority):
""" Add new sample.
Parameters
----------
data : object
new sample
priority : float
sample's priority
"""
self.tree.add(data, priority**self.alpha)
def sample(self, conf):
""" The method return samples randomly.
Parameters
----------
beta : float
Returns
-------
out :
list of samples
weights:
list of weight
indices:
list of sample indices
The indices indicate sample positions in a sum tree.
"""
beta = conf['beta']
if self.tree.filled_size() < self.batch_size:
return None, None, None
out = []
indices = []
weights = []
priorities = []
for _ in range(self.batch_size):
r = random.random()
data, priority, index = self.tree.find(r)
priorities.append(priority)
weights.append((1./self.memory_size/priority)**beta if priority > 1e-16 else 0)
indices.append(index)
out.append(data)
self.update_priority([index], [0]) # To avoid duplicating
self.update_priority(indices, priorities) # Revert priorities
max_weights = max(weights)
weights[:] = [x / max_weights for x in weights] # Normalize for stability
return out, weights, indices
def update_priority(self, indices, priorities):
""" The methods update samples's priority.
Parameters
----------
indices :
list of sample indices
"""
for i, p in zip(indices, priorities):
self.tree.val_update(i, p**self.alpha)
def reset_alpha(self, alpha):
""" Reset a exponent alpha.
Parameters
----------
alpha : float
"""
self.alpha, old_alpha = alpha, self.alpha
priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())]
self.update_priority(range(self.tree.filled_size()), priorities)

View File

@ -0,0 +1,184 @@
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: stmayue@gmail.com
# description:
import sys
import math
import random
import numpy as np
from binary_heap import BinaryHeap
from buffer import ReplayBuffer
class RankBasedExperience(ReplayBuffer):
def __init__(self, conf):
self.size = conf['size']
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
self.total_steps = conf['steps'] if 'steps' in conf else 100000
# partition number N, split total size to N part
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100
self.index = 0
self.record_size = 0
self.isFull = False
self._experience = {}
self.priority_queue = BinaryHeap(self.priority_size)
self.distributions = self.build_distributions()
self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start)
def build_distributions(self):
"""
preprocess pow of rank
(rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
:return: distributions, dict
"""
res = {}
n_partitions = self.partition_num
partition_num = 1
# each part size
partition_size = int(math.floor(self.size / n_partitions))
for n in range(partition_size, self.size + 1, partition_size):
if self.learn_start <= n <= self.priority_size:
distribution = {}
# P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
pdf = list(
map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))
)
pdf_sum = math.fsum(pdf)
distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf))
# split to k segment, and than uniform sample in each k
# set k = batch_size, each segment has total probability is 1 / batch_size
# strata_ends keep each segment start pos and end pos
cdf = np.cumsum(distribution['pdf'])
strata_ends = {1: 0, self.batch_size + 1: n}
step = 1. / self.batch_size
index = 1
for s in range(2, self.batch_size + 1):
while cdf[index] < step:
index += 1
strata_ends[s] = index
step += 1. / self.batch_size
distribution['strata_ends'] = strata_ends
res[partition_num] = distribution
partition_num += 1
return res
def fix_index(self):
"""
get next insert index
:return: index, int
"""
if self.record_size <= self.size:
self.record_size += 1
if self.index % self.size == 0:
self.isFull = True if len(self._experience) == self.size else False
if self.replace_flag:
self.index = 1
return self.index
else:
sys.stderr.write('Experience replay buff is full and replace is set to FALSE!\n')
return -1
else:
self.index += 1
return self.index
def add(self, data, priority = 0):
"""
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
so each experience is valid
:param experience: maybe a tuple, or list
:return: bool, indicate insert status
"""
insert_index = self.fix_index()
if insert_index > 0:
if insert_index in self._experience:
del self._experience[insert_index]
self._experience[insert_index] = data
# add to priority queue
priority = self.priority_queue.get_max_priority()
self.priority_queue.update(priority, insert_index)
return True
else:
sys.stderr.write('Insert failed\n')
return False
def retrieve(self, indices):
"""
get experience from indices
:param indices: list of experience id
:return: experience replay sample
"""
return [self._experience[v] for v in indices]
def rebalance(self):
"""
rebalance priority queue
:return: None
"""
self.priority_queue.balance_tree()
def update_priority(self, indices, delta):
"""
update priority according indices and deltas
:param indices: list of experience id
:param delta: list of delta, order correspond to indices
:return: None
"""
for i in range(0, len(indices)):
self.priority_queue.update(math.fabs(delta[i]), indices[i])
def sample(self, conf):
"""
sample a mini batch from experience replay
:param global_step: now training step
:return: experience, list, samples
:return: w, list, weights
:return: rank_e_id, list, samples id, used for update priority
"""
global_step = conf['global_step']
if self.record_size < self.learn_start:
sys.stderr.write('Record size less than learn start! Sample failed\n')
return False, False, False
dist_index = math.floor(self.record_size / self.size * self.partition_num)
# issue 1 by @camigord
partition_size = math.floor(self.size / self.partition_num)
partition_max = dist_index * partition_size
distribution = self.distributions[dist_index]
rank_list = []
# sample from k segments
for n in range(1, self.batch_size + 1):
index = random.randint(distribution['strata_ends'][n] + 1,
distribution['strata_ends'][n + 1])
rank_list.append(index)
# beta, increase by global_step, max 1
beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1)
# find all alpha pow, notice that pdf is a list, start from 0
alpha_pow = [distribution['pdf'][v - 1] for v in rank_list]
# w = (N * P(i)) ^ (-beta) / max w
w = np.power(np.array(alpha_pow) * partition_max, -beta)
w_max = max(w)
w = np.divide(w, w_max)
# rank list is priority id
# convert to experience id
rank_e_id = self.priority_queue.priority_to_experience(rank_list)
# get experience id according rank_e_id
experience = self.retrieve(rank_e_id)
return experience, w, rank_e_id

View File

@ -0,0 +1,129 @@
from utils import *
from functions import *
def test_rank_based():
conf = {'size': 50,
'learn_start': 10,
'partition_num': 5,
'total_step': 100,
'batch_size': 4}
experience = getReplayBuffer('rank_based', conf)
# insert to experience
print 'test insert experience'
for i in range(1, 51):
# tuple, like(state_t, a, r, state_t_1, t)
to_insert = (i, 1, 1, i, 1)
experience.add(to_insert)
print experience.priority_queue
print experience._experience[1]
print experience._experience[2]
print 'test replace'
to_insert = (51, 1, 1, 51, 1)
experience.add(to_insert)
print experience.priority_queue
print experience._experience[1]
print experience._experience[2]
# sample
print 'test sample'
global_step = {'global_step': 51}
sample, w, e_id = experience.sample(global_step)
print sample
print w
print e_id
# update delta to priority
print 'test update delta'
delta = [v for v in range(1, 5)]
experience.update_priority(e_id, delta)
print experience.priority_queue
sample, w, e_id = experience.sample(global_step)
print sample
print w
print e_id
# rebalance
print 'test rebalance'
experience.rebalance()
print experience.priority_queue
def test_proportional():
conf = {'size': 50,
'alpha': 0.7,
'batch_size': 4}
experience = getReplayBuffer('proportional', conf)
# insert to experience
print 'test insert experience'
for i in range(1, 51):
# tuple, like(state_t, a, r, state_t_1, t)
to_insert = (i, 1, 1, i, 1)
experience.add(to_insert, i)
print experience.tree
print experience.tree.get_val(1)
print experience.tree.get_val(2)
print 'test replace'
to_insert = (51, 1, 1, 51, 1)
experience.add(to_insert, 51)
print experience.tree
print experience.tree.get_val(1)
print experience.tree.get_val(2)
# sample
print 'test sample'
beta = {'beta': 0.005}
sample, w, e_id = experience.sample(beta)
print sample
print w
print e_id
# update delta to priority
print 'test update delta'
delta = [v for v in range(1, 5)]
experience.update_priority(e_id, delta)
print experience.tree
sample, w, e_id = experience.sample(beta)
print sample
print w
print e_id
def test_naive():
conf = {'size': 50}
experience = getReplayBuffer('naive', conf)
# insert to experience
print 'test insert experience'
for i in range(1, 51):
# tuple, like(state_t, a, r, state_t_1, t)
to_insert = (i, 1, 1, i, 1)
experience.add(to_insert)
print experience.memory
print 'test replace'
to_insert = (51, 1, 1, 51, 1)
experience.add(to_insert)
print experience.memory
# sample
print 'test sample'
batch_size = {'batch_size': 5}
sample, w, e_id = experience.sample(batch_size)
print sample
print w
print e_id
# update delta to priority
print 'test update delta'
delta = [v for v in range(1, 5)]
experience.update_priority(e_id, delta)
print experience.memory
sample, w, e_id = experience.sample(batch_size)
print sample
print w
print e_id
if __name__ == '__main__':
test_rank_based()
test_proportional()
test_naive()

View File

@ -0,0 +1,64 @@
#! -*- coding:utf-8 -*-
from __future__ import print_function
import sys
import os
import math
class SumTree(object):
def __init__(self, max_size):
self.max_size = max_size
self.tree_level = int(math.ceil(math.log(max_size+1, 2))+1)
self.tree_size = 2**self.tree_level-1
self.tree = [0 for i in range(self.tree_size)]
self.data = [None for i in range(self.max_size)]
self.size = 0
self.cursor = 0
def add(self, contents, value):
index = self.cursor
self.cursor = (self.cursor+1)%self.max_size
self.size = min(self.size+1, self.max_size)
self.data[index] = contents
self.val_update(index, value)
def get_val(self, index):
tree_index = 2**(self.tree_level-1)-1+index
return self.tree[tree_index]
def val_update(self, index, value):
tree_index = 2**(self.tree_level-1)-1+index
diff = value-self.tree[tree_index]
self.reconstruct(tree_index, diff)
def reconstruct(self, tindex, diff):
self.tree[tindex] += diff
if not tindex == 0:
tindex = int((tindex-1)/2)
self.reconstruct(tindex, diff)
def find(self, value, norm=True):
if norm:
value *= self.tree[0]
return self._find(value, 0)
def _find(self, value, index):
if 2**(self.tree_level-1)-1 <= index:
return self.data[index-(2**(self.tree_level-1)-1)], self.tree[index], index-(2**(self.tree_level-1)-1)
left = self.tree[2*index+1]
if value <= left:
return self._find(value,2*index+1)
else:
return self._find(value-left,2*(index+1))
def print_tree(self):
for k in range(1, self.tree_level+1):
for j in range(2**(k-1)-1, 2**k-1):
print(self.tree[j], end=' ')
print()
def filled_size(self):
return self.size

View File

@ -0,0 +1,13 @@
#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: stmayue@gmail.com
# description:
def list_to_dict(in_list):
return dict((i, in_list[i]) for i in range(0, len(in_list)))
def exchange_key_value(in_dict):
return dict((in_dict[i], i) for i in in_dict)

View File

@ -0,0 +1,17 @@
from rank_based import *
from proportional import *
from naive import *
import sys
def getReplayBuffer(name, conf):
'''
Get replay buffer according to the given name.
'''
if (name == 'rank_based'):
return RankBasedExperience(conf)
elif (name == 'proportional'):
return PropotionalExperience(conf)
elif (name == 'naive'):
return NaiveExperience(conf)
else:
sys.stderr.write('no such replay buffer')