auto target network. ppo_cartpole.py run ok. but results is different from previous version even with the same random seed, still needs debugging.

This commit is contained in:
haoshengzou 2018-01-14 20:58:28 +08:00
parent 3b222f5edb
commit fed3bf2a12
5 changed files with 151 additions and 103 deletions

View File

@ -17,24 +17,9 @@ from rllab.envs.box2d.cartpole_env import CartpoleEnv
from rllab.envs.normalized_env import normalize
def policy_net(observation, action_dim, scope=None):
"""
Constructs the policy network. NOT NEEDED IN THE LIBRARY! this is pure tf
# for tutorial purpose, placeholders are explicitly appended with '_ph' suffix
:param observation: Placeholder for the observation. A tensor of shape (bs, x, y, channels)
:param action_dim: int. The number of actions.
:param scope: str. Specifying the scope of the variables.
"""
# with tf.variable_scope(scope):
net = tf.layers.dense(observation, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
act_mean = tf.layers.dense(net, action_dim, activation=None)
return act_mean
if __name__ == '__main__': # a clean version with only policy net, no value net
if __name__ == '__main__':
env = normalize(CartpoleEnv())
observation_dim = env.observation_space.shape
action_dim = env.action_space.flat_dim
@ -47,44 +32,51 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
np.random.seed(seed)
tf.set_random_seed(seed)
# 1. build network with pure tf
observation = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
### 1. build network with pure tf
observation_ph = tf.placeholder(tf.float32, shape=(None,) + observation_dim) # network input
with tf.variable_scope('pi'):
action_mean = policy_net(observation, action_dim, 'pi')
action_logstd = tf.get_variable('action_logstd', shape=(action_dim,))
train_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) # TODO: better management of TRAINABLE_VARIABLES
with tf.variable_scope('pi_old'):
action_mean_old = policy_net(observation, action_dim, 'pi_old')
action_logstd_old = tf.get_variable('action_logstd_old', shape=(action_dim,))
pi_old_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'pi_old')
def my_policy():
net = tf.layers.dense(observation_ph, 32, activation=tf.nn.tanh)
net = tf.layers.dense(net, 32, activation=tf.nn.tanh)
action_mean = tf.layers.dense(net, action_dim, activation=None)
action_logstd = tf.get_variable('action_logstd', shape=(action_dim, ))
# 2. build losses, optimizers
pi = policy.Normal(action_mean, action_logstd, observation_placeholder=observation) # YongRen: policy.Gaussian (could reference the policy in TRPO paper, my code is adapted from zhusuan.distributions) policy.DQN etc.
# for continuous action space, you may need to change an environment to run
pi_old = policy.Normal(action_mean_old, action_logstd_old, observation_placeholder=observation)
# value = tf.layers.dense(net, 1, activation=None)
action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions
advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients
return action_mean, action_logstd, None # None value head
ppo_loss_clip = losses.ppo_clip(action, advantage, clip_param, pi, pi_old) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
# TODO: current implementation of passing function or overriding function has to return a value head
# to allow network sharing between policy and value networks. This makes 'policy' and 'value_function'
# imbalanced semantically (though they are naturally imbalanced since 'policy' is required to interact
# with the environment and 'value_function' is not). I have an idea to solve this imbalance, which is
# not based on passing function or overriding function.
### 2. build policy, losses, optimizers
pi = policy.Normal(my_policy, observation_placeholder=observation_ph, weight_update=0)
# action = tf.placeholder(dtype=tf.float32, shape=(None, action_dim)) # batch of integer actions
# advantage = tf.placeholder(dtype=tf.float32, shape=(None,)) # advantage values used in the Gradients
ppo_loss_clip = losses.ppo_clip(pi, clip_param) # TongzhengRen: losses.vpg ... management of placeholders and feed_dict
total_loss = ppo_loss_clip
optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(total_loss, var_list=train_var_list)
train_op = optimizer.minimize(total_loss, var_list=pi.trainable_variables)
# 3. define data collection
### 3. define data collection
training_data = Batch(env, pi, advantage_estimation.full_return) # YouQiaoben: finish and polish Batch, advantage_estimation.gae_lambda as in PPO paper
# ShihongSong: Replay(), see dqn_example.py
# maybe a dict to manage the elements to be collected
# 4. start training
### 4. start training
# init = tf.global_variables_initializer()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
sess.run(tf.global_variables_initializer())
# sync pi and pi_old
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
pi.sync_weights() # TODO: automate this for policies with target network
start_time = time.time()
for i in range(100): # until some stopping criterion met...
@ -97,11 +89,10 @@ if __name__ == '__main__': # a clean version with only policy net, no value net
# update network
for _ in range(num_batches):
data = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong
# TODO: auto managing of the placeholders? or add this to params of data.Batch
sess.run(train_op, feed_dict={observation: data['observations'], action: data['actions'], advantage: data['returns']})
feed_dict = training_data.next_batch(batch_size) # YouQiaoben, ShihongSong
sess.run(train_op, feed_dict=feed_dict)
# assigning pi to pi_old
sess.run([tf.assign(theta_old, theta) for (theta_old, theta) in zip(pi_old_var_list, train_var_list)])
pi.update_weights()
print('Elapsed time: {:.1f} min'.format((time.time() - start_time) / 60))

