part of API doc

This commit is contained in:
haoshengzou 2018-04-12 21:10:50 +08:00
parent 03246f7ded
commit 2a3bc3ef35
6 changed files with 194 additions and 28 deletions

View File

@ -68,7 +68,7 @@ if __name__ == '__main__':
test_interval = 5000
target_network_update_interval = 800
seed = 0
seed = 123
np.random.seed(seed)
tf.set_random_seed(seed)

View File

@ -6,14 +6,50 @@ import tensorflow as tf
class PolicyBase(object):
"""
base class for policy. only provides `act` method with exploration
Base class for policy. Mandatory methods for a policy class are:
- :func:`act`. It's used interacting with the environment during training, \
so exploration noise should be added in this method.
- :func:`act_test`. Since RL usually adds additional exploration noise during training, a different method\
for testing the policy should be defined with different exploration specification.\
Generally, DQN uses different :math:`\epsilon` in :math:`\epsilon`-greedy and\
DDPG removes exploration noise during test.
- :func:`reset`. It's mainly to reset the states of the exploration random process, or if your policy has\
some internal states that should be reset at the beginning of each new episode. Otherwise, this method\
does nothing.
"""
def act(self, observation, my_feed_dict):
"""
Return action given observation, when interacting with the environment during training.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: A dict. Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array. Action given the single observation. Its "batch_size" is 1,
but should not be explicitly set.
"""
raise NotImplementedError()
def act_test(self, observation, my_feed_dict):
"""
Return action given observation, when interacting with the environment during test.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: A dict. Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array. Action given the single observation. Its "batch_size" is 1,
but should not be explicitly set.
"""
raise NotImplementedError
def reset(self):
"""
for temporal correlated random process exploration, as in DDPG
:return:
Reset the internal states of the policy. Does nothing by default.
"""
pass

View File

@ -8,7 +8,18 @@ from ..utils import identify_dependent_variables
class Deterministic(PolicyBase):
"""
deterministic policy as used in deterministic policy gradient (DDPG) methods
Deterministic policy as used in deterministic policy gradient (DDPG) methods. It can only be used with
continuous action space. The output of the policy network is directly the action.
:param network_callable: A Python callable returning (action head, value head). When called it builds the tf graph and returns a Tensor
of the action on the action head.
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder of the network graph.
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
and DDPG, or just an old net to help optimization as in PPO.
:param random_process: Optional. A :class:`RandomProcess`. The additional random process for exploration.
Defaults to an :class:`OrnsteinUhlenbeckProcess` with :math:`\\theta=0.15` and :math:`\sigma=0.3` if not
set explicitly.
"""
def __init__(self, network_callable, observation_placeholder, has_old_net=False, random_process=None):
self.observation_placeholder = observation_placeholder
@ -54,9 +65,25 @@ class Deterministic(PolicyBase):
@property
def trainable_variables(self):
"""
The trainable variables of the policy in a Python **set**. It contains only the :class:`tf.Variable` s
that affect the action.
"""
return set(self._trainable_variables)
def act(self, observation, my_feed_dict={}):
"""
Return action given observation, adding the exploration noise sampled from ``self.random_process``.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array.
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
"""
sess = tf.get_default_session()
# observation[None] adds one dimension at the beginning
@ -69,9 +96,24 @@ class Deterministic(PolicyBase):
return sampled_action
def reset(self):
"""
Reset the internal states of ``self.random_process``.
"""
self.random_process.reset_states()
def act_test(self, observation, my_feed_dict={}):
"""
Return action given observation, removing the exploration noise.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array.
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
"""
sess = tf.get_default_session()
# observation[None] adds one dimension at the beginning
@ -85,18 +127,22 @@ class Deterministic(PolicyBase):
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
Sync the variables of the "old net" to be the same as the current network.
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()
sess.run(self.sync_weights_ops)
def eval_action(self, observation):
def eval_action(self, observation, my_feed_dict={}):
"""
evaluate action in minibatch
:param observation:
:return: 2-D numpy array
Evaluate action in minibatch using the current network.
:param observation: An array-like. Contrary to :func:`act` and :func:`act_test`, it has the dimension
of batch_size.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array with the batch_size dimension and same batch_size as ``observation``.
"""
sess = tf.get_default_session()
@ -105,11 +151,16 @@ class Deterministic(PolicyBase):
return action
def eval_action_old(self, observation):
def eval_action_old(self, observation, my_feed_dict={}):
"""
evaluate action in minibatch
:param observation:
:return: 2-D numpy array
Evaluate action in minibatch using the old net.
:param observation: An array-like. Contrary to :func:`act` and :func:`act_test`, it has the dimension
of batch_size.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array with the batch_size dimension and same batch_size as ``observation``.
"""
sess = tf.get_default_session()

View File

