vanilla policy gradient
This commit is contained in:
parent
a00b930c2c
commit
0c4a83f3eb
@ -20,5 +20,23 @@ def ppo_clip(sampled_action, advantage, clip_param, pi, pi_old):
|
|||||||
return ppo_clip_loss
|
return ppo_clip_loss
|
||||||
|
|
||||||
|
|
||||||
def vanilla_policy_gradient():
|
def vanilla_policy_gradient(sampled_action, reward, pi, baseline="None"):
|
||||||
|
"""
|
||||||
|
vanilla policy gradient
|
||||||
|
|
||||||
|
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
||||||
|
:param reward: placeholder of reward the 'sampled_action' get
|
||||||
|
:param pi: current 'policy' to be optimized
|
||||||
|
:param baseline: the baseline method used to reduce the variance, default is 'None'
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
log_pi_act = pi.log_prob(sampled_action)
|
||||||
|
vanilla_policy_gradient_loss = tf.reduce_mean(reward * log_pi_act)
|
||||||
|
# TODO: Different baseline methods like REINFORCE, etc.
|
||||||
|
return vanilla_policy_gradient_loss
|
||||||
|
|
||||||
|
def temporal_difference_loss():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def deterministic_policy_gradient():
|
||||||
pass
|
pass
|
@ -148,20 +148,6 @@ class StochasticPolicy(object):
|
|||||||
"""
|
"""
|
||||||
return self._act(observation)
|
return self._act(observation)
|
||||||
|
|
||||||
if n_samples is None:
|
|
||||||
samples = self._sample(n_samples=1)
|
|
||||||
return tf.squeeze(samples, axis=0)
|
|
||||||
elif isinstance(n_samples, int):
|
|
||||||
return self._sample(n_samples)
|
|
||||||
else:
|
|
||||||
n_samples = tf.convert_to_tensor(n_samples, dtype=tf.int32)
|
|
||||||
_assert_rank_op = tf.assert_rank(
|
|
||||||
n_samples, 0,
|
|
||||||
message="n_samples should be a scalar (0-D Tensor).")
|
|
||||||
with tf.control_dependencies([_assert_rank_op]):
|
|
||||||
samples = self._sample(n_samples)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
def _act(self, observation):
|
def _act(self, observation):
|
||||||
"""
|
"""
|
||||||
Private method for subclasses to rewrite the :meth:`sample` method.
|
Private method for subclasses to rewrite the :meth:`sample` method.
|
||||||
|
@ -76,7 +76,6 @@ class OnehotCategorical(StochasticPolicy):
|
|||||||
return sampled_action
|
return sampled_action
|
||||||
|
|
||||||
def _log_prob(self, sampled_action):
|
def _log_prob(self, sampled_action):
|
||||||
sampled_action_onehot = tf.one_hot(sampled_action, self.n_categories, dtype=self.act_dtype)
|
|
||||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits)
|
return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=sampled_action, logits=self.logits)
|
||||||
|
|
||||||
# given = tf.cast(given, self.param_dtype)
|
# given = tf.cast(given, self.param_dtype)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user