Improve the documentation of compute_episodic_return in base policy. (#1130)

This commit is contained in:
bordeauxred 2024-04-30 14:40:16 +02:00 committed by GitHub
parent a65920fc68
commit 61426acf07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 8 deletions

View File

@ -266,3 +266,8 @@ postfix
backend backend
rliable rliable
hl hl
v_s
v_s_
obs
obs_next

View File

@ -301,9 +301,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
:return: A :class:`~tianshou.data.Batch` which MUST have the following keys: :return: A :class:`~tianshou.data.Batch` which MUST have the following keys:
* ``act`` an numpy.ndarray or a torch.Tensor, the action over \ * ``act`` a numpy.ndarray or a torch.Tensor, the action over \
given batch data. given batch data.
* ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \ * ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \
internal state of the policy, ``None`` as default. internal state of the policy, ``None`` as default.
Other keys are user-defined. It depends on the algorithm. For example, Other keys are user-defined. It depends on the algorithm. For example,
@ -556,6 +556,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
advantage + value, which is exactly equivalent to using :math:`TD(\lambda)` advantage + value, which is exactly equivalent to using :math:`TD(\lambda)`
for estimating returns. for estimating returns.
Setting v_s_ and v_s to None (or all zeros) and gae_lambda to 1.0 calculates the
discounted return-to-go/ Monte-Carlo return.
:param batch: a data batch which contains several episodes of data in :param batch: a data batch which contains several episodes of data in
sequential order. Mind that the end of each finished episode of batch sequential order. Mind that the end of each finished episode of batch
should be marked by done flag, unfinished (or collecting) episodes will be should be marked by done flag, unfinished (or collecting) episodes will be
@ -565,10 +568,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
to buffer[indices]. to buffer[indices].
:param np.ndarray v_s_: the value function of all next states :math:`V(s')`. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`.
If None, it will be set to an array of 0. If None, it will be set to an array of 0.
:param v_s: the value function of all current states :math:`V(s)`. :param v_s: the value function of all current states :math:`V(s)`. If None,
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99. it is set based upon v_s_ rolled by 1.
:param gamma: the discount factor, should be in [0, 1].
:param gae_lambda: the parameter for Generalized Advantage Estimation, :param gae_lambda: the parameter for Generalized Advantage Estimation,
should be in [0, 1]. Default to 0.95. should be in [0, 1].
:return: two numpy arrays (returns, advantage) with each shape (bsz, ). :return: two numpy arrays (returns, advantage) with each shape (bsz, ).
""" """
@ -612,10 +616,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC):
:param indices: tell batch's location in buffer :param indices: tell batch's location in buffer
:param function target_q_fn: a function which compute target Q value :param function target_q_fn: a function which compute target Q value
of "obs_next" given data buffer and wanted indices. of "obs_next" given data buffer and wanted indices.
:param gamma: the discount factor, should be in [0, 1]. Default to 0.99. :param gamma: the discount factor, should be in [0, 1].
:param n_step: the number of estimation step, should be an int greater :param n_step: the number of estimation step, should be an int greater
than 0. Default to 1. than 0.
:param rew_norm: normalize the reward to Normal(0, 1), Default to False. :param rew_norm: normalize the reward to Normal(0, 1).
TODO: passing True is not supported and will cause an error! TODO: passing True is not supported and will cause an error!
:return: a Batch. The result will be stored in batch.returns as a :return: a Batch. The result will be stored in batch.returns as a
torch.Tensor with the same shape as target_q_fn's return tensor. torch.Tensor with the same shape as target_q_fn's return tensor.