add some my_feed_dict in advantage_estimation and data_collector
This commit is contained in:
parent
a791916fc4
commit
c937630bd3
@ -216,7 +216,7 @@ class nstep_q_return:
|
||||
self.use_target_network = use_target_network
|
||||
self.discount_factor = discount_factor
|
||||
|
||||
def __call__(self, buffer, indexes=None):
|
||||
def __call__(self, buffer, indexes=None, my_feed_dict={}):
|
||||
"""
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the full return should be computed.
|
||||
@ -249,9 +249,9 @@ class nstep_q_return:
|
||||
state = episode[last_frame_index + 1][STATE]
|
||||
if self.use_target_network:
|
||||
# [None] adds one dimension to the beginning
|
||||
qpredict = self.action_value.eval_value_all_actions_old(state[None])
|
||||
qpredict = self.action_value.eval_value_all_actions_old(state[None], my_feed_dict=my_feed_dict)
|
||||
else:
|
||||
qpredict = self.action_value.eval_value_all_actions(state[None])
|
||||
qpredict = self.action_value.eval_value_all_actions(state[None], my_feed_dict=my_feed_dict)
|
||||
target_q += current_discount_factor * max(qpredict[0])
|
||||
episode_q.append(target_q)
|
||||
|
||||
|
@ -118,7 +118,7 @@ class DataCollector(object):
|
||||
|
||||
return
|
||||
|
||||
def next_batch(self, batch_size, standardize_advantage=True):
|
||||
def next_batch(self, batch_size, standardize_advantage=True, my_feed_dict={}):
|
||||
"""
|
||||
Constructs and returns the feed_dict of data to be used with ``sess.run``.
|
||||
|
||||
@ -133,7 +133,7 @@ class DataCollector(object):
|
||||
sampled_index = self.data_buffer.sample(batch_size)
|
||||
if self.process_mode == 'sample':
|
||||
for processor in self.process_functions:
|
||||
self.data_batch.update(processor(self.data_buffer, indexes=sampled_index))
|
||||
self.data_batch.update(processor(self.data_buffer, indexes=sampled_index, my_feed_dict=my_feed_dict))
|
||||
|
||||
# flatten rank-2 list to numpy array, construct feed_dict
|
||||
feed_dict = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user