finished very naive dqn: changed the interface of replay buffer by adding collect and next_batch, but still need refactoring; added implementation of dqn.py, but still need to consider the interface to make it more extensive; slightly refactored the code style of the codebase; more comments and todos will be in the next commit

This commit is contained in:
宋世虹 2017-12-17 12:52:00 +08:00
parent e10acf5130
commit 62e2c6582d
12 changed files with 411 additions and 104 deletions

View File

@ -9,8 +9,7 @@ import gym
import sys
sys.path.append('..')
import tianshou.core.losses as losses
from tianshou.data.replay import Replay
import tianshou.data.advantage_estimation as advantage_estimation
from tianshou.data.replay_buffer.utils import get_replay_buffer
import tianshou.core.policy as policy
@ -38,11 +37,10 @@ if __name__ == '__main__':
action_dim = env.action_space.n
# 1. build network with pure tf
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
with tf.variable_scope('q_net'):
q_values = policy_net(observation, action_dim)
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
with tf.variable_scope('target_net'):
q_values_target = policy_net(observation, action_dim)
@ -54,13 +52,15 @@ if __name__ == '__main__':
target = tf.placeholder(dtype=tf.float32, shape=[None]) # target value for DQN
dqn_loss = losses.dqn_loss(action, target, q_net) # TongzhengRen
global_step = tf.Variable(0, name='global_step', trainable=False)
train_var_list = tf.get_collection(
tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
# 3. define data collection
training_data = Replay(env, q_net, advantage_estimation.qlearning_target(target_net)) #
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
# maybe a dict to manage the elements to be collected
@ -70,14 +70,16 @@ if __name__ == '__main__':
minibatch_count = 0
collection_count = 0
collect_freq = 100
while True: # until some stopping criterion met...
# collect data
training_data.collect() # ShihongSong
collection_count += 1
print('Collected {} times.'.format(collection_count))
for i in range(0, collect_freq):
replay_memory.collect() # ShihongSong
collection_count += 1
print('Collected {} times.'.format(collection_count))
# update network
data = training_data.next_batch(64) # YouQiaoben, ShihongSong
data = replay_memory.next_batch(10) # YouQiaoben, ShihongSong
# TODO: auto managing of the placeholders? or add this to params of data.Batch
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], target: data['target']})
minibatch_count += 1

View File

@ -32,7 +32,7 @@ def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
"""
log_pi_act = pi.log_prob(sampled_action)
vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act)
# TODO Different baseline methods like REINFORCE, etc.
# TODO: Different baseline methods like REINFORCE, etc.
return vanilla_policy_gradient_loss
def dqn_loss(sampled_action, sampled_target, q_net):
@ -44,8 +44,8 @@ def dqn_loss(sampled_action, sampled_target, q_net):
:param q_net: current `policy` to be optimized
:return:
"""
action_num = q_net.get_values().shape()[1]
sampled_q = tf.reduce_sum(q_net.get_values() * tf.one_hot(sampled_action, action_num), axis=1)
action_num = q_net.values_tensor().get_shape()[1]
sampled_q = tf.reduce_sum(q_net.values_tensor() * tf.one_hot(sampled_action, action_num), axis=1)
return tf.reduce_mean(tf.square(sampled_target - sampled_q))
def deterministic_policy_gradient(sampled_state, critic):

View File

@ -3,3 +3,4 @@
from .base import *
from .stochastic import *
from .dqn import *

View File

