an initial version of untested replaymemory qreturn
This commit is contained in:
parent
528c4be93c
commit
0cf2fd6c53
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
STATE = 0
|
||||
ACTION = 1
|
||||
@ -100,7 +101,7 @@ class ddpg_return:
|
||||
pass
|
||||
|
||||
|
||||
class nstep_q_return:
|
||||
class ReplayMemoryQReturn:
|
||||
"""
|
||||
compute the n-step return for Q-learning targets
|
||||
"""
|
||||
@ -109,11 +110,49 @@ class nstep_q_return:
|
||||
self.action_value = action_value
|
||||
self.use_target_network = use_target_network
|
||||
|
||||
def __call__(self, buffer, index=None):
|
||||
# TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf
|
||||
def __call__(self, buffer, indexes =None):
|
||||
"""
|
||||
:param buffer: buffer with property index and data. index determines the current content in `buffer`.
|
||||
:param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within
|
||||
each episode.
|
||||
:return: dict with key 'return' and value the computed returns corresponding to `index`.
|
||||
"""
|
||||
pass
|
||||
qvalue = self.action_value._value_tensor_all_actions
|
||||
indexes = indexes or buffer.index
|
||||
episodes = buffer.data
|
||||
discount_factor = 0.99
|
||||
returns = []
|
||||
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
with tf.Session(config=config) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
for episode_index in range(len(indexes)):
|
||||
index = indexes[episode_index]
|
||||
if index:
|
||||
episode = episodes[episode_index]
|
||||
episode_q = []
|
||||
if not episode[-1][DONE]:
|
||||
logging.warning('Computing Q return on episode {} with no terminal state.'.format(episode_index))
|
||||
|
||||
for i in index:
|
||||
current_discount_factor = 1
|
||||
last_frame_index = i
|
||||
target_q = episode[i][REWARD]
|
||||
for lfi in range(i, min(len(episode), i + self.n + 1)):
|
||||
if episode[lfi][DONE]:
|
||||
break
|
||||
target_q += current_discount_factor * episode[lfi][REWARD]
|
||||
current_discount_factor *= discount_factor
|
||||
last_frame_index = lfi
|
||||
if last_frame_index > i:
|
||||
target_q += current_discount_factor * \
|
||||
max(sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']:
|
||||
episode[last_frame_index][STATE]}))
|
||||
episode_q.append(target_q)
|
||||
|
||||
returns.append(episode_q)
|
||||
else:
|
||||
returns.append([])
|
||||
return {'TD-lambda': returns}
|
||||
|
Loading…
x
Reference in New Issue
Block a user