an initial version of untested replaymemory qreturn

This commit is contained in:
Dong Yan 2018-03-03 21:25:29 +08:00
parent 528c4be93c
commit 0cf2fd6c53

View File

@ -1,5 +1,6 @@
import logging
import tensorflow as tf
import numpy as np
STATE = 0
ACTION = 1
@ -100,7 +101,7 @@ class ddpg_return:
pass
class nstep_q_return:
class ReplayMemoryQReturn:
"""
compute the n-step return for Q-learning targets
"""
@ -109,11 +110,49 @@ class nstep_q_return:
self.action_value = action_value
self.use_target_network = use_target_network
def __call__(self, buffer, index=None):
# TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf
def __call__(self, buffer, indexes =None):
"""
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
each episode.
:return: dict with key 'return' and value the computed returns corresponding to `index`.
"""
pass
qvalue = self.action_value._value_tensor_all_actions
indexes = indexes or buffer.index
episodes = buffer.data
discount_factor = 0.99
returns = []
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
for episode_index in range(len(indexes)):
index = indexes[episode_index]
if index:
episode = episodes[episode_index]
episode_q = []
if not episode[-1][DONE]:
logging.warning('Computing Q return on episode {} with no terminal state.'.format(episode_index))
for i in index:
current_discount_factor = 1
last_frame_index = i
target_q = episode[i][REWARD]
for lfi in range(i, min(len(episode), i + self.n + 1)):
if episode[lfi][DONE]:
break
target_q += current_discount_factor * episode[lfi][REWARD]
current_discount_factor *= discount_factor
last_frame_index = lfi
if last_frame_index > i:
target_q += current_discount_factor * \
max(sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']:
episode[last_frame_index][STATE]}))
episode_q.append(target_q)
returns.append(episode_q)
else:
returns.append([])
return {'TD-lambda': returns}