advantage estimation function all take my_feed_dict (all examples runnable); such requirement should be made a signature
This commit is contained in:
parent
c937630bd3
commit
909dc786d1
1
.gitignore
vendored
1
.gitignore
vendored
@ -11,3 +11,4 @@ data
|
||||
!tianshou/data
|
||||
.log
|
||||
go-*
|
||||
*.egg-info/
|
@ -15,7 +15,7 @@ DONE = 3
|
||||
|
||||
|
||||
# TODO: add discount_factor... maybe make it to be a global config?
|
||||
def full_return(buffer, indexes=None):
|
||||
def full_return(buffer, indexes=None, my_feed_dict={}):
|
||||
"""
|
||||
Naively compute full undiscounted return on episodic data, :math:`G_t = \sum_{t=0}^T r_t`.
|
||||
This function will print a warning when some of the episodes
|
||||
@ -80,7 +80,7 @@ class nstep_return:
|
||||
self.return_advantage = return_advantage
|
||||
self.discount_factor = discount_factor
|
||||
|
||||
def __call__(self, buffer, indexes=None):
|
||||
def __call__(self, buffer, indexes=None, my_feed_dict={}):
|
||||
"""
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the specified return should be computed.
|
||||
@ -153,7 +153,7 @@ class ddpg_return:
|
||||
self.use_target_network = use_target_network
|
||||
self.discount_factor = discount_factor
|
||||
|
||||
def __call__(self, buffer, indexes=None):
|
||||
def __call__(self, buffer, indexes=None, my_feed_dict={}):
|
||||
"""
|
||||
:param buffer: A :class:`tianshou.data.data_buffer`.
|
||||
:param indexes: Optional. Indexes of data points on which the specified return should be computed.
|
||||
|
Loading…
x
Reference in New Issue
Block a user