View File

@ -1,22 +1,26 @@
import tensorflow as tf
def ppo_clip(sampled_action, advantage, clip_param, pi, pi_old):
def ppo_clip(policy, clip_param):
"""
the clip loss in ppo paper
:param sampled_action: placeholder of sampled actions during interaction with the environment
:param advantage: placeholder of estimated advantage values.
:param clip param: float or Tensor of type float.
:param pi: current `policy` to be optimized
:param policy: current `policy` to be optimized
:param pi_old: old `policy` for computing the ppo loss as in Eqn. (7) in the paper
"""
action_ph = tf.placeholder(policy.act_dtype, shape=(None, policy.action_dim), name='ppo_clip_loss/action_placeholder')
advantage_ph = tf.placeholder(tf.float32, shape=(None,), name='ppo_clip_loss/advantage_placeholder')
policy.managed_placeholders['action'] = action_ph
policy.managed_placeholders['processed_reward'] = advantage_ph
log_pi_act = pi.log_prob(sampled_action)
log_pi_old_act = pi_old.log_prob(sampled_action)
log_pi_act = policy.log_prob(action_ph)
log_pi_old_act = policy.log_prob_old(action_ph)
ratio = tf.exp(log_pi_act - log_pi_old_act)
clipped_ratio = tf.clip_by_value(ratio, 1. - clip_param, 1. + clip_param)
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage, clipped_ratio * advantage))
ppo_clip_loss = -tf.reduce_mean(tf.minimum(ratio * advantage_ph, clipped_ratio * advantage_ph))
return ppo_clip_loss

View File

