Improve the documentation of compute_episodic_return in base policy. (#1130)

This commit is contained in:
bordeauxred 2024-04-30 14:40:16 +02:00 committed by GitHub
parent a65920fc68
commit 61426acf07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 8 deletions

View File

@ -266,3 +266,8 @@ postfix
backend
rliable
hl
v_s
v_s_
obs
obs_next

View File

@ -301,9 +301,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \
* ``act`` a numpy.ndarray or a torch.Tensor, the action over \
given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \
* ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example,
@ -556,6 +556,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
for estimating returns.
Setting v_s_ and v_s to None (or all zeros) and gae_lambda to 1.0 calculates the
discounted return-to-go/ Monte-Carlo return.
:param batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be
@ -565,10 +568,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
to buffer[indices].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
If None, it will be set to an array of 0.
:param v_s: the value function of all current states :math:`V(s)`.
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param v_s: the value function of all current states :math:`V(s)`. If None,
it is set based upon v_s_ rolled by 1.
:param gamma: the discount factor, should be in [0, 1].
:param gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95.
should be in [0, 1].
:return: two numpy arrays (returns, advantage) with each shape (bsz, ).
"""
@ -612,10 +616,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
:param indices: tell batch's location in buffer
:param function target_q_fn: a function which compute target Q value
of "obs_next" given data buffer and wanted indices.
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99.
:param gamma: the discount factor, should be in [0, 1].
:param n_step: the number of estimation step, should be an int greater
than 0. Default to 1.
:param rew_norm: normalize the reward to Normal(0, 1), Default to False.
than 0.
:param rew_norm: normalize the reward to Normal(0, 1).
TODO: passing True is not supported and will cause an error!
:return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with the same shape as target_q_fn's return tensor.