From f1a7fd9ee1a4b74caa484e2ba45eb1fe76a7422a Mon Sep 17 00:00:00 2001 From: songshshshsh <644240545@qq.com> Date: Sun, 10 Dec 2017 14:53:57 +0800 Subject: [PATCH] replay buffer initial commit --- tianshou/data/replay_buffer/__init__.py | 4 + tianshou/data/replay_buffer/binary_heap.py | 221 ++++++++++++++++++ tianshou/data/replay_buffer/buffer.py | 39 ++++ tianshou/data/replay_buffer/naive.py | 29 +++ tianshou/data/replay_buffer/proportional.py | 119 ++++++++++ tianshou/data/replay_buffer/rank_based.py | 184 +++++++++++++++ .../data/replay_buffer/replay_buffer_test.py | 129 ++++++++++ tianshou/data/replay_buffer/sum_tree.py | 64 +++++ tianshou/data/replay_buffer/utility.py | 13 ++ tianshou/data/replay_buffer/utils.py | 17 ++ 10 files changed, 819 insertions(+) create mode 100644 tianshou/data/replay_buffer/__init__.py create mode 100644 tianshou/data/replay_buffer/binary_heap.py create mode 100644 tianshou/data/replay_buffer/buffer.py create mode 100644 tianshou/data/replay_buffer/naive.py create mode 100644 tianshou/data/replay_buffer/proportional.py create mode 100644 tianshou/data/replay_buffer/rank_based.py create mode 100644 tianshou/data/replay_buffer/replay_buffer_test.py create mode 100755 tianshou/data/replay_buffer/sum_tree.py create mode 100644 tianshou/data/replay_buffer/utility.py create mode 100644 tianshou/data/replay_buffer/utils.py diff --git a/tianshou/data/replay_buffer/__init__.py b/tianshou/data/replay_buffer/__init__.py new file mode 100644 index 0000000..0deb77e --- /dev/null +++ b/tianshou/data/replay_buffer/__init__.py @@ -0,0 +1,4 @@ +from os.path import dirname, basename, isfile +import glob +modules = glob.glob(dirname(__file__)+"/*.py") +__all__ = [ basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')] diff --git a/tianshou/data/replay_buffer/binary_heap.py b/tianshou/data/replay_buffer/binary_heap.py new file mode 100644 index 0000000..e2b1474 --- /dev/null +++ b/tianshou/data/replay_buffer/binary_heap.py @@ -0,0 +1,221 @@ +#!/usr/bin/python +# -*- encoding=utf-8 -*- +# author: Ian +# e-mail: stmayue@gmail.com +# description: + +import sys +import math + +import utility + + +class BinaryHeap(object): + + def __init__(self, priority_size=100, priority_init=None, replace=True): + self.e2p = {} + self.p2e = {} + self.replace = replace + + if priority_init is None: + self.priority_queue = {} + self.size = 0 + self.max_size = priority_size + else: + # not yet test + self.priority_queue = priority_init + self.size = len(self.priority_queue) + self.max_size = None or self.size + + experience_list = list(map(lambda x: self.priority_queue[x], self.priority_queue)) + self.p2e = utility.list_to_dict(experience_list) + self.e2p = utility.exchange_key_value(self.p2e) + for i in range(int(self.size / 2), -1, -1): + self.down_heap(i) + + def __repr__(self): + """ + :return: string of the priority queue, with level info + """ + if self.size == 0: + return 'No element in heap!' + to_string = '' + level = -1 + max_level = int(math.floor(math.log(self.size, 2))) + + for i in range(1, self.size + 1): + now_level = int(math.floor(math.log(i, 2))) + if level != now_level: + to_string = to_string + ('\n' if level != -1 else '') \ + + ' ' * (max_level - now_level) + level = now_level + + to_string = to_string + '%.2f ' % self.priority_queue[i][1] + ' ' * (max_level - now_level) + + return to_string + + def check_full(self): + return self.size > self.max_size + + def _insert(self, priority, e_id): + """ + insert new experience id with priority + (maybe don't need get_max_priority and implement it in this function) + :param priority: priority value + :param e_id: experience id + :return: bool + """ + self.size += 1 + + if self.check_full() and not self.replace: + sys.stderr.write('Error: no space left to add experience id %d with priority value %f\n' % (e_id, priority)) + return False + else: + self.size = min(self.size, self.max_size) + + self.priority_queue[self.size] = (priority, e_id) + self.p2e[self.size] = e_id + self.e2p[e_id] = self.size + + self.up_heap(self.size) + return True + + def update(self, priority, e_id): + """ + update priority value according its experience id + :param priority: new priority value + :param e_id: experience id + :return: bool + """ + if e_id in self.e2p: + p_id = self.e2p[e_id] + self.priority_queue[p_id] = (priority, e_id) + self.p2e[p_id] = e_id + + self.down_heap(p_id) + self.up_heap(p_id) + return True + else: + # this e id is new, do insert + return self._insert(priority, e_id) + + def get_max_priority(self): + """ + get max priority, if no experience, return 1 + :return: max priority if size > 0 else 1 + """ + if self.size > 0: + return self.priority_queue[1][0] + else: + return 1 + + def pop(self): + """ + pop out the max priority value with its experience id + :return: priority value & experience id + """ + if self.size == 0: + sys.stderr.write('Error: no value in heap, pop failed\n') + return False, False + + pop_priority, pop_e_id = self.priority_queue[1] + self.e2p[pop_e_id] = -1 + # replace first + last_priority, last_e_id = self.priority_queue[self.size] + self.priority_queue[1] = (last_priority, last_e_id) + self.size -= 1 + self.e2p[last_e_id] = 1 + self.p2e[1] = last_e_id + + self.down_heap(1) + + return pop_priority, pop_e_id + + def up_heap(self, i): + """ + upward balance + :param i: tree node i + :return: None + """ + if i > 1: + parent = math.floor(i / 2) + if self.priority_queue[parent][0] < self.priority_queue[i][0]: + tmp = self.priority_queue[i] + self.priority_queue[i] = self.priority_queue[parent] + self.priority_queue[parent] = tmp + # change e2p & p2e + self.e2p[self.priority_queue[i][1]] = i + self.e2p[self.priority_queue[parent][1]] = parent + self.p2e[i] = self.priority_queue[i][1] + self.p2e[parent] = self.priority_queue[parent][1] + # up heap parent + self.up_heap(parent) + + def down_heap(self, i): + """ + downward balance + :param i: tree node i + :return: None + """ + if i < self.size: + greatest = i + left, right = i * 2, i * 2 + 1 + if left < self.size and self.priority_queue[left][0] > self.priority_queue[greatest][0]: + greatest = left + if right < self.size and self.priority_queue[right][0] > self.priority_queue[greatest][0]: + greatest = right + + if greatest != i: + tmp = self.priority_queue[i] + self.priority_queue[i] = self.priority_queue[greatest] + self.priority_queue[greatest] = tmp + # change e2p & p2e + self.e2p[self.priority_queue[i][1]] = i + self.e2p[self.priority_queue[greatest][1]] = greatest + self.p2e[i] = self.priority_queue[i][1] + self.p2e[greatest] = self.priority_queue[greatest][1] + # down heap greatest + self.down_heap(greatest) + + def get_priority(self): + """ + get all priority value + :return: list of priority + """ + return list(map(lambda x: x[0], self.priority_queue.values()))[0:self.size] + + def get_e_id(self): + """ + get all experience id in priority queue + :return: list of experience ids order by their priority + """ + return list(map(lambda x: x[1], self.priority_queue.values()))[0:self.size] + + def balance_tree(self): + """ + rebalance priority queue + :return: None + """ + sort_array = sorted(self.priority_queue.values(), key=lambda x: x[0], reverse=True) + # reconstruct priority_queue + self.priority_queue.clear() + self.p2e.clear() + self.e2p.clear() + cnt = 1 + while cnt <= self.size: + priority, e_id = sort_array[cnt - 1] + self.priority_queue[cnt] = (priority, e_id) + self.p2e[cnt] = e_id + self.e2p[e_id] = cnt + cnt += 1 + # sort the heap + for i in range(int(math.floor(self.size / 2)), 1, -1): + self.down_heap(i) + + def priority_to_experience(self, priority_ids): + """ + retrieve experience ids by priority ids + :param priority_ids: list of priority id + :return: list of experience id + """ + return [self.p2e[i] for i in priority_ids] diff --git a/tianshou/data/replay_buffer/buffer.py b/tianshou/data/replay_buffer/buffer.py new file mode 100644 index 0000000..4b92cfc --- /dev/null +++ b/tianshou/data/replay_buffer/buffer.py @@ -0,0 +1,39 @@ +class ReplayBuffer(object): + def __init__(self, conf): + ''' + Initialize a replay buffer with parameters in conf. + ''' + pass + + def add(self, data, priority): + ''' + Add a data with priority = priority to replay buffer. + ''' + pass + + def update_priority(self, indices, priorities): + ''' + Update the data's priority whose indices = indices. + For proportional replay buffer, the priority is the priority. + For rank based replay buffer, the priorities parameter will be the delta used to update the priority. + ''' + pass + + def reset_alpha(self, alpha): + ''' + This function only works for proportional replay buffer. + This function resets alpha. + ''' + pass + + def sample(self, conf): + ''' + Sample from replay buffer with parameters in conf. + ''' + pass + + def rebalance(self): + ''' + This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue. + ''' + pass \ No newline at end of file diff --git a/tianshou/data/replay_buffer/naive.py b/tianshou/data/replay_buffer/naive.py new file mode 100644 index 0000000..9436a39 --- /dev/null +++ b/tianshou/data/replay_buffer/naive.py @@ -0,0 +1,29 @@ +from buffer import ReplayBuffer +import numpy as np +from collections import deque + +class NaiveExperience(ReplayBuffer): + def __init__(self, conf): + self.max_size = conf['size'] + self.n_entries = 0 + self.memory = deque(maxlen = self.max_size) + + def add(self, data, priority = 0): + self.memory.append(data) + if self.n_entries < self.max_size: + self.n_entries += 1 + + def update_priority(self, indices, priorities = 0): + pass + + def reset_alpha(self, alpha): + pass + + def sample(self, conf): + batch_size = conf['batch_size'] + batch_size = min(len(self.memory), batch_size) + idxs = np.random.choice(len(self.memory), batch_size) + return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs + + def rebalance(self): + pass diff --git a/tianshou/data/replay_buffer/proportional.py b/tianshou/data/replay_buffer/proportional.py new file mode 100644 index 0000000..72d1457 --- /dev/null +++ b/tianshou/data/replay_buffer/proportional.py @@ -0,0 +1,119 @@ +import numpy +import random +import sum_tree +from buffer import ReplayBuffer + + +class PropotionalExperience(ReplayBuffer): + """ The class represents prioritized experience replay buffer. + + The class has functions: store samples, pick samples with + probability in proportion to sample's priority, update + each sample's priority, reset alpha. + + see https://arxiv.org/pdf/1511.05952.pdf . + + """ + + def __init__(self, conf): + """ Prioritized experience replay buffer initialization. + + Parameters + ---------- + memory_size : int + sample size to be stored + batch_size : int + batch size to be selected by `select` method + alpha: float + exponent determine how much prioritization. + Prob_i \sim priority_i**alpha/sum(priority**alpha) + """ + memory_size = conf['size'] + batch_size = conf['batch_size'] + alpha = conf['alpha'] + self.tree = sum_tree.SumTree(memory_size) + self.memory_size = memory_size + self.batch_size = batch_size + self.alpha = alpha + + def add(self, data, priority): + """ Add new sample. + + Parameters + ---------- + data : object + new sample + priority : float + sample's priority + """ + self.tree.add(data, priority**self.alpha) + + def sample(self, conf): + """ The method return samples randomly. + + Parameters + ---------- + beta : float + + Returns + ------- + out : + list of samples + weights: + list of weight + indices: + list of sample indices + The indices indicate sample positions in a sum tree. + """ + beta = conf['beta'] + if self.tree.filled_size() < self.batch_size: + return None, None, None + + out = [] + indices = [] + weights = [] + priorities = [] + for _ in range(self.batch_size): + r = random.random() + data, priority, index = self.tree.find(r) + priorities.append(priority) + weights.append((1./self.memory_size/priority)**beta if priority > 1e-16 else 0) + indices.append(index) + out.append(data) + self.update_priority([index], [0]) # To avoid duplicating + + + self.update_priority(indices, priorities) # Revert priorities + + max_weights = max(weights) + + weights[:] = [x / max_weights for x in weights] # Normalize for stability + + return out, weights, indices + + def update_priority(self, indices, priorities): + """ The methods update samples's priority. + + Parameters + ---------- + indices : + list of sample indices + """ + for i, p in zip(indices, priorities): + self.tree.val_update(i, p**self.alpha) + + def reset_alpha(self, alpha): + """ Reset a exponent alpha. + + Parameters + ---------- + alpha : float + """ + self.alpha, old_alpha = alpha, self.alpha + priorities = [self.tree.get_val(i)**-old_alpha for i in range(self.tree.filled_size())] + self.update_priority(range(self.tree.filled_size()), priorities) + + + + + diff --git a/tianshou/data/replay_buffer/rank_based.py b/tianshou/data/replay_buffer/rank_based.py new file mode 100644 index 0000000..eb770af --- /dev/null +++ b/tianshou/data/replay_buffer/rank_based.py @@ -0,0 +1,184 @@ +#!/usr/bin/python +# -*- encoding=utf-8 -*- +# author: Ian +# e-mail: stmayue@gmail.com +# description: + +import sys +import math +import random +import numpy as np + +from binary_heap import BinaryHeap +from buffer import ReplayBuffer + +class RankBasedExperience(ReplayBuffer): + + def __init__(self, conf): + self.size = conf['size'] + self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True + self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size + + self.alpha = conf['alpha'] if 'alpha' in conf else 0.7 + self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5 + self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32 + self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000 + self.total_steps = conf['steps'] if 'steps' in conf else 100000 + # partition number N, split total size to N part + self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100 + + self.index = 0 + self.record_size = 0 + self.isFull = False + + self._experience = {} + self.priority_queue = BinaryHeap(self.priority_size) + self.distributions = self.build_distributions() + + self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start) + + def build_distributions(self): + """ + preprocess pow of rank + (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha)) + :return: distributions, dict + """ + res = {} + n_partitions = self.partition_num + partition_num = 1 + # each part size + partition_size = int(math.floor(self.size / n_partitions)) + + for n in range(partition_size, self.size + 1, partition_size): + if self.learn_start <= n <= self.priority_size: + distribution = {} + # P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha)) + pdf = list( + map(lambda x: math.pow(x, -self.alpha), range(1, n + 1)) + ) + pdf_sum = math.fsum(pdf) + distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf)) + # split to k segment, and than uniform sample in each k + # set k = batch_size, each segment has total probability is 1 / batch_size + # strata_ends keep each segment start pos and end pos + cdf = np.cumsum(distribution['pdf']) + strata_ends = {1: 0, self.batch_size + 1: n} + step = 1. / self.batch_size + index = 1 + for s in range(2, self.batch_size + 1): + while cdf[index] < step: + index += 1 + strata_ends[s] = index + step += 1. / self.batch_size + + distribution['strata_ends'] = strata_ends + + res[partition_num] = distribution + + partition_num += 1 + + return res + + def fix_index(self): + """ + get next insert index + :return: index, int + """ + if self.record_size <= self.size: + self.record_size += 1 + if self.index % self.size == 0: + self.isFull = True if len(self._experience) == self.size else False + if self.replace_flag: + self.index = 1 + return self.index + else: + sys.stderr.write('Experience replay buff is full and replace is set to FALSE!\n') + return -1 + else: + self.index += 1 + return self.index + + def add(self, data, priority = 0): + """ + store experience, suggest that experience is a tuple of (s1, a, r, s2, t) + so each experience is valid + :param experience: maybe a tuple, or list + :return: bool, indicate insert status + """ + insert_index = self.fix_index() + if insert_index > 0: + if insert_index in self._experience: + del self._experience[insert_index] + self._experience[insert_index] = data + # add to priority queue + priority = self.priority_queue.get_max_priority() + self.priority_queue.update(priority, insert_index) + return True + else: + sys.stderr.write('Insert failed\n') + return False + + def retrieve(self, indices): + """ + get experience from indices + :param indices: list of experience id + :return: experience replay sample + """ + return [self._experience[v] for v in indices] + + def rebalance(self): + """ + rebalance priority queue + :return: None + """ + self.priority_queue.balance_tree() + + def update_priority(self, indices, delta): + """ + update priority according indices and deltas + :param indices: list of experience id + :param delta: list of delta, order correspond to indices + :return: None + """ + for i in range(0, len(indices)): + self.priority_queue.update(math.fabs(delta[i]), indices[i]) + + def sample(self, conf): + """ + sample a mini batch from experience replay + :param global_step: now training step + :return: experience, list, samples + :return: w, list, weights + :return: rank_e_id, list, samples id, used for update priority + """ + global_step = conf['global_step'] + if self.record_size < self.learn_start: + sys.stderr.write('Record size less than learn start! Sample failed\n') + return False, False, False + + dist_index = math.floor(self.record_size / self.size * self.partition_num) + # issue 1 by @camigord + partition_size = math.floor(self.size / self.partition_num) + partition_max = dist_index * partition_size + distribution = self.distributions[dist_index] + rank_list = [] + # sample from k segments + for n in range(1, self.batch_size + 1): + index = random.randint(distribution['strata_ends'][n] + 1, + distribution['strata_ends'][n + 1]) + rank_list.append(index) + + # beta, increase by global_step, max 1 + beta = min(self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1) + # find all alpha pow, notice that pdf is a list, start from 0 + alpha_pow = [distribution['pdf'][v - 1] for v in rank_list] + # w = (N * P(i)) ^ (-beta) / max w + w = np.power(np.array(alpha_pow) * partition_max, -beta) + w_max = max(w) + w = np.divide(w, w_max) + # rank list is priority id + # convert to experience id + rank_e_id = self.priority_queue.priority_to_experience(rank_list) + # get experience id according rank_e_id + experience = self.retrieve(rank_e_id) + return experience, w, rank_e_id diff --git a/tianshou/data/replay_buffer/replay_buffer_test.py b/tianshou/data/replay_buffer/replay_buffer_test.py new file mode 100644 index 0000000..9be659b --- /dev/null +++ b/tianshou/data/replay_buffer/replay_buffer_test.py @@ -0,0 +1,129 @@ +from utils import * +from functions import * + +def test_rank_based(): + conf = {'size': 50, + 'learn_start': 10, + 'partition_num': 5, + 'total_step': 100, + 'batch_size': 4} + experience = getReplayBuffer('rank_based', conf) + + # insert to experience + print 'test insert experience' + for i in range(1, 51): + # tuple, like(state_t, a, r, state_t_1, t) + to_insert = (i, 1, 1, i, 1) + experience.add(to_insert) + print experience.priority_queue + print experience._experience[1] + print experience._experience[2] + print 'test replace' + to_insert = (51, 1, 1, 51, 1) + experience.add(to_insert) + print experience.priority_queue + print experience._experience[1] + print experience._experience[2] + + # sample + print 'test sample' + global_step = {'global_step': 51} + sample, w, e_id = experience.sample(global_step) + print sample + print w + print e_id + + # update delta to priority + print 'test update delta' + delta = [v for v in range(1, 5)] + experience.update_priority(e_id, delta) + print experience.priority_queue + sample, w, e_id = experience.sample(global_step) + print sample + print w + print e_id + + # rebalance + print 'test rebalance' + experience.rebalance() + print experience.priority_queue + +def test_proportional(): + conf = {'size': 50, + 'alpha': 0.7, + 'batch_size': 4} + experience = getReplayBuffer('proportional', conf) + + # insert to experience + print 'test insert experience' + for i in range(1, 51): + # tuple, like(state_t, a, r, state_t_1, t) + to_insert = (i, 1, 1, i, 1) + experience.add(to_insert, i) + print experience.tree + print experience.tree.get_val(1) + print experience.tree.get_val(2) + print 'test replace' + to_insert = (51, 1, 1, 51, 1) + experience.add(to_insert, 51) + print experience.tree + print experience.tree.get_val(1) + print experience.tree.get_val(2) + + # sample + print 'test sample' + beta = {'beta': 0.005} + sample, w, e_id = experience.sample(beta) + print sample + print w + print e_id + + # update delta to priority + print 'test update delta' + delta = [v for v in range(1, 5)] + experience.update_priority(e_id, delta) + print experience.tree + sample, w, e_id = experience.sample(beta) + print sample + print w + print e_id + +def test_naive(): + conf = {'size': 50} + experience = getReplayBuffer('naive', conf) + + # insert to experience + print 'test insert experience' + for i in range(1, 51): + # tuple, like(state_t, a, r, state_t_1, t) + to_insert = (i, 1, 1, i, 1) + experience.add(to_insert) + print experience.memory + print 'test replace' + to_insert = (51, 1, 1, 51, 1) + experience.add(to_insert) + print experience.memory + + # sample + print 'test sample' + batch_size = {'batch_size': 5} + sample, w, e_id = experience.sample(batch_size) + print sample + print w + print e_id + + # update delta to priority + print 'test update delta' + delta = [v for v in range(1, 5)] + experience.update_priority(e_id, delta) + print experience.memory + sample, w, e_id = experience.sample(batch_size) + print sample + print w + print e_id + + +if __name__ == '__main__': + test_rank_based() + test_proportional() + test_naive() diff --git a/tianshou/data/replay_buffer/sum_tree.py b/tianshou/data/replay_buffer/sum_tree.py new file mode 100755 index 0000000..d7171d4 --- /dev/null +++ b/tianshou/data/replay_buffer/sum_tree.py @@ -0,0 +1,64 @@ +#! -*- coding:utf-8 -*- +from __future__ import print_function + +import sys +import os +import math + +class SumTree(object): + def __init__(self, max_size): + self.max_size = max_size + self.tree_level = int(math.ceil(math.log(max_size+1, 2))+1) + self.tree_size = 2**self.tree_level-1 + self.tree = [0 for i in range(self.tree_size)] + self.data = [None for i in range(self.max_size)] + self.size = 0 + self.cursor = 0 + + def add(self, contents, value): + index = self.cursor + self.cursor = (self.cursor+1)%self.max_size + self.size = min(self.size+1, self.max_size) + + self.data[index] = contents + self.val_update(index, value) + + def get_val(self, index): + tree_index = 2**(self.tree_level-1)-1+index + return self.tree[tree_index] + + def val_update(self, index, value): + tree_index = 2**(self.tree_level-1)-1+index + diff = value-self.tree[tree_index] + self.reconstruct(tree_index, diff) + + def reconstruct(self, tindex, diff): + self.tree[tindex] += diff + if not tindex == 0: + tindex = int((tindex-1)/2) + self.reconstruct(tindex, diff) + + def find(self, value, norm=True): + if norm: + value *= self.tree[0] + return self._find(value, 0) + + def _find(self, value, index): + if 2**(self.tree_level-1)-1 <= index: + return self.data[index-(2**(self.tree_level-1)-1)], self.tree[index], index-(2**(self.tree_level-1)-1) + + left = self.tree[2*index+1] + + if value <= left: + return self._find(value,2*index+1) + else: + return self._find(value-left,2*(index+1)) + + def print_tree(self): + for k in range(1, self.tree_level+1): + for j in range(2**(k-1)-1, 2**k-1): + print(self.tree[j], end=' ') + print() + + def filled_size(self): + return self.size diff --git a/tianshou/data/replay_buffer/utility.py b/tianshou/data/replay_buffer/utility.py new file mode 100644 index 0000000..e304c05 --- /dev/null +++ b/tianshou/data/replay_buffer/utility.py @@ -0,0 +1,13 @@ +#!/usr/bin/python +# -*- encoding=utf-8 -*- +# author: Ian +# e-mail: stmayue@gmail.com +# description: + + +def list_to_dict(in_list): + return dict((i, in_list[i]) for i in range(0, len(in_list))) + + +def exchange_key_value(in_dict): + return dict((in_dict[i], i) for i in in_dict) diff --git a/tianshou/data/replay_buffer/utils.py b/tianshou/data/replay_buffer/utils.py new file mode 100644 index 0000000..3bb9bfe --- /dev/null +++ b/tianshou/data/replay_buffer/utils.py @@ -0,0 +1,17 @@ +from rank_based import * +from proportional import * +from naive import * +import sys + +def getReplayBuffer(name, conf): + ''' + Get replay buffer according to the given name. + ''' + if (name == 'rank_based'): + return RankBasedExperience(conf) + elif (name == 'proportional'): + return PropotionalExperience(conf) + elif (name == 'naive'): + return NaiveExperience(conf) + else: + sys.stderr.write('no such replay buffer')