add comments and todos
This commit is contained in:
parent
62e2c6582d
commit
7693c38f44
@ -37,6 +37,9 @@ if __name__ == '__main__':
|
||||
action_dim = env.action_space.n
|
||||
|
||||
# 1. build network with pure tf
|
||||
# TODO:
|
||||
# pass the observation variable to the replay buffer or find a more reasonable way to help replay buffer
|
||||
# access this observation variable.
|
||||
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim, name="dqn_observation") # network input
|
||||
|
||||
with tf.variable_scope('q_net'):
|
||||
@ -59,6 +62,7 @@ if __name__ == '__main__':
|
||||
optimizer = tf.train.AdamOptimizer(1e-3)
|
||||
train_op = optimizer.minimize(total_loss, var_list=train_var_list, global_step=tf.train.get_global_step())
|
||||
# 3. define data collection
|
||||
# configuration should be given as parameters, different replay buffer has different parameters.
|
||||
replay_memory = get_replay_buffer('rank_based', env, q_values, q_net, target_net,
|
||||
{'size': 1000, 'batch_size': 64, 'learn_start': 20})
|
||||
# ShihongSong: Replay(env, q_net, advantage_estimation.qlearning_target(target_network)), use your ReplayMemory, interact as follows. Simplify your advantage_estimation.dqn to run before YongRen's DQN
|
||||
@ -70,6 +74,7 @@ if __name__ == '__main__':
|
||||
|
||||
minibatch_count = 0
|
||||
collection_count = 0
|
||||
# need to first collect some then sample, collect_freq must be larger than batch_size
|
||||
collect_freq = 100
|
||||
while True: # until some stopping criterion met...
|
||||
# collect data
|
||||
|
@ -8,6 +8,7 @@ class DQN(QValuePolicy):
|
||||
"""
|
||||
|
||||
def __init__(self, logits, observation_placeholder, dtype=None, **kwargs):
|
||||
# TODO: this version only support non-continuous action space, extend it to support continuous action space
|
||||
self._logits = tf.convert_to_tensor(logits)
|
||||
if dtype is None:
|
||||
dtype = tf.int32
|
||||
@ -15,6 +16,7 @@ class DQN(QValuePolicy):
|
||||
|
||||
super(DQN, self).__init__(observation_placeholder)
|
||||
|
||||
# TODO: put the net definition outside of the class
|
||||
net = tf.layers.conv2d(self._observation_placeholder, 16, 8, 4, 'valid', activation=tf.nn.relu)
|
||||
net = tf.layers.conv2d(net, 32, 4, 2, 'valid', activation=tf.nn.relu)
|
||||
net = tf.layers.flatten(net)
|
||||
@ -26,6 +28,7 @@ class DQN(QValuePolicy):
|
||||
return the action (int) to be executed.
|
||||
no exploration when exploration=None.
|
||||
"""
|
||||
# TODO: ensure thread safety
|
||||
sess = tf.get_default_session()
|
||||
sampled_action = sess.run(tf.multinomial(self.logits, num_samples=1),
|
||||
feed_dict={self._observation_placeholder: observation[None]})
|
||||
@ -33,10 +36,16 @@ class DQN(QValuePolicy):
|
||||
|
||||
@property
|
||||
def logits(self):
|
||||
"""
|
||||
:return: action values
|
||||
"""
|
||||
return self._logits
|
||||
|
||||
@property
|
||||
def n_categories(self):
|
||||
"""
|
||||
:return: dimension of action space if not continuous
|
||||
"""
|
||||
return self._n_categories
|
||||
|
||||
def values(self, observation):
|
||||
|
@ -23,6 +23,10 @@ class NaiveExperience(ReplayBuffer):
|
||||
self.n_entries += 1
|
||||
|
||||
def _begin_act(self):
|
||||
"""
|
||||
if the previous interaction is ended or the interaction hasn't started
|
||||
then begin act from the state of env.reset()
|
||||
"""
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
@ -33,6 +37,10 @@ class NaiveExperience(ReplayBuffer):
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
collect data for replay memory and update the priority according to the given data.
|
||||
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
@ -59,6 +67,13 @@ class NaiveExperience(ReplayBuffer):
|
||||
return [self.memory[idx] for idx in idxs], [1] * len(idxs), idxs
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||
updating q value network.
|
||||
:param batch_size: int batch size.
|
||||
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||
for gradient of q value network.
|
||||
"""
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
|
@ -45,6 +45,10 @@ class PropotionalExperience(ReplayBuffer):
|
||||
self._begin_act()
|
||||
|
||||
def _begin_act(self):
|
||||
"""
|
||||
if the previous interaction is ended or the interaction hasn't started
|
||||
then begin act from the state of env.reset()
|
||||
"""
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
@ -66,12 +70,6 @@ class PropotionalExperience(ReplayBuffer):
|
||||
"""
|
||||
self.tree.add(data, priority**self.alpha)
|
||||
|
||||
def collect(self):
|
||||
pass
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
pass
|
||||
|
||||
def sample(self, conf):
|
||||
""" The method return samples randomly.
|
||||
|
||||
@ -117,6 +115,10 @@ class PropotionalExperience(ReplayBuffer):
|
||||
return out, weights, indices
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
collect data for replay memory and update the priority according to the given data.
|
||||
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
@ -134,6 +136,13 @@ class PropotionalExperience(ReplayBuffer):
|
||||
self._begin_act()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||
updating q value network.
|
||||
:param batch_size: int batch size.
|
||||
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||
for gradient of q value network.
|
||||
"""
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
|
@ -107,6 +107,10 @@ class RankBasedExperience(ReplayBuffer):
|
||||
return self.index
|
||||
|
||||
def _begin_act(self):
|
||||
"""
|
||||
if the previous interaction is ended or the interaction hasn't started
|
||||
then begin act from the state of env.reset()
|
||||
"""
|
||||
self.observation = self._env.reset()
|
||||
self.action = self._env.action_space.sample()
|
||||
done = False
|
||||
@ -117,6 +121,10 @@ class RankBasedExperience(ReplayBuffer):
|
||||
self.observation, _, done, _ = self._env.step(self.action)
|
||||
|
||||
def collect(self):
|
||||
"""
|
||||
collect data for replay memory and update the priority according to the given data.
|
||||
store the previous action, previous observation, reward, action, observation in the replay memory.
|
||||
"""
|
||||
sess = tf.get_default_session()
|
||||
current_data = dict()
|
||||
current_data['previous_action'] = self.action
|
||||
@ -131,6 +139,13 @@ class RankBasedExperience(ReplayBuffer):
|
||||
self._begin_act()
|
||||
|
||||
def next_batch(self, batch_size):
|
||||
"""
|
||||
collect a batch of data from replay buffer, update the priority and calculate the necessary statistics for
|
||||
updating q value network.
|
||||
:param batch_size: int batch size.
|
||||
:return: a batch of data, with target storing the target q value and wi, rewards storing the coefficient
|
||||
for gradient of q value network.
|
||||
"""
|
||||
data = dict()
|
||||
observations = list()
|
||||
actions = list()
|
||||
|
Loading…
x
Reference in New Issue
Block a user