more API docs

This commit is contained in:
haoshengzou 2018-04-15 09:35:31 +08:00
parent 2a3bc3ef35
commit 9186dae6a3
5 changed files with 169 additions and 80 deletions

View File

@ -39,7 +39,7 @@ if __name__ == '__main__':
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, has_old_net=True)
pi = policy.DQN(dqn)
dqn_loss = losses.qlearning(dqn)
dqn_loss = losses.value_mse(dqn)
total_loss = dqn_loss
optimizer = tf.train.AdamOptimizer(1e-4)

View File

@ -3,13 +3,15 @@ import tensorflow as tf
def ppo_clip(policy, clip_param):
"""
the clip loss in ppo paper
Builds the graph of clipped loss :math:`L^{CLIP}` as in the
`Link PPO paper <https://arxiv.org/pdf/1707.06347.pdf>`_, which is basically
:math:`-\min(r_t(\\theta)Aˆt, clip(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)Aˆt)`.
We minimize the objective instead of maximizing, hence the leading negative sign.
:param sampled_action: placeholder of sampled actions during interaction with the environment
:param advantage: placeholder of estimated advantage values.
:param clip param: float or Tensor of type float.
:param policy: current `policy` to be optimized
:param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper
:param policy: A :class:`tianshou.core.policy` to be optimized.
:param clip param: A float or Tensor of type float. The :math:`\epsilon` in the loss equation.
:return: A scalar float Tensor of the loss.
"""
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape, name='ppo_clip_loss/action_placeholder')
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
@ -26,13 +28,13 @@ def ppo_clip(policy, clip_param):
def REINFORCE(policy):
"""
vanilla policy gradient
Builds the graph of the loss function as used in vanilla policy gradient algorithms, i.e., REINFORCE.
The loss is basically :math:`\log \pi(a|s) A^t`.
We minimize the objective instead of maximizing, hence the leading negative sign.
:param sampled_action: placeholder of sampled actions during interaction with the environment
:param reward: placeholder of reward the 'sampled_action' get
:param pi: current `policy` to be optimized
:param baseline: the baseline method used to reduce the variance, default is 'None'
:return:
:param policy: A :class:`tianshou.core.policy` to be optimized.
:return: A scalar float Tensor of the loss.
"""
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape,
name='REINFORCE/action_placeholder')
@ -45,27 +47,16 @@ def REINFORCE(policy):
return REINFORCE_loss
def value_mse(state_value_function):
def value_mse(value_function):
"""
L2 loss of state value
:param state_value_function: instance of StateValue
:return: tensor of the mse loss
Builds the graph of L2 loss on value functions for, e.g., training critics or DQN.
:param value_function: A :class:`tianshou.core.value_function` to be optimized.
:return: A scalar float Tensor of the loss.
"""
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='value_mse/return_placeholder')
state_value_function.managed_placeholders['return'] = target_value_ph
value_function.managed_placeholders['return'] = target_value_ph
state_value = state_value_function.value_tensor
state_value = value_function.value_tensor
return tf.losses.mean_squared_error(target_value_ph, state_value)
def qlearning(action_value_function):
"""
deep q-network
:param action_value_function: current `action_value` to be optimized
:return:
"""
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='qlearning/action_value_placeholder')
action_value_function.managed_placeholders['return'] = target_value_ph
q_value = action_value_function.value_tensor
return tf.losses.mean_squared_error(target_value_ph, q_value)

View File

@ -3,10 +3,12 @@ import tensorflow as tf
def DPG(policy, action_value):
"""
construct the gradient tensor of deterministic policy gradient
:param policy:
:param action_value:
:return: list of (gradient, variable) pairs
Constructs the gradient Tensor of `Link deterministic policy gradient <https://arxiv.org/pdf/1509.02971.pdf>`_.
:param policy: A :class:`tianshou.core.policy.Deterministic` to be optimized.
:param action_value: A :class:`tianshou.core.value_function.ActionValue` to guide the optimization of `policy`.
:return: A list of (gradient, variable) pairs.
"""
trainable_variables = list(policy.trainable_variables)
critic_action_input = action_value.action_placeholder

View File

