diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 3b6d930..8026c7e 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -216,7 +216,7 @@ class nstep_q_return: self.use_target_network = use_target_network self.discount_factor = discount_factor - def __call__(self, buffer, indexes=None): + def __call__(self, buffer, indexes=None, my_feed_dict={}): """ :param buffer: A :class:`tianshou.data.data_buffer`. :param indexes: Optional. Indexes of data points on which the full return should be computed. @@ -249,9 +249,9 @@ class nstep_q_return: state = episode[last_frame_index + 1][STATE] if self.use_target_network: # [None] adds one dimension to the beginning - qpredict = self.action_value.eval_value_all_actions_old(state[None]) + qpredict = self.action_value.eval_value_all_actions_old(state[None], my_feed_dict=my_feed_dict) else: - qpredict = self.action_value.eval_value_all_actions(state[None]) + qpredict = self.action_value.eval_value_all_actions(state[None], my_feed_dict=my_feed_dict) target_q += current_discount_factor * max(qpredict[0]) episode_q.append(target_q) diff --git a/tianshou/data/data_collector.py b/tianshou/data/data_collector.py index 67d346c..ea5db51 100644 --- a/tianshou/data/data_collector.py +++ b/tianshou/data/data_collector.py @@ -118,7 +118,7 @@ class DataCollector(object): return - def next_batch(self, batch_size, standardize_advantage=True): + def next_batch(self, batch_size, standardize_advantage=True, my_feed_dict={}): """ Constructs and returns the feed_dict of data to be used with ``sess.run``. @@ -133,7 +133,7 @@ class DataCollector(object): sampled_index = self.data_buffer.sample(batch_size) if self.process_mode == 'sample': for processor in self.process_functions: - self.data_batch.update(processor(self.data_buffer, indexes=sampled_index)) + self.data_batch.update(processor(self.data_buffer, indexes=sampled_index, my_feed_dict=my_feed_dict)) # flatten rank-2 list to numpy array, construct feed_dict feed_dict = {}