part of API doc
This commit is contained in:
parent
03246f7ded
commit
2a3bc3ef35
@ -68,7 +68,7 @@ if __name__ == '__main__':
|
|||||||
test_interval = 5000
|
test_interval = 5000
|
||||||
target_network_update_interval = 800
|
target_network_update_interval = 800
|
||||||
|
|
||||||
seed = 0
|
seed = 123
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
tf.set_random_seed(seed)
|
tf.set_random_seed(seed)
|
||||||
|
|
||||||
|
@ -6,14 +6,50 @@ import tensorflow as tf
|
|||||||
|
|
||||||
class PolicyBase(object):
|
class PolicyBase(object):
|
||||||
"""
|
"""
|
||||||
base class for policy. only provides `act` method with exploration
|
Base class for policy. Mandatory methods for a policy class are:
|
||||||
|
|
||||||
|
- :func:`act`. It's used interacting with the environment during training, \
|
||||||
|
so exploration noise should be added in this method.
|
||||||
|
|
||||||
|
- :func:`act_test`. Since RL usually adds additional exploration noise during training, a different method\
|
||||||
|
for testing the policy should be defined with different exploration specification.\
|
||||||
|
Generally, DQN uses different :math:`\epsilon` in :math:`\epsilon`-greedy and\
|
||||||
|
DDPG removes exploration noise during test.
|
||||||
|
|
||||||
|
- :func:`reset`. It's mainly to reset the states of the exploration random process, or if your policy has\
|
||||||
|
some internal states that should be reset at the beginning of each new episode. Otherwise, this method\
|
||||||
|
does nothing.
|
||||||
"""
|
"""
|
||||||
def act(self, observation, my_feed_dict):
|
def act(self, observation, my_feed_dict):
|
||||||
|
"""
|
||||||
|
Return action given observation, when interacting with the environment during training.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: A dict. Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array. Action given the single observation. Its "batch_size" is 1,
|
||||||
|
but should not be explicitly set.
|
||||||
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def act_test(self, observation, my_feed_dict):
|
||||||
|
"""
|
||||||
|
Return action given observation, when interacting with the environment during test.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: A dict. Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array. Action given the single observation. Its "batch_size" is 1,
|
||||||
|
but should not be explicitly set.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""
|
"""
|
||||||
for temporal correlated random process exploration, as in DDPG
|
Reset the internal states of the policy. Does nothing by default.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
@ -8,7 +8,18 @@ from ..utils import identify_dependent_variables
|
|||||||
|
|
||||||
class Deterministic(PolicyBase):
|
class Deterministic(PolicyBase):
|
||||||
"""
|
"""
|
||||||
deterministic policy as used in deterministic policy gradient (DDPG) methods
|
Deterministic policy as used in deterministic policy gradient (DDPG) methods. It can only be used with
|
||||||
|
continuous action space. The output of the policy network is directly the action.
|
||||||
|
|
||||||
|
:param network_callable: A Python callable returning (action head, value head). When called it builds the tf graph and returns a Tensor
|
||||||
|
of the action on the action head.
|
||||||
|
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder of the network graph.
|
||||||
|
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
|
||||||
|
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
|
||||||
|
and DDPG, or just an old net to help optimization as in PPO.
|
||||||
|
:param random_process: Optional. A :class:`RandomProcess`. The additional random process for exploration.
|
||||||
|
Defaults to an :class:`OrnsteinUhlenbeckProcess` with :math:`\\theta=0.15` and :math:`\sigma=0.3` if not
|
||||||
|
set explicitly.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network_callable, observation_placeholder, has_old_net=False, random_process=None):
|
def __init__(self, network_callable, observation_placeholder, has_old_net=False, random_process=None):
|
||||||
self.observation_placeholder = observation_placeholder
|
self.observation_placeholder = observation_placeholder
|
||||||
@ -54,9 +65,25 @@ class Deterministic(PolicyBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable_variables(self):
|
def trainable_variables(self):
|
||||||
|
"""
|
||||||
|
The trainable variables of the policy in a Python **set**. It contains only the :class:`tf.Variable` s
|
||||||
|
that affect the action.
|
||||||
|
"""
|
||||||
return set(self._trainable_variables)
|
return set(self._trainable_variables)
|
||||||
|
|
||||||
def act(self, observation, my_feed_dict={}):
|
def act(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Return action given observation, adding the exploration noise sampled from ``self.random_process``.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array.
|
||||||
|
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
# observation[None] adds one dimension at the beginning
|
# observation[None] adds one dimension at the beginning
|
||||||
@ -69,9 +96,24 @@ class Deterministic(PolicyBase):
|
|||||||
return sampled_action
|
return sampled_action
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
"""
|
||||||
|
Reset the internal states of ``self.random_process``.
|
||||||
|
"""
|
||||||
self.random_process.reset_states()
|
self.random_process.reset_states()
|
||||||
|
|
||||||
def act_test(self, observation, my_feed_dict={}):
|
def act_test(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Return action given observation, removing the exploration noise.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array.
|
||||||
|
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
# observation[None] adds one dimension at the beginning
|
# observation[None] adds one dimension at the beginning
|
||||||
@ -85,18 +127,22 @@ class Deterministic(PolicyBase):
|
|||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
"""
|
"""
|
||||||
sync the weights of network_old. Direct copy the weights of network.
|
Sync the variables of the "old net" to be the same as the current network.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
if self.sync_weights_ops is not None:
|
if self.sync_weights_ops is not None:
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
sess.run(self.sync_weights_ops)
|
sess.run(self.sync_weights_ops)
|
||||||
|
|
||||||
def eval_action(self, observation):
|
def eval_action(self, observation, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
evaluate action in minibatch
|
Evaluate action in minibatch using the current network.
|
||||||
:param observation:
|
|
||||||
:return: 2-D numpy array
|
:param observation: An array-like. Contrary to :func:`act` and :func:`act_test`, it has the dimension
|
||||||
|
of batch_size.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array with the batch_size dimension and same batch_size as ``observation``.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
@ -105,11 +151,16 @@ class Deterministic(PolicyBase):
|
|||||||
|
|
||||||
return action
|
return action
|
||||||
|
|
||||||
def eval_action_old(self, observation):
|
def eval_action_old(self, observation, my_feed_dict={}):
|
||||||
"""
|
"""
|
||||||
evaluate action in minibatch
|
Evaluate action in minibatch using the old net.
|
||||||
:param observation:
|
|
||||||
:return: 2-D numpy array
|
:param observation: An array-like. Contrary to :func:`act` and :func:`act_test`, it has the dimension
|
||||||
|
of batch_size.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array with the batch_size dimension and same batch_size as ``observation``.
|
||||||
"""
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
|
@ -6,7 +6,15 @@ from ..utils import identify_dependent_variables
|
|||||||
|
|
||||||
class Distributional(PolicyBase):
|
class Distributional(PolicyBase):
|
||||||
"""
|
"""
|
||||||
policy class where action is specified by a probability distribution
|
Policy class where action is specified by a probability distribution. Depending on the distribution,
|
||||||
|
it can be applied to both continuous and discrete action spaces.
|
||||||
|
|
||||||
|
:param network_callable: A Python callable returning (action head, value head). When called it builds the tf graph and returns a
|
||||||
|
:class:`tf.distributions.Distribution` on the action space on the action head.
|
||||||
|
:param observation_placeholder: A :class:`tf.placeholder`. The observation placeholder of the network graph.
|
||||||
|
:param has_old_net: A bool defaulting to ``False``. If true this class will create another graph with another
|
||||||
|
set of :class:`tf.Variable` s to be the "old net". The "old net" could be the target networks as in DQN
|
||||||
|
and DDPG, or just an old net to help optimization as in PPO.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
|
def __init__(self, network_callable, observation_placeholder, has_old_net=False):
|
||||||
self.observation_placeholder = observation_placeholder
|
self.observation_placeholder = observation_placeholder
|
||||||
@ -50,9 +58,25 @@ class Distributional(PolicyBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def trainable_variables(self):
|
def trainable_variables(self):
|
||||||
|
"""
|
||||||
|
The trainable variables of the policy in a Python **set**. It contains only the :class:`tf.Variable` s
|
||||||
|
that affect the action.
|
||||||
|
"""
|
||||||
return set(self._trainable_variables)
|
return set(self._trainable_variables)
|
||||||
|
|
||||||
def act(self, observation, my_feed_dict={}):
|
def act(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Return action given observation, directly sampling from the action distribution.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array.
|
||||||
|
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
# observation[None] adds one dimension at the beginning
|
# observation[None] adds one dimension at the beginning
|
||||||
@ -64,12 +88,23 @@ class Distributional(PolicyBase):
|
|||||||
return sampled_action
|
return sampled_action
|
||||||
|
|
||||||
def act_test(self, observation, my_feed_dict={}):
|
def act_test(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Return action given observation, directly sampling from the action distribution.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array.
|
||||||
|
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
|
||||||
|
"""
|
||||||
return self.act(observation, my_feed_dict)
|
return self.act(observation, my_feed_dict)
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
"""
|
"""
|
||||||
sync the weights of network_old. Direct copy the weights of network.
|
Sync the variables of the "old net" to be the same as the current network.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
if self.sync_weights_ops is not None:
|
if self.sync_weights_ops is not None:
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
@ -8,6 +8,16 @@ import numpy as np
|
|||||||
class DQN(PolicyBase):
|
class DQN(PolicyBase):
|
||||||
"""
|
"""
|
||||||
use DQN from value_function as a member
|
use DQN from value_function as a member
|
||||||
|
|
||||||
|
Policy derived from a Deep-Q Network (DQN). It should be constructed from a :class:`tianshou.core.value_function.DQN`.
|
||||||
|
Action is the argmax of the Q-values (usually with further :math:`\epsilon`-greedy).
|
||||||
|
It can only be applied to discrete action spaces.
|
||||||
|
|
||||||
|
:param dqn: A :class:`tianshou.core.value_function.DQN`. The Q-value network to derive this policy.
|
||||||
|
:param epsilon_train: A float in range :math:`[0, 1]`. The :math:`\epsilon` used in :math:`\epsilon`-greedy
|
||||||
|
during training while interacting with the environment.
|
||||||
|
:param epsilon_test: A float in range :math:`[0, 1]`. The :math:`\epsilon` used in :math:`\epsilon`-greedy
|
||||||
|
during test while interacting with the environment.
|
||||||
"""
|
"""
|
||||||
def __init__(self, dqn, epsilon_train=0.1, epsilon_test=0.05):
|
def __init__(self, dqn, epsilon_train=0.1, epsilon_test=0.05):
|
||||||
self.action_value = dqn
|
self.action_value = dqn
|
||||||
@ -17,6 +27,18 @@ class DQN(PolicyBase):
|
|||||||
self.epsilon_test = epsilon_test
|
self.epsilon_test = epsilon_test
|
||||||
|
|
||||||
def act(self, observation, my_feed_dict={}):
|
def act(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Return action given observation, with :math:`\epsilon`-greedy using ``self.epsilon_train``.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array.
|
||||||
|
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
feed_dict = {self.action_value.observation_placeholder: observation[None]}
|
feed_dict = {self.action_value.observation_placeholder: observation[None]}
|
||||||
@ -30,6 +52,18 @@ class DQN(PolicyBase):
|
|||||||
return np.squeeze(action)
|
return np.squeeze(action)
|
||||||
|
|
||||||
def act_test(self, observation, my_feed_dict={}):
|
def act_test(self, observation, my_feed_dict={}):
|
||||||
|
"""
|
||||||
|
Return action given observation, with :math:`\epsilon`-greedy using ``self.epsilon_test``.
|
||||||
|
|
||||||
|
:param observation: An array-like with rank the same as a single observation of the environment.
|
||||||
|
Its "batch_size" is 1, but should not be explicitly set. This method will add the dimension
|
||||||
|
of "batch_size" to the first dimension.
|
||||||
|
:param my_feed_dict: Optional. A dict defaulting to empty.
|
||||||
|
Specifies placeholders such as dropout and batch_norm except observation.
|
||||||
|
|
||||||
|
:return: A numpy array.
|
||||||
|
Action given the single observation. Its "batch_size" is 1, but should not be explicitly set.
|
||||||
|
"""
|
||||||
sess = tf.get_default_session()
|
sess = tf.get_default_session()
|
||||||
|
|
||||||
feed_dict = {self.action_value.observation_placeholder: observation[None]}
|
feed_dict = {self.action_value.observation_placeholder: observation[None]}
|
||||||
@ -44,18 +78,26 @@ class DQN(PolicyBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def q_net(self):
|
def q_net(self):
|
||||||
|
"""The DQN (:class:`tianshou.core.value_function.DQN`) this policy based on."""
|
||||||
return self.action_value
|
return self.action_value
|
||||||
|
|
||||||
def sync_weights(self):
|
def sync_weights(self):
|
||||||
"""
|
"""
|
||||||
sync the weights of network_old. Direct copy the weights of network.
|
Sync the variables of the "old net" to be the same as the current network.
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
if self.action_value.sync_weights_ops is not None:
|
if self.action_value.sync_weights_ops is not None:
|
||||||
self.action_value.sync_weights()
|
self.action_value.sync_weights()
|
||||||
|
|
||||||
def set_epsilon_train(self, epsilon):
|
def set_epsilon_train(self, epsilon):
|
||||||
|
"""
|
||||||
|
Set the :math:`\epsilon` in :math:`\epsilon`-greedy during training.
|
||||||
|
:param epsilon: A float in range :math:`[0, 1]`.
|
||||||
|
"""
|
||||||
self.epsilon_train = epsilon
|
self.epsilon_train = epsilon
|
||||||
|
|
||||||
def set_epsilon_test(self, epsilon):
|
def set_epsilon_test(self, epsilon):
|
||||||
|
"""
|
||||||
|
Set the :math:`\epsilon` in :math:`\epsilon`-greedy during training.
|
||||||
|
:param epsilon: A float in range :math:`[0, 1]`.
|
||||||
|
"""
|
||||||
self.epsilon_test = epsilon
|
self.epsilon_test = epsilon
|
||||||
|
@ -4,23 +4,25 @@ import tensorflow as tf
|
|||||||
|
|
||||||
class ValueFunctionBase(object):
|
class ValueFunctionBase(object):
|
||||||
"""
|
"""
|
||||||
base class of value functions. Children include state values V(s) and action values Q(s, a)
|
Base class for value functions, including S-values and Q-values. The only
|
||||||
|
mandatory method for a value function class is:
|
||||||
|
|
||||||
|
:func:`eval_value`, which runs the graph and evaluates the corresponding value.
|
||||||
|
|
||||||
|
:param value_tensor: a Tensor. The tensor of V(s) or Q(s, a).
|
||||||
|
:param observation_placeholder: a :class:`tf.placeholder`. The observation placeholder of the network graph.
|
||||||
"""
|
"""
|
||||||
def __init__(self, value_tensor, observation_placeholder):
|
def __init__(self, value_tensor, observation_placeholder):
|
||||||
self.observation_placeholder = observation_placeholder
|
self.observation_placeholder = observation_placeholder
|
||||||
self._value_tensor = tf.squeeze(value_tensor) # canonical values has shape (batchsize, )
|
self._value_tensor = tf.squeeze(value_tensor) # canonical value has shape (batchsize, )
|
||||||
|
|
||||||
def eval_value(self, **kwargs):
|
def eval_value(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
Runs the graph and evaluates the corresponding value.
|
||||||
:return: batch of corresponding values in numpy array
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_tensor(self):
|
def value_tensor(self):
|
||||||
"""
|
"""Tensor of the corresponding value"""
|
||||||
|
|
||||||
:return: tensor of the corresponding values
|
|
||||||
"""
|
|
||||||
return self._value_tensor
|
return self._value_tensor
|
||||||
|
Loading…
x
Reference in New Issue
Block a user