Tianshou/tianshou/core/losses.py
2018-04-15 09:35:31 +08:00

63 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tensorflow as tf
def ppo_clip(policy, clip_param):
"""
Builds the graph of clipped loss :math:`L^{CLIP}` as in the
`Link PPO paper <https://arxiv.org/pdf/1707.06347.pdf>`_, which is basically
:math:`-\min(r_t(\\theta)Aˆt, clip(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)Aˆt)`.
We minimize the objective instead of maximizing, hence the leading negative sign.
:param policy: A :class:`tianshou.core.policy` to be optimized.
:param clip param: A float or Tensor of type float. The :math:`\epsilon` in the loss equation.
:return: A scalar float Tensor of the loss.
"""
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape, name='ppo_clip_loss/action_placeholder')
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
policy.managed_placeholders['action'] = action_ph
policy.managed_placeholders['advantage'] = advantage_ph
log_pi_act = policy.action_dist.log_prob(action_ph)
log_pi_old_act = policy.action_dist_old.log_prob(action_ph)
ratio = tf.exp(log_pi_act - log_pi_old_act)
clipped_ratio = tf.clip_by_value(ratio, 1. - clip_param, 1. + clip_param)
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage_ph, clipped_ratio * advantage_ph))
return ppo_clip_loss
def REINFORCE(policy):
"""
Builds the graph of the loss function as used in vanilla policy gradient algorithms, i.e., REINFORCE.
The loss is basically :math:`\log \pi(a|s) A^t`.
We minimize the objective instead of maximizing, hence the leading negative sign.
:param policy: A :class:`tianshou.core.policy` to be optimized.
:return: A scalar float Tensor of the loss.
"""
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape,
name='REINFORCE/action_placeholder')
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='REINFORCE/advantage_placeholder')
policy.managed_placeholders['action'] = action_ph
policy.managed_placeholders['advantage'] = advantage_ph
log_pi_act = policy.action_dist.log_prob(action_ph)
REINFORCE_loss = -tf.reduce_mean(advantage_ph * log_pi_act)
return REINFORCE_loss
def value_mse(value_function):
"""
Builds the graph of L2 loss on value functions for, e.g., training critics or DQN.
:param value_function: A :class:`tianshou.core.value_function` to be optimized.
:return: A scalar float Tensor of the loss.
"""
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='value_mse/return_placeholder')
value_function.managed_placeholders['return'] = target_value_ph
state_value = value_function.value_tensor
return tf.losses.mean_squared_error(target_value_ph, state_value)