Improve the documentation of compute_episodic_return in base policy. (#1130)
This commit is contained in:
parent
a65920fc68
commit
61426acf07
@ -266,3 +266,8 @@ postfix
|
|||||||
backend
|
backend
|
||||||
rliable
|
rliable
|
||||||
hl
|
hl
|
||||||
|
v_s
|
||||||
|
v_s_
|
||||||
|
obs
|
||||||
|
obs_next
|
||||||
|
|
||||||
|
@ -301,9 +301,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
|
|
||||||
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
|
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
|
||||||
|
|
||||||
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
|
* ``act`` a numpy.ndarray or a torch.Tensor, the action over \
|
||||||
given batch data.
|
given batch data.
|
||||||
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
|
* ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \
|
||||||
internal state of the policy, ``None`` as default.
|
internal state of the policy, ``None`` as default.
|
||||||
|
|
||||||
Other keys are user-defined. It depends on the algorithm. For example,
|
Other keys are user-defined. It depends on the algorithm. For example,
|
||||||
@ -556,6 +556,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
|
advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
|
||||||
for estimating returns.
|
for estimating returns.
|
||||||
|
|
||||||
|
Setting v_s_ and v_s to None (or all zeros) and gae_lambda to 1.0 calculates the
|
||||||
|
discounted return-to-go/ Monte-Carlo return.
|
||||||
|
|
||||||
:param batch: a data batch which contains several episodes of data in
|
:param batch: a data batch which contains several episodes of data in
|
||||||
sequential order. Mind that the end of each finished episode of batch
|
sequential order. Mind that the end of each finished episode of batch
|
||||||
should be marked by done flag, unfinished (or collecting) episodes will be
|
should be marked by done flag, unfinished (or collecting) episodes will be
|
||||||
@ -565,10 +568,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
to buffer[indices].
|
to buffer[indices].
|
||||||
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
|
||||||
If None, it will be set to an array of 0.
|
If None, it will be set to an array of 0.
|
||||||
:param v_s: the value function of all current states :math:`V(s)`.
|
:param v_s: the value function of all current states :math:`V(s)`. If None,
|
||||||
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
it is set based upon v_s_ rolled by 1.
|
||||||
|
:param gamma: the discount factor, should be in [0, 1].
|
||||||
:param gae_lambda: the parameter for Generalized Advantage Estimation,
|
:param gae_lambda: the parameter for Generalized Advantage Estimation,
|
||||||
should be in [0, 1]. Default to 0.95.
|
should be in [0, 1].
|
||||||
|
|
||||||
:return: two numpy arrays (returns, advantage) with each shape (bsz, ).
|
:return: two numpy arrays (returns, advantage) with each shape (bsz, ).
|
||||||
"""
|
"""
|
||||||
@ -612,10 +616,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
|
|||||||
:param indices: tell batch's location in buffer
|
:param indices: tell batch's location in buffer
|
||||||
:param function target_q_fn: a function which compute target Q value
|
:param function target_q_fn: a function which compute target Q value
|
||||||
of "obs_next" given data buffer and wanted indices.
|
of "obs_next" given data buffer and wanted indices.
|
||||||
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
|
:param gamma: the discount factor, should be in [0, 1].
|
||||||
:param n_step: the number of estimation step, should be an int greater
|
:param n_step: the number of estimation step, should be an int greater
|
||||||
than 0. Default to 1.
|
than 0.
|
||||||
:param rew_norm: normalize the reward to Normal(0, 1), Default to False.
|
:param rew_norm: normalize the reward to Normal(0, 1).
|
||||||
TODO: passing True is not supported and will cause an error!
|
TODO: passing True is not supported and will cause an error!
|
||||||
:return: a Batch. The result will be stored in batch.returns as a
|
:return: a Batch. The result will be stored in batch.returns as a
|
||||||
torch.Tensor with the same shape as target_q_fn's return tensor.
|
torch.Tensor with the same shape as target_q_fn's return tensor.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user