@ -12,23 +12,28 @@ import tensorflow as tf
__all__ = [
'StochasticPolicy',
'QValuePolicy',
]
#TODO: separate actor and critic, we should focus on it once we finish the basic module.
# TODO: separate actor and critic, we should focus on it once we finish the basic module.
class QValuePolicy(object):
"""
The policy as in DQN
"""
def __init__(self, observation_placeholder):
self.observation_placeholder = observation_placeholder
self._observation_placeholder = observation_placeholder
def act(self, observation, exploration=None): # first implement no exploration
"""
return the action (int) to be executed.
no exploration when exploration=None.
"""
pass
self._act(observation, exploration)
def _act(self, observation, exploration = None):
raise NotImplementedError()
def values(self, observation):
"""
@ -36,7 +41,7 @@ class QValuePolicy(object):
"""
pass
def values_tensor(self, observation):
def values_tensor(self):
"""
returns the tensor of the values for all actions a at observation s
"""

View File

@ -1,7 +1,54 @@
from .base import QValuePolicy
from tianshou.core.policy.base import QValuePolicy
import tensorflow as tf
class DQN(QValuePolicy):
pass
"""
The policy as in DQN
"""
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
self._logits = tf.convert_to_tensor(logits)
if dtype is None:
dtype = tf.int32
self._n_categories = self._logits.get_shape()[-1].value
super(DQN, self).__init__(observation_placeholder)
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
net = tf.layers.dense(net, 256, activation=tf.nn.relu, use_bias=True)
self._value = tf.layers.dense(net, self._n_categories)
def _act(self, observation, exploration=None): # first implement no exploration
"""
return the action (int) to be executed.
no exploration when exploration=None.
"""
sess = tf.get_default_session()
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
feed_dict={self._observation_placeholder: observation[None]})
return sampled_action
@property
def logits(self):
return self._logits
@property
def n_categories(self):
return self._n_categories
def values(self, observation):
"""
returns the Q(s, a) values (float) for all actions a at observation s
"""
sess = tf.get_default_session()
value = sess.run(self._value, feed_dict={self._observation_placeholder: observation[None]})
return value
def values_tensor(self):
"""
returns the tensor of the values for all actions a at observation s
"""
return self._value

View File

