From 909dc786d1d6f1db96441f71a0fcd0283f705197 Mon Sep 17 00:00:00 2001 From: haoshengzou Date: Thu, 22 Nov 2018 08:03:03 +0800 Subject: [PATCH] advantage estimation function all take my_feed_dict (all examples runnable); such requirement should be made a signature --- .gitignore | 1 + tianshou/data/advantage_estimation.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index d0168e0..4d6390b 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ data !tianshou/data .log go-* +*.egg-info/ \ No newline at end of file diff --git a/tianshou/data/advantage_estimation.py b/tianshou/data/advantage_estimation.py index 8026c7e..02dcd2e 100644 --- a/tianshou/data/advantage_estimation.py +++ b/tianshou/data/advantage_estimation.py @@ -15,7 +15,7 @@ DONE = 3 # TODO: add discount_factor... maybe make it to be a global config? -def full_return(buffer, indexes=None): +def full_return(buffer, indexes=None, my_feed_dict={}): """ Naively compute full undiscounted return on episodic data, :math:`G_t = \sum_{t=0}^T r_t`. This function will print a warning when some of the episodes @@ -80,7 +80,7 @@ class nstep_return: self.return_advantage = return_advantage self.discount_factor = discount_factor - def __call__(self, buffer, indexes=None): + def __call__(self, buffer, indexes=None, my_feed_dict={}): """ :param buffer: A :class:`tianshou.data.data_buffer`. :param indexes: Optional. Indexes of data points on which the specified return should be computed. @@ -153,7 +153,7 @@ class ddpg_return: self.use_target_network = use_target_network self.discount_factor = discount_factor - def __call__(self, buffer, indexes=None): + def __call__(self, buffer, indexes=None, my_feed_dict={}): """ :param buffer: A :class:`tianshou.data.data_buffer`. :param indexes: Optional. Indexes of data points on which the specified return should be computed.