an initial version of untested replaymemory qreturn
This commit is contained in:
parent
528c4be93c
commit
0cf2fd6c53
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import tensorflow as tf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
STATE = 0
|
STATE = 0
|
||||||
ACTION = 1
|
ACTION = 1
|
||||||
@ -100,7 +101,7 @@ class ddpg_return:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class nstep_q_return:
|
class ReplayMemoryQReturn:
|
||||||
"""
|
"""
|
||||||
compute the n-step return for Q-learning targets
|
compute the n-step return for Q-learning targets
|
||||||
"""
|
"""
|
||||||
@ -109,11 +110,49 @@ class nstep_q_return:
|
|||||||
self.action_value = action_value
|
self.action_value = action_value
|
||||||
self.use_target_network = use_target_network
|
self.use_target_network = use_target_network
|
||||||
|
|
||||||
def __call__(self, buffer, index=None):
|
# TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf
|
||||||
|
def __call__(self, buffer, indexes =None):
|
||||||
"""
|
"""
|
||||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||||
each episode.
|
each episode.
|
||||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||||
"""
|
"""
|
||||||
pass
|
qvalue = self.action_value._value_tensor_all_actions
|
||||||
|
indexes = indexes or buffer.index
|
||||||
|
episodes = buffer.data
|
||||||
|
discount_factor = 0.99
|
||||||
|
returns = []
|
||||||
|
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
with tf.Session(config=config) as sess:
|
||||||
|
sess.run(tf.global_variables_initializer())
|
||||||
|
for episode_index in range(len(indexes)):
|
||||||
|
index = indexes[episode_index]
|
||||||
|
if index:
|
||||||
|
episode = episodes[episode_index]
|
||||||
|
episode_q = []
|
||||||
|
if not episode[-1][DONE]:
|
||||||
|
logging.warning('Computing Q return on episode {} with no terminal state.'.format(episode_index))
|
||||||
|
|
||||||
|
for i in index:
|
||||||
|
current_discount_factor = 1
|
||||||
|
last_frame_index = i
|
||||||
|
target_q = episode[i][REWARD]
|
||||||
|
for lfi in range(i, min(len(episode), i + self.n + 1)):
|
||||||
|
if episode[lfi][DONE]:
|
||||||
|
break
|
||||||
|
target_q += current_discount_factor * episode[lfi][REWARD]
|
||||||
|
current_discount_factor *= discount_factor
|
||||||
|
last_frame_index = lfi
|
||||||
|
if last_frame_index > i:
|
||||||
|
target_q += current_discount_factor * \
|
||||||
|
max(sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']:
|
||||||
|
episode[last_frame_index][STATE]}))
|
||||||
|
episode_q.append(target_q)
|
||||||
|
|
||||||
|
returns.append(episode_q)
|
||||||
|
else:
|
||||||
|
returns.append([])
|
||||||
|
return {'TD-lambda': returns}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user