finished very naive dqn: changed the interface of replay buffer by adding collect and next_batch, but still need refactoring; added implementation of dqn.py, but still need to consider the interface to make it more extensive; slightly refactored the code style of the codebase; more comments and todos will be in the next commit

This commit is contained in:
宋世虹 2017-12-17 12:52:00 +08:00
parent 31199c7d0d
commit 3624cc9036
12 changed files with 411 additions and 104 deletions

View File

@ -9,8 +9,7 @@ import gym
import sys import sys
sys.path.append('..') sys.path.append('..')
import tianshou.core.losses as losses import tianshou.core.losses as losses
from tianshou.data.replay import Replay from tianshou.data.replay_buffer.utils import get_replay_buffer
import tianshou.data.advantage_estimation as advantage_estimation
import tianshou.core.policy as policy import tianshou.core.policy as policy
@ -38,11 +37,10 @@ if __name__ == '__main__':
action_dim = env.action_space.n action_dim = env.action_space.n
# 1. build network with pure tf # 1. build network with pure tf
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
with tf.variable_scope('q_net'): with tf.variable_scope('q_net'):
q_values = policy_net(observation, action_dim) q_values = policy_net(observation, action_dim)
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
with tf.variable_scope('target_net'): with tf.variable_scope('target_net'):
q_values_target = policy_net(observation, action_dim) q_values_target = policy_net(observation, action_dim)
@ -54,13 +52,15 @@ if __name__ == '__main__':
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
global_step = tf.Variable(0, name='global_step', trainable=False)
train_var_list = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
total_loss = dqn_loss total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-3) optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(total_loss, var_list=train_var_list) train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
# 3. define data collection # 3. define data collection
training_data = Replay(env, q_net, advantage_estimation.qlearning_target(target_net)) # replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN # ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
# maybe a dict to manage the elements to be collected # maybe a dict to manage the elements to be collected
@ -70,14 +70,16 @@ if __name__ == '__main__':
minibatch_count = 0 minibatch_count = 0
collection_count = 0 collection_count = 0
collect_freq = 100
while True: # until some stopping criterion met... while True: # until some stopping criterion met...
# collect data # collect data
training_data.collect() # ShihongSong for i in range(0, collect_freq):
collection_count += 1 replay_memory.collect() # ShihongSong
print('Collected {} times.'.format(collection_count)) collection_count += 1
print('Collected {} times.'.format(collection_count))
# update network # update network
data = training_data.next_batch(64) # YouQiaoben, ShihongSong data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong
# TODO: auto managing of the placeholders? or add this to params of data.Batch # TODO: auto managing of the placeholders? or add this to params of data.Batch
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']}) sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']})
minibatch_count += 1 minibatch_count += 1

View File

@ -32,7 +32,7 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
""" """
log_pi_act = pi.log_prob(sampled_action) log_pi_act = pi.log_prob(sampled_action)
vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act) vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act)
# TODO Different baseline methods like REINFORCE, etc. # TODO: Different baseline methods like REINFORCE, etc.
return vanilla_policy_gradient_loss return vanilla_policy_gradient_loss
def dqn_loss(sampled_action, sampled_target, q_net): def dqn_loss(sampled_action, sampled_target, q_net):
@ -44,8 +44,8 @@ def dqn_loss(sampled_action, sampled_target, q_net):
:param q_net: current `policy` to be optimized :param q_net: current `policy` to be optimized
:return: :return:
""" """
action_num = q_net.get_values().shape()[1] action_num = q_net.values_tensor().get_shape()[1]
sampled_q = tf.reduce_sum(q_net.get_values() * tf.one_hot(sampled_action, action_num), axis=1) sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1)
return tf.reduce_mean(tf.square(sampled_target - sampled_q)) return tf.reduce_mean(tf.square(sampled_target - sampled_q))
def deterministic_policy_gradient(sampled_state, critic): def deterministic_policy_gradient(sampled_state, critic):

View File

@ -3,3 +3,4 @@
from .base import * from .base import *
from .stochastic import * from .stochastic import *
from .dqn import *

View File

