add comments and todos

This commit is contained in:
宋世虹 2017-12-17 13:28:21 +08:00
parent 62e2c6582d
commit 7693c38f44
5 changed files with 59 additions and 6 deletions

View File

@ -37,6 +37,9 @@ if __name__ == '__main__':
action_dim = env.action_space.n
# 1. build network with pure tf
# TODO:
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
# access this observation variable.
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
with tf.variable_scope('q_net'):
@ -59,6 +62,7 @@ if __name__ == '__main__':
optimizer = tf.train.AdamOptimizer(1e-3)
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
# 3. define data collection
# configuration should be given as parameters, different replay buffer has different parameters.
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
@ -70,6 +74,7 @@ if __name__ == '__main__':
minibatch_count = 0
collection_count = 0
# need to first collect some then sample, collect_freq must be larger than batch_size
collect_freq = 100
while True: # until some stopping criterion met...
# collect data

View File

@ -8,6 +8,7 @@ class DQN(QValuePolicy):
"""
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
# TODO: this version only support non-continuous action space, extend it to support continuous action space
self._logits = tf.convert_to_tensor(logits)
if dtype is None:
dtype = tf.int32
@ -15,6 +16,7 @@ class DQN(QValuePolicy):
super(DQN, self).__init__(observation_placeholder)
# TODO: put the net definition outside of the class
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
net = tf.layers.flatten(net)
@ -26,6 +28,7 @@ class DQN(QValuePolicy):
return the action (int) to be executed.
no exploration when exploration=None.
"""
# TODO: ensure thread safety
sess = tf.get_default_session()
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
feed_dict={self._observation_placeholder: observation[None]})
@ -33,10 +36,16 @@ class DQN(QValuePolicy):
@property
def logits(self):
"""
:return: action values
"""
return self._logits
@property
def n_categories(self):
"""
:return: dimension of action space if not continuous
"""
return self._n_categories
def values(self, observation):

View File

@ -23,6 +23,10 @@ class NaiveExperience(ReplayBuffer):
self.n_entries += 1
def _begin_act(self):
"""
if the previous interaction is ended or the interaction hasn't started
then begin act from the state of env.reset()
"""
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
@ -33,6 +37,10 @@ class NaiveExperience(ReplayBuffer):
self.observation, _, done, _ = self._env.step(self.action)
def collect(self):
"""
collect data for replay memory and update the priority according to the given data.
store the previous action, previous observation, reward, action, observation in the replay memory.
"""
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
@ -59,6 +67,13 @@ class NaiveExperience(ReplayBuffer):
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
def next_batch(self, batch_size):
"""
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
updating q value network.
:param batch_size: int batch size.
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
for gradient of q value network.
"""
data = dict()
observations = list()
actions = list()

View File

@ -45,6 +45,10 @@ class PropotionalExperience(ReplayBuffer):
self._begin_act()
def _begin_act(self):
"""
if the previous interaction is ended or the interaction hasn't started
then begin act from the state of env.reset()
"""
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
@ -66,12 +70,6 @@ class PropotionalExperience(ReplayBuffer):
"""
self.tree.add(data, priority**self.alpha)
def collect(self):
pass
def next_batch(self, batch_size):
pass
def sample(self, conf):
""" The method return samples randomly.
@ -117,6 +115,10 @@ class PropotionalExperience(ReplayBuffer):
return out, weights, indices
def collect(self):
"""
collect data for replay memory and update the priority according to the given data.
store the previous action, previous observation, reward, action, observation in the replay memory.
"""
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
@ -134,6 +136,13 @@ class PropotionalExperience(ReplayBuffer):
self._begin_act()
def next_batch(self, batch_size):
"""
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
updating q value network.
:param batch_size: int batch size.
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
for gradient of q value network.
"""
data = dict()
observations = list()
actions = list()

View File

@ -107,6 +107,10 @@ class RankBasedExperience(ReplayBuffer):
return self.index
def _begin_act(self):
"""
if the previous interaction is ended or the interaction hasn't started
then begin act from the state of env.reset()
"""
self.observation = self._env.reset()
self.action = self._env.action_space.sample()
done = False
@ -117,6 +121,10 @@ class RankBasedExperience(ReplayBuffer):
self.observation, _, done, _ = self._env.step(self.action)
def collect(self):
"""
collect data for replay memory and update the priority according to the given data.
store the previous action, previous observation, reward, action, observation in the replay memory.
"""
sess = tf.get_default_session()
current_data = dict()
current_data['previous_action'] = self.action
@ -131,6 +139,13 @@ class RankBasedExperience(ReplayBuffer):
self._begin_act()
def next_batch(self, batch_size):
"""
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
updating q value network.
:param batch_size: int batch size.
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
for gradient of q value network.
"""
data = dict()
observations = list()
actions = list()