diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 383429e..83de823 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -266,3 +266,8 @@ postfix backend rliable hl +v_s +v_s_ +obs +obs_next + diff --git a/tianshou/policy/base.py b/tianshou/policy/base.py index 77602a0..1462ff4 100644 --- a/tianshou/policy/base.py +++ b/tianshou/policy/base.py @@ -301,9 +301,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): :return: A :class:`~tianshou.data.Batch` which MUST have the following keys: - * ``act`` an numpy.ndarray or a torch.Tensor, the action over \ + * ``act`` a numpy.ndarray or a torch.Tensor, the action over \ given batch data. - * ``state`` a dict, an numpy.ndarray or a torch.Tensor, the \ + * ``state`` a dict, a numpy.ndarray or a torch.Tensor, the \ internal state of the policy, ``None`` as default. Other keys are user-defined. It depends on the algorithm. For example, @@ -556,6 +556,9 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): advantage + value, which is exactly equivalent to using :math:`TD(\lambda)` for estimating returns. + Setting v_s_ and v_s to None (or all zeros) and gae_lambda to 1.0 calculates the + discounted return-to-go/ Monte-Carlo return. + :param batch: a data batch which contains several episodes of data in sequential order. Mind that the end of each finished episode of batch should be marked by done flag, unfinished (or collecting) episodes will be @@ -565,10 +568,11 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): to buffer[indices]. :param np.ndarray v_s_: the value function of all next states :math:`V(s')`. If None, it will be set to an array of 0. - :param v_s: the value function of all current states :math:`V(s)`. - :param gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param v_s: the value function of all current states :math:`V(s)`. If None, + it is set based upon v_s_ rolled by 1. + :param gamma: the discount factor, should be in [0, 1]. :param gae_lambda: the parameter for Generalized Advantage Estimation, - should be in [0, 1]. Default to 0.95. + should be in [0, 1]. :return: two numpy arrays (returns, advantage) with each shape (bsz, ). """ @@ -612,10 +616,10 @@ class BasePolicy(nn.Module, Generic[TTrainingStats], ABC): :param indices: tell batch's location in buffer :param function target_q_fn: a function which compute target Q value of "obs_next" given data buffer and wanted indices. - :param gamma: the discount factor, should be in [0, 1]. Default to 0.99. + :param gamma: the discount factor, should be in [0, 1]. :param n_step: the number of estimation step, should be an int greater - than 0. Default to 1. - :param rew_norm: normalize the reward to Normal(0, 1), Default to False. + than 0. + :param rew_norm: normalize the reward to Normal(0, 1). TODO: passing True is not supported and will cause an error! :return: a Batch. The result will be stored in batch.returns as a torch.Tensor with the same shape as target_q_fn's return tensor.