add some my_feed_dict in advantage_estimation and data_collector

This commit is contained in:
haoshengzou 2018-08-16 16:20:14 +08:00
parent a791916fc4
commit c937630bd3
2 changed files with 5 additions and 5 deletions

View File

@ -216,7 +216,7 @@ class nstep_q_return:
self.use_target_network = use_target_network
self.discount_factor = discount_factor
def __call__(self, buffer, indexes=None):
def __call__(self, buffer, indexes=None, my_feed_dict={}):
"""
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the full return should be computed.
@ -249,9 +249,9 @@ class nstep_q_return:
state = episode[last_frame_index + 1][STATE]
if self.use_target_network:
# [None] adds one dimension to the beginning
qpredict = self.action_value.eval_value_all_actions_old(state[None])
qpredict = self.action_value.eval_value_all_actions_old(state[None], my_feed_dict=my_feed_dict)
else:
qpredict = self.action_value.eval_value_all_actions(state[None])
qpredict = self.action_value.eval_value_all_actions(state[None], my_feed_dict=my_feed_dict)
target_q += current_discount_factor * max(qpredict[0])
episode_q.append(target_q)

View File

@ -118,7 +118,7 @@ class DataCollector(object):
return
def next_batch(self, batch_size, standardize_advantage=True):
def next_batch(self, batch_size, standardize_advantage=True, my_feed_dict={}):
"""
Constructs and returns the feed_dict of data to be used with ``sess.run``.
@ -133,7 +133,7 @@ class DataCollector(object):
sampled_index = self.data_buffer.sample(batch_size)
if self.process_mode == 'sample':
for processor in self.process_functions:
self.data_batch.update(processor(self.data_buffer, indexes=sampled_index))
self.data_batch.update(processor(self.data_buffer, indexes=sampled_index, my_feed_dict=my_feed_dict))
# flatten rank-2 list to numpy array, construct feed_dict
feed_dict = {}