diff --git a/.gitignore b/.gitignore index d0168e0..4d6390b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ data !tianshou/data .log go-* +*.egg-info/ \ No newline at end of file diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 8026c7e..02dcd2e 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -15,7 +15,7 @@ DONE = 3 # TODO: add discount_factor... maybe make it to be a global config? -def full_return(buffer, indexes=None): +def full_return(buffer, indexes=None, my_feed_dict={}): """ Naively compute full undiscounted return on episodic data, :math:`G_t = \sum_{t=0}^T r_t`. This function will print a warning when some of the episodes @@ -80,7 +80,7 @@ class nstep_return: self.return_advantage = return_advantage self.discount_factor = discount_factor - def __call__(self, buffer, indexes=None): + def __call__(self, buffer, indexes=None, my_feed_dict={}): """ :param buffer: A :class:`tianshou.data.data_buffer`. :param indexes: Optional. Indexes of data points on which the specified return should be computed. @@ -153,7 +153,7 @@ class ddpg_return: self.use_target_network = use_target_network self.discount_factor = discount_factor - def __call__(self, buffer, indexes=None): + def __call__(self, buffer, indexes=None, my_feed_dict={}): """ :param buffer: A :class:`tianshou.data.data_buffer`. :param indexes: Optional. Indexes of data points on which the specified return should be computed.