From 0cf2fd6c53b097d99fc7378901f6cd3e525078c2 Mon Sep 17 00:00:00 2001 From: Dong Yan Date: Sat, 3 Mar 2018 21:25:29 +0800 Subject: [PATCH] an initial version of untested replaymemory qreturn --- tianshou/data/advantage_estimation.py | 47 ++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 2f7f2ed..a870f02 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -1,5 +1,6 @@ import logging - +import tensorflow as tf +import numpy as np STATE = 0 ACTION = 1 @@ -100,7 +101,7 @@ class ddpg_return: pass -class nstep_q_return: +class ReplayMemoryQReturn: """ compute the n-step return for Q-learning targets """ @@ -109,11 +110,49 @@ class nstep_q_return: self.action_value = action_value self.use_target_network = use_target_network - def __call__(self, buffer, index=None): + # TODO : we should transfer the tf -> numpy/python -> tf into a monolithic compute graph in tf + def __call__(self, buffer, indexes =None): """ :param buffer: buffer with property index and data. index determines the current content in `buffer`. :param index: (sampled) index to be computed. Defaults to all the data in `buffer`. Not necessarily in order within each episode. :return: dict with key 'return' and value the computed returns corresponding to `index`. """ - pass + qvalue = self.action_value._value_tensor_all_actions + indexes = indexes or buffer.index + episodes = buffer.data + discount_factor = 0.99 + returns = [] + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: + sess.run(tf.global_variables_initializer()) + for episode_index in range(len(indexes)): + index = indexes[episode_index] + if index: + episode = episodes[episode_index] + episode_q = [] + if not episode[-1][DONE]: + logging.warning('Computing Q return on episode {} with no terminal state.'.format(episode_index)) + + for i in index: + current_discount_factor = 1 + last_frame_index = i + target_q = episode[i][REWARD] + for lfi in range(i, min(len(episode), i + self.n + 1)): + if episode[lfi][DONE]: + break + target_q += current_discount_factor * episode[lfi][REWARD] + current_discount_factor *= discount_factor + last_frame_index = lfi + if last_frame_index > i: + target_q += current_discount_factor * \ + max(sess.run(qvalue, feed_dict={self.action_value.managed_placeholders['observation']: + episode[last_frame_index][STATE]})) + episode_q.append(target_q) + + returns.append(episode_q) + else: + returns.append([]) + return {'TD-lambda': returns}