add comments and todos
This commit is contained in:
parent
3624cc9036
commit
d220f7f2a8
@ -37,6 +37,9 @@ if __name__ == '__main__':
|
|||||||
action_dim = env.action_space.n
|
action_dim = env.action_space.n
|
||||||
|
|
||||||
# 1. build network with pure tf
|
# 1. build network with pure tf
|
||||||
|
# TODO:
|
||||||
|
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
|
||||||
|
# access this observation variable.
|
||||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
|
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
|
||||||
|
|
||||||
with tf.variable_scope('q_net'):
|
with tf.variable_scope('q_net'):
|
||||||
@ -59,6 +62,7 @@ if __name__ == '__main__':
|
|||||||
optimizer = tf.train.AdamOptimizer(1e-3)
|
optimizer = tf.train.AdamOptimizer(1e-3)
|
||||||
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
|
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
|
||||||
# 3. define data collection
|
# 3. define data collection
|
||||||
|
# configuration should be given as parameters, different replay buffer has different parameters.
|
||||||
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
|
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
|
||||||
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
|
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
|
||||||
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
|
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
|
||||||
@ -70,6 +74,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
minibatch_count = 0
|
minibatch_count = 0
|
||||||
collection_count = 0
|
collection_count = 0
|
||||||
|
# need to first collect some then sample, collect_freq must be larger than batch_size
|
||||||
collect_freq = 100
|
collect_freq = 100
|
||||||
while True: # until some stopping criterion met...
|
while True: # until some stopping criterion met...
|
||||||
# collect data
|
# collect data
|
||||||
|
@ -8,6 +8,7 @@ class DQN(QValuePolicy):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
|
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
|
||||||
|
# TODO: this version only support non-continuous action space, extend it to support continuous action space
|
||||||
self._logits = tf.convert_to_tensor(logits)
|
self._logits = tf.convert_to_tensor(logits)
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = tf.int32
|
dtype = tf.int32
|
||||||
@ -15,6 +16,7 @@ class DQN(QValuePolicy):
|
|||||||
|
|
||||||
super(DQN, self).__init__(observation_placeholder)
|
super(DQN, self).__init__(observation_placeholder)
|
||||||
|
|
||||||
|
# TODO: put the net definition outside of the class
|
||||||
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
||||||
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
||||||
net = tf.layers.flatten(net)
|
net = tf.layers.flatten(net)
|
||||||
@ -26,6 +28,7 @@ class DQN(QValuePolicy):
|
|||||||
return the action (int) to be executed.
|
return the action (int) to be executed.
|
||||||
no exploration when exploration=None.
|
no exploration when exploration=None.
|
||||||
"""
|
"""
|
||||||
|
# TODO: ensure thread safety
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
|
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
|
||||||
feed_dict={self._observation_placeholder: observation[None]})
|
feed_dict={self._observation_placeholder: observation[None]})
|
||||||
@ -33,10 +36,16 @@ class DQN(QValuePolicy):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def logits(self):
|
def logits(self):
|
||||||
|
"""
|
||||||
|
:return: action values
|
||||||
|
"""
|
||||||
return self._logits
|
return self._logits
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_categories(self):
|
def n_categories(self):
|
||||||
|
"""
|
||||||
|
:return: dimension of action space if not continuous
|
||||||
|
"""
|
||||||
return self._n_categories
|
return self._n_categories
|
||||||
|
|
||||||
def values(self, observation):
|
def values(self, observation):
|
||||||
|
@ -23,6 +23,10 @@ class NaiveExperience(ReplayBuffer):
|
|||||||
self.n_entries += 1
|
self.n_entries += 1
|
||||||
|
|
||||||
def _begin_act(self):
|
def _begin_act(self):
|
||||||
|
"""
|
||||||
|
if the previous interaction is ended or the interaction hasn't started
|
||||||
|
then begin act from the state of env.reset()
|
||||||
|
"""
|
||||||
self.observation = self._env.reset()
|
self.observation = self._env.reset()
|
||||||
self.action = self._env.action_space.sample()
|
self.action = self._env.action_space.sample()
|
||||||
done = False
|
done = False
|
||||||
@ -33,6 +37,10 @@ class NaiveExperience(ReplayBuffer):
|
|||||||
self.observation, _, done, _ = self._env.step(self.action)
|
self.observation, _, done, _ = self._env.step(self.action)
|
||||||
|
|
||||||
def collect(self):
|
def collect(self):
|
||||||
|
"""
|
||||||
|
collect data for replay memory and update the priority according to the given data.
|
||||||
|
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
current_data = dict()
|
current_data = dict()
|
||||||
current_data['previous_action'] = self.action
|
current_data['previous_action'] = self.action
|
||||||
@ -59,6 +67,13 @@ class NaiveExperience(ReplayBuffer):
|
|||||||
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
|
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
|
||||||
|
|
||||||
def next_batch(self, batch_size):
|
def next_batch(self, batch_size):
|
||||||
|
"""
|
||||||
|
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||||
|
updating q value network.
|
||||||
|
:param batch_size: int batch size.
|
||||||
|
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||||
|
for gradient of q value network.
|
||||||
|
"""
|
||||||
data = dict()
|
data = dict()
|
||||||
observations = list()
|
observations = list()
|
||||||
actions = list()
|
actions = list()
|
||||||
|
@ -45,6 +45,10 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
self._begin_act()
|
self._begin_act()
|
||||||
|
|
||||||
def _begin_act(self):
|
def _begin_act(self):
|
||||||
|
"""
|
||||||
|
if the previous interaction is ended or the interaction hasn't started
|
||||||
|
then begin act from the state of env.reset()
|
||||||
|
"""
|
||||||
self.observation = self._env.reset()
|
self.observation = self._env.reset()
|
||||||
self.action = self._env.action_space.sample()
|
self.action = self._env.action_space.sample()
|
||||||
done = False
|
done = False
|
||||||
@ -66,12 +70,6 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
"""
|
"""
|
||||||
self.tree.add(data, priority**self.alpha)
|
self.tree.add(data, priority**self.alpha)
|
||||||
|
|
||||||
def collect(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def next_batch(self, batch_size):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def sample(self, conf):
|
def sample(self, conf):
|
||||||
""" The method return samples randomly.
|
""" The method return samples randomly.
|
||||||
|
|
||||||
@ -117,6 +115,10 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
return out, weights, indices
|
return out, weights, indices
|
||||||
|
|
||||||
def collect(self):
|
def collect(self):
|
||||||
|
"""
|
||||||
|
collect data for replay memory and update the priority according to the given data.
|
||||||
|
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
current_data = dict()
|
current_data = dict()
|
||||||
current_data['previous_action'] = self.action
|
current_data['previous_action'] = self.action
|
||||||
@ -134,6 +136,13 @@ class PropotionalExperience(ReplayBuffer):
|
|||||||
self._begin_act()
|
self._begin_act()
|
||||||
|
|
||||||
def next_batch(self, batch_size):
|
def next_batch(self, batch_size):
|
||||||
|
"""
|
||||||
|
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||||
|
updating q value network.
|
||||||
|
:param batch_size: int batch size.
|
||||||
|
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||||
|
for gradient of q value network.
|
||||||
|
"""
|
||||||
data = dict()
|
data = dict()
|
||||||
observations = list()
|
observations = list()
|
||||||
actions = list()
|
actions = list()
|
||||||
|
@ -107,6 +107,10 @@ class RankBasedExperience(ReplayBuffer):
|
|||||||
return self.index
|
return self.index
|
||||||
|
|
||||||
def _begin_act(self):
|
def _begin_act(self):
|
||||||
|
"""
|
||||||
|
if the previous interaction is ended or the interaction hasn't started
|
||||||
|
then begin act from the state of env.reset()
|
||||||
|
"""
|
||||||
self.observation = self._env.reset()
|
self.observation = self._env.reset()
|
||||||
self.action = self._env.action_space.sample()
|
self.action = self._env.action_space.sample()
|
||||||
done = False
|
done = False
|
||||||
@ -117,6 +121,10 @@ class RankBasedExperience(ReplayBuffer):
|
|||||||
self.observation, _, done, _ = self._env.step(self.action)
|
self.observation, _, done, _ = self._env.step(self.action)
|
||||||
|
|
||||||
def collect(self):
|
def collect(self):
|
||||||
|
"""
|
||||||
|
collect data for replay memory and update the priority according to the given data.
|
||||||
|
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
current_data = dict()
|
current_data = dict()
|
||||||
current_data['previous_action'] = self.action
|
current_data['previous_action'] = self.action
|
||||||
@ -131,6 +139,13 @@ class RankBasedExperience(ReplayBuffer):
|
|||||||
self._begin_act()
|
self._begin_act()
|
||||||
|
|
||||||
def next_batch(self, batch_size):
|
def next_batch(self, batch_size):
|
||||||
|
"""
|
||||||
|
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||||
|
updating q value network.
|
||||||
|
:param batch_size: int batch size.
|
||||||
|
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||||
|
for gradient of q value network.
|
||||||
|
"""
|
||||||
data = dict()
|
data = dict()
|
||||||
observations = list()
|
observations = list()
|
||||||
actions = list()
|
actions = list()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user