2017-12-08 21:09:23 +08:00
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
|
2018-01-14 20:58:28 +08:00
|
|
|
def ppo_clip(policy, clip_param):
|
2017-12-10 17:23:13 +08:00
|
|
|
"""
|
|
|
|
the clip loss in ppo paper
|
|
|
|
|
|
|
|
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
|
|
|
:param advantage: placeholder of estimated advantage values.
|
|
|
|
:param clip param: float or Tensor of type float.
|
2018-01-14 20:58:28 +08:00
|
|
|
:param policy: current `policy` to be optimized
|
2017-12-10 17:23:13 +08:00
|
|
|
:param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper
|
|
|
|
"""
|
2018-01-15 00:03:06 +08:00
|
|
|
action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape, name='ppo_clip_loss/action_placeholder')
|
2018-01-14 20:58:28 +08:00
|
|
|
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
|
|
|
|
policy.managed_placeholders['action'] = action_ph
|
2018-01-17 11:55:51 +08:00
|
|
|
policy.managed_placeholders['advantage'] = advantage_ph
|
2017-12-10 17:23:13 +08:00
|
|
|
|
2018-01-14 20:58:28 +08:00
|
|
|
log_pi_act = policy.log_prob(action_ph)
|
|
|
|
log_pi_old_act = policy.log_prob_old(action_ph)
|
2017-12-08 21:09:23 +08:00
|
|
|
ratio = tf.exp(log_pi_act - log_pi_old_act)
|
|
|
|
clipped_ratio = tf.clip_by_value(ratio, 1. - clip_param, 1. + clip_param)
|
2018-01-14 20:58:28 +08:00
|
|
|
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage_ph, clipped_ratio * advantage_ph))
|
2017-12-08 21:09:23 +08:00
|
|
|
return ppo_clip_loss
|
|
|
|
|
|
|
|
|
2018-01-17 11:55:51 +08:00
|
|
|
def REINFORCE(policy):
|
2017-12-11 13:37:27 +08:00
|
|
|
"""
|
|
|
|
vanilla policy gradient
|
|
|
|
|
|
|
|
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
|
|
|
:param reward: placeholder of reward the 'sampled_action' get
|
2017-12-15 14:24:08 +08:00
|
|
|
:param pi: current `policy` to be optimized
|
2017-12-11 13:37:27 +08:00
|
|
|
:param baseline: the baseline method used to reduce the variance, default is 'None'
|
|
|
|
:return:
|
|
|
|
"""
|
2018-01-17 11:55:51 +08:00
|
|
|
action_ph = tf.placeholder(policy.act_dtype, shape=(None,) + policy.action_shape,
|
|
|
|
name='REINFORCE/action_placeholder')
|
|
|
|
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='REINFORCE/advantage_placeholder')
|
|
|
|
policy.managed_placeholders['action'] = action_ph
|
|
|
|
policy.managed_placeholders['advantage'] = advantage_ph
|
|
|
|
|
|
|
|
log_pi_act = policy.log_prob(action_ph)
|
|
|
|
REINFORCE_loss = -tf.reduce_mean(advantage_ph * log_pi_act)
|
|
|
|
return REINFORCE_loss
|
|
|
|
|
|
|
|
|
|
|
|
def state_value_mse(state_value_function):
|
|
|
|
"""
|
|
|
|
L2 loss of state value
|
|
|
|
:param state_value_function: instance of StateValue
|
|
|
|
:return: tensor of the mse loss
|
|
|
|
"""
|
2018-01-18 12:19:48 +08:00
|
|
|
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='state_value_mse/state_value_placeholder')
|
|
|
|
state_value_function.managed_placeholders['return'] = target_value_ph
|
2018-01-17 11:55:51 +08:00
|
|
|
|
|
|
|
state_value = state_value_function.value_tensor
|
2018-01-18 12:19:48 +08:00
|
|
|
return tf.losses.mean_squared_error(target_value_ph, state_value)
|
2018-01-17 11:55:51 +08:00
|
|
|
|
2017-12-11 13:37:27 +08:00
|
|
|
|
2018-01-18 12:19:48 +08:00
|
|
|
def qlearning(action_value_function):
|
2017-12-15 14:24:08 +08:00
|
|
|
"""
|
|
|
|
deep q-network
|
2018-01-18 12:19:48 +08:00
|
|
|
:param action_value_function: current `action_value` to be optimized
|
2017-12-15 14:24:08 +08:00
|
|
|
:return:
|
|
|
|
"""
|
2018-01-18 12:19:48 +08:00
|
|
|
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='qlearning/action_value_placeholder')
|
|
|
|
action_value_function.managed_placeholders['return'] = target_value_ph
|
|
|
|
|
|
|
|
q_value = action_value_function.value_tensor
|
|
|
|
return tf.losses.mean_squared_error(target_value_ph, q_value)
|
|
|
|
|
2017-12-11 13:37:27 +08:00
|
|
|
|
2017-12-15 14:24:08 +08:00
|
|
|
def deterministic_policy_gradient(sampled_state, critic):
|
|
|
|
"""
|
|
|
|
deterministic policy gradient:
|
|
|
|
|
|
|
|
:param sampled_action: placeholder of sampled actions during the interaction with the environment
|
|
|
|
:param critic: current `value` function
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
return tf.reduce_mean(critic.get_value(sampled_state))
|