diff --git a/examples/dqn_example.py b/examples/dqn_example.py index 4fbe466..7d20731 100644 --- a/examples/dqn_example.py +++ b/examples/dqn_example.py @@ -9,8 +9,7 @@ import gym import sys sys.path.append('..') import tianshou.core.losses as losses -from tianshou.data.replay import Replay -import tianshou.data.advantage_estimation as advantage_estimation +from tianshou.data.replay_buffer.utils import get_replay_buffer import tianshou.core.policy as policy @@ -38,11 +37,10 @@ if __name__ == '__main__': action_dim = env.action_space.n # 1. build network with pure tf - observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input + observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input with tf.variable_scope('q_net'): q_values = policy_net(observation, action_dim) - train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES with tf.variable_scope('target_net'): q_values_target = policy_net(observation, action_dim) @@ -54,13 +52,15 @@ if __name__ == '__main__': target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen - + global_step = tf.Variable(0, name='global_step', trainable=False) + train_var_list = tf.get_collection( + tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES total_loss = dqn_loss optimizer = tf.train.AdamOptimizer(1e-3) - train_op = optimizer.minimize(total_loss, var_list=train_var_list) - + train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step()) # 3. define data collection - training_data = Replay(env, q_net, advantage_estimation.qlearning_target(target_net)) # + replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net, + {'size': 1000, 'batch_size': 64, 'learn_start': 20}) # ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN # maybe a dict to manage the elements to be collected @@ -70,14 +70,16 @@ if __name__ == '__main__': minibatch_count = 0 collection_count = 0 + collect_freq = 100 while True: # until some stopping criterion met... # collect data - training_data.collect() # ShihongSong - collection_count += 1 - print('Collected {} times.'.format(collection_count)) + for i in range(0, collect_freq): + replay_memory.collect() # ShihongSong + collection_count += 1 + print('Collected {} times.'.format(collection_count)) # update network - data = training_data.next_batch(64) # YouQiaoben, ShihongSong + data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong # TODO: auto managing of the placeholders? or add this to params of data.Batch sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']}) minibatch_count += 1 diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index d281df9..3461afb 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -32,7 +32,7 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): """ log_pi_act = pi.log_prob(sampled_action) vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act) - # TODO: Different baseline methods like REINFORCE, etc. + # TODO: Different baseline methods like REINFORCE, etc. return vanilla_policy_gradient_loss def dqn_loss(sampled_action, sampled_target, q_net): @@ -44,8 +44,8 @@ def dqn_loss(sampled_action, sampled_target, q_net): :param q_net: current `policy` to be optimized :return: """ - action_num = q_net.get_values().shape()[1] - sampled_q = tf.reduce_sum(q_net.get_values() * tf.one_hot(sampled_action, action_num), axis=1) + action_num = q_net.values_tensor().get_shape()[1] + sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1) return tf.reduce_mean(tf.square(sampled_target - sampled_q)) def deterministic_policy_gradient(sampled_state, critic): diff --git a/tianshou/core/policy/__init__.py b/tianshou/core/policy/__init__.py index f67b3ba..ccde775 100644 --- a/tianshou/core/policy/__init__.py +++ b/tianshou/core/policy/__init__.py @@ -2,4 +2,5 @@ # -*- coding: utf-8 -*- from .base import * -from .stochastic import * \ No newline at end of file +from .stochastic import * +from .dqn import * \ No newline at end of file diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index b6d8d48..eecfc4f 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -12,23 +12,28 @@ import tensorflow as tf __all__ = [ 'StochasticPolicy', + 'QValuePolicy', ] -#TODO: separate actor and critic, we should focus on it once we finish the basic module. +# TODO: separate actor and critic, we should focus on it once we finish the basic module. + class QValuePolicy(object): """ The policy as in DQN """ def __init__(self, observation_placeholder): - self.observation_placeholder = observation_placeholder + self._observation_placeholder = observation_placeholder def act(self, observation, exploration=None): # first implement no exploration """ return the action (int) to be executed. no exploration when exploration=None. """ - pass + self._act(observation, exploration) + + def _act(self, observation, exploration = None): + raise NotImplementedError() def values(self, observation): """ @@ -36,7 +41,7 @@ class QValuePolicy(object): """ pass - def values_tensor(self, observation): + def values_tensor(self): """ returns the tensor of the values for all actions a at observation s """ diff --git a/tianshou/core/policy/dqn.py b/tianshou/core/policy/dqn.py index cfc6abf..81efc9b 100644 --- a/tianshou/core/policy/dqn.py +++ b/tianshou/core/policy/dqn.py @@ -1,7 +1,54 @@ - - -from .base import QValuePolicy +from tianshou.core.policy.base import QValuePolicy +import tensorflow as tf class DQN(QValuePolicy): - pass \ No newline at end of file + """ + The policy as in DQN + """ + + def __init__(self, logits, observation_placeholder, dtype=None, **kwargs): + self._logits = tf.convert_to_tensor(logits) + if dtype is None: + dtype = tf.int32 + self._n_categories = self._logits.get_shape()[-1].value + + super(DQN, self).__init__(observation_placeholder) + + net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu) + net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu) + net = tf.layers.flatten(net) + net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True) + self._value = tf.layers.dense(net, self._n_categories) + + def _act(self, observation, exploration=None): # first implement no exploration + """ + return the action (int) to be executed. + no exploration when exploration=None. + """ + sess = tf.get_default_session() + sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1), + feed_dict={self._observation_placeholder: observation[None]}) + return sampled_action + + @property + def logits(self): + return self._logits + + @property + def n_categories(self): + return self._n_categories + + def values(self, observation): + """ + returns the Q(s, a) values (float) for all actions a at observation s + """ + sess = tf.get_default_session() + value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]}) + return value + + def values_tensor(self): + """ + returns the tensor of the values for all actions a at observation s + """ + return self._value diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 6f5b8a6..3c2d644 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -19,7 +19,8 @@ def full_return(raw_data): returns = rewards.copy() episode_start_idx = 0 for i in range(1, num_timesteps): - if episode_start_flags[i] or (i == num_timesteps - 1): # found the start of next episode or the end of all episodes + if episode_start_flags[i] or ( + i == num_timesteps - 1): # found the start of next episode or the end of all episodes if i < rewards.shape[0] - 1: t = i - 1 else: @@ -34,4 +35,36 @@ def full_return(raw_data): data['returns'] = returns - return data \ No newline at end of file + return data + + +class QLearningTarget: + def __init__(self, policy, gamma): + self._policy = policy + self._gamma = gamma + + def __call__(self, raw_data): + data = dict() + observations = list() + actions = list() + rewards = list() + wi = list() + all_data, data_wi, data_index = raw_data + + for i in range(0, all_data.shape[0]): + current_data = all_data[i] + current_wi = data_wi[i] + current_index = data_index[i] + observations.append(current_data['observation']) + actions.append(current_data['action']) + next_max_qvalue = np.max(self._policy.values(current_data['observation'])) + current_qvalue = self._policy.values(current_data['previous_observation'])[current_data['previous_action']] + reward = current_data['reward'] + next_max_qvalue - current_qvalue + rewards.append(reward) + wi.append(current_wi) + + data['observations'] = np.array(observations) + data['actions'] = np.array(actions) + data['rewards'] = np.array(rewards) + + return data diff --git a/tianshou/data/replay_buffer/buffer.py b/tianshou/data/replay_buffer/buffer.py index 4b92cfc..6a44170 100644 --- a/tianshou/data/replay_buffer/buffer.py +++ b/tianshou/data/replay_buffer/buffer.py @@ -1,39 +1,51 @@ class ReplayBuffer(object): - def __init__(self, conf): - ''' + def __init__(self, env, policy, qnet, target_qnet, conf): + """ Initialize a replay buffer with parameters in conf. - ''' - pass + """ + pass - def add(self, data, priority): - ''' + def add(self, data, priority): + """ Add a data with priority = priority to replay buffer. - ''' - pass + """ + pass - def update_priority(self, indices, priorities): - ''' + def collect(self): + """ + Collect data from current environment and policy. + """ + pass + + def next_batch(self, batch_size): + """ + get batch of data from the replay buffer. + """ + pass + + def update_priority(self, indices, priorities): + """ Update the data's priority whose indices = indices. For proportional replay buffer, the priority is the priority. For rank based replay buffer, the priorities parameter will be the delta used to update the priority. - ''' - pass + """ + pass - def reset_alpha(self, alpha): - ''' + def reset_alpha(self, alpha): + """ This function only works for proportional replay buffer. This function resets alpha. - ''' - pass + """ + pass - def sample(self, conf): - ''' + def sample(self, conf): + """ Sample from replay buffer with parameters in conf. - ''' - pass + """ + pass - def rebalance(self): - ''' + def rebalance(self): + """ This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue. - ''' - pass \ No newline at end of file + """ + pass diff --git a/tianshou/data/replay_buffer/naive.py b/tianshou/data/replay_buffer/naive.py index 9436a39..50ba1c3 100644 --- a/tianshou/data/replay_buffer/naive.py +++ b/tianshou/data/replay_buffer/naive.py @@ -1,29 +1,93 @@ -from buffer import ReplayBuffer import numpy as np +import tensorflow as tf from collections import deque +from math import fabs + +from tianshou.data.replay_buffer.buffer import ReplayBuffer + class NaiveExperience(ReplayBuffer): - def __init__(self, conf): - self.max_size = conf['size'] - self.n_entries = 0 - self.memory = deque(maxlen = self.max_size) + def __init__(self, env, policy, qnet, target_qnet, conf): + self.max_size = conf['size'] + self._env = env + self._policy = policy + self._qnet = qnet + self._target_qnet = target_qnet + self._begin_act() + self.n_entries = 0 + self.memory = deque(maxlen=self.max_size) - def add(self, data, priority = 0): - self.memory.append(data) - if self.n_entries < self.max_size: - self.n_entries += 1 + def add(self, data, priority=0): + self.memory.append(data) + if self.n_entries < self.max_size: + self.n_entries += 1 - def update_priority(self, indices, priorities = 0): - pass + def _begin_act(self): + self.observation = self._env.reset() + self.action = self._env.action_space.sample() + done = False + while not done: + if done: + self.observation = self._env.reset() + self.action = self._env.action_space.sample() + self.observation, _, done, _ = self._env.step(self.action) - def reset_alpha(self, alpha): - pass + def collect(self): + sess = tf.get_default_session() + current_data = dict() + current_data['previous_action'] = self.action + current_data['previous_observation'] = self.observation + self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)})) + self.observation, reward, done, _ = self._env.step(self.action) + current_data['action'] = self.action + current_data['observation'] = self.observation + current_data['reward'] = reward + self.add(current_data) + if done: + self._begin_act() - def sample(self, conf): - batch_size = conf['batch_size'] - batch_size = min(len(self.memory), batch_size) - idxs = np.random.choice(len(self.memory), batch_size) - return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs + def update_priority(self, indices, priorities=0): + pass - def rebalance(self): - pass + def reset_alpha(self, alpha): + pass + + def sample(self, conf): + batch_size = conf['batch_size'] + batch_size = min(len(self.memory), batch_size) + idxs = np.random.choice(len(self.memory), batch_size) + return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs + + def next_batch(self, batch_size): + data = dict() + observations = list() + actions = list() + rewards = list() + wi = list() + target = list() + + for i in range(0, batch_size): + current_datas, current_wis, current_indexs = self.sample({'batch_size': 1}) + current_data = current_datas[0] + current_wi = current_wis[0] + current_index = current_indexs[0] + observations.append(current_data['observation']) + actions.append(current_data['action']) + next_max_qvalue = np.max(self._target_qnet.values(current_data['observation'])) + current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']] + reward = current_data['reward'] + next_max_qvalue - current_qvalue + rewards.append(reward) + target.append(current_data['reward'] + next_max_qvalue) + self.update_priority(current_index, [fabs(reward)]) + wi.append(current_wi) + + data['observations'] = np.array(observations) + data['actions'] = np.array(actions) + data['rewards'] = np.array(rewards) + data['wi'] = np.array(wi) + data['target'] = np.array(target) + + return data + + def rebalance(self): + pass diff --git a/tianshou/data/replay_buffer/proportional.py b/tianshou/data/replay_buffer/proportional.py index 72d1457..63aab66 100644 --- a/tianshou/data/replay_buffer/proportional.py +++ b/tianshou/data/replay_buffer/proportional.py @@ -1,7 +1,10 @@ -import numpy +import numpy as np import random -import sum_tree -from buffer import ReplayBuffer +import tensorflow as tf +import math + +from tianshou.data.replay_buffer import sum_tree +from tianshou.data.replay_buffer.buffer import ReplayBuffer class PropotionalExperience(ReplayBuffer): @@ -15,7 +18,7 @@ class PropotionalExperience(ReplayBuffer): """ - def __init__(self, conf): + def __init__(self, env, policy, qnet, target_qnet, conf): """ Prioritized experience replay buffer initialization. Parameters @@ -30,11 +33,26 @@ class PropotionalExperience(ReplayBuffer): """ memory_size = conf['size'] batch_size = conf['batch_size'] - alpha = conf['alpha'] + alpha = conf['alpha'] if 'alpha' in conf else 0.6 self.tree = sum_tree.SumTree(memory_size) self.memory_size = memory_size self.batch_size = batch_size self.alpha = alpha + self._env = env + self._policy = policy + self._qnet = qnet + self._target_qnet = target_qnet + self._begin_act() + + def _begin_act(self): + self.observation = self._env.reset() + self.action = self._env.action_space.sample() + done = False + while not done: + if done: + self.observation = self._env.reset() + self.action = self._env.action_space.sample() + self.observation, _, done, _ = self._env.step(self.action) def add(self, data, priority): """ Add new sample. @@ -48,6 +66,12 @@ class PropotionalExperience(ReplayBuffer): """ self.tree.add(data, priority**self.alpha) + def collect(self): + pass + + def next_batch(self, batch_size): + pass + def sample(self, conf): """ The method return samples randomly. @@ -64,8 +88,9 @@ class PropotionalExperience(ReplayBuffer): indices: list of sample indices The indices indicate sample positions in a sum tree. + :param conf: giving beta """ - beta = conf['beta'] + beta = conf['beta'] if 'beta' in conf else 0.4 if self.tree.filled_size() < self.batch_size: return None, None, None @@ -91,6 +116,54 @@ class PropotionalExperience(ReplayBuffer): return out, weights, indices + def collect(self): + sess = tf.get_default_session() + current_data = dict() + current_data['previous_action'] = self.action + current_data['previous_observation'] = self.observation + # TODO: change the name of the feed_dict + self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)})) + self.observation, reward, done, _ = self._env.step(self.action) + current_data['action'] = self.action + current_data['observation'] = self.observation + current_data['reward'] = reward + priorities = np.array([self.tree.get_val(i) ** -self.alpha for i in range(self.tree.filled_size())]) + priority = np.max(priorities) if len(priorities) > 0 else 1 + self.add(current_data, priority) + if done: + self._begin_act() + + def next_batch(self, batch_size): + data = dict() + observations = list() + actions = list() + rewards = list() + wi = list() + target = list() + + for i in range(0, batch_size): + current_datas, current_wis, current_indexs = self.sample({'batch_size': 1}) + current_data = current_datas[0] + current_wi = current_wis[0] + current_index = current_indexs[0] + observations.append(current_data['observation']) + actions.append(current_data['action']) + next_max_qvalue = np.max(self._target_qnet.values(current_data['observation'])) + current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']] + reward = current_data['reward'] + next_max_qvalue - current_qvalue + rewards.append(reward) + target.append(current_data['reward'] + next_max_qvalue) + self.update_priority([current_index], [math.fabs(reward)]) + wi.append(current_wi) + + data['observations'] = np.array(observations) + data['actions'] = np.array(actions) + data['rewards'] = np.array(rewards) + data['wi'] = np.array(wi) + data['target'] = np.array(target) + + return data + def update_priority(self, indices, priorities): """ The methods update samples's priority. diff --git a/tianshou/data/replay_buffer/rank_based.py b/tianshou/data/replay_buffer/rank_based.py index eb770af..da56763 100644 --- a/tianshou/data/replay_buffer/rank_based.py +++ b/tianshou/data/replay_buffer/rank_based.py @@ -8,13 +8,15 @@ import sys import math import random import numpy as np +import tensorflow as tf + +from tianshou.data.replay_buffer.binary_heap import BinaryHeap +from tianshou.data.replay_buffer.buffer import ReplayBuffer -from binary_heap import BinaryHeap -from buffer import ReplayBuffer class RankBasedExperience(ReplayBuffer): - def __init__(self, conf): + def __init__(self, env, policy, qnet, target_qnet, conf): self.size = conf['size'] self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size @@ -25,12 +27,18 @@ class RankBasedExperience(ReplayBuffer): self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000 self.total_steps = conf['steps'] if 'steps' in conf else 100000 # partition number N, split total size to N part - self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100 + self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10 self.index = 0 self.record_size = 0 self.isFull = False + self._env = env + self._policy = policy + self._qnet = qnet + self._target_qnet = target_qnet + self._begin_act() + self._experience = {} self.priority_queue = BinaryHeap(self.priority_size) self.distributions = self.build_distributions() @@ -98,7 +106,64 @@ class RankBasedExperience(ReplayBuffer): self.index += 1 return self.index - def add(self, data, priority = 0): + def _begin_act(self): + self.observation = self._env.reset() + self.action = self._env.action_space.sample() + done = False + while not done: + if done: + self.observation = self._env.reset() + self.action = self._env.action_space.sample() + self.observation, _, done, _ = self._env.step(self.action) + + def collect(self): + sess = tf.get_default_session() + current_data = dict() + current_data['previous_action'] = self.action + current_data['previous_observation'] = self.observation + self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)})) + self.observation, reward, done, _ = self._env.step(self.action) + current_data['action'] = self.action + current_data['observation'] = self.observation + current_data['reward'] = reward + self.add(current_data) + if done: + self._begin_act() + + def next_batch(self, batch_size): + data = dict() + observations = list() + actions = list() + rewards = list() + wi = list() + target = list() + + sess = tf.get_default_session() + current_datas, current_wis, current_indexs = self.sample({'global_step': sess.run(tf.train.get_global_step())}) + + for i in range(0, batch_size): + current_data = current_datas[i] + current_wi = current_wis[i] + current_index = current_indexs[i] + observations.append(current_data['observation']) + actions.append(current_data['action']) + next_max_qvalue = np.max(self._target_qnet.values(current_data['observation'])) + current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']] + reward = current_data['reward'] + next_max_qvalue - current_qvalue + rewards.append(reward) + target.append(current_data['reward'] + next_max_qvalue) + self.update_priority([current_index], [math.fabs(reward)]) + wi.append(current_wi) + + data['observations'] = np.array(observations) + data['actions'] = np.array(actions) + data['rewards'] = np.array(rewards) + data['wi'] = np.array(wi) + data['target'] = np.array(target) + + return data + + def add(self, data, priority = 1): """ store experience, suggest that experience is a tuple of (s1, a, r, s2, t) so each experience is valid @@ -156,16 +221,16 @@ class RankBasedExperience(ReplayBuffer): sys.stderr.write('Record size less than learn start! Sample failed\n') return False, False, False - dist_index = math.floor(self.record_size / self.size * self.partition_num) + dist_index = math.floor(self.record_size * 1. / self.size * self.partition_num) # issue 1 by @camigord - partition_size = math.floor(self.size / self.partition_num) + partition_size = math.floor(self.size * 1. / self.partition_num) partition_max = dist_index * partition_size distribution = self.distributions[dist_index] rank_list = [] # sample from k segments for n in range(1, self.batch_size + 1): - index = random.randint(distribution['strata_ends'][n] + 1, - distribution['strata_ends'][n + 1]) + index = random.randint(distribution['strata_ends'][n], + distribution['strata_ends'][n + 1]) rank_list.append(index) # beta, increase by global_step, max 1 diff --git a/tianshou/data/replay_buffer/replay_buffer_test.py b/tianshou/data/replay_buffer/replay_buffer_test.py index 9be659b..46b25c8 100644 --- a/tianshou/data/replay_buffer/replay_buffer_test.py +++ b/tianshou/data/replay_buffer/replay_buffer_test.py @@ -1,13 +1,15 @@ -from utils import * from functions import * +from tianshou.data.replay_buffer.utils import get_replay_buffer + + def test_rank_based(): conf = {'size': 50, 'learn_start': 10, 'partition_num': 5, 'total_step': 100, 'batch_size': 4} - experience = getReplayBuffer('rank_based', conf) + experience = get_replay_buffer('rank_based', conf) # insert to experience print 'test insert experience' @@ -52,7 +54,7 @@ def test_proportional(): conf = {'size': 50, 'alpha': 0.7, 'batch_size': 4} - experience = getReplayBuffer('proportional', conf) + experience = get_replay_buffer('proportional', conf) # insert to experience print 'test insert experience' @@ -90,7 +92,7 @@ def test_proportional(): def test_naive(): conf = {'size': 50} - experience = getReplayBuffer('naive', conf) + experience = get_replay_buffer('naive', conf) # insert to experience print 'test insert experience' diff --git a/tianshou/data/replay_buffer/utils.py b/tianshou/data/replay_buffer/utils.py index 3bb9bfe..4480375 100644 --- a/tianshou/data/replay_buffer/utils.py +++ b/tianshou/data/replay_buffer/utils.py @@ -1,17 +1,20 @@ -from rank_based import * -from proportional import * -from naive import * import sys -def getReplayBuffer(name, conf): - ''' - Get replay buffer according to the given name. - ''' - if (name == 'rank_based'): - return RankBasedExperience(conf) - elif (name == 'proportional'): - return PropotionalExperience(conf) - elif (name == 'naive'): - return NaiveExperience(conf) - else: - sys.stderr.write('no such replay buffer') +from tianshou.data.replay_buffer.naive import NaiveExperience +from tianshou.data.replay_buffer.proportional import PropotionalExperience +from tianshou.data.replay_buffer.rank_based import RankBasedExperience + + +def get_replay_buffer(name, env, policy, qnet, target_qnet, conf): + """ + Get replay buffer according to the given name. + """ + + if name == 'rank_based': + return RankBasedExperience(env, policy, qnet, target_qnet, conf) + elif name == 'proportional': + return PropotionalExperience(env, policy, qnet, target_qnet, conf) + elif name == 'naive': + return NaiveExperience(env, policy, qnet, target_qnet, conf) + else: + sys.stderr.write('no such replay buffer')