2017-12-10 14:56:04 +08:00

185 lines
6.8 KiB
Python

#!/usr/bin/python
# -*- encoding=utf-8 -*-
# author: Ian
# e-mail: stmayue@gmail.com
# description:
import sys
import math
import random
import numpy as np
from binary_heap import BinaryHeap
from buffer import ReplayBuffer
class RankBasedExperience(ReplayBuffer):
def __init__(self, conf):
self.size = conf['size']
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
self.alpha = conf['alpha'] if 'alpha' in conf else 0.7
self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5
self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
self.total_steps = conf['steps'] if 'steps' in conf else 100000
# partition number N, split total size to N part
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100
self.index = 0
self.record_size = 0
self.isFull = False
self._experience = {}
self.priority_queue = BinaryHeap(self.priority_size)
self.distributions = self.build_distributions()
self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start)
def build_distributions(self):
"""
preprocess pow of rank
(rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
:return: distributions, dict
"""
res = {}
n_partitions = self.partition_num
partition_num = 1
# each part size
partition_size = int(math.floor(self.size / n_partitions))
for n in range(partition_size, self.size + 1, partition_size):
if self.learn_start <= n <= self.priority_size:
distribution = {}
# P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha))
pdf = list(
map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))
)
pdf_sum = math.fsum(pdf)
distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf))
# split to k segment, and than uniform sample in each k
# set k = batch_size, each segment has total probability is 1 / batch_size
# strata_ends keep each segment start pos and end pos
cdf = np.cumsum(distribution['pdf'])
strata_ends = {1: 0, self.batch_size + 1: n}
step = 1. / self.batch_size
index = 1
for s in range(2, self.batch_size + 1):
while cdf[index] < step:
index += 1
strata_ends[s] = index
step += 1. / self.batch_size
distribution['strata_ends'] = strata_ends
res[partition_num] = distribution
partition_num += 1
return res
def fix_index(self):
"""
get next insert index
:return: index, int
"""
if self.record_size <= self.size:
self.record_size += 1
if self.index % self.size == 0:
self.isFull = True if len(self._experience) == self.size else False
if self.replace_flag:
self.index = 1
return self.index
else:
sys.stderr.write('Experience replay buff is full and replace is set to FALSE!\n')
return -1
else:
self.index += 1
return self.index
def add(self, data, priority = 0):
"""
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
so each experience is valid
:param experience: maybe a tuple, or list
:return: bool, indicate insert status
"""
insert_index = self.fix_index()
if insert_index > 0:
if insert_index in self._experience:
del self._experience[insert_index]
self._experience[insert_index] = data
# add to priority queue
priority = self.priority_queue.get_max_priority()
self.priority_queue.update(priority, insert_index)
return True
else:
sys.stderr.write('Insert failed\n')
return False
def retrieve(self, indices):
"""
get experience from indices
:param indices: list of experience id
:return: experience replay sample
"""
return [self._experience[v] for v in indices]
def rebalance(self):
"""
rebalance priority queue
:return: None
"""
self.priority_queue.balance_tree()
def update_priority(self, indices, delta):
"""
update priority according indices and deltas
:param indices: list of experience id
:param delta: list of delta, order correspond to indices
:return: None
"""
for i in range(0, len(indices)):
self.priority_queue.update(math.fabs(delta[i]), indices[i])
def sample(self, conf):
"""
sample a mini batch from experience replay
:param global_step: now training step
:return: experience, list, samples
:return: w, list, weights
:return: rank_e_id, list, samples id, used for update priority
"""
global_step = conf['global_step']
if self.record_size < self.learn_start:
sys.stderr.write('Record size less than learn start! Sample failed\n')
return False, False, False
dist_index = math.floor(self.record_size / self.size * self.partition_num)
# issue 1 by @camigord
partition_size = math.floor(self.size / self.partition_num)
partition_max = dist_index * partition_size
distribution = self.distributions[dist_index]
rank_list = []
# sample from k segments
for n in range(1, self.batch_size + 1):
index = random.randint(distribution['strata_ends'][n] + 1,
distribution['strata_ends'][n + 1])
rank_list.append(index)
# beta, increase by global_step, max 1
beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1)
# find all alpha pow, notice that pdf is a list, start from 0
alpha_pow = [distribution['pdf'][v - 1] for v in rank_list]
# w = (N * P(i)) ^ (-beta) / max w
w = np.power(np.array(alpha_pow) * partition_max, -beta)
w_max = max(w)
w = np.divide(w, w_max)
# rank list is priority id
# convert to experience id
rank_e_id = self.priority_queue.priority_to_experience(rank_list)
# get experience id according rank_e_id
experience = self.retrieve(rank_e_id)
return experience, w, rank_e_id