more API docs
This commit is contained in:
parent
2a3bc3ef35
commit
9186dae6a3
@ -39,7 +39,7 @@ if __name__ == '__main__':
|
|||||||
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, has_old_net=True)
|
dqn = value_function.DQN(my_network, observation_placeholder=observation_ph, has_old_net=True)
|
||||||
pi = policy.DQN(dqn)
|
pi = policy.DQN(dqn)
|
||||||
|
|
||||||
dqn_loss = losses.qlearning(dqn)
|
dqn_loss = losses.value_mse(dqn)
|
||||||
|
|
||||||
total_loss = dqn_loss
|
total_loss = dqn_loss
|
||||||
optimizer = tf.train.AdamOptimizer(1e-4)
|
optimizer = tf.train.AdamOptimizer(1e-4)
|
||||||
|
@ -3,13 +3,15 @@ import tensorflow as tf
|
|||||||
|
|
||||||
def ppo_clip(policy, clip_param):
|
def ppo_clip(policy, clip_param):
|
||||||
"""
|
"""
|
||||||
the clip loss in ppo paper
|
Builds the graph of clipped loss :math:`L^{CLIP}` as in the
|
||||||
|
`Link PPO paper <https://arxiv.org/pdf/1707.06347.pdf>`_, which is basically
|
||||||
|
:math:`-\min(r_t(\\theta)Aˆt, clip(r_t(\\theta), 1 - \epsilon, 1 + \epsilon)Aˆt)`.
|
||||||
|
We minimize the objective instead of maximizing, hence the leading negative sign.
|
||||||
|
|
||||||
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
:param policy: A :class:`tianshou.core.policy` to be optimized.
|
||||||
:param advantage: placeholder of estimated advantage values.
|
:param clip param: A float or Tensor of type float. The :math:`\epsilon` in the loss equation.
|
||||||
:param clip param: float or Tensor of type float.
|
|
||||||
:param policy: current `policy` to be optimized
|
:return: A scalar float Tensor of the loss.
|
||||||
:param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper
|
|
||||||
"""
|
"""
|
||||||
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape, name='ppo_clip_loss/action_placeholder')
|
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape, name='ppo_clip_loss/action_placeholder')
|
||||||
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
|
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
|
||||||
@ -26,13 +28,13 @@ def ppo_clip(policy, clip_param):
|
|||||||
|
|
||||||
def REINFORCE(policy):
|
def REINFORCE(policy):
|
||||||
"""
|
"""
|
||||||
vanilla policy gradient
|
Builds the graph of the loss function as used in vanilla policy gradient algorithms, i.e., REINFORCE.
|
||||||
|
The loss is basically :math:`\log \pi(a|s) A^t`.
|
||||||
|
We minimize the objective instead of maximizing, hence the leading negative sign.
|
||||||
|
|
||||||
:param sampled_action: placeholder of sampled actions during interaction with the environment
|
:param policy: A :class:`tianshou.core.policy` to be optimized.
|
||||||
:param reward: placeholder of reward the 'sampled_action' get
|
|
||||||
:param pi: current `policy` to be optimized
|
:return: A scalar float Tensor of the loss.
|
||||||
:param baseline: the baseline method used to reduce the variance, default is 'None'
|
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape,
|
action_ph = tf.placeholder(policy.action.dtype, shape=policy.action.shape,
|
||||||
name='REINFORCE/action_placeholder')
|
name='REINFORCE/action_placeholder')
|
||||||
@ -45,27 +47,16 @@ def REINFORCE(policy):
|
|||||||
return REINFORCE_loss
|
return REINFORCE_loss
|
||||||
|
|
||||||
|
|
||||||
def value_mse(state_value_function):
|
def value_mse(value_function):
|
||||||
"""
|
"""
|
||||||
L2 loss of state value
|
Builds the graph of L2 loss on value functions for, e.g., training critics or DQN.
|
||||||
:param state_value_function: instance of StateValue
|
|
||||||
:return: tensor of the mse loss
|
:param value_function: A :class:`tianshou.core.value_function` to be optimized.
|
||||||
|
|
||||||
|
:return: A scalar float Tensor of the loss.
|
||||||
"""
|
"""
|
||||||
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='value_mse/return_placeholder')
|
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='value_mse/return_placeholder')
|
||||||
state_value_function.managed_placeholders['return'] = target_value_ph
|
value_function.managed_placeholders['return'] = target_value_ph
|
||||||
|
|
||||||
state_value = state_value_function.value_tensor
|
state_value = value_function.value_tensor
|
||||||
return tf.losses.mean_squared_error(target_value_ph, state_value)
|
return tf.losses.mean_squared_error(target_value_ph, state_value)
|
||||||
|
|
||||||
|
|
||||||
def qlearning(action_value_function):
|
|
||||||
"""
|
|
||||||
deep q-network
|
|
||||||
:param action_value_function: current `action_value` to be optimized
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
target_value_ph = tf.placeholder(tf.float32, shape=(None,), name='qlearning/action_value_placeholder')
|
|
||||||
action_value_function.managed_placeholders['return'] = target_value_ph
|
|
||||||
|
|
||||||
q_value = action_value_function.value_tensor
|
|
||||||
return tf.losses.mean_squared_error(target_value_ph, q_value)
|
|
||||||
|
@ -3,10 +3,12 @@ import tensorflow as tf
|
|||||||
|
|
||||||
def DPG(policy, action_value):
|
def DPG(policy, action_value):
|
||||||
"""
|
"""
|
||||||
construct the gradient tensor of deterministic policy gradient
|
Constructs the gradient Tensor of `Link deterministic policy gradient <https://arxiv.org/pdf/1509.02971.pdf>`_.
|
||||||
:param policy:
|
|
||||||
:param action_value:
|
:param policy: A :class:`tianshou.core.policy.Deterministic` to be optimized.
|
||||||
:return: list of (gradient, variable) pairs
|
:param action_value: A :class:`tianshou.core.value_function.ActionValue` to guide the optimization of `policy`.
|
||||||
|
|
||||||
|
:return: A list of (gradient, variable) pairs.
|
||||||
"""
|
"""
|
||||||
trainable_variables = list(policy.trainable_variables)
|
trainable_variables = list(policy.trainable_variables)
|
||||||
critic_action_input = action_value.action_placeholder
|
critic_action_input = action_value.action_placeholder
|
||||||
|
@ -8,7 +8,18 @@ from ..utils import identify_dependent_variables
|
|||||||
|
|
||||||
class ActionValue(ValueFunctionBase):
|
class ActionValue(ValueFunctionBase):
|
||||||
"""
|
"""
|
||||||
class of action values Q(s, a).
|
Class for action values Q(s, a). The input of the value network is states and actions and the output
|
||||||
|
of the value network is directly the Q-value of the input (state, action) pairs.
|
||||||
|
|
||||||
|
:param network_callable: A Python callable returning (action head, value head). When called it builds
|
||||||
|
the tf graph and returns a Tensor of the value on the value head.
|
||||||
|
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, a)
|
||||||
|
in the network graph.
|
||||||
|
:param action_placeholder: A :class:`tf.placeholder`. The action placeholder for a in Q(s, a)
|
||||||
|
in the network graph.
|
||||||
|
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
|
||||||
|
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
|
||||||
|
and DDPG, or just an old net to help optimization as in PPO.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network_callable, observation_placeholder, action_placeholder, has_old_net=False):
|
def __init__(self, network_callable, observation_placeholder, action_placeholder, has_old_net=False):
|
||||||
self.observation_placeholder = observation_placeholder
|
self.observation_placeholder = observation_placeholder
|
||||||
@ -51,35 +62,45 @@ class ActionValue(ValueFunctionBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable_variables(self):
|
def trainable_variables(self):
|
||||||
|
"""
|
||||||
|
The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s
|
||||||
|
that affect the value.
|
||||||
|
"""
|
||||||
return set(self._trainable_variables)
|
return set(self._trainable_variables)
|
||||||
|
|
||||||
def eval_value(self, observation, action):
|
def eval_value(self, observation, action, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
Evaluate value in minibatch using the current network.
|
||||||
:param action: numpy array of actions, of shape (batchsize, action_dim)
|
|
||||||
# TODO: Atari discrete action should have dim 1. Super Mario may should have, say, dim 5, where each can be 0/1
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
:return: numpy array of state values, of shape (batchsize, )
|
:param action: An array-like, of shape (batch_size,) + action_shape.
|
||||||
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a)
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||||
|
|
||||||
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.value_tensor, feed_dict=
|
return sess.run(self.value_tensor, feed_dict=
|
||||||
{self.observation_placeholder: observation, self.action_placeholder: action})
|
{self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict))
|
||||||
|
|
||||||
def eval_value_old(self, observation, action):
|
def eval_value_old(self, observation, action, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
eval value using target network
|
Evaluate value in minibatch using the old net.
|
||||||
:param observation: numpy array of obs
|
|
||||||
:param action: numpy array of action
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
:return: numpy array of action value
|
:param action: An array-like, of shape (batch_size,) + action_shape.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||||
|
|
||||||
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
||||||
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
"""
|
"""
|
||||||
sync the weights of network_old. Direct copy the weights of network.
|
Sync the variables of the "old net" to be the same as the current network.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
if self.sync_weights_ops is not None:
|
if self.sync_weights_ops is not None:
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
@ -88,8 +109,18 @@ class ActionValue(ValueFunctionBase):
|
|||||||
|
|
||||||
class DQN(ValueFunctionBase):
|
class DQN(ValueFunctionBase):
|
||||||
"""
|
"""
|
||||||
class of the very DQN architecture. Instead of feeding s and a to the network to get a value, DQN feed s to the
|
Class for the special action value function DQN. Instead of feeding s and a to the network to get a value,
|
||||||
network and the last layer is Q(s, *) for all actions
|
DQN feeds s to the network and gets at the last layer Q(s, *) for all actions under this state. Still, as
|
||||||
|
:class:`ActionValue`, this class still builds the Q(s, a) value Tensor. It can only be used with discrete
|
||||||
|
(and finite) action spaces.
|
||||||
|
|
||||||
|
:param network_callable: A Python callable returning (action head, value head). When called it builds
|
||||||
|
the tf graph and returns a Tensor of Q(s, *) on the value head.
|
||||||
|
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in Q(s, *)
|
||||||
|
in the network graph.
|
||||||
|
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
|
||||||
|
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
|
||||||
|
and DDPG, or just an old net to help optimization as in PPO.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
|
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
|
||||||
self.observation_placeholder = observation_placeholder
|
self.observation_placeholder = observation_placeholder
|
||||||
@ -149,43 +180,76 @@ class DQN(ValueFunctionBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable_variables(self):
|
def trainable_variables(self):
|
||||||
|
"""
|
||||||
|
The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s
|
||||||
|
that affect the value.
|
||||||
|
"""
|
||||||
return set(self._trainable_variables)
|
return set(self._trainable_variables)
|
||||||
|
|
||||||
def eval_value_all_actions(self, observation):
|
def eval_value(self, observation, action, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
:param observation:
|
Evaluate value Q(s, a) in minibatch using the current network.
|
||||||
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
|
|
||||||
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
|
:param action: An array-like, of shape (batch_size,) + action_shape.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||||
|
|
||||||
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation})
|
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
||||||
|
return sess.run(self.value_tensor, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
def eval_value_old(self, observation, action, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Evaluate value Q(s, a) in minibatch using the old net.
|
||||||
|
|
||||||
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
|
:param action: An array-like, of shape (batch_size,) + action_shape.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||||
|
|
||||||
|
:return: A numpy array of shape (batch_size,). The corresponding action value for each observation.
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}.update(my_feed_dict)
|
||||||
|
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_tensor_all_actions(self):
|
def value_tensor_all_actions(self):
|
||||||
|
"""The Tensor for Q(s, *)"""
|
||||||
return self._value_tensor_all_actions
|
return self._value_tensor_all_actions
|
||||||
|
|
||||||
def eval_value_old(self, observation, action):
|
def eval_value_all_actions(self, observation, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
eval value using target network
|
Evaluate values Q(s, *) in minibatch using the current network.
|
||||||
:param observation: numpy array of obs
|
|
||||||
:param action: numpy array of action
|
|
||||||
:return: numpy array of action value
|
|
||||||
"""
|
|
||||||
sess = tf.get_default_session()
|
|
||||||
feed_dict = {self.observation_placeholder: observation, self.action_placeholder: action}
|
|
||||||
return sess.run(self.value_tensor_old, feed_dict=feed_dict)
|
|
||||||
|
|
||||||
def eval_value_all_actions_old(self, observation):
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
"""
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
:param observation:
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||||
:return: numpy array of Q(s, *) given s, of shape (batchsize, num_actions)
|
|
||||||
|
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation})
|
return sess.run(self._value_tensor_all_actions, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||||
|
|
||||||
|
def eval_value_all_actions_old(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Evaluate values Q(s, *) in minibatch using the old net.
|
||||||
|
|
||||||
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation and action.
|
||||||
|
|
||||||
|
:return: A numpy array of shape (batch_size, num_actions). The corresponding action values for each observation.
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
return sess.run(self.value_tensor_all_actions_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
"""
|
"""
|
||||||
sync the weights of network_old. Direct copy the weights of network.
|
Sync the variables of the "old net" to be the same as the current network.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
if self.sync_weights_ops is not None:
|
if self.sync_weights_ops is not None:
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
@ -9,7 +9,16 @@ from ..utils import identify_dependent_variables
|
|||||||
|
|
||||||
class StateValue(ValueFunctionBase):
|
class StateValue(ValueFunctionBase):
|
||||||
"""
|
"""
|
||||||
class of state values V(s).
|
Class for state value functions V(s). The input of the value network is states and the output
|
||||||
|
of the value network is directly the V-value of the input state.
|
||||||
|
|
||||||
|
:param network_callable: A Python callable returning (action head, value head). When called it builds
|
||||||
|
the tf graph and returns a Tensor of the value on the value head.
|
||||||
|
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder for s in V(s)
|
||||||
|
in the network graph.
|
||||||
|
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
|
||||||
|
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
|
||||||
|
and DDPG, or just an old net to help optimization as in PPO.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
|
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
|
||||||
self.observation_placeholder = observation_placeholder
|
self.observation_placeholder = observation_placeholder
|
||||||
@ -53,19 +62,42 @@ class StateValue(ValueFunctionBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable_variables(self):
|
def trainable_variables(self):
|
||||||
|
"""
|
||||||
|
The trainable variables of the value network in a Python **set**. It contains only the :class:`tf.Variable` s
|
||||||
|
that affect the value.
|
||||||
|
"""
|
||||||
return set(self._trainable_variables)
|
return set(self._trainable_variables)
|
||||||
|
|
||||||
def eval_value(self, observation):
|
def eval_value(self, observation, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
|
Evaluate value in minibatch using the current network.
|
||||||
|
|
||||||
:param observation: numpy array of observations, of shape (batchsize, observation_dim).
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
:return: numpy array of state values, of shape (batchsize, )
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
# TODO: dealing with the last dim of 1 in V(s) and Q(s, a), this should rely on the action shape returned by env
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation})
|
return sess.run(self.value_tensor, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||||
|
|
||||||
|
def eval_value_old(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Evaluate value in minibatch using the old net.
|
||||||
|
|
||||||
|
:param observation: An array-like, of shape (batch_size,) + observation_shape.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array of shape (batch_size,). The corresponding state value for each observation.
|
||||||
|
"""
|
||||||
|
sess = tf.get_default_session()
|
||||||
|
return sess.run(self.value_tensor_old, feed_dict={self.observation_placeholder: observation}.update(my_feed_dict))
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
|
"""
|
||||||
|
Sync the variables of the "old net" to be the same as the current network.
|
||||||
|
"""
|
||||||
if self.sync_weights_ops is not None:
|
if self.sync_weights_ops is not None:
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
sess.run(self.sync_weights_ops)
|
sess.run(self.sync_weights_ops)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user