add some my_feed_dict in advantage_estimation and data_collector
This commit is contained in:
parent
a791916fc4
commit
c937630bd3
@ -216,7 +216,7 @@ class nstep_q_return:
|
|||||||
self.use_target_network = use_target_network
|
self.use_target_network = use_target_network
|
||||||
self.discount_factor = discount_factor
|
self.discount_factor = discount_factor
|
||||||
|
|
||||||
def __call__(self, buffer, indexes=None):
|
def __call__(self, buffer, indexes=None, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||||
:param indexes: Optional. Indexes of data points on which the full return should be computed.
|
:param indexes: Optional. Indexes of data points on which the full return should be computed.
|
||||||
@ -249,9 +249,9 @@ class nstep_q_return:
|
|||||||
state = episode[last_frame_index + 1][STATE]
|
state = episode[last_frame_index + 1][STATE]
|
||||||
if self.use_target_network:
|
if self.use_target_network:
|
||||||
# [None] adds one dimension to the beginning
|
# [None] adds one dimension to the beginning
|
||||||
qpredict = self.action_value.eval_value_all_actions_old(state[None])
|
qpredict = self.action_value.eval_value_all_actions_old(state[None], my_feed_dict=my_feed_dict)
|
||||||
else:
|
else:
|
||||||
qpredict = self.action_value.eval_value_all_actions(state[None])
|
qpredict = self.action_value.eval_value_all_actions(state[None], my_feed_dict=my_feed_dict)
|
||||||
target_q += current_discount_factor * max(qpredict[0])
|
target_q += current_discount_factor * max(qpredict[0])
|
||||||
episode_q.append(target_q)
|
episode_q.append(target_q)
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class DataCollector(object):
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
def next_batch(self, batch_size, standardize_advantage=True):
|
def next_batch(self, batch_size, standardize_advantage=True, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
Constructs and returns the feed_dict of data to be used with ``sess.run``.
|
Constructs and returns the feed_dict of data to be used with ``sess.run``.
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class DataCollector(object):
|
|||||||
sampled_index = self.data_buffer.sample(batch_size)
|
sampled_index = self.data_buffer.sample(batch_size)
|
||||||
if self.process_mode == 'sample':
|
if self.process_mode == 'sample':
|
||||||
for processor in self.process_functions:
|
for processor in self.process_functions:
|
||||||
self.data_batch.update(processor(self.data_buffer, indexes=sampled_index))
|
self.data_batch.update(processor(self.data_buffer, indexes=sampled_index, my_feed_dict=my_feed_dict))
|
||||||
|
|
||||||
# flatten rank-2 list to numpy array, construct feed_dict
|
# flatten rank-2 list to numpy array, construct feed_dict
|
||||||
feed_dict = {}
|
feed_dict = {}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user