diff --git a/tianshou/core/losses.py b/tianshou/core/losses.py index 5e127c2..f7d798b 100644 --- a/tianshou/core/losses.py +++ b/tianshou/core/losses.py @@ -20,5 +20,23 @@ def ppo_clip(sampled_action, advantage, clip_param, pi, pi_old): return ppo_clip_loss -def vanilla_policy_gradient(): +def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"): + """ + vanilla policy gradient + + :param sampled_action: placeholder of sampled actions during interaction with the environment + :param reward: placeholder of reward the 'sampled_action' get + :param pi: current 'policy' to be optimized + :param baseline: the baseline method used to reduce the variance, default is 'None' + :return: + """ + log_pi_act = pi.log_prob(sampled_action) + vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act) + # TODO: Different baseline methods like REINFORCE, etc. + return vanilla_policy_gradient_loss + +def temporal_difference_loss(): pass + +def deterministic_policy_gradient(): + pass \ No newline at end of file diff --git a/tianshou/core/policy/base.py b/tianshou/core/policy/base.py index a61661c..b0bf28a 100644 --- a/tianshou/core/policy/base.py +++ b/tianshou/core/policy/base.py @@ -148,20 +148,6 @@ class StochasticPolicy(object): """ return self._act(observation) - if n_samples is None: - samples = self._sample(n_samples=1) - return tf.squeeze(samples, axis=0) - elif isinstance(n_samples, int): - return self._sample(n_samples) - else: - n_samples = tf.convert_to_tensor(n_samples, dtype=tf.int32) - _assert_rank_op = tf.assert_rank( - n_samples, 0, - message="n_samples should be a scalar (0-D Tensor).") - with tf.control_dependencies([_assert_rank_op]): - samples = self._sample(n_samples) - return samples - def _act(self, observation): """ Private method for subclasses to rewrite the :meth:`sample` method. diff --git a/tianshou/core/policy/stochastic.py b/tianshou/core/policy/stochastic.py index 822600a..37eb1be 100644 --- a/tianshou/core/policy/stochastic.py +++ b/tianshou/core/policy/stochastic.py @@ -76,7 +76,6 @@ class OnehotCategorical(StochasticPolicy): return sampled_action def _log_prob(self, sampled_action): - sampled_action_onehot = tf.one_hot(sampled_action, self.n_categories, dtype=self.act_dtype) return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits) # given = tf.cast(given, self.param_dtype)