replay buffer initial commit
This commit is contained in:
parent
a8a12f1083
commit
f1a7fd9ee1
4
tianshou/data/replay_buffer/__init__.py
Normal file
4
tianshou/data/replay_buffer/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from os.path import dirname, basename, isfile
|
||||
import glob
|
||||
modules = glob.glob(dirname(__file__)+"/*.py")
|
||||
__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
|
221
tianshou/data/replay_buffer/binary_heap.py
Normal file
221
tianshou/data/replay_buffer/binary_heap.py
Normal file
@ -0,0 +1,221 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding=utf-8 -*-
|
||||
# author: Ian
|
||||
# e-mail: stmayue@gmail.com
|
||||
# description:
|
||||
|
||||
import sys
|
||||
import math
|
||||
|
||||
import utility
|
||||
|
||||
|
||||
class BinaryHeap(object):
|
||||
|
||||
def __init__(self, priority_size=100, priority_init=None, replace=True):
|
||||
self.e2p = {}
|
||||
self.p2e = {}
|
||||
self.replace = replace
|
||||
|
||||
if priority_init is None:
|
||||
self.priority_queue = {}
|
||||
self.size = 0
|
||||
self.max_size = priority_size
|
||||
else:
|
||||
# not yet test
|
||||
self.priority_queue = priority_init
|
||||
self.size = len(self.priority_queue)
|
||||
self.max_size = None or self.size
|
||||
|
||||
experience_list = list(map(lambda x: self.priority_queue[x], self.priority_queue))
|
||||
self.p2e = utility.list_to_dict(experience_list)
|
||||
self.e2p = utility.exchange_key_value(self.p2e)
|
||||
for i in range(int(self.size / 2), -1, -1):
|
||||
self.down_heap(i)
|
||||
|
||||
def __repr__(self):
|
||||
"""
|
||||
:return: string of the priority queue, with level info
|
||||
"""
|
||||
if self.size == 0:
|
||||
return 'No element in heap!'
|
||||
to_string = ''
|
||||
level = -1
|
||||
max_level = int(math.floor(math.log(self.size, 2)))
|
||||
|
||||
for i in range(1, self.size + 1):
|
||||
now_level = int(math.floor(math.log(i, 2)))
|
||||
if level != now_level:
|
||||
to_string = to_string + ('\n' if level != -1 else '') \
|
||||
+ ' ' * (max_level - now_level)
|
||||
level = now_level
|
||||
|
||||
to_string = to_string + '%.2f ' % self.priority_queue[i][1] + ' ' * (max_level - now_level)
|
||||
|
||||
return to_string
|
||||
|
||||
def check_full(self):
|
||||
return self.size > self.max_size
|
||||
|
||||
def _insert(self, priority, e_id):
|
||||
"""
|
||||
insert new experience id with priority
|
||||
(maybe don't need get_max_priority and implement it in this function)
|
||||
:param priority: priority value
|
||||
:param e_id: experience id
|
||||
:return: bool
|
||||
"""
|
||||
self.size += 1
|
||||
|
||||
if self.check_full() and not self.replace:
|
||||
sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority))
|
||||
return False
|
||||
else:
|
||||
self.size = min(self.size, self.max_size)
|
||||
|
||||
self.priority_queue[self.size] = (priority, e_id)
|
||||
self.p2e[self.size] = e_id
|
||||
self.e2p[e_id] = self.size
|
||||
|
||||
self.up_heap(self.size)
|
||||
return True
|
||||
|
||||
def update(self, priority, e_id):
|
||||
"""
|
||||
update priority value according its experience id
|
||||
:param priority: new priority value
|
||||
:param e_id: experience id
|
||||
:return: bool
|
||||
"""
|
||||
if e_id in self.e2p:
|
||||
p_id = self.e2p[e_id]
|
||||
self.priority_queue[p_id] = (priority, e_id)
|
||||
self.p2e[p_id] = e_id
|
||||
|
||||
self.down_heap(p_id)
|
||||
self.up_heap(p_id)
|
||||
return True
|
||||
else:
|
||||
# this e id is new, do insert
|
||||
return self._insert(priority, e_id)
|
||||
|
||||
def get_max_priority(self):
|
||||
"""
|
||||
get max priority, if no experience, return 1
|
||||
:return: max priority if size > 0 else 1
|
||||
"""
|
||||
if self.size > 0:
|
||||
return self.priority_queue[1][0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def pop(self):
|
||||
"""
|
||||
pop out the max priority value with its experience id
|
||||
:return: priority value & experience id
|
||||
"""
|
||||
if self.size == 0:
|
||||
sys.stderr.write('Error: no value in heap, pop failed\n')
|
||||
return False, False
|
||||
|
||||
pop_priority, pop_e_id = self.priority_queue[1]
|
||||
self.e2p[pop_e_id] = -1
|
||||
# replace first
|
||||
last_priority, last_e_id = self.priority_queue[self.size]
|
||||
self.priority_queue[1] = (last_priority, last_e_id)
|
||||
self.size -= 1
|
||||
self.e2p[last_e_id] = 1
|
||||
self.p2e[1] = last_e_id
|
||||
|
||||
self.down_heap(1)
|
||||
|
||||
return pop_priority, pop_e_id
|
||||
|
||||
def up_heap(self, i):
|
||||
"""
|
||||
upward balance
|
||||
:param i: tree node i
|
||||
:return: None
|
||||
"""
|
||||
if i > 1:
|
||||
parent = math.floor(i / 2)
|
||||
if self.priority_queue[parent][0] < self.priority_queue[i][0]:
|
||||
tmp = self.priority_queue[i]
|
||||
self.priority_queue[i] = self.priority_queue[parent]
|
||||
self.priority_queue[parent] = tmp
|
||||
# change e2p & p2e
|
||||
self.e2p[self.priority_queue[i][1]] = i
|
||||
self.e2p[self.priority_queue[parent][1]] = parent
|
||||
self.p2e[i] = self.priority_queue[i][1]
|
||||
self.p2e[parent] = self.priority_queue[parent][1]
|
||||
# up heap parent
|
||||
self.up_heap(parent)
|
||||
|
||||
def down_heap(self, i):
|
||||
"""
|
||||
downward balance
|
||||
:param i: tree node i
|
||||
:return: None
|
||||
"""
|
||||
if i < self.size:
|
||||
greatest = i
|
||||
left, right = i * 2, i * 2 + 1
|
||||
if left < self.size and self.priority_queue[left][0] > self.priority_queue[greatest][0]:
|
||||
greatest = left
|
||||
if right < self.size and self.priority_queue[right][0] > self.priority_queue[greatest][0]:
|
||||
greatest = right
|
||||
|
||||
if greatest != i:
|
||||
tmp = self.priority_queue[i]
|
||||
self.priority_queue[i] = self.priority_queue[greatest]
|
||||
self.priority_queue[greatest] = tmp
|
||||
# change e2p & p2e
|
||||
self.e2p[self.priority_queue[i][1]] = i
|
||||
self.e2p[self.priority_queue[greatest][1]] = greatest
|
||||
self.p2e[i] = self.priority_queue[i][1]
|
||||
self.p2e[greatest] = self.priority_queue[greatest][1]
|
||||
# down heap greatest
|
||||
self.down_heap(greatest)
|
||||
|
||||
def get_priority(self):
|
||||
"""
|
||||
get all priority value
|
||||
:return: list of priority
|
||||
"""
|
||||
return list(map(lambda x: x[0], self.priority_queue.values()))[0:self.size]
|
||||
|
||||
def get_e_id(self):
|
||||
"""
|
||||
get all experience id in priority queue
|
||||
:return: list of experience ids order by their priority
|
||||
"""
|
||||
return list(map(lambda x: x[1], self.priority_queue.values()))[0:self.size]
|
||||
|
||||
def balance_tree(self):
|
||||
"""
|
||||
rebalance priority queue
|
||||
:return: None
|
||||
"""
|
||||
sort_array = sorted(self.priority_queue.values(), key=lambda x: x[0], reverse=True)
|
||||
# reconstruct priority_queue
|
||||
self.priority_queue.clear()
|
||||
self.p2e.clear()
|
||||
self.e2p.clear()
|
||||
cnt = 1
|
||||
while cnt <= self.size:
|
||||
priority, e_id = sort_array[cnt - 1]
|
||||
self.priority_queue[cnt] = (priority, e_id)
|
||||
self.p2e[cnt] = e_id
|
||||
self.e2p[e_id] = cnt
|
||||
cnt += 1
|
||||
# sort the heap
|
||||
for i in range(int(math.floor(self.size / 2)), 1, -1):
|
||||
self.down_heap(i)
|
||||
|
||||
def priority_to_experience(self, priority_ids):
|
||||
"""
|
||||
retrieve experience ids by priority ids
|
||||
:param priority_ids: list of priority id
|
||||
:return: list of experience id
|
||||
"""
|
||||
return [self.p2e[i] for i in priority_ids]
|
39
tianshou/data/replay_buffer/buffer.py
Normal file
39
tianshou/data/replay_buffer/buffer.py
Normal file
@ -0,0 +1,39 @@
|
||||
class ReplayBuffer(object):
|
||||
def __init__(self, conf):
|
||||
'''
|
||||
Initialize a replay buffer with parameters in conf.
|
||||
'''
|
||||
pass
|
||||
|
||||
def add(self, data, priority):
|
||||
'''
|
||||
Add a data with priority = priority to replay buffer.
|
||||
'''
|
||||
pass
|
||||
|
||||
def update_priority(self, indices, priorities):
|
||||
'''
|
||||
Update the data's priority whose indices = indices.
|
||||
For proportional replay buffer, the priority is the priority.
|
||||
For rank based replay buffer, the priorities parameter will be the delta used to update the priority.
|
||||
'''
|
||||
pass
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
'''
|
||||
This function only works for proportional replay buffer.
|
||||
This function resets alpha.
|
||||
'''
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
'''
|
||||
Sample from replay buffer with parameters in conf.
|
||||
'''
|
||||
pass
|
||||
|
||||
def rebalance(self):
|
||||
'''
|
||||
This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue.
|
||||
'''
|
||||
pass
|
29
tianshou/data/replay_buffer/naive.py
Normal file
29
tianshou/data/replay_buffer/naive.py
Normal file
@ -0,0 +1,29 @@
|
||||
from buffer import ReplayBuffer
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
|
||||
class NaiveExperience(ReplayBuffer):
|
||||
def __init__(self, conf):
|
||||
self.max_size = conf['size']
|
||||
self.n_entries = 0
|
||||
self.memory = deque(maxlen = self.max_size)
|
||||
|
||||
def add(self, data, priority = 0):
|
||||
self.memory.append(data)
|
||||
if self.n_entries < self.max_size:
|
||||
self.n_entries += 1
|
||||
|
||||
def update_priority(self, indices, priorities = 0):
|
||||
pass
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
batch_size = conf['batch_size']
|
||||
batch_size = min(len(self.memory), batch_size)
|
||||
idxs = np.random.choice(len(self.memory), batch_size)
|
||||
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
|
||||
|
||||
def rebalance(self):
|
||||
pass
|
119
tianshou/data/replay_buffer/proportional.py
Normal file
119
tianshou/data/replay_buffer/proportional.py
Normal file
@ -0,0 +1,119 @@
|
||||
import numpy
|
||||
import random
|
||||
import sum_tree
|
||||
from buffer import ReplayBuffer
|
||||
|
||||
|
||||
class PropotionalExperience(ReplayBuffer):
|
||||
""" The class represents prioritized experience replay buffer.
|
||||
|
||||
The class has functions: store samples, pick samples with
|
||||
probability in proportion to sample's priority, update
|
||||
each sample's priority, reset alpha.
|
||||
|
||||
see https://arxiv.org/pdf/1511.05952.pdf .
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, conf):
|
||||
""" Prioritized experience replay buffer initialization.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
memory_size : int
|
||||
sample size to be stored
|
||||
batch_size : int
|
||||
batch size to be selected by `select` method
|
||||
alpha: float
|
||||
exponent determine how much prioritization.
|
||||
Prob_i \sim priority_i**alpha/sum(priority**alpha)
|
||||
"""
|
||||
memory_size = conf['size']
|
||||
batch_size = conf['batch_size']
|
||||
alpha = conf['alpha']
|
||||
self.tree = sum_tree.SumTree(memory_size)
|
||||
self.memory_size = memory_size
|
||||
self.batch_size = batch_size
|
||||
self.alpha = alpha
|
||||
|
||||
def add(self, data, priority):
|
||||
""" Add new sample.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : object
|
||||
new sample
|
||||
priority : float
|
||||
sample's priority
|
||||
"""
|
||||
self.tree.add(data, priority**self.alpha)
|
||||
|
||||
def sample(self, conf):
|
||||
""" The method return samples randomly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
beta : float
|
||||
|
||||
Returns
|
||||
-------
|
||||
out :
|
||||
list of samples
|
||||
weights:
|
||||
list of weight
|
||||
indices:
|
||||
list of sample indices
|
||||
The indices indicate sample positions in a sum tree.
|
||||
"""
|
||||
beta = conf['beta']
|
||||
if self.tree.filled_size() < self.batch_size:
|
||||
return None, None, None
|
||||
|
||||
out = []
|
||||
indices = []
|
||||
weights = []
|
||||
priorities = []
|
||||
for _ in range(self.batch_size):
|
||||
r = random.random()
|
||||
data, priority, index = self.tree.find(r)
|
||||
priorities.append(priority)
|
||||
weights.append((1./self.memory_size/priority)**beta if priority > 1e-16 else 0)
|
||||
indices.append(index)
|
||||
out.append(data)
|
||||
self.update_priority([index], [0]) # To avoid duplicating
|
||||
|
||||
|
||||
self.update_priority(indices, priorities) # Revert priorities
|
||||
|
||||
max_weights = max(weights)
|
||||
|
||||
weights[:] = [x / max_weights for x in weights] # Normalize for stability
|
||||
|
||||
return out, weights, indices
|
||||
|
||||
def update_priority(self, indices, priorities):
|
||||
""" The methods update samples's priority.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
indices :
|
||||
list of sample indices
|
||||
"""
|
||||
for i, p in zip(indices, priorities):
|
||||
self.tree.val_update(i, p**self.alpha)
|
||||
|
||||
def reset_alpha(self, alpha):
|
||||
""" Reset a exponent alpha.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alpha : float
|
||||
"""
|
||||
self.alpha, old_alpha = alpha, self.alpha
|
||||
priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())]
|
||||
self.update_priority(range(self.tree.filled_size()), priorities)
|
||||
|
||||
|
||||
|
||||
|
||||
|
184
tianshou/data/replay_buffer/rank_based.py
Normal file
184
tianshou/data/replay_buffer/rank_based.py
Normal file
@ -0,0 +1,184 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding=utf-8 -*-
|
||||
# author: Ian
|
||||
# e-mail: stmayue@gmail.com
|
||||
# description:
|
||||
|
||||
import sys
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
from binary_heap import BinaryHeap
|
||||
from buffer import ReplayBuffer
|
||||
|
||||
class RankBasedExperience(ReplayBuffer):
|
||||
|
||||
def __init__(self, conf):
|
||||
self.size = conf['size']
|
||||
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
|
||||
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
|
||||
|
||||
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
|
||||
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
|
||||
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
|
||||
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
|
||||
self.total_steps = conf['steps'] if 'steps' in conf else 100000
|
||||
# partition number N, split total size to N part
|
||||
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100
|
||||
|
||||
self.index = 0
|
||||
self.record_size = 0
|
||||
self.isFull = False
|
||||
|
||||
self._experience = {}
|
||||
self.priority_queue = BinaryHeap(self.priority_size)
|
||||
self.distributions = self.build_distributions()
|
||||
|
||||
self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start)
|
||||
|
||||
def build_distributions(self):
|
||||
"""
|
||||
preprocess pow of rank
|
||||
(rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
|
||||
:return: distributions, dict
|
||||
"""
|
||||
res = {}
|
||||
n_partitions = self.partition_num
|
||||
partition_num = 1
|
||||
# each part size
|
||||
partition_size = int(math.floor(self.size / n_partitions))
|
||||
|
||||
for n in range(partition_size, self.size + 1, partition_size):
|
||||
if self.learn_start <= n <= self.priority_size:
|
||||
distribution = {}
|
||||
# P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
|
||||
pdf = list(
|
||||
map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))
|
||||
)
|
||||
pdf_sum = math.fsum(pdf)
|
||||
distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf))
|
||||
# split to k segment, and than uniform sample in each k
|
||||
# set k = batch_size, each segment has total probability is 1 / batch_size
|
||||
# strata_ends keep each segment start pos and end pos
|
||||
cdf = np.cumsum(distribution['pdf'])
|
||||
strata_ends = {1: 0, self.batch_size + 1: n}
|
||||
step = 1. / self.batch_size
|
||||
index = 1
|
||||
for s in range(2, self.batch_size + 1):
|
||||
while cdf[index] < step:
|
||||
index += 1
|
||||
strata_ends[s] = index
|
||||
step += 1. / self.batch_size
|
||||
|
||||
distribution['strata_ends'] = strata_ends
|
||||
|
||||
res[partition_num] = distribution
|
||||
|
||||
partition_num += 1
|
||||
|
||||
return res
|
||||
|
||||
def fix_index(self):
|
||||
"""
|
||||
get next insert index
|
||||
:return: index, int
|
||||
"""
|
||||
if self.record_size <= self.size:
|
||||
self.record_size += 1
|
||||
if self.index % self.size == 0:
|
||||
self.isFull = True if len(self._experience) == self.size else False
|
||||
if self.replace_flag:
|
||||
self.index = 1
|
||||
return self.index
|
||||
else:
|
||||
sys.stderr.write('Experience replay buff is full and replace is set to FALSE!\n')
|
||||
return -1
|
||||
else:
|
||||
self.index += 1
|
||||
return self.index
|
||||
|
||||
def add(self, data, priority = 0):
|
||||
"""
|
||||
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
|
||||
so each experience is valid
|
||||
:param experience: maybe a tuple, or list
|
||||
:return: bool, indicate insert status
|
||||
"""
|
||||
insert_index = self.fix_index()
|
||||
if insert_index > 0:
|
||||
if insert_index in self._experience:
|
||||
del self._experience[insert_index]
|
||||
self._experience[insert_index] = data
|
||||
# add to priority queue
|
||||
priority = self.priority_queue.get_max_priority()
|
||||
self.priority_queue.update(priority, insert_index)
|
||||
return True
|
||||
else:
|
||||
sys.stderr.write('Insert failed\n')
|
||||
return False
|
||||
|
||||
def retrieve(self, indices):
|
||||
"""
|
||||
get experience from indices
|
||||
:param indices: list of experience id
|
||||
:return: experience replay sample
|
||||
"""
|
||||
return [self._experience[v] for v in indices]
|
||||
|
||||
def rebalance(self):
|
||||
"""
|
||||
rebalance priority queue
|
||||
:return: None
|
||||
"""
|
||||
self.priority_queue.balance_tree()
|
||||
|
||||
def update_priority(self, indices, delta):
|
||||
"""
|
||||
update priority according indices and deltas
|
||||
:param indices: list of experience id
|
||||
:param delta: list of delta, order correspond to indices
|
||||
:return: None
|
||||
"""
|
||||
for i in range(0, len(indices)):
|
||||
self.priority_queue.update(math.fabs(delta[i]), indices[i])
|
||||
|
||||
def sample(self, conf):
|
||||
"""
|
||||
sample a mini batch from experience replay
|
||||
:param global_step: now training step
|
||||
:return: experience, list, samples
|
||||
:return: w, list, weights
|
||||
:return: rank_e_id, list, samples id, used for update priority
|
||||
"""
|
||||
global_step = conf['global_step']
|
||||
if self.record_size < self.learn_start:
|
||||
sys.stderr.write('Record size less than learn start! Sample failed\n')
|
||||
return False, False, False
|
||||
|
||||
dist_index = math.floor(self.record_size / self.size * self.partition_num)
|
||||
# issue 1 by @camigord
|
||||
partition_size = math.floor(self.size / self.partition_num)
|
||||
partition_max = dist_index * partition_size
|
||||
distribution = self.distributions[dist_index]
|
||||
rank_list = []
|
||||
# sample from k segments
|
||||
for n in range(1, self.batch_size + 1):
|
||||
index = random.randint(distribution['strata_ends'][n] + 1,
|
||||
distribution['strata_ends'][n + 1])
|
||||
rank_list.append(index)
|
||||
|
||||
# beta, increase by global_step, max 1
|
||||
beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1)
|
||||
# find all alpha pow, notice that pdf is a list, start from 0
|
||||
alpha_pow = [distribution['pdf'][v - 1] for v in rank_list]
|
||||
# w = (N * P(i)) ^ (-beta) / max w
|
||||
w = np.power(np.array(alpha_pow) * partition_max, -beta)
|
||||
w_max = max(w)
|
||||
w = np.divide(w, w_max)
|
||||
# rank list is priority id
|
||||
# convert to experience id
|
||||
rank_e_id = self.priority_queue.priority_to_experience(rank_list)
|
||||
# get experience id according rank_e_id
|
||||
experience = self.retrieve(rank_e_id)
|
||||
return experience, w, rank_e_id
|
129
tianshou/data/replay_buffer/replay_buffer_test.py
Normal file
129
tianshou/data/replay_buffer/replay_buffer_test.py
Normal file
@ -0,0 +1,129 @@
|
||||
from utils import *
|
||||
from functions import *
|
||||
|
||||
def test_rank_based():
|
||||
conf = {'size': 50,
|
||||
'learn_start': 10,
|
||||
'partition_num': 5,
|
||||
'total_step': 100,
|
||||
'batch_size': 4}
|
||||
experience = getReplayBuffer('rank_based', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
for i in range(1, 51):
|
||||
# tuple, like(state_t, a, r, state_t_1, t)
|
||||
to_insert = (i, 1, 1, i, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.priority_queue
|
||||
print experience._experience[1]
|
||||
print experience._experience[2]
|
||||
print 'test replace'
|
||||
to_insert = (51, 1, 1, 51, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.priority_queue
|
||||
print experience._experience[1]
|
||||
print experience._experience[2]
|
||||
|
||||
# sample
|
||||
print 'test sample'
|
||||
global_step = {'global_step': 51}
|
||||
sample, w, e_id = experience.sample(global_step)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# update delta to priority
|
||||
print 'test update delta'
|
||||
delta = [v for v in range(1, 5)]
|
||||
experience.update_priority(e_id, delta)
|
||||
print experience.priority_queue
|
||||
sample, w, e_id = experience.sample(global_step)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# rebalance
|
||||
print 'test rebalance'
|
||||
experience.rebalance()
|
||||
print experience.priority_queue
|
||||
|
||||
def test_proportional():
|
||||
conf = {'size': 50,
|
||||
'alpha': 0.7,
|
||||
'batch_size': 4}
|
||||
experience = getReplayBuffer('proportional', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
for i in range(1, 51):
|
||||
# tuple, like(state_t, a, r, state_t_1, t)
|
||||
to_insert = (i, 1, 1, i, 1)
|
||||
experience.add(to_insert, i)
|
||||
print experience.tree
|
||||
print experience.tree.get_val(1)
|
||||
print experience.tree.get_val(2)
|
||||
print 'test replace'
|
||||
to_insert = (51, 1, 1, 51, 1)
|
||||
experience.add(to_insert, 51)
|
||||
print experience.tree
|
||||
print experience.tree.get_val(1)
|
||||
print experience.tree.get_val(2)
|
||||
|
||||
# sample
|
||||
print 'test sample'
|
||||
beta = {'beta': 0.005}
|
||||
sample, w, e_id = experience.sample(beta)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# update delta to priority
|
||||
print 'test update delta'
|
||||
delta = [v for v in range(1, 5)]
|
||||
experience.update_priority(e_id, delta)
|
||||
print experience.tree
|
||||
sample, w, e_id = experience.sample(beta)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
def test_naive():
|
||||
conf = {'size': 50}
|
||||
experience = getReplayBuffer('naive', conf)
|
||||
|
||||
# insert to experience
|
||||
print 'test insert experience'
|
||||
for i in range(1, 51):
|
||||
# tuple, like(state_t, a, r, state_t_1, t)
|
||||
to_insert = (i, 1, 1, i, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.memory
|
||||
print 'test replace'
|
||||
to_insert = (51, 1, 1, 51, 1)
|
||||
experience.add(to_insert)
|
||||
print experience.memory
|
||||
|
||||
# sample
|
||||
print 'test sample'
|
||||
batch_size = {'batch_size': 5}
|
||||
sample, w, e_id = experience.sample(batch_size)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
# update delta to priority
|
||||
print 'test update delta'
|
||||
delta = [v for v in range(1, 5)]
|
||||
experience.update_priority(e_id, delta)
|
||||
print experience.memory
|
||||
sample, w, e_id = experience.sample(batch_size)
|
||||
print sample
|
||||
print w
|
||||
print e_id
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_rank_based()
|
||||
test_proportional()
|
||||
test_naive()
|
64
tianshou/data/replay_buffer/sum_tree.py
Executable file
64
tianshou/data/replay_buffer/sum_tree.py
Executable file
@ -0,0 +1,64 @@
|
||||
#! -*- coding:utf-8 -*-
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
import math
|
||||
|
||||
class SumTree(object):
|
||||
def __init__(self, max_size):
|
||||
self.max_size = max_size
|
||||
self.tree_level = int(math.ceil(math.log(max_size+1, 2))+1)
|
||||
self.tree_size = 2**self.tree_level-1
|
||||
self.tree = [0 for i in range(self.tree_size)]
|
||||
self.data = [None for i in range(self.max_size)]
|
||||
self.size = 0
|
||||
self.cursor = 0
|
||||
|
||||
def add(self, contents, value):
|
||||
index = self.cursor
|
||||
self.cursor = (self.cursor+1)%self.max_size
|
||||
self.size = min(self.size+1, self.max_size)
|
||||
|
||||
self.data[index] = contents
|
||||
self.val_update(index, value)
|
||||
|
||||
def get_val(self, index):
|
||||
tree_index = 2**(self.tree_level-1)-1+index
|
||||
return self.tree[tree_index]
|
||||
|
||||
def val_update(self, index, value):
|
||||
tree_index = 2**(self.tree_level-1)-1+index
|
||||
diff = value-self.tree[tree_index]
|
||||
self.reconstruct(tree_index, diff)
|
||||
|
||||
def reconstruct(self, tindex, diff):
|
||||
self.tree[tindex] += diff
|
||||
if not tindex == 0:
|
||||
tindex = int((tindex-1)/2)
|
||||
self.reconstruct(tindex, diff)
|
||||
|
||||
def find(self, value, norm=True):
|
||||
if norm:
|
||||
value *= self.tree[0]
|
||||
return self._find(value, 0)
|
||||
|
||||
def _find(self, value, index):
|
||||
if 2**(self.tree_level-1)-1 <= index:
|
||||
return self.data[index-(2**(self.tree_level-1)-1)], self.tree[index], index-(2**(self.tree_level-1)-1)
|
||||
|
||||
left = self.tree[2*index+1]
|
||||
|
||||
if value <= left:
|
||||
return self._find(value,2*index+1)
|
||||
else:
|
||||
return self._find(value-left,2*(index+1))
|
||||
|
||||
def print_tree(self):
|
||||
for k in range(1, self.tree_level+1):
|
||||
for j in range(2**(k-1)-1, 2**k-1):
|
||||
print(self.tree[j], end=' ')
|
||||
print()
|
||||
|
||||
def filled_size(self):
|
||||
return self.size
|
13
tianshou/data/replay_buffer/utility.py
Normal file
13
tianshou/data/replay_buffer/utility.py
Normal file
@ -0,0 +1,13 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding=utf-8 -*-
|
||||
# author: Ian
|
||||
# e-mail: stmayue@gmail.com
|
||||
# description:
|
||||
|
||||
|
||||
def list_to_dict(in_list):
|
||||
return dict((i, in_list[i]) for i in range(0, len(in_list)))
|
||||
|
||||
|
||||
def exchange_key_value(in_dict):
|
||||
return dict((in_dict[i], i) for i in in_dict)
|
17
tianshou/data/replay_buffer/utils.py
Normal file
17
tianshou/data/replay_buffer/utils.py
Normal file
@ -0,0 +1,17 @@
|
||||
from rank_based import *
|
||||
from proportional import *
|
||||
from naive import *
|
||||
import sys
|
||||
|
||||
def getReplayBuffer(name, conf):
|
||||
'''
|
||||
Get replay buffer according to the given name.
|
||||
'''
|
||||
if (name == 'rank_based'):
|
||||
return RankBasedExperience(conf)
|
||||
elif (name == 'proportional'):
|
||||
return PropotionalExperience(conf)
|
||||
elif (name == 'naive'):
|
||||
return NaiveExperience(conf)
|
||||
else:
|
||||
sys.stderr.write('no such replay buffer')
|
Loading…
x
Reference in New Issue
Block a user