@ -6,7 +6,15 @@ from ..utils import identify_dependent_variables
class Distributional(PolicyBase):
"""
policy class where action is specified by a probability distribution
Policy class where action is specified by a probability distribution. Depending on the distribution,
it can be applied to both continuous and discrete action spaces.
:param network_callable: A Python callable returning (action head, value head). When called it builds the tf graph and returns a
:class:`tf.distributions.Distribution` on the action space on the action head.
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder of the network graph.
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
and DDPG, or just an old net to help optimization as in PPO.
"""
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
self.observation_placeholder = observation_placeholder
@ -50,9 +58,25 @@ class Distributional(PolicyBase):
@property
def trainable_variables(self):
"""
The trainable variables of the policy in a Python **set**. It contains only the :class:`tf.Variable` s
that affect the action.
"""
return set(self._trainable_variables)
def act(self, observation, my_feed_dict={}):
"""
Return action given observation, directly sampling from the action distribution.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array.
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
"""
sess = tf.get_default_session()
# observation[None] adds one dimension at the beginning
@ -64,12 +88,23 @@ class Distributional(PolicyBase):
return sampled_action
def act_test(self, observation, my_feed_dict={}):
"""
Return action given observation, directly sampling from the action distribution.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array.
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
"""
return self.act(observation, my_feed_dict)
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
Sync the variables of the "old net" to be the same as the current network.
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()

View File

@ -8,6 +8,16 @@ import numpy as np
class DQN(PolicyBase):
"""
use DQN from value_function as a member
Policy derived from a Deep-Q Network (DQN). It should be constructed from a :class:`tianshou.core.value_function.DQN`.
Action is the argmax of the Q-values (usually with further :math:`\epsilon`-greedy).
It can only be applied to discrete action spaces.
:param dqn: A :class:`tianshou.core.value_function.DQN`. The Q-value network to derive this policy.
:param epsilon_train: A float in range :math:`[0, 1]`. The :math:`\epsilon` used in :math:`\epsilon`-greedy
during training while interacting with the environment.
:param epsilon_test: A float in range :math:`[0, 1]`. The :math:`\epsilon` used in :math:`\epsilon`-greedy
during test while interacting with the environment.
"""
def __init__(self, dqn, epsilon_train=0.1, epsilon_test=0.05):
self.action_value = dqn
@ -17,6 +27,18 @@ class DQN(PolicyBase):
self.epsilon_test = epsilon_test
def act(self, observation, my_feed_dict={}):
"""
Return action given observation, with :math:`\epsilon`-greedy using ``self.epsilon_train``.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array.
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
"""
sess = tf.get_default_session()
feed_dict = {self.action_value.observation_placeholder: observation[None]}
@ -30,6 +52,18 @@ class DQN(PolicyBase):
return np.squeeze(action)
def act_test(self, observation, my_feed_dict={}):
"""
Return action given observation, with :math:`\epsilon`-greedy using ``self.epsilon_test``.
:param observation: An array-like with rank the same as a single observation of the environment.
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
of "batch_size" to the first dimension.
:param my_feed_dict: Optional. A dict defaulting to empty.
Specifies placeholders such as dropout and batch_norm except observation.
:return: A numpy array.
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
"""
sess = tf.get_default_session()
feed_dict = {self.action_value.observation_placeholder: observation[None]}
@ -44,18 +78,26 @@ class DQN(PolicyBase):
@property
def q_net(self):
"""The DQN (:class:`tianshou.core.value_function.DQN`) this policy based on."""
return self.action_value
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
Sync the variables of the "old net" to be the same as the current network.
"""
if self.action_value.sync_weights_ops is not None:
self.action_value.sync_weights()
def set_epsilon_train(self, epsilon):
"""
Set the :math:`\epsilon` in :math:`\epsilon`-greedy during training.
:param epsilon: A float in range :math:`[0, 1]`.
"""
self.epsilon_train = epsilon
def set_epsilon_test(self, epsilon):
"""
Set the :math:`\epsilon` in :math:`\epsilon`-greedy during training.
:param epsilon: A float in range :math:`[0, 1]`.
"""
self.epsilon_test = epsilon

View File

@ -4,23 +4,25 @@ import tensorflow as tf
class ValueFunctionBase(object):
"""
base class of value functions. Children include state values V(s) and action values Q(s, a)
Base class for value functions, including S-values and Q-values. The only
mandatory method for a value function class is:
:func:`eval_value`, which runs the graph and evaluates the corresponding value.
:param value_tensor: a Tensor. The tensor of V(s) or Q(s, a).
:param observation_placeholder: a :class:`tf.placeholder`. The observation placeholder of the network graph.
"""
def __init__(self, value_tensor, observation_placeholder):
self.observation_placeholder = observation_placeholder
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
self._value_tensor = tf.squeeze(value_tensor) # canonical value has shape (batchsize, )
def eval_value(self, **kwargs):
"""
:return: batch of corresponding values in numpy array
Runs the graph and evaluates the corresponding value.
"""
raise NotImplementedError()
@property
def value_tensor(self):
"""
:return: tensor of the corresponding values
"""
"""Tensor of the corresponding value"""
return self._value_tensor