advantage estimation function all take my_feed_dict (all examples runnable); such requirement should be made a signature

This commit is contained in:
haoshengzou 2018-11-22 08:03:03 +08:00
parent c937630bd3
commit 909dc786d1
2 changed files with 4 additions and 3 deletions

1
.gitignore vendored
View File

@ -11,3 +11,4 @@ data
!tianshou/data
.log
go-*
*.egg-info/

View File

@ -15,7 +15,7 @@ DONE = 3
# TODO: add discount_factor... maybe make it to be a global config?
def full_return(buffer, indexes=None):
def full_return(buffer, indexes=None, my_feed_dict={}):
"""
Naively compute full undiscounted return on episodic data, :math:`G_t = \sum_{t=0}^T r_t`.
This function will print a warning when some of the episodes
@ -80,7 +80,7 @@ class nstep_return:
self.return_advantage = return_advantage
self.discount_factor = discount_factor
def __call__(self, buffer, indexes=None):
def __call__(self, buffer, indexes=None, my_feed_dict={}):
"""
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the specified return should be computed.
@ -153,7 +153,7 @@ class ddpg_return:
self.use_target_network = use_target_network
self.discount_factor = discount_factor
def __call__(self, buffer, indexes=None):
def __call__(self, buffer, indexes=None, my_feed_dict={}):
"""
:param buffer: A :class:`tianshou.data.data_buffer`.
:param indexes: Optional. Indexes of data points on which the specified return should be computed.