@ -19,7 +19,8 @@ def full_return(raw_data):
returns = rewards.copy()
episode_start_idx = 0
for i in range(1, num_timesteps):
if episode_start_flags[i] or (i == num_timesteps - 1): # found the start of next episode or the end of all episodes
if episode_start_flags[i] or (
i == num_timesteps - 1): # found the start of next episode or the end of all episodes
if i < rewards.shape[0] - 1:
t = i - 1
else:
@ -35,3 +36,35 @@ def full_return(raw_data):
data['returns'] = returns
return data
class QLearningTarget:
def __init__(self, policy, gamma):
self._policy = policy
self._gamma = gamma
def __call__(self, raw_data):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
all_data, data_wi, data_index = raw_data
for i in range(0, all_data.shape[0]):
current_data = all_data[i]
current_wi = data_wi[i]
current_index = data_index[i]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._policy.values(current_data['observation']))
current_qvalue = self._policy.values(current_data['previous_observation'])[current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
return data

View File

@ -1,39 +1,51 @@
class ReplayBuffer(object):
def __init__(self, conf):
'''
def __init__(self, env, policy, qnet, target_qnet, conf):
"""
Initialize a replay buffer with parameters in conf.
'''
pass
"""
pass
def add(self, data, priority):
'''
def add(self, data, priority):
"""
Add a data with priority = priority to replay buffer.
'''
pass
"""
pass
def update_priority(self, indices, priorities):
'''
def collect(self):
"""
Collect data from current environment and policy.
"""
pass
def next_batch(self, batch_size):
"""
get batch of data from the replay buffer.
"""
pass
def update_priority(self, indices, priorities):
"""
Update the data's priority whose indices = indices.
For proportional replay buffer, the priority is the priority.
For rank based replay buffer, the priorities parameter will be the delta used to update the priority.
'''
pass
"""
pass
def reset_alpha(self, alpha):
'''
def reset_alpha(self, alpha):
"""
This function only works for proportional replay buffer.
This function resets alpha.
'''
pass
"""
pass
def sample(self, conf):
'''
def sample(self, conf):
"""
Sample from replay buffer with parameters in conf.
'''
pass
"""
pass
def rebalance(self):
'''
def rebalance(self):
"""
This is for rank based priority replay buffer, which is used to rebalance the sum tree of the priority queue.
'''
pass
"""
pass

View File

@ -1,29 +1,93 @@
from buffer import ReplayBuffer
import numpy as np
import tensorflow as tf
from collections import deque
from math import fabs
from tianshou.data.replay_buffer.buffer import ReplayBuffer
class NaiveExperience(ReplayBuffer):
def __init__(self, conf):
self.max_size = conf['size']
self.n_entries = 0
self.memory = deque(maxlen = self.max_size)
def __init__(self, env, policy, qnet, target_qnet, conf):
self.max_size = conf['size']
self._env = env
self._policy = policy
self._qnet = qnet
self._target_qnet = target_qnet
self._begin_act()
self.n_entries = 0
self.memory = deque(maxlen=self.max_size)
def add(self, data, priority = 0):
self.memory.append(data)
if self.n_entries < self.max_size:
self.n_entries += 1
def add(self, data, priority=0):
self.memory.append(data)
if self.n_entries < self.max_size:
self.n_entries += 1
def update_priority(self, indices, priorities = 0):
pass
def _begin_act(self):
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
while not done:
if done:
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
self.observation, _, done, _ = self._env.step(self.action)
def reset_alpha(self, alpha):
pass
def collect(self):
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
current_data['previous_observation'] = self.observation
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
self.observation, reward, done, _ = self._env.step(self.action)
current_data['action'] = self.action
current_data['observation'] = self.observation
current_data['reward'] = reward
self.add(current_data)
if done:
self._begin_act()
def sample(self, conf):
batch_size = conf['batch_size']
batch_size = min(len(self.memory), batch_size)
idxs = np.random.choice(len(self.memory), batch_size)
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
def update_priority(self, indices, priorities=0):
pass
def rebalance(self):
pass
def reset_alpha(self, alpha):
pass
def sample(self, conf):
batch_size = conf['batch_size']
batch_size = min(len(self.memory), batch_size)
idxs = np.random.choice(len(self.memory), batch_size)
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
def next_batch(self, batch_size):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
target = list()
for i in range(0, batch_size):
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
current_data = current_datas[0]
current_wi = current_wis[0]
current_index = current_indexs[0]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
target.append(current_data['reward'] + next_max_qvalue)
self.update_priority(current_index, [fabs(reward)])
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
data['wi'] = np.array(wi)
data['target'] = np.array(target)
return data
def rebalance(self):
pass

View File

@ -1,7 +1,10 @@
import numpy
import numpy as np
import random
import sum_tree
from buffer import ReplayBuffer
import tensorflow as tf
import math
from tianshou.data.replay_buffer import sum_tree
from tianshou.data.replay_buffer.buffer import ReplayBuffer
class PropotionalExperience(ReplayBuffer):
@ -15,7 +18,7 @@ class PropotionalExperience(ReplayBuffer):
"""
def __init__(self, conf):
def __init__(self, env, policy, qnet, target_qnet, conf):
""" Prioritized experience replay buffer initialization.
Parameters
@ -30,11 +33,26 @@ class PropotionalExperience(ReplayBuffer):
"""
memory_size = conf['size']
batch_size = conf['batch_size']
alpha = conf['alpha']
alpha = conf['alpha'] if 'alpha' in conf else 0.6
self.tree = sum_tree.SumTree(memory_size)
self.memory_size = memory_size
self.batch_size = batch_size
self.alpha = alpha
self._env = env
self._policy = policy
self._qnet = qnet
self._target_qnet = target_qnet
self._begin_act()
def _begin_act(self):
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
while not done:
if done:
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
self.observation, _, done, _ = self._env.step(self.action)
def add(self, data, priority):
""" Add new sample.
@ -48,6 +66,12 @@ class PropotionalExperience(ReplayBuffer):
"""
self.tree.add(data, priority**self.alpha)
def collect(self):
pass
def next_batch(self, batch_size):
pass
def sample(self, conf):
""" The method return samples randomly.
@ -64,8 +88,9 @@ class PropotionalExperience(ReplayBuffer):
indices:
list of sample indices
The indices indicate sample positions in a sum tree.
:param conf: giving beta
"""
beta = conf['beta']
beta = conf['beta'] if 'beta' in conf else 0.4
if self.tree.filled_size() < self.batch_size:
return None, None, None
@ -91,6 +116,54 @@ class PropotionalExperience(ReplayBuffer):
return out, weights, indices
def collect(self):
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
current_data['previous_observation'] = self.observation
# TODO: change the name of the feed_dict
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
self.observation, reward, done, _ = self._env.step(self.action)
current_data['action'] = self.action
current_data['observation'] = self.observation
current_data['reward'] = reward
priorities = np.array([self.tree.get_val(i) ** -self.alpha for i in range(self.tree.filled_size())])
priority = np.max(priorities) if len(priorities) > 0 else 1
self.add(current_data, priority)
if done:
self._begin_act()
def next_batch(self, batch_size):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
target = list()
for i in range(0, batch_size):
current_datas, current_wis, current_indexs = self.sample({'batch_size': 1})
current_data = current_datas[0]
current_wi = current_wis[0]
current_index = current_indexs[0]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
target.append(current_data['reward'] + next_max_qvalue)
self.update_priority([current_index], [math.fabs(reward)])
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
data['wi'] = np.array(wi)
data['target'] = np.array(target)
return data
def update_priority(self, indices, priorities):
""" The methods update samples's priority.

View File

@ -8,13 +8,15 @@ import sys
import math
import random
import numpy as np
import tensorflow as tf
from tianshou.data.replay_buffer.binary_heap import BinaryHeap
from tianshou.data.replay_buffer.buffer import ReplayBuffer
from binary_heap import BinaryHeap
from buffer import ReplayBuffer
class RankBasedExperience(ReplayBuffer):
def __init__(self, conf):
def __init__(self, env, policy, qnet, target_qnet, conf):
self.size = conf['size']
self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True
self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size
@ -25,12 +27,18 @@ class RankBasedExperience(ReplayBuffer):
self.learn_start = conf['learn_start'] if 'learn_start' in conf else 1000
self.total_steps = conf['steps'] if 'steps' in conf else 100000
# partition number N, split total size to N part
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100
self.partition_num = conf['partition_num'] if 'partition_num' in conf else 10
self.index = 0
self.record_size = 0
self.isFull = False
self._env = env
self._policy = policy
self._qnet = qnet
self._target_qnet = target_qnet
self._begin_act()
self._experience = {}
self.priority_queue = BinaryHeap(self.priority_size)
self.distributions = self.build_distributions()
@ -98,7 +106,64 @@ class RankBasedExperience(ReplayBuffer):
self.index += 1
return self.index
def add(self, data, priority = 0):
def _begin_act(self):
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
while not done:
if done:
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
self.observation, _, done, _ = self._env.step(self.action)
def collect(self):
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
current_data['previous_observation'] = self.observation
self.action = np.argmax(sess.run(self._policy, feed_dict={"dqn_observation:0": self.observation.reshape((1,) + self.observation.shape)}))
self.observation, reward, done, _ = self._env.step(self.action)
current_data['action'] = self.action
current_data['observation'] = self.observation
current_data['reward'] = reward
self.add(current_data)
if done:
self._begin_act()
def next_batch(self, batch_size):
data = dict()
observations = list()
actions = list()
rewards = list()
wi = list()
target = list()
sess = tf.get_default_session()
current_datas, current_wis, current_indexs = self.sample({'global_step': sess.run(tf.train.get_global_step())})
for i in range(0, batch_size):
current_data = current_datas[i]
current_wi = current_wis[i]
current_index = current_indexs[i]
observations.append(current_data['observation'])
actions.append(current_data['action'])
next_max_qvalue = np.max(self._target_qnet.values(current_data['observation']))
current_qvalue = self._qnet.values(current_data['previous_observation'])[0, current_data['previous_action']]
reward = current_data['reward'] + next_max_qvalue - current_qvalue
rewards.append(reward)
target.append(current_data['reward'] + next_max_qvalue)
self.update_priority([current_index], [math.fabs(reward)])
wi.append(current_wi)
data['observations'] = np.array(observations)
data['actions'] = np.array(actions)
data['rewards'] = np.array(rewards)
data['wi'] = np.array(wi)
data['target'] = np.array(target)
return data
def add(self, data, priority = 1):
"""
store experience, suggest that experience is a tuple of (s1, a, r, s2, t)
so each experience is valid
@ -156,16 +221,16 @@ class RankBasedExperience(ReplayBuffer):
sys.stderr.write('Record size less than learn start! Sample failed\n')
return False, False, False
dist_index = math.floor(self.record_size / self.size * self.partition_num)
dist_index = math.floor(self.record_size * 1. / self.size * self.partition_num)
# issue 1 by @camigord
partition_size = math.floor(self.size / self.partition_num)
partition_size = math.floor(self.size * 1. / self.partition_num)
partition_max = dist_index * partition_size
distribution = self.distributions[dist_index]
rank_list = []
# sample from k segments
for n in range(1, self.batch_size + 1):
index = random.randint(distribution['strata_ends'][n] + 1,
distribution['strata_ends'][n + 1])
index = random.randint(distribution['strata_ends'][n],
distribution['strata_ends'][n + 1])
rank_list.append(index)
# beta, increase by global_step, max 1

View File

@ -1,13 +1,15 @@
from utils import *
from functions import *
from tianshou.data.replay_buffer.utils import get_replay_buffer
def test_rank_based():
conf = {'size': 50,
'learn_start': 10,
'partition_num': 5,
'total_step': 100,
'batch_size': 4}
experience = getReplayBuffer('rank_based', conf)
experience = get_replay_buffer('rank_based', conf)
# insert to experience
print 'test insert experience'
@ -52,7 +54,7 @@ def test_proportional():
conf = {'size': 50,
'alpha': 0.7,
'batch_size': 4}
experience = getReplayBuffer('proportional', conf)
experience = get_replay_buffer('proportional', conf)
# insert to experience
print 'test insert experience'
@ -90,7 +92,7 @@ def test_proportional():
def test_naive():
conf = {'size': 50}
experience = getReplayBuffer('naive', conf)
experience = get_replay_buffer('naive', conf)
# insert to experience
print 'test insert experience'

View File

@ -1,17 +1,20 @@
from rank_based import *
from proportional import *
from naive import *
import sys
def getReplayBuffer(name, conf):
'''
Get replay buffer according to the given name.
'''
if (name == 'rank_based'):
return RankBasedExperience(conf)
elif (name == 'proportional'):
return PropotionalExperience(conf)
elif (name == 'naive'):
return NaiveExperience(conf)
else:
sys.stderr.write('no such replay buffer')
from tianshou.data.replay_buffer.naive import NaiveExperience
from tianshou.data.replay_buffer.proportional import PropotionalExperience
from tianshou.data.replay_buffer.rank_based import RankBasedExperience
def get_replay_buffer(name, env, policy, qnet, target_qnet, conf):
"""
Get replay buffer according to the given name.
"""
if name == 'rank_based':
return RankBasedExperience(env, policy, qnet, target_qnet, conf)
elif name == 'proportional':
return PropotionalExperience(env, policy, qnet, target_qnet, conf)
elif name == 'naive':
return NaiveExperience(env, policy, qnet, target_qnet, conf)
else:
sys.stderr.write('no such replay buffer')