@ -16,44 +16,11 @@ class PolicyBase(object):
"""
base class for policy. only provides `act` method with exploration
"""
def __init__(self, observation_placeholder):
self._observation_placeholder = observation_placeholder
def act(self, observation, exploration):
def act(self, observation):
raise NotImplementedError()
class QValuePolicy(object):
"""
The policy as in DQN
"""
def __init__(self, observation_placeholder):
self._observation_placeholder = observation_placeholder
def act(self, observation, exploration=None): # first implement no exploration
"""
return the action (int) to be executed.
no exploration when exploration=None.
"""
self._act(observation, exploration)
def _act(self, observation, exploration=None):
raise NotImplementedError()
def values(self, observation):
"""
returns the Q(s, a) values (float) for all actions a at observation s
"""
pass
def values_tensor(self):
"""
returns the tensor of the values for all actions a at observation s
"""
pass
class StochasticPolicy(object):
class StochasticPolicy(PolicyBase):
"""
The :class:`StochasticPolicy` class is the base class for various probabilistic
distributions which support batch inputs, generating batches of samples and
@ -170,7 +137,7 @@ class StochasticPolicy(object):
return self._group_ndims
# @add_name_scope
def act(self, observation):
def act(self, observation, my_feed_dict={}):
"""
sample(n_samples=None)
@ -184,9 +151,9 @@ class StochasticPolicy(object):
samples to draw from the distribution.
:return: A Tensor of samples.
"""
return self._act(observation)
return self._act(observation, my_feed_dict)
def _act(self, observation):
def _act(self, observation, my_feed_dict):
"""
Private method for subclasses to rewrite the :meth:`sample` method.
"""

View File

@ -94,24 +94,64 @@ class Normal(StochasticPolicy):
:param observation_placeholder
"""
def __init__(self,
mean = 0.,
logstd = 1.,
group_ndims = 1,
observation_placeholder = None,
policy_callable,
observation_placeholder,
weight_update=1,
group_ndims=1,
**kwargs):
self.managed_placeholders = {'observation': observation_placeholder}
self.weight_update = weight_update
self.interaction_count = -1 # defaults to -1. only useful if weight_update > 1.
with tf.variable_scope('network'):
mean, logstd, value_head = policy_callable()
self._mean = tf.convert_to_tensor(mean, dtype = tf.float32)
self._logstd = tf.convert_to_tensor(logstd, dtype = tf.float32)
self._std = tf.exp(self._logstd)
shape = tf.broadcast_dynamic_shape(tf.shape(self._mean), tf.shape(self._std))
self._action = tf.random_normal(tf.concat([[1], shape], 0), dtype = tf.float32) * self._std + self._mean
# TODO: self._action should be exactly the action tensor to run, without [0, 0] in self._act
if value_head is not None:
pass
self.trainable_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')
if self.weight_update == 1:
self.weight_update_ops = None
self.sync_weights_ops = None
else: # then we need to build another tf graph as target network
with tf.variable_scope('net_old'):
mean, logstd, value_head = policy_callable()
self._mean_old = tf.convert_to_tensor(mean, dtype=tf.float32)
self._logstd_old = tf.convert_to_tensor(logstd, dtype=tf.float32)
if value_head is not None: # useful in DDPG
pass
network_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='network')
self.network_old_weights = network_old_weights = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net_old')
assert len(network_weights) == len(network_old_weights)
self.sync_weights_ops = [tf.assign(variable_old, variable)
for (variable_old, variable) in zip(network_old_weights, network_weights)]
if weight_update == 0:
self.weight_update_ops = self.sync_weights_ops
elif 0 < weight_update < 1:
pass
else:
self.interaction_count = 0
import math
self.weight_update = math.ceil(weight_update)
super(Normal, self).__init__(
act_dtype = tf.float32,
param_dtype = tf.float32,
is_continuous = True,
observation_placeholder = observation_placeholder,
act_dtype=tf.float32,
param_dtype=tf.float32,
is_continuous=True,
observation_placeholder=observation_placeholder,
group_ndims = group_ndims,
**kwargs)
@ -127,13 +167,22 @@ class Normal(StochasticPolicy):
def logstd(self):
return self._logstd
def _act(self, observation):
@property
def action(self):
return self._action
@property
def action_dim(self):
return self.mean.shape.as_list()[-1]
def _act(self, observation, my_feed_dict):
# TODO: getting session like this maybe ugly. also maybe huge problem when parallel
sess = tf.get_default_session()
# observation[None] adds one dimension at the beginning
sampled_action = sess.run(self._action,
feed_dict={self._observation_placeholder: observation[None]})
feed_dict = {self._observation_placeholder: observation[None]}
feed_dict.update(my_feed_dict)
sampled_action = sess.run(self._action, feed_dict=feed_dict)
sampled_action = sampled_action[0, 0]
return sampled_action
@ -146,3 +195,35 @@ class Normal(StochasticPolicy):
def _prob(self, sampled_action):
return tf.exp(self._log_prob(sampled_action))
def log_prob_old(self, sampled_action):
"""
return the log_prob of the old policy when constructing tf graphs. Raises error when there's no old policy.
:param sampled_action: the placeholder for sampled actions during interaction with the environment.
:return: tensor of the log_prob of the old policy
"""
if self.weight_update == 1:
raise AttributeError('Policy has no policy_old since it\'s initialized with weight_update=1!')
mean, logstd = self._mean_old, self._logstd_old
c = -0.5 * np.log(2 * np.pi)
precision = tf.exp(-2 * logstd)
return c - logstd - 0.5 * precision * tf.square(sampled_action - mean)
def update_weights(self):
"""
updates the weights of policy_old.
:return:
"""
if self.weight_update_ops is not None:
sess = tf.get_default_session()
sess.run(self.weight_update_ops)
def sync_weights(self):
"""
sync the weights of network_old. Direct copy the weights of network.
:return:
"""
if self.sync_weights_ops is not None:
sess = tf.get_default_session()
sess.run(self.sync_weights_ops)

View File

@ -135,7 +135,12 @@ class Batch(object):
advantage_std = np.std(current_batch['returns'])
current_batch['returns'] = (current_batch['returns'] - advantage_mean) / advantage_std
return current_batch
feed_dict = {}
feed_dict[self._pi.managed_placeholders['observation']] = current_batch['observations']
feed_dict[self._pi.managed_placeholders['action']] = current_batch['actions']
feed_dict[self._pi.managed_placeholders['processed_reward']] = current_batch['returns']
return feed_dict
# TODO: this will definitely be refactored with a proper logger
def statistics(self):