@ -12,23 +12,28 @@ import tensorflow as tf
__all__ = [ __all__ = [
'StochasticPolicy', 'StochasticPolicy',
'QValuePolicy',
] ]
#TODO: separate actor and critic, we should focus on it once we finish the basic module. # TODO: separate actor and critic, we should focus on it once we finish the basic module.
class QValuePolicy(object): class QValuePolicy(object):
""" """
The policy as in DQN The policy as in DQN
""" """
def __init__(self, observation_placeholder): def __init__(self, observation_placeholder):
self.observation_placeholder = observation_placeholder self._observation_placeholder = observation_placeholder
def act(self, observation, exploration=None): # first implement no exploration def act(self, observation, exploration=None): # first implement no exploration
""" """
return the action (int) to be executed. return the action (int) to be executed.
no exploration when exploration=None. no exploration when exploration=None.
""" """
pass self._act(observation, exploration)
def _act(self, observation, exploration = None):
raise NotImplementedError()
def values(self, observation): def values(self, observation):
""" """
@ -36,7 +41,7 @@ class QValuePolicy(object):
""" """
pass pass
def values_tensor(self, observation): def values_tensor(self):
""" """
returns the tensor of the values for all actions a at observation s returns the tensor of the values for all actions a at observation s
""" """

View File

@ -1,7 +1,54 @@
from tianshou.core.policy.base import QValuePolicy
import tensorflow as tf
from .base import QValuePolicy
class DQN(QValuePolicy): class DQN(QValuePolicy):
pass """
The policy as in DQN
"""
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
self._logits = tf.convert_to_tensor(logits)
if dtype is None:
dtype = tf.int32
self._n_categories = self._logits.get_shape()[-1].value
super(DQN, self).__init__(observation_placeholder)
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True)
self._value = tf.layers.dense(net, self._n_categories)
def _act(self, observation, exploration=None): # first implement no exploration
"""
return the action (int) to be executed.
no exploration when exploration=None.
"""
sess = tf.get_default_session()
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
feed_dict={self._observation_placeholder: observation[None]})
return sampled_action
@property
def logits(self):
return self._logits
@property
def n_categories(self):
return self._n_categories
def values(self, observation):
"""
returns the Q(s, a) values (float) for all actions a at observation s
"""
sess = tf.get_default_session()
value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]})
return value
def values_tensor(self):
"""
returns the tensor of the values for all actions a at observation s
"""
return self._value

View File

@ -19,7 +19,8 @@ def full_return(raw_data):
returns = rewards.copy() returns = rewards.copy()
episode_start_idx = 0 episode_start_idx = 0
for i in range(1, num_timesteps): for i in range(1, num_timesteps):
if episode_start_flags[i] or (i == num_timesteps - 1): # found the start of next episode or the end of all episodes if episode_start_flags[i] or (
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
if i < rewards.shape[0] - 1: if i < rewards.shape[0] - 1:
t = i - 1 t = i - 1
else: else:
@ -35,3 +36,35 @@ def full_return(raw_data):
data['returns'] = returns data['returns'] = returns
return data return data
class QLearningTarget:
def __init__(self, policy, gamma):
self._policy = policy
self._gamma = gamma
def __call__(self, raw_data):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
all_data, data_wi, data_index = raw_data
for i in range(0, all_data.shape[0]):
current_data = all_data[i]
current_wi = data_wi[i]
current_index = data_index[i]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._policy.values(current_data['observation']))
current_qvalue = self._policy.values(current_data['previous_observation'])[current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
return data

View File

@ -1,39 +1,51 @@
class ReplayBuffer(object): class ReplayBuffer(object):
def __init__(self, conf): def __init__(self, env, policy, qnet, target_qnet, conf):
''' """
Initialize a replay buffer with parameters in conf. Initialize a replay buffer with parameters in conf.
''' """
pass pass
def add(self, data, priority): def add(self, data, priority):
''' """
Add a data with priority = priority to replay buffer. Add a data with priority = priority to replay buffer.
''' """
pass pass
def update_priority(self, indices, priorities): def collect(self):
''' """
Collect data from current environment and policy.
"""
pass
def next_batch(self, batch_size):
"""
get batch of data from the replay buffer.
"""
pass
def update_priority(self, indices, priorities):
"""
Update the data's priority whose indices = indices. Update the data's priority whose indices = indices.
For proportional replay buffer, the priority is the priority. For proportional replay buffer, the priority is the priority.
For rank based replay buffer, the priorities parameter will be the delta used to update the priority. For rank based replay buffer, the priorities parameter will be the delta used to update the priority.
''' """
pass pass
def reset_alpha(self, alpha): def reset_alpha(self, alpha):
''' """
This function only works for proportional replay buffer. This function only works for proportional replay buffer.
This function resets alpha. This function resets alpha.
''' """
pass pass
def sample(self, conf): def sample(self, conf):
''' """
Sample from replay buffer with parameters in conf. Sample from replay buffer with parameters in conf.
''' """
pass pass
def rebalance(self): def rebalance(self):
''' """
This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue. This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue.
''' """
pass pass

View File

@ -1,29 +1,93 @@
from buffer import ReplayBuffer
import numpy as np import numpy as np
import tensorflow as tf
from collections import deque from collections import deque
from math import fabs
from tianshou.data.replay_buffer.buffer import ReplayBuffer
class NaiveExperience(ReplayBuffer): class NaiveExperience(ReplayBuffer):
def __init__(self, conf): def __init__(self, env, policy, qnet, target_qnet, conf):
self.max_size = conf['size'] self.max_size = conf['size']
self.n_entries = 0 self._env = env
self.memory = deque(maxlen = self.max_size) self._policy = policy
self._qnet = qnet
self._target_qnet = target_qnet
self._begin_act()
self.n_entries = 0
self.memory = deque(maxlen=self.max_size)
def add(self, data, priority = 0): def add(self, data, priority=0):
self.memory.append(data) self.memory.append(data)
if self.n_entries < self.max_size: if self.n_entries < self.max_size:
self.n_entries += 1 self.n_entries += 1
def update_priority(self, indices, priorities = 0): def _begin_act(self):
pass self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
while not done:
if done:
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
self.observation, _, done, _ = self._env.step(self.action)
def reset_alpha(self, alpha): def collect(self):
pass sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
current_data['previous_observation'] = self.observation
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
self.observation, reward, done, _ = self._env.step(self.action)
current_data['action'] = self.action
current_data['observation'] = self.observation
current_data['reward'] = reward
self.add(current_data)
if done:
self._begin_act()
def sample(self, conf): def update_priority(self, indices, priorities=0):
batch_size = conf['batch_size'] pass
batch_size = min(len(self.memory), batch_size)
idxs = np.random.choice(len(self.memory), batch_size)
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
def rebalance(self): def reset_alpha(self, alpha):
pass pass
def sample(self, conf):
batch_size = conf['batch_size']
batch_size = min(len(self.memory), batch_size)
idxs = np.random.choice(len(self.memory), batch_size)
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
def next_batch(self, batch_size):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
target = list()
for i in range(0, batch_size):
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
current_data = current_datas[0]
current_wi = current_wis[0]
current_index = current_indexs[0]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
target.append(current_data['reward'] + next_max_qvalue)
self.update_priority(current_index, [fabs(reward)])
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
data['wi'] = np.array(wi)
data['target'] = np.array(target)
return data
def rebalance(self):
pass

View File

@ -1,7 +1,10 @@
import numpy import numpy as np
import random import random
import sum_tree import tensorflow as tf
from buffer import ReplayBuffer import math
from tianshou.data.replay_buffer import sum_tree
from tianshou.data.replay_buffer.buffer import ReplayBuffer
class PropotionalExperience(ReplayBuffer): class PropotionalExperience(ReplayBuffer):
@ -15,7 +18,7 @@ class PropotionalExperience(ReplayBuffer):
""" """
def __init__(self, conf): def __init__(self, env, policy, qnet, target_qnet, conf):
""" Prioritized experience replay buffer initialization. """ Prioritized experience replay buffer initialization.
Parameters Parameters
@ -30,11 +33,26 @@ class PropotionalExperience(ReplayBuffer):
""" """
memory_size = conf['size'] memory_size = conf['size']
batch_size = conf['batch_size'] batch_size = conf['batch_size']
alpha = conf['alpha'] alpha = conf['alpha'] if 'alpha' in conf else 0.6
self.tree = sum_tree.SumTree(memory_size) self.tree = sum_tree.SumTree(memory_size)
self.memory_size = memory_size self.memory_size = memory_size
self.batch_size = batch_size self.batch_size = batch_size
self.alpha = alpha self.alpha = alpha
self._env = env
self._policy = policy
self._qnet = qnet
self._target_qnet = target_qnet
self._begin_act()
def _begin_act(self):
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
while not done:
if done:
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
self.observation, _, done, _ = self._env.step(self.action)
def add(self, data, priority): def add(self, data, priority):
""" Add new sample. """ Add new sample.
@ -48,6 +66,12 @@ class PropotionalExperience(ReplayBuffer):
""" """
self.tree.add(data, priority**self.alpha) self.tree.add(data, priority**self.alpha)
def collect(self):
pass
def next_batch(self, batch_size):
pass
def sample(self, conf): def sample(self, conf):
""" The method return samples randomly. """ The method return samples randomly.
@ -64,8 +88,9 @@ class PropotionalExperience(ReplayBuffer):
indices: indices:
list of sample indices list of sample indices
The indices indicate sample positions in a sum tree. The indices indicate sample positions in a sum tree.
:param conf: giving beta
""" """
beta = conf['beta'] beta = conf['beta'] if 'beta' in conf else 0.4
if self.tree.filled_size() < self.batch_size: if self.tree.filled_size() < self.batch_size:
return None, None, None return None, None, None
@ -91,6 +116,54 @@ class PropotionalExperience(ReplayBuffer):
return out, weights, indices return out, weights, indices
def collect(self):
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
current_data['previous_observation'] = self.observation
# TODO: change the name of the feed_dict
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
self.observation, reward, done, _ = self._env.step(self.action)
current_data['action'] = self.action
current_data['observation'] = self.observation
current_data['reward'] = reward
priorities = np.array([self.tree.get_val(i) ** -self.alpha for i in range(self.tree.filled_size())])
priority = np.max(priorities) if len(priorities) > 0 else 1
self.add(current_data, priority)
if done:
self._begin_act()
def next_batch(self, batch_size):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
target = list()
for i in range(0, batch_size):
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
current_data = current_datas[0]
current_wi = current_wis[0]
current_index = current_indexs[0]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
target.append(current_data['reward'] + next_max_qvalue)
self.update_priority([current_index], [math.fabs(reward)])
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
data['wi'] = np.array(wi)
data['target'] = np.array(target)
return data
def update_priority(self, indices, priorities): def update_priority(self, indices, priorities):
""" The methods update samples's priority. """ The methods update samples's priority.

View File

@ -8,13 +8,15 @@ import sys
import math import math
import random import random
import numpy as np import numpy as np
import tensorflow as tf
from tianshou.data.replay_buffer.binary_heap import BinaryHeap
from tianshou.data.replay_buffer.buffer import ReplayBuffer
from binary_heap import BinaryHeap
from buffer import ReplayBuffer
class RankBasedExperience(ReplayBuffer): class RankBasedExperience(ReplayBuffer):
def __init__(self, conf): def __init__(self, env, policy, qnet, target_qnet, conf):
self.size = conf['size'] self.size = conf['size']
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
@ -25,12 +27,18 @@ class RankBasedExperience(ReplayBuffer):
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000 self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
self.total_steps = conf['steps'] if 'steps' in conf else 100000 self.total_steps = conf['steps'] if 'steps' in conf else 100000
# partition number N, split total size to N part # partition number N, split total size to N part
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100 self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10
self.index = 0 self.index = 0
self.record_size = 0 self.record_size = 0
self.isFull = False self.isFull = False
self._env = env
self._policy = policy
self._qnet = qnet
self._target_qnet = target_qnet
self._begin_act()
self._experience = {} self._experience = {}
self.priority_queue = BinaryHeap(self.priority_size) self.priority_queue = BinaryHeap(self.priority_size)
self.distributions = self.build_distributions() self.distributions = self.build_distributions()
@ -98,7 +106,64 @@ class RankBasedExperience(ReplayBuffer):
self.index += 1 self.index += 1
return self.index return self.index
def add(self, data, priority = 0): def _begin_act(self):
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
while not done:
if done:
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
self.observation, _, done, _ = self._env.step(self.action)
def collect(self):
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
current_data['previous_observation'] = self.observation
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
self.observation, reward, done, _ = self._env.step(self.action)
current_data['action'] = self.action
current_data['observation'] = self.observation
current_data['reward'] = reward
self.add(current_data)
if done:
self._begin_act()
def next_batch(self, batch_size):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
target = list()
sess = tf.get_default_session()
current_datas, current_wis, current_indexs = self.sample({'global_step': sess.run(tf.train.get_global_step())})
for i in range(0, batch_size):
current_data = current_datas[i]
current_wi = current_wis[i]
current_index = current_indexs[i]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
target.append(current_data['reward'] + next_max_qvalue)
self.update_priority([current_index], [math.fabs(reward)])
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
data['wi'] = np.array(wi)
data['target'] = np.array(target)
return data
def add(self, data, priority = 1):
""" """
store experience, suggest that experience is a tuple of (s1, a, r, s2, t) store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
so each experience is valid so each experience is valid
@ -156,16 +221,16 @@ class RankBasedExperience(ReplayBuffer):
sys.stderr.write('Record size less than learn start! Sample failed\n') sys.stderr.write('Record size less than learn start! Sample failed\n')
return False, False, False return False, False, False
dist_index = math.floor(self.record_size / self.size * self.partition_num) dist_index = math.floor(self.record_size * 1. / self.size * self.partition_num)
# issue 1 by @camigord # issue 1 by @camigord
partition_size = math.floor(self.size / self.partition_num) partition_size = math.floor(self.size * 1. / self.partition_num)
partition_max = dist_index * partition_size partition_max = dist_index * partition_size
distribution = self.distributions[dist_index] distribution = self.distributions[dist_index]
rank_list = [] rank_list = []
# sample from k segments # sample from k segments
for n in range(1, self.batch_size + 1): for n in range(1, self.batch_size + 1):
index = random.randint(distribution['strata_ends'][n] + 1, index = random.randint(distribution['strata_ends'][n],
distribution['strata_ends'][n + 1]) distribution['strata_ends'][n + 1])
rank_list.append(index) rank_list.append(index)
# beta, increase by global_step, max 1 # beta, increase by global_step, max 1

View File

@ -1,13 +1,15 @@
from utils import *
from functions import * from functions import *
from tianshou.data.replay_buffer.utils import get_replay_buffer
def test_rank_based(): def test_rank_based():
conf = {'size': 50, conf = {'size': 50,
'learn_start': 10, 'learn_start': 10,
'partition_num': 5, 'partition_num': 5,
'total_step': 100, 'total_step': 100,
'batch_size': 4} 'batch_size': 4}
experience = getReplayBuffer('rank_based', conf) experience = get_replay_buffer('rank_based', conf)
# insert to experience # insert to experience
print 'test insert experience' print 'test insert experience'
@ -52,7 +54,7 @@ def test_proportional():
conf = {'size': 50, conf = {'size': 50,
'alpha': 0.7, 'alpha': 0.7,
'batch_size': 4} 'batch_size': 4}
experience = getReplayBuffer('proportional', conf) experience = get_replay_buffer('proportional', conf)
# insert to experience # insert to experience
print 'test insert experience' print 'test insert experience'
@ -90,7 +92,7 @@ def test_proportional():
def test_naive(): def test_naive():
conf = {'size': 50} conf = {'size': 50}
experience = getReplayBuffer('naive', conf) experience = get_replay_buffer('naive', conf)
# insert to experience # insert to experience
print 'test insert experience' print 'test insert experience'

View File

@ -1,17 +1,20 @@
from rank_based import *
from proportional import *
from naive import *
import sys import sys
def getReplayBuffer(name, conf): from tianshou.data.replay_buffer.naive import NaiveExperience
''' from tianshou.data.replay_buffer.proportional import PropotionalExperience
Get replay buffer according to the given name. from tianshou.data.replay_buffer.rank_based import RankBasedExperience
'''
if (name == 'rank_based'):
return RankBasedExperience(conf) def get_replay_buffer(name, env, policy, qnet, target_qnet, conf):
elif (name == 'proportional'): """
return PropotionalExperience(conf) Get replay buffer according to the given name.
elif (name == 'naive'): """
return NaiveExperience(conf)
else: if name == 'rank_based':
sys.stderr.write('no such replay buffer') return RankBasedExperience(env, policy, qnet, target_qnet, conf)
elif name == 'proportional':
return PropotionalExperience(env, policy, qnet, target_qnet, conf)
elif name == 'naive':
return NaiveExperience(env, policy, qnet, target_qnet, conf)
else:
sys.stderr.write('no such replay buffer')