@ -8,7 +8,18 @@ from ..utils import identify_dependent_variables
class ActionValue(ValueFunctionBase):
"""
class of action values Q(s, a).
Class for action values Q(s, a). The input of the value network is states and actions and the output
of the value network is directly the Q-value of the input (state, action) pairs.
:param network_callable: A Python callable returning (action head, value head). When called it builds
the tf graph and returns a Tensor of the value on the value head.
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, a)
in the network graph.
:param action_placeholder: A :class:`tf.placeholder`. The action placeholder for a in Q(s, a)
in the network graph.
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
and DDPG, or just an old net to help optimization as in PPO.
"""
def __init__(self, network_callable, observation_placeholder, action_placeholder, has_old_net=False):
self.observation_placeholder = observation_placeholder
@ -51,35 +62,45 @@ class ActionValue(ValueFunctionBase):
@property
def trainable_variables(self):
"""
The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s
that affect the value.
"""
return set(self._trainable_variables)
def eval_value(self, observation, action):
def eval_value(self, observation, action, my_feed_dict={}):
"""
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
:param action: numpy array of actions, of shape (batchsize, action_dim)
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
:return: numpy array of state values, of shape (batchsize, )
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
Evaluate value in minibatch using the current network.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param action: An array-like, of shape (batch_size,) + action_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
"""
sess = tf.get_default_session()
return sess.run(self.value_tensor, feed_dict=
{self.observation_placeholder: observation, self.action_placeholder: action})
{self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict))
def eval_value_old(self, observation, action):
def eval_value_old(self, observation, action, my_feed_dict={}):
"""
eval value using target network
:param observation: numpy array of obs
:param action: numpy array of action
:return: numpy array of action value
Evaluate value in minibatch using the old net.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param action: An array-like, of shape (batch_size,) + action_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
"""
sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
Sync the variables of the "old net" to be the same as the current network.
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()
@ -88,8 +109,18 @@ class ActionValue(ValueFunctionBase):
class DQN(ValueFunctionBase):
"""
class of the very DQN architecture. Instead of feeding s and a to the network to get a value, DQN feed s to the
network and the last layer is Q(s, *) for all actions
Class for the special action value function DQN. Instead of feeding s and a to the network to get a value,
DQN feeds s to the network and gets at the last layer Q(s, *) for all actions under this state. Still, as
:class:`ActionValue`, this class still builds the Q(s, a) value Tensor. It can only be used with discrete
(and finite) action spaces.
:param network_callable: A Python callable returning (action head, value head). When called it builds
the tf graph and returns a Tensor of Q(s, *) on the value head.
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, *)
in the network graph.
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
and DDPG, or just an old net to help optimization as in PPO.
"""
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
self.observation_placeholder = observation_placeholder
@ -149,43 +180,76 @@ class DQN(ValueFunctionBase):
@property
def trainable_variables(self):
"""
The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s
that affect the value.
"""
return set(self._trainable_variables)
def eval_value_all_actions(self, observation):
def eval_value(self, observation, action, my_feed_dict={}):
"""
:param observation:
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
Evaluate value Q(s, a) in minibatch using the current network.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param action: An array-like, of shape (batch_size,) + action_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
"""
sess = tf.get_default_session()
return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation})
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
return sess.run(self.value_tensor, feed_dict=feed_dict)
def eval_value_old(self, observation, action, my_feed_dict={}):
"""
Evaluate value Q(s, a) in minibatch using the old net.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param action: An array-like, of shape (batch_size,) + action_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
"""
sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
@property
def value_tensor_all_actions(self):
"""The Tensor for Q(s, *)"""
return self._value_tensor_all_actions
def eval_value_old(self, observation, action):
def eval_value_all_actions(self, observation, my_feed_dict={}):
"""
eval value using target network
:param observation: numpy array of obs
:param action: numpy array of action
:return: numpy array of action value
"""
sess = tf.get_default_session()
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
Evaluate values Q(s, *) in minibatch using the current network.
def eval_value_all_actions_old(self, observation):
"""
:param observation:
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
"""
sess = tf.get_default_session()
return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation})
return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
def eval_value_all_actions_old(self, observation, my_feed_dict={}):
"""
Evaluate values Q(s, *) in minibatch using the old net.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation and action.
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
"""
sess = tf.get_default_session()
return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
Sync the variables of the "old net" to be the same as the current network.
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()

View File

@ -9,7 +9,16 @@ from ..utils import identify_dependent_variables
class StateValue(ValueFunctionBase):
"""
class of state values V(s).
Class for state value functions V(s). The input of the value network is states and the output
of the value network is directly the V-value of the input state.
:param network_callable: A Python callable returning (action head, value head). When called it builds
the tf graph and returns a Tensor of the value on the value head.
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in V(s)
in the network graph.
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
and DDPG, or just an old net to help optimization as in PPO.
"""
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
self.observation_placeholder = observation_placeholder
@ -53,19 +62,42 @@ class StateValue(ValueFunctionBase):
@property
def trainable_variables(self):
"""
The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s
that affect the value.
"""
return set(self._trainable_variables)
def eval_value(self, observation):
def eval_value(self, observation, my_feed_dict={}):
"""
Evaluate value in minibatch using the current network.
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
:return: numpy array of state values, of shape (batchsize, )
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
"""
sess = tf.get_default_session()
return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation})
return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
def eval_value_old(self, observation, my_feed_dict={}):
"""
Evaluate value in minibatch using the old net.
:param observation: An array-like, of shape (batch_size,) + observation_shape.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
"""
sess = tf.get_default_session()
return sess.run(self.value_tensor_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
def sync_weights(self):
"""
Sync the variables of the "old net" to be the same as the current network.
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()
sess.run(self.sync